深度 Q 网络 (Deep Q-Network, DQN)
DQN 是 Q-learning 的深度学习版本。传统 Q-learning 用表格存储 Q(s,a),但当状态空间很大(如图像输入)时,表格方法不可行。
DQN 使用神经网络近似 Q 函数,即:
玩 Atari 游戏时,屏幕像素就是状态,无法用表格记录所有状态。
DQN 的思路:
用卷积神经网络从图像中提取特征。
输出每个动作的 Q 值。
用 Q-learning 的目标函数训练网络。
DQN 还引入了两个关键技巧,稳定训练:
经验回放 (Experience Replay) :将过往经验 (s,a,r,s′) 存入缓冲区,随机采样训练,避免样本相关性。
目标网络 (Target Network) :复制一份参数 θ⁻,固定一段时间不更新,用来计算目标值,减少震荡。
在 DQN 中,参数 θ 表示 Q 网络,θ⁻ 表示目标网络。
更新目标:
损失函数:
通过梯度下降更新 θ,使得 Q 网络预测逼近 Q-learning 的目标。
简化环境:
两个状态 S1, S2;两个动作 A1, A2。
奖励:S1-A1 → S2,奖励 +1;其余动作奖励 0。
γ=1。
假设当前网络估计:
Q(S1,A1)=0.5, Q(S1,A2)=0.2
Q(S2,A1)=0.3, Q(S2,A2)=0.4
一次交互:执行 (S1,A1),得到 r=+1, 下一状态 S2。
目标:
y = r + γ * max Q(S2, a; θ⁻) = 1 + 0.4 = 1.4
损失:
L = (1.4 - 0.5)^2 = 0.81
训练后,Q(S1,A1) 会被推向 1.4,更接近最优值。
以下是一个简化版 DQN 实现(以 Gym 的 CartPole 为例):
import random
import gym
import numpy as np
import torch
import torch .nn as nn
import torch .optim as optim
from collections import deque
# Q 网络
class QNetwork (nn .Module ):
def __init__ (self , state_dim , action_dim ):
super (QNetwork , self ).__init__ ()
self .layers = nn .Sequential (
nn .Linear (state_dim , 64 ),
nn .ReLU (),
nn .Linear (64 , 64 ),
nn .ReLU (),
nn .Linear (64 , action_dim )
)
def forward (self , x ):
return self .layers (x )
# DQN 训练器
class DQNAgent :
def __init__ (self , state_dim , action_dim ):
self .q_net = QNetwork (state_dim , action_dim )
self .target_net = QNetwork (state_dim , action_dim )
self .target_net .load_state_dict (self .q_net .state_dict ())
self .memory = deque (maxlen = 10000 )
self .optimizer = optim .Adam (self .q_net .parameters (), lr = 1e-3 )
self .gamma = 0.99
self .batch_size = 64
self .update_target_freq = 100
self .steps = 0
self .action_dim = action_dim
def select_action (self , state , eps = 0.1 ):
if random .random () < eps :
return random .randrange (self .action_dim )
state = torch .FloatTensor (state ).unsqueeze (0 )
return self .q_net (state ).argmax ().item ()
def store (self , s , a , r , s_ , done ):
self .memory .append ((s ,a ,r ,s_ ,done ))
def update (self ):
if len (self .memory ) < self .batch_size :
return
batch = random .sample (self .memory , self .batch_size )
s , a , r , s_ , d = zip (* batch )
s = torch .FloatTensor (s )
a = torch .LongTensor (a ).unsqueeze (1 )
r = torch .FloatTensor (r ).unsqueeze (1 )
s_ = torch .FloatTensor (s_ )
d = torch .FloatTensor (d ).unsqueeze (1 )
q_values = self .q_net (s ).gather (1 , a )
with torch .no_grad ():
target = r + self .gamma * (1 - d ) * self .target_net (s_ ).max (1 , keepdim = True )[0 ]
loss = nn .MSELoss ()(q_values , target )
self .optimizer .zero_grad ()
loss .backward ()
self .optimizer .step ()
self .steps += 1
if self .steps % self .update_target_freq == 0 :
self .target_net .load_state_dict (self .q_net .state_dict ())
# 运行示例
env = gym .make ("CartPole-v1" )
agent = DQNAgent (env .observation_space .shape [0 ], env .action_space .n )
for episode in range (10 ): # 简短演示
state = env .reset ()[0 ]
done = False
while not done :
action = agent .select_action (state , eps = 0.1 )
next_state , reward , done , _ , _ = env .step (action )
agent .store (state , action , reward , next_state , done )
agent .update ()
state = next_state
DQN 用神经网络近似 Q 函数,解决大规模状态空间问题。
关键技巧:经验回放 + 目标网络。
损失函数基于 Q-learning 的 TD 目标。
为后续改进(Double DQN, Dueling DQN, Prioritized Replay 等)奠定了基础。
笔记是AI生成的,目前来看有一些错误,后面慢慢检查。