DQN 模型解析及 Pytorch 完整代码

深度强化学习(Deep Reinforcement Learning)是强化学习与深度学习相结合的重要领域。它的一个经典模型是 DQN(Deep Q-Network),是由谷歌 DeepMind 在 2013 年提出的。DQN 通过深度神经网络来逼近 Q 值函数,在 Atari 游戏中取得了显著的成功。

DQN 模型简介

强化学习的核心是 Q 学习(Q-learning),其目标是通过学习一个 Q 函数 ( Q(s, a) ) 来评估在状态 ( s ) 采取动作 ( a ) 的期望未来收益。DQN 模型的创新之处在于使用深度神经网络来估计 Q 值,通过将离散动作空间的 Q 值映射到神经网络输出。

DQN 的训练过程主要包括以下几个步骤:

  1. 经验回放(Experience Replay):将智能体的经历(状态、动作、奖励、新状态)存储在经验回放池中,并从中随机抽取小批量数据进行训练。这样可以打破数据之间的相关性,提高训练效果。

  2. 目标网络(Target Network):引入一个目标网络来计算 Q 值,以提高训练的稳定性。目标网络的参数在每隔一段时间后才会更新为主网络的参数,从而避免训练过程中的震荡。

  3. 损失函数:DQN 使用均方误差损失函数来更新 Q 值。损失函数定义为: [ L = \mathbb{E} \left[ \left( r + \gamma \max_{a'} Q'(s', a'; \theta^{-}) - Q(s, a; \theta) \right)^2 \right] ] 其中,( r ) 是当前奖励,( \gamma ) 是折扣因子,( Q' ) 是目标网络,( Q ) 是主网络。

Pytorch 实现 DQN

以下是一个基于 Pytorch 的简单 DQN 算法实现,应用于 OpenAI Gym 中的 CartPole 环境。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import gym

# 超参数
EPISODES = 1000
MAX_T = 200
GAMMA = 0.99
LEARNING_RATE = 0.001
MEMORY_SIZE = 10000
BATCH_SIZE = 64
TARGET_UPDATE_FREQ = 10

# 网络结构
class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_size, 24)
        self.fc2 = nn.Linear(24, 24)
        self.fc3 = nn.Linear(24, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# 训练
def train_dqn():
    env = gym.make('CartPole-v1')
    input_size = env.observation_space.shape[0]
    output_size = env.action_space.n

    policy_net = DQN(input_size, output_size)
    target_net = DQN(input_size, output_size)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
    memory = deque(maxlen=MEMORY_SIZE)
    steps_done = 0

    for episode in range(EPISODES):
        state = env.reset()
        for t in range(MAX_T):
            steps_done += 1
            epsilon = 1.0 / (1 + steps_done / 200)  # epsilon-greedy策略
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    action = policy_net(torch.FloatTensor(state)).argmax().item()

            next_state, reward, done, _ = env.step(action)
            if done:
                reward = -1

            memory.append((state, action, reward, next_state, done))
            state = next_state

            if len(memory) >= BATCH_SIZE:
                transitions = random.sample(memory, BATCH_SIZE)
                batch = list(zip(*transitions))
                states = torch.FloatTensor(batch[0])
                actions = torch.LongTensor(batch[1])
                rewards = torch.FloatTensor(batch[2])
                next_states = torch.FloatTensor(batch[3])
                dones = torch.FloatTensor(batch[4])

                q_values = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()
                next_q_values = target_net(next_states).max(1)[0]
                expected_q_values = rewards + (GAMMA * next_q_values * (1 - dones))

                loss = nn.MSELoss()(q_values, expected_q_values.detach())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        if episode % TARGET_UPDATE_FREQ == 0:
            target_net.load_state_dict(policy_net.state_dict())

        print(f"Episode {episode + 1}/{EPISODES}, Score: {t + 1}")

if __name__ == "__main__":
    train_dqn()

代码解析

这个代码实现了一个基本的 DQN 算法,使用 Pytorch 框架构建深度神经网络。首先,我们定义了 DQN 网络结构,包含三层全连接层。然后,在 train_dqn 函数中,初始化环境、网络、优化器和经验回放池。通过 epsilon-greedy 策略选择动作,并将经验存储到回放池中。

在学习过程中,我们从回放池中随机抽取一批经验,利用这些经验训练 DQN 网络。在每个 episode 的末尾,我们定期更新目标网络的参数。通过这样的方式,我们能够让 DQN 不断地改进,从而在环境中获得更好的表现。

总结来说,DQN 是一个强大的深度强化学习模型,通过结合经验回放和目标网络等技术,显著提高了 Q 学习算法的效率和效果。通过 Pytorch 实现 DQN,不仅可以帮助我们更好地理解深度学习原理,还可以为增强学习算法的研究与应用提供实用工具。

点赞(0) 打赏

微信小程序

微信扫一扫体验

微信公众账号

微信扫一扫加关注

发表
评论
返回
顶部