【RL Latest Tech】离线强化学习:保守Q学习 (CQL) 算法
Conservative Q-Learning (CQL) 是由Sergey Levine及其团队于2020年提出的一种针对离线强化学习的算法。CQL旨在解决离线强化学习中的两个主要问题:分布偏移(Distributional Shift) 和 过度乐观的值函数估计(Overestimation of Q-Values)。CQL通过对Q值的保守约束,确保学习到的策略更为稳健,避免过度依赖于离线数据
📢本篇文章是博主强化学习RL领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在👉强化学习专栏:
【强化学习】(20)---《离线强化学习:保守Q学习 (CQL) 算法》
离线强化学习:保守Q学习 (CQL) 算法
目录
1.CQL算法简介
Conservative Q-Learning (CQL) 是由Sergey Levine及其团队于2020年提出的一种针对离线强化学习的算法。CQL旨在解决离线强化学习中的两个主要问题:分布偏移(Distributional Shift) 和 过度乐观的值函数估计(Overestimation of Q-Values)。CQL通过对Q值的保守约束,确保学习到的策略更为稳健,避免过度依赖于离线数据中的稀有样本或未充分探索的区域。
2.CQL算法的背景与动机
在离线强化学习中,智能体无法直接与环境交互,而是依赖于历史数据来学习最优策略。然而,由于历史数据往往是基于某些固定策略生成的,新策略在状态空间中未充分探索的区域可能会导致估计值函数出现过度乐观的问题。这种现象称为分布偏移,会导致模型的性能在实际环境中大幅下降。CQL通过在Q学习的过程中加入保守性约束,减轻这些问题。
3.CQL算法实现
3.1 CQL的核心思想
CQL的核心思想是通过对Q值施加保守约束,来防止策略在离线数据中未充分覆盖的状态和动作上产生过度乐观的值估计。具体而言,CQL在标准的Q-learning目标函数中加入了一个额外的正则化项,限制估计的Q值不超过某个合理的范围。
3.2 CQL的目标函数
传统的Q学习更新目标是最小化贝尔曼残差:
CQL在此基础上提出了一种保守的修正,通过引入正则项来约束Q值估计,避免在稀有样本或未充分探索的区域中产生过高的Q值估计。CQL的目标函数形式如下:
其中:
- 第一项是针对当前策略下的动作Q值的期望。
- 第二项则是对历史数据中实际观测到的动作的Q值期望。
- 是一个超参数,用来调节保守程度。
3.3 关键公式解析
CQL通过减少策略生成的Q值期望,并增加数据中真实动作的Q值期望,确保估计值函数更加保守,从而有效降低分布偏移带来的不良影响。这种约束使得模型不倾向于给在历史数据中未出现过或出现很少的状态-动作对分配过高的Q值。
CQL的优化目标可以总结为:
其中:
- 是离线数据集,表示历史数据中的状态-动作对。
- 是即时奖励。
- 是折扣因子,用来衡量未来奖励的重要性。
- 是正则化系数,调节了CQL对保守性约束的强度。
通过增加这个保守性正则化项,CQL确保估计的Q值不会因为稀有样本或未探索区域的高Q值而产生策略偏差。
3.4 CQL算法的执行流程
CQL算法的执行过程如下:
- 收集数据:从环境中收集历史交互数据,形成一个离线数据集 。
- 初始化Q值函数:初始化Q值函数 。
- 迭代更新Q值:通过最小化CQL目标函数,更新Q值函数。每次更新时,首先基于当前策略和历史数据计算Q值,然后将保守性正则项加入损失函数,进行梯度下降优化。
- 策略更新:基于更新后的Q值函数,更新策略,使得策略趋向于选择具有高Q值的动作,同时不过度偏离历史数据中的行为。
- 收敛检测:重复上述步骤,直到Q值和策略收敛。
3.5 CQL算法的优势
- 分布偏移鲁棒性:CQL通过对策略的行为施加约束,有效应对分布偏移问题,使得策略不容易在历史数据未覆盖的区域出现过度乐观的决策。
- 值函数的保守估计:通过在Q值更新时引入正则化,CQL减少了估计值函数的偏差,防止离线数据的稀疏性导致的过高Q值估计。
- 适用场景广泛:CQL适用于那些无法进行在线探索的高风险场景,如医疗决策、自动驾驶、机器人控制等。
3.6 实验结果
在多个离线强化学习基准任务中,CQL算法显示出了卓越的性能,特别是在数据分布偏移严重的情况下。与传统的Q-learning算法相比,CQL的策略更为稳健,能够在没有在线探索的情况下取得较好的决策效果。
4.总结
CQL作为一种专为离线强化学习设计的算法,通过对Q值函数施加保守性约束,解决了离线学习中的分布偏移和Q值过度乐观的问题。它的提出为离线强化学习在高风险、高成本应用中的落地提供了重要的理论基础和实践指导。
如需阅读详细文献,可以访问原文链接:Conservative Q-Learning for Offline Reinforcement Learning。
[Python] CQL算法实现1
🔥若是下面代码复现困难或者有问题,欢迎评论区留言;需要以整个项目形式的代码,请在评论区留下您的邮箱📌,以便于及时分享给您(私信难以及时回复)。
在经典的 CartPole
环境中实现 CQL 算法,可以使用 gym
提供的环境来测试 CQL 算法的性能。CartPole
环境的立杆的环境,如下所示,通常用来简单的测试下强化学习算法。
下面是实现代码,包含了环境创建、CQL 算法应用和测试逻辑。
安装依赖
在运行代码之前,确保安装以下依赖:
pip install torch gym numpy
实现 CQL 在 CartPole 环境中的算法
"""《CQL算法在 CartPole 环境中实现》
时间:2024.07.27
环境:gym-taxi
作者:不去幼儿园
"""
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# Q-network definition
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, action_dim)
def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
return self.fc3(x)
# CQL algorithm
class ConservativeQLearning:
def __init__(self, state_dim, action_dim, gamma=0.99, alpha=1.0, lr=3e-4):
self.q_network = QNetwork(state_dim, action_dim)
self.target_q_network = QNetwork(state_dim, action_dim)
self.target_q_network.load_state_dict(self.q_network.state_dict())
self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
self.gamma = gamma
self.alpha = alpha
def update(self, batch, policy):
states, actions, rewards, next_states, dones = batch
# Compute target Q value
with torch.no_grad():
next_q_values = self.target_q_network(next_states)
max_next_q_values = torch.max(next_q_values, dim=1)[0]
target_q_values = rewards + self.gamma * max_next_q_values * (1 - dones)
# Compute current Q value
current_q_values = self.q_network(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
# Compute conservative Q-learning loss
q_diff = self.q_network(states) - policy(states)
conservative_loss = torch.mean(torch.logsumexp(q_diff / self.alpha, dim=1) - torch.mean(q_diff, dim=1))
loss = nn.MSELoss()(current_q_values, target_q_values) + self.alpha * conservative_loss
# Update Q network
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def soft_update(self, tau=0.005):
for target_param, param in zip(self.target_q_network.parameters(), self.q_network.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
# Replay buffer to store experience
class ReplayBuffer:
def __init__(self, capacity, state_dim):
self.capacity = capacity
self.buffer = []
self.position = 0
self.state_dim = state_dim
def push(self, state, action, reward, next_state, done):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
batch = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, rewards, next_states, dones = zip(*[self.buffer[idx] for idx in batch])
return (
torch.FloatTensor(np.array(states)),
torch.LongTensor(actions),
torch.FloatTensor(rewards),
torch.FloatTensor(np.array(next_states)),
torch.FloatTensor(dones),
)
def __len__(self):
return len(self.buffer)
# 状态归一化
class Normalizer:
def __init__(self, state_dim):
self.state_dim = state_dim
self.state_mean = np.zeros(state_dim)
self.state_std = np.ones(state_dim)
self.count = 0
def update(self, state):
self.count += 1
self.state_mean += (state - self.state_mean) / self.count
self.state_std = np.sqrt(((self.count - 1) * self.state_std ** 2 + (state - self.state_mean) ** 2) / self.count)
def normalize(self, state):
return (state - self.state_mean) / (self.state_std + 1e-8)
# CartPole environment CQL training
def train_cql_in_cartpole(env_name="CartPole-v1", num_episodes=500, batch_size=64, buffer_capacity=10000, alpha=1.0):
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
cql = ConservativeQLearning(state_dim, action_dim, alpha=alpha, lr=1e-4)
# 初始化ReplayBuffer时增加容量
buffer_capacity = 50000 # 增大容量至50000
replay_buffer = ReplayBuffer(buffer_capacity, state_dim)
policy = lambda x: cql.q_network(x)
total_rewards = []
# 使用Normalizer对状态进行归一化
normalizer = Normalizer(state_dim)
for episode in range(num_episodes):
state, _ = env.reset()
episode_reward = 0
done = False
while not done:
state_tensor = torch.FloatTensor(state).unsqueeze(0)
q_values = cql.q_network(state_tensor)
action = torch.argmax(q_values).item()
next_state, reward, done, _, __ = env.step(action)
replay_buffer.push(state, action, reward, next_state, done)
episode_reward += reward
state = next_state
if len(replay_buffer) > batch_size:
batch = replay_buffer.sample(batch_size)
cql.update(batch, policy)
cql.soft_update()
total_rewards.append(episode_reward)
if episode % 10 == 0:
print(f"Episode {episode}/{num_episodes}, Reward: {episode_reward}")
# if np.mean(total_rewards[-10:]) > 195: # Terminate early if environment solved
# print("Solved CartPole!")
# break
return total_rewards
# Train and test CQL in CartPole
rewards = train_cql_in_cartpole(num_episodes=1000)
# Plot the rewards over time
import matplotlib.pyplot as plt
plt.plot(rewards)
plt.title("CQL Training in CartPole")
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.show()
要在本地环境中显示 CartPole
环境的测试动画,可以使用 gym
提供的内置功能来渲染环境。在本地运行以下代码,以实现测试动画的显示。
import gym
import torch
# 测试CQL在CartPole环境中的表现并显示动画
# 测试CQL在CartPole环境中的表现并显示动画
def test_cql_in_cartpole(env_name="CartPole-v1", num_episodes=5):
env = gym.make(env_name, render_mode="human") # 指定渲染模式为 human
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
cql = ConservativeQLearning(state_dim, action_dim)
for episode in range(num_episodes):
state, _ = env.reset()
done = False
total_reward = 0
while not done:
env.render() # 显示环境动画
state_tensor = torch.FloatTensor(state).unsqueeze(0)
q_values = cql.q_network(state_tensor)
action = torch.argmax(q_values).item()
next_state, reward, done, _, __ = env.step(action) # Gym新版本返回五个值
total_reward += reward
state = next_state
print(f"Episode {episode}, Total Reward: {total_reward}")
env.close()
# 运行测试并显示动画
test_cql_in_cartpole(num_episodes=10)
[Results] 运行结果
观察到奖励值不收敛,下文[注意事项1]对此进行相关分析,也是强化学习其他算法中出现奖励值不收敛情况下,可以采取的相关操作,具备一定借鉴意义
[Notice]代码解释
- QNetwork: 实现了一个简单的三层神经网络,用于估计 Q 值。
- ConservativeQLearning: 实现了 CQL 算法。
update
方法中包含了保守性正则化项,用于约束 Q 值的估计。soft_update
用于对目标 Q 网络进行软更新。 - ReplayBuffer: 实现了经验回放缓冲区,用于存储环境交互过程中的样本,并从中随机采样批量数据进行更新。
- train_cql_in_cartpole: 这是核心训练函数,用于在
CartPole
环境中训练 CQL 算法。每个回合中,智能体会与环境交互,生成样本并存储到经验回放缓冲区中。当缓冲区中有足够的样本时,CQL 开始更新 Q 网络。 - 结果测试: 观察
CartPole
杆子的平衡动作
运行该代码时,CQL 算法会逐渐学习到如何在 CartPole
环境中保持杆子的平衡。可以通过绘制奖励随回合变化的曲线来观察训练效果。
[Notice] 注意事项1
奖励值不收敛的原因可能有多种,尤其是在离线强化学习的情况下。以下是一些可能导致奖励不收敛的常见原因,以及如何修改代码以提高模型的表现:
1. Replay Buffer 不足
在离线强化学习中,经验回放缓冲区用于存储从环境中收集的样本。如果 ReplayBuffer
的容量太小或者填充不充分,模型将无法获得足够的数据进行有效的学习。这可能会导致模型更新不稳定,表现为奖励不收敛。
修改:
增加 ReplayBuffer
的容量,并确保在开始训练之前,缓冲区中存有足够的数据。此外,增加探索动作的随机性,以确保缓冲区中包含多样化的样本。
# Replay buffer to store experience
class ReplayBuffer:
def __init__(self, capacity, state_dim):
self.capacity = capacity
self.buffer = []
self.position = 0
self.state_dim = state_dim
def push(self, state, action, reward, next_state, done):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
batch = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, rewards, next_states, dones = zip(*[self.buffer[idx] for idx in batch])
return (
torch.FloatTensor(np.array(states)),
torch.LongTensor(actions),
torch.FloatTensor(rewards),
torch.FloatTensor(np.array(next_states)),
torch.FloatTensor(dones),
)
def __len__(self):
return len(self.buffer)
# 初始化ReplayBuffer时增加容量
buffer_capacity = 50000 # 增大容量至50000
replay_buffer = ReplayBuffer(buffer_capacity, state_dim)
2. 过早或不充分的探索
在强化学习中,探索是至关重要的。如果模型仅通过贪婪策略(即始终选择最优动作)进行探索,可能会过早陷入局部最优解。特别是在早期训练阶段,应通过增加动作的随机性来鼓励更多的探索。
修改:
在选择动作时加入 epsilon-greedy 策略。初期随机选择动作的概率应较高,然后逐渐减少以收敛到最佳策略。
# Add epsilon-greedy action selection for exploration
def select_action(q_values, epsilon):
if np.random.rand() < epsilon:
return np.random.randint(q_values.shape[-1]) # 随机选择动作
else:
return torch.argmax(q_values).item() # 选择Q值最大的动作
epsilon_start = 1.0
epsilon_end = 0.1
epsilon_decay = 0.995
epsilon = epsilon_start
for episode in range(num_episodes):
state, _ = env.reset()
state = state[0]
done = False
episode_reward = 0
while not done:
env.render()
# Select action using epsilon-greedy strategy
state_tensor = torch.FloatTensor(state).unsqueeze(0)
q_values = cql.q_network(state_tensor)
action = select_action(q_values, epsilon)
next_state, reward, done, _, __ = env.step(action)
replay_buffer.push(state, action, reward, next_state, done)
episode_reward += reward
state = next_state
if len(replay_buffer) > batch_size:
batch = replay_buffer.sample(batch_size)
cql.update(batch, policy)
cql.soft_update()
# 更新epsilon值,减少随机性,逐渐过渡到贪婪策略
epsilon = max(epsilon_end, epsilon * epsilon_decay)
print(f"Episode {episode}, Reward: {episode_reward}")
3. 学习率过高或过低
学习率会影响模型的训练速度和稳定性。过高的学习率会导致更新不稳定,表现为奖励不收敛;而过低的学习率则会导致训练过慢,甚至无法学到有效的策略。
修改:
调整优化器的学习率,尝试稍微降低学习率以查看是否能提高模型的稳定性。
# 调整学习率
cql = ConservativeQLearning(state_dim, action_dim, alpha=alpha, lr=1e-4) # 将学习率降低至1e-4
4. CQL 保守性正则化的调整
CQL 算法中的保守性正则化项用于控制Q值的保守估计。如果 过大,会使模型过于保守,从而导致模型难以收敛到最优策略。相反,过小的会使模型忽略保守性,导致分布偏移问题。
修改:
尝试不同的 值来找到合适的正则化强度。
# 调整保守性参数alpha
cql = ConservativeQLearning(state_dim, action_dim, alpha=0.5) # 调整 alpha 参数为 0.5
5. 状态标准化或归一化
强化学习中的状态通常具有不同的尺度。对状态进行标准化或归一化有助于提高模型的收敛性,尤其是当输入状态具有较大的范围差异时。
修改:
在训练过程中,对状态数据进行归一化或标准化。
# 状态归一化
class Normalizer:
def __init__(self, state_dim):
self.state_dim = state_dim
self.state_mean = np.zeros(state_dim)
self.state_std = np.ones(state_dim)
self.count = 0
def update(self, state):
self.count += 1
self.state_mean += (state - self.state_mean) / self.count
self.state_std = np.sqrt(((self.count - 1) * self.state_std**2 + (state - self.state_mean)**2) / self.count)
def normalize(self, state):
return (state - self.state_mean) / (self.state_std + 1e-8)
# 使用Normalizer对状态进行归一化
normalizer = Normalizer(state_dim)
for episode in range(num_episodes):
state, _ = env.reset()
normalizer.update(state)
state = normalizer.normalize(state)
done = False
episode_reward = 0
while not done:
env.render()
state_tensor = torch.FloatTensor(state).unsqueeze(0)
q_values = cql.q_network(state_tensor)
action = select_action(q_values, epsilon)
next_state, reward, done, _, __ = env.step(action)
normalizer.update(next_state)
next_state = normalizer.normalize(next_state)
replay_buffer.push(state, action, reward, next_state, done)
episode_reward += reward
state = next_state
if len(replay_buffer) > batch_size:
batch = replay_buffer.sample(batch_size)
cql.update(batch, policy)
cql.soft_update()
epsilon = max(epsilon_end, epsilon * epsilon_decay)
print(f"Episode {episode}, Reward: {episode_reward}")
总结:通过增加经验回放缓冲区的容量、引入
epsilon-greedy
策略增强探索、调整学习率和正则化参数、以及对状态进行标准化,可以提高模型的收敛性。你可以根据实验结果,逐步调整这些参数以优化算法的表现。
[论文环境介绍]
在原论文《Conservative Q-Learning for Offline Reinforcement Learning》中,作者采用了多个标准的强化学习基准环境来评估 CQL 算法的性能。这些环境大部分来自于 D4RL(Datasets for Deep Data-Driven Reinforcement Learning) 基准库,该库专门为离线强化学习设计,包含多个不同领域的任务数据集。主要使用的环境包括以下几类:
1. MuJoCo 控制任务
MuJoCo 是一个广泛用于机器人和连续控制任务的物理模拟环境,CQL 算法在多个 MuJoCo 环境上进行了测试,包括:
- HalfCheetah:一种机器人四足动物,任务是使其在水平面上奔跑。
- Hopper:一种单足跳跃的机器人,任务是跳跃并保持平衡。
- Walker2d:一个两足行走机器人,任务是让机器人向前行走并保持稳定。
这些任务的目标是学习到能够稳定控制这些复杂机械体的策略,同时尽可能高效地完成给定任务。MuJoCo 任务采用连续动作空间,非常适合评估离线强化学习算法在复杂控制问题上的表现。
2. Ant 环境
Ant:四足机器人任务,机器人需要学会如何在平面上稳定地移动,避免摔倒。与 HalfCheetah 和 Walker2d 类似,Ant 也是一个连续动作空间任务,适合评估强化学习算法在高维控制空间下的表现。
3. Adroit 任务
Adroit 是用于模拟手部控制任务的环境,涉及复杂的手指操控。具体任务包括:
- Pen:任务是用手操控笔,使其在手中平稳旋转。
- Hammer:任务是控制手部使用锤子进行敲击。
- Door:任务是控制手打开门。
- Relocate:任务是使用手重新定位某个物体。
这些任务具有高维的状态和动作空间,尤其适合测试算法在稀疏奖励和高难度操控任务下的表现。
4. Maze2D 环境
Maze2D 是一个迷宫导航任务,目标是让智能体从起点导航到终点。这类任务主要考验智能体在二维平面上的导航能力,并涉及稀疏奖励的情况。
5. Gym Environments(e.g., CartPole)
虽然在 CQL 原始论文中主要以 MuJoCo 和 Adroit 这样的复杂环境为主,但经典的 Gym 环境如 CartPole 等也可用于算法的基础测试和验证。在一些早期实验或简单环境测试中,CartPole 也常用来验证算法的有效性。
6. 其他任务
除了以上任务,CQL 还在一些其他基准环境中进行了测试,包括 Atari 游戏等经典强化学习任务。不过,论文的主要实验集中在 D4RL 基准库中的连续控制和手部操控任务上。
D4RL 基准库链接:
论文提到的 D4RL 基准库可以在以下链接找到:D4RL: Datasets for Deep Data-Driven Reinforcement Learning
CQL 算法在多个复杂的连续控制和操控任务上展示了强大的性能,尤其在 MuJoCo 和 Adroit 等环境中的表现证明了它应对分布偏移和过度乐观估计问题的有效性。这些任务具有高维状态和动作空间,非常适合测试离线强化学习算法的稳健性和性能。
[ 原文实现] CQL算法实现2
🔥若是下面代码复现困难或者有问题,欢迎评论区留言;需要以整个项目形式的代码,请在评论区留下您的邮箱📌,以便于及时分享给您(私信难以及时回复)。
1. 安装必要的库
首先,确保安装了 d4rl
和 mujoco-py
,以及之前提到的 gym
和 torch
。
pip install d4rl mujoco-py torch gym numpy
2. 改写代码以使用 D4RL 环境
import gym
import d4rl # Import D4RL库来加载离线数据
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# Q-network definition
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, action_dim)
def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
return self.fc3(x)
# CQL algorithm
class ConservativeQLearning:
def __init__(self, state_dim, action_dim, gamma=0.99, alpha=1.0, lr=3e-4):
self.q_network = QNetwork(state_dim, action_dim)
self.target_q_network = QNetwork(state_dim, action_dim)
self.target_q_network.load_state_dict(self.q_network.state_dict())
self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
self.gamma = gamma
self.alpha = alpha
def update(self, batch, policy):
states, actions, rewards, next_states, dones = batch
# Compute target Q value
with torch.no_grad():
next_q_values = self.target_q_network(next_states)
max_next_q_values = torch.max(next_q_values, dim=1)[0]
target_q_values = rewards + self.gamma * max_next_q_values * (1 - dones)
# Compute current Q value
current_q_values = self.q_network(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
# Compute conservative Q-learning loss
q_diff = self.q_network(states) - policy(states)
conservative_loss = torch.mean(torch.logsumexp(q_diff / self.alpha, dim=1) - torch.mean(q_diff, dim=1))
loss = nn.MSELoss()(current_q_values, target_q_values) + self.alpha * conservative_loss
# Update Q network
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def soft_update(self, tau=0.005):
for target_param, param in zip(self.target_q_network.parameters(), self.q_network.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
# Replay buffer to store experience
class ReplayBuffer:
def __init__(self, capacity, state_dim):
self.capacity = capacity
self.buffer = []
self.position = 0
self.state_dim = state_dim
def push(self, state, action, reward, next_state, done):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
batch = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, rewards, next_states, dones = zip(*[self.buffer[idx] for idx in batch])
return (
torch.FloatTensor(np.array(states)),
torch.LongTensor(actions),
torch.FloatTensor(rewards),
torch.FloatTensor(np.array(next_states)),
torch.FloatTensor(dones),
)
def __len__(self):
return len(self.buffer)
训练函数代码,测试函数参照上文的改写,实现的方法相近,修改相关参数和代码即可。
# 训练CQL算法并在D4RL环境中测试
def train_cql_in_d4rl(env_name="halfcheetah-medium-v2", num_episodes=100, batch_size=256, buffer_capacity=100000, alpha=1.0):
# 加载D4RL环境
env = gym.make(env_name)
dataset = env.get_dataset()
# 提取数据集
states = torch.FloatTensor(dataset['observations'])
actions = torch.FloatTensor(dataset['actions'])
rewards = torch.FloatTensor(dataset['rewards'])
next_states = torch.FloatTensor(dataset['next_observations'])
dones = torch.FloatTensor(dataset['terminals'])
state_dim = states.shape[1]
action_dim = actions.shape[1]
# 初始化CQL和经验回放缓冲区
cql = ConservativeQLearning(state_dim, action_dim, alpha=alpha)
replay_buffer = ReplayBuffer(buffer_capacity, state_dim)
# 将数据集加载到ReplayBuffer中
for i in range(len(states)):
replay_buffer.push(states[i], actions[i], rewards[i], next_states[i], dones[i])
policy = lambda x: cql.q_network(x)
total_rewards = []
for episode in range(num_episodes):
state = env.reset()
episode_reward = 0
done = False
while not done:
state_tensor = torch.FloatTensor(state).unsqueeze(0)
q_values = cql.q_network(state_tensor)
action = torch.argmax(q_values).detach().numpy()
next_state, reward, done, _ = env.step(action)
replay_buffer.push(state, action, reward, next_state, done)
episode_reward += reward
state = next_state
if len(replay_buffer) > batch_size:
batch = replay_buffer.sample(batch_size)
cql.update(batch, policy)
cql.soft_update()
total_rewards.append(episode_reward)
print(f"Episode {episode}, Reward: {episode_reward}")
return total_rewards
# 训练并绘制奖励曲线
rewards = train_cql_in_d4rl(num_episodes=100)
import matplotlib.pyplot as plt
plt.plot(rewards)
plt.title("CQL Training in D4RL (halfcheetah-medium-v2)")
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.show()
[Notice] 注意事项2
要运行基于 MuJoCo 环境的 D4RL 基准库,您需要安装 MuJoCo 模拟器。MuJoCo 是一款广泛用于物理模拟的引擎,用于强化学习任务中的复杂物理环境,比如四足机器人和其他连续控制任务。
解决方法:安装 MuJoCo
-
下载 MuJoCo: 前往 MuJoCo 官方网站,并下载适用于您的操作系统的 MuJoCo 2.1 版本。
-
安装 MuJoCo:
- 将下载的 MuJoCo 解压到指定目录,例如:
C:\Users\PC\.mujoco\mujoco210
。 - 确保在系统的环境变量中添加了指向该目录的路径。
- 将下载的 MuJoCo 解压到指定目录,例如:
-
安装
mujoco-py
: 您需要安装mujoco-py
作为 MuJoCo 的 Python 接口。pip install mujoco-py
-
安装完成后,设置许可证: 下载 MuJoCo 时会提供一个
mjkey.txt
文件。将此文件复制到 MuJoCo 的安装目录下,例如:C:\Users\PC\.mujoco\mjkey.txt
。
升级或降级 Gym 版本
根据警告提示,您可以通过升级到 gym==0.25.1
或降级到 gym==0.23.1
来解决与 gym.make()
相关的问题。推荐的步骤是升级 gym
到较新的版本:
pip install gym==0.25.1
文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者关注VX公众号:Rain21321,联系作者。✨
更多推荐
所有评论(0)