td0算法(深度强化学习-TD3算法原理与代码)
深度强化学习-TD3算法原理与代码
引言
1 TD3算法简介
2 TD3算法原理
2.1 双重网络
2.1.1 网络过估计的成因
2.1.2 双重网络的引入
2.2 目标策略平滑正则化
2.3 延迟更新
3 TD3算法更新过程
4 TD3算法伪代码
5 PyTorch代码实现
6 实验结果
7 结论
引言
Twin Delayed Deep Deterministic policy gradient (TD3)是由Scott Fujimoto等人在Deep Deterministic Policy Gradient (DDPG)算法上改进得到的一种用于解决连续控制问题的在线(on-line)异策(off-policy)式深度强化学习算法 。本质上 ,TD3算法就是将Double Q-Learning算法的思想融入到DDPG算法中 。前面我们已经分别介绍过DDPG算法和Double DQN算法的原理并进行了代码实现 ,有兴趣的小伙伴可以先去看一下,之后再来看本文应该就能很容易理解 。本文就带领大家了解一下TD3算法的具体原理 ,并采用Pytorch进行实现 ,论文和代码的链接见下方 。
论文:http://proceedings.mlr.press/v80/fujimoto18a/fujimoto18a.pdf
代码:https://github.com/indigoLovee/TD3
1 TD3算法简介
之前我们在讲Double DQN算法时就曾分析过Deep Q-Learning (DQN)算法存在高估问题 ,而DDPG算法是从DQN算法进化得到 ,因此它也存在一样的问题 。为此 ,TD3算法就很自然地被提出 ,主要解决DDPG算法的高估问题 。
TD3算法也是Actor-Critic (AC)框架下的一种确定性深度强化学习算法 ,它结合了深度确定性策略梯度算法和双重Q学习 ,在许多连续控制任务上都取得了不错的表现 。
2 TD3算法原理
TD3算法在DDPG算法的基础上 ,提出了三个关键技术:
(1)双重网络 (Double network):采用两套Critic网络,计算目标值时取二者中的较小值 ,从而抑制网络过估计问题 。
(2)目标策略平滑正则化 (Target policy smoothing regularization):计算目标值时 ,在下一个状态的动作上加入扰动,从而使得价值评估更准确 。
(3)延迟更新 (Delayed update):Critic网络更新多次后 ,再更新Actor网络 ,从而保证Actor网络的训练更加稳定。
2.1 双重网络
TD3算法中包括六个网络,分别是Actor网络 ,Critic1网络 ,Critic2网络 ,Target Actor网络 ,Target Critic1网络 ,Target Critic2网络 。相较于DDPG算法 ,TD3算法多了一套Critic网络 ,这就是双重网络的由来 。本节首先分析网络过估计的成因 ,然后引入双重网络 ,最后介绍算法的更新过程。
2.1.1 网络过估计的成因
DQN算法的高估主要来源于两个方面:自举 (Bootstrapping)和最大化,DDPG算法也是如此 。之前我们在讲DDPG算法时曾强调 ,如果高估是均匀的 ,对于智能体最终的决策不会带来影响;如果是非均匀的,对于智能体最终的决策会带来显著影响 。然而实际上网络的高估通常是非均匀的 ,这里简单分析一下原因。
在更新Critic网络时 ,假设从经验池中采样的数据为 。首先我们会计算目标
由于网络高估,因此
其中 , 表示状态动作对真实的最优状态动作价值 。
接着我们会让逼近 ,从而使得出现过估计 ,即
其中 , 表示状态动作对真实的最优状态动作价值 。
每次采样状态动作对来对Critic网络进行更新时 ,就会让网络高估的状态动作价值 ,而在经验池中的频率显然是不均匀的 。如果出现的频率越高 ,那么高估就越严重 。因此 ,网络的高估是非均匀的 ,而非均匀的高估对智能体的决策有害,因此我们需要避免网络高估 。
2.1.2 双重网络的引入
DDPG算法采用目标网络解决了自举问题 ,有兴趣的小伙伴可以看一下我的那篇博文 ,里面详细分析了自举问题带来的危害,以及目标网络是如何解决自举问题的 ,这里就不再赘述了 。但是 ,除了自举以外,最大化也是造成过估计的重要原因 ,因此要想彻底解决网络过估计 ,还需要解决最大化问题 。这里简单分析一下为什么最大化会造成网络过估计 。
假设为观测到的真实值 ,在其中加入均值为0的随机噪声 ,得到。由于噪声的均值为0 ,因此满足
但是随机噪声会让最大值变大 ,即
同理 , 随机噪声会让最小值变小 ,即
这三个公式都可以被证明出来 ,这里就不给大家证明了 。
回到DQN算法的高估问题,假设每个状态动作对的真实状态动作价值为 。Q网络的估计会存在一定噪声 ,不妨假设是无偏估计 ,那么估计出的状态动作价值为。由于噪声的均值为0,因此满足
是一种典型的高估 ,即
我们在计算目标值时 ,会执行,由于高估 ,因此目标值
也会高估 。网络更新时我们会将逼近 ,由于高估 ,因此就会出现高估 。
总结起来就是 ,最大化操作会使得网络的估计值大于真实值 ,从而造成网络过估计。
双重网络是解决最大化问题的有效方法 。在TD3算法中 ,作者引入了两套相同网络架构的Critic网络 。计算目标值时 ,会利用二者间的较小值来估计下一个状态动作对的状态动作价值 ,即
从而可以有效避免最大化问题带来的高估 。这时可能会有小伙伴比较疑惑 ,取两个网络之间的较小值会不会不太稳妥?如果用多个网络,然后取它们中的最小值会不会更好呢?其实有实验证明采用两个网络就可以了 ,多个网络不会带来明显的性能提升 。
2.2 目标策略平滑正则化
确定性策略存在一个问题:它会过度拟合以缩小价值估计中的峰值 。当更新Critic网络时 ,使用确定性策略的学习目标极易受到函数逼近误差的影响,从而导致目标估计的方差大 ,估计值不准确 。这种诱导方差可以通过正则化来减少 ,因此作者模仿SARSA的学习更新,引入了一种深度价值学习的正则化策略——目标策略平滑 。
这种方法主要强调:类似的行动应该具有类似的价值 。虽然函数近似隐式地实现了这一点 ,但可以通过修改训练过程显示地强调类似动作之间的关系 。具体的实现是利用目标动作周围的区域来计算目标值 ,从而有利于平滑估计值
在实际操作时 ,我们可以通过向目标动作中添加少量随机噪声 ,并在小批量中求平均值 ,来近似动作的期望。因此 ,上式可以修改为
其中 ,我们添加的噪声是服从正态分布的 ,并且对采样的噪声做了裁剪 ,以保持目标接近原始动作 。直观的说,采用这种方法得出的策略往往更加安全 ,因为它们为抵抗干扰的动作提供了更高的价值 。说了这么多可能不是特别容易理解 ,不妨来看两张图。
假设上图为Critic网络估计的Q值曲面 。这里我们直接采用来估计,因此方差会很大 ,不利于网络训练 。
这次我们采用状态动作对的邻域来估计 ,从而可以极大地降低方差,提高目标值估计的准确性 ,保证网络训练过程的稳定。
2.3 延迟更新
这里的延迟更新指的是Actor网络的延迟更新 ,即Critic网络更新多次之后再对Actor网络进行更新 。这个想法其实是非常直观的 ,因为Actor网络是通过最大化累积期望回报来更新的 ,它需要利用Critic网络来进行评估 。如果Critic网络非常不稳定 ,那么Actor网络自然也会出现震荡 。
因此 ,我们可以让Critic网络的更新频率高于Actor网络 ,即等待Critic网络更加稳定之后再来帮助Actor网络更新 。
3 TD3算法更新过程
TD3算法的更新过程与DDPG算法的更新过程差别不大 ,主要区别在于目标值的计算方式(2.1.2节已经给出) 。其中Actor网络通过最大化累积期望回报来更新(确定性策略梯度) ,Critic1和Critic2网络都是通过最小化评估值与目标值之间的误差来更新(MSE),所有的目标网络都采用软更新的方式来更新(Exponential Moving Average, EMA) 。在训练阶段 ,我们从Replay Buffer中采样一个批次 (Batch size) 的数据 ,假设采样到的一条数据为,所有网络的更新过程如下 。
Critic1和Critic2网络更新过程:利用Target Actor网络计算出状态下的动作
然后基于目标策略平滑正则化 ,再目标动作上加入噪声
接着基于双重网络的思想 ,计算目标值
最后利用梯度下降算法最小化评估值和目标值之间的误差,从而对Critic1和Critic2网络中的参数进行更新
Actor网络更新过程:(在Ctitic1和Critic2网络更新 步之后 ,启动Actor网络更新) 利用Actor网络计算出状态下的动作
这里需要注意:计算出动作后不需要加入噪声 ,因为这里是希望Actor网络能够朝着最大值方向更新 ,加入噪声没有任何意义 。然后利用Critic1或者Critic2网络来计算状态动作对的评估值 ,这里我们假定使用Critic1网络
最后采用梯度上升算法最大化 ,从而完成对Actor网络的更新 。
注:这里我们之所以可以使用Critic1和Critic2两者中的任何一个来计算Q值 ,我觉得主要是因为Actor网络的目的就在于最大化累积期望回报 ,没有必要使用最小值。
目标网络的更新过程:采用软更新方式对目标网络进行更新 。引入一个学习率(或者成为动量) ,将旧的目标网络参数和新的对应网络参数做加权平均 ,然后赋值给目标网络
学习率(动量),通常取值0.005 。
4 TD3算法伪代码
5 PyTorch代码实现
Replay Buffer的代码实现(buffer.py):
import numpy as np class ReplayBuffer: def __init__(self, max_size, state_dim, action_dim, batch_size): self.mem_size = max_size self.batch_size = batch_size self.mem_cnt = 0 self.state_memory = np.zeros((max_size, state_dim)) self.action_memory = np.zeros((max_size, action_dim)) self.reward_memory = np.zeros((max_size, )) self.next_state_memory = np.zeros((max_size, state_dim)) self.terminal_memory = np.zeros((max_size, ), dtype=np.bool) def store_transition(self, state, action, reward, state_, done): mem_idx = self.mem_cnt % self.mem_size self.state_memory[mem_idx] = state self.action_memory[mem_idx] = action self.reward_memory[mem_idx] = reward self.next_state_memory[mem_idx] = state_ self.terminal_memory[mem_idx] = done self.mem_cnt += 1 def sample_buffer(self): mem_len = min(self.mem_cnt, self.mem_size) batch = np.random.choice(mem_len, self.batch_size, replace=False) states = self.state_memory[batch] actions = self.action_memory[batch] rewards = self.reward_memory[batch] states_ = self.next_state_memory[batch] terminals = self.terminal_memory[batch] return states, actions, rewards, states_, terminals def ready(self): return self.mem_cnt >= self.batch_sizeActor和Critic网络的代码实现(networks.py):
import torch as T import torch.nn as nn import torch.optim as optim device = T.device("cuda:0" if T.cuda.is_available() else "cpu") class ActorNetwork(nn.Module): def __init__(self, alpha, state_dim, action_dim, fc1_dim, fc2_dim): super(ActorNetwork, self).__init__() self.fc1 = nn.Linear(state_dim, fc1_dim) self.ln1 = nn.LayerNorm(fc1_dim) self.fc2 = nn.Linear(fc1_dim, fc2_dim) self.ln2 = nn.LayerNorm(fc2_dim) self.action = nn.Linear(fc2_dim, action_dim) self.optimizer = optim.Adam(self.parameters(), lr=alpha) self.to(device) def forward(self, state): x = T.relu(self.ln1(self.fc1(state))) x = T.relu(self.ln2(self.fc2(x))) action = T.tanh(self.action(x)) return action def save_checkpoint(self, checkpoint_file): T.save(self.state_dict(), checkpoint_file, _use_new_zipfile_serialization=False) def load_checkpoint(self, checkpoint_file): self.load_state_dict(T.load(checkpoint_file)) class CriticNetwork(nn.Module): def __init__(self, beta, state_dim, action_dim, fc1_dim, fc2_dim): super(CriticNetwork, self).__init__() self.fc1 = nn.Linear(state_dim+action_dim, fc1_dim) self.ln1 = nn.LayerNorm(fc1_dim) self.fc2 = nn.Linear(fc1_dim, fc2_dim) self.ln2 = nn.LayerNorm(fc2_dim) self.q = nn.Linear(fc2_dim, 1) self.optimizer = optim.Adam(self.parameters(), lr=beta) self.to(device) def forward(self, state, action): x = T.cat([state, action], dim=-1) x = T.relu(self.ln1(self.fc1(x))) x = T.relu(self.ln2(self.fc2(x))) q = self.q(x) return q def save_checkpoint(self, checkpoint_file): T.save(self.state_dict(), checkpoint_file, _use_new_zipfile_serialization=False) def load_checkpoint(self, checkpoint_file): self.load_state_dict(T.load(checkpoint_file))TD3算法的代码实现(TD3.py):
import torch as T import torch.nn.functional as F import numpy as np from networks import ActorNetwork, CriticNetwork from buffer import ReplayBuffer device = T.device("cuda:0" if T.cuda.is_available() else "cpu") class TD3: def __init__(self, alpha, beta, state_dim, action_dim, actor_fc1_dim, actor_fc2_dim, critic_fc1_dim, critic_fc2_dim, ckpt_dir, gamma=0.99, tau=0.005, action_noise=0.1, policy_noise=0.2, policy_noise_clip=0.5, delay_time=2, max_size=1000000, batch_size=256): self.gamma = gamma self.tau = tau self.action_noise = action_noise self.policy_noise = policy_noise self.policy_noise_clip = policy_noise_clip self.delay_time = delay_time self.update_time = 0 self.checkpoint_dir = ckpt_dir self.actor = ActorNetwork(alpha=alpha, state_dim=state_dim, action_dim=action_dim, fc1_dim=actor_fc1_dim, fc2_dim=actor_fc2_dim) self.critic1 = CriticNetwork(beta=beta, state_dim=state_dim, action_dim=action_dim, fc1_dim=critic_fc1_dim, fc2_dim=critic_fc2_dim) self.critic2 = CriticNetwork(beta=beta, state_dim=state_dim, action_dim=action_dim, fc1_dim=critic_fc1_dim, fc2_dim=critic_fc2_dim) self.target_actor = ActorNetwork(alpha=alpha, state_dim=state_dim, action_dim=action_dim, fc1_dim=actor_fc1_dim, fc2_dim=actor_fc2_dim) self.target_critic1 = CriticNetwork(beta=beta, state_dim=state_dim, action_dim=action_dim, fc1_dim=critic_fc1_dim, fc2_dim=critic_fc2_dim) self.target_critic2 = CriticNetwork(beta=beta, state_dim=state_dim, action_dim=action_dim, fc1_dim=critic_fc1_dim, fc2_dim=critic_fc2_dim) self.memory = ReplayBuffer(max_size=max_size, state_dim=state_dim, action_dim=action_dim, batch_size=batch_size) self.update_network_parameters(tau=1.0) def update_network_parameters(self, tau=None): if tau is None: tau = self.tau for actor_params, target_actor_params in zip(self.actor.parameters(), self.target_actor.parameters()): target_actor_params.data.copy_(tau * actor_params + (1 - tau) * target_actor_params) for critic1_params, target_critic1_params in zip(self.critic1.parameters(), self.target_critic1.parameters()): target_critic1_params.data.copy_(tau * critic1_params + (1 - tau) * target_critic1_params) for critic2_params, target_critic2_params in zip(self.critic2.parameters(), self.target_critic2.parameters()): target_critic2_params.data.copy_(tau * critic2_params + (1 - tau) * target_critic2_params) def remember(self, state, action, reward, state_, done): self.memory.store_transition(state, action, reward, state_, done) def choose_action(self, observation, train=True): self.actor.eval() state = T.tensor([observation], dtype=T.float).to(device) action = self.actor.forward(state) if train: # exploration noise noise = T.tensor(np.random.normal(loc=0.0, scale=self.action_noise), dtype=T.float).to(device) action = T.clamp(action+noise, -1, 1) self.actor.train() return action.squeeze().detach().cpu().numpy() def learn(self): if not self.memory.ready(): return states, actions, rewards, states_, terminals = self.memory.sample_buffer() states_tensor = T.tensor(states, dtype=T.float).to(device) actions_tensor = T.tensor(actions, dtype=T.float).to(device) rewards_tensor = T.tensor(rewards, dtype=T.float).to(device) next_states_tensor = T.tensor(states_, dtype=T.float).to(device) terminals_tensor = T.tensor(terminals).to(device) with T.no_grad(): next_actions_tensor = self.target_actor.forward(next_states_tensor) action_noise = T.tensor(np.random.normal(loc=0.0, scale=self.policy_noise), dtype=T.float).to(device) # smooth noise action_noise = T.clamp(action_noise, -self.policy_noise_clip, self.policy_noise_clip) next_actions_tensor = T.clamp(next_actions_tensor+action_noise, -1, 1) q1_ = self.target_critic1.forward(next_states_tensor, next_actions_tensor).view(-1) q2_ = self.target_critic2.forward(next_states_tensor, next_actions_tensor).view(-1) q1_[terminals_tensor] = 0.0 q2_[terminals_tensor] = 0.0 critic_val = T.min(q1_, q2_) target = rewards_tensor + self.gamma * critic_val q1 = self.critic1.forward(states_tensor, actions_tensor).view(-1) q2 = self.critic2.forward(states_tensor, actions_tensor).view(-1) critic1_loss = F.mse_loss(q1, target.detach()) critic2_loss = F.mse_loss(q2, target.detach()) critic_loss = critic1_loss + critic2_loss self.critic1.optimizer.zero_grad() self.critic2.optimizer.zero_grad() critic_loss.backward() self.critic1.optimizer.step() self.critic2.optimizer.step() self.update_time += 1 if self.update_time % self.delay_time != 0: return new_actions_tensor = self.actor.forward(states_tensor) q1 = self.critic1.forward(states_tensor, new_actions_tensor) actor_loss = -T.mean(q1) self.actor.optimizer.zero_grad() actor_loss.backward() self.actor.optimizer.step() self.update_network_parameters() def save_models(self, episode): self.actor.save_checkpoint(self.checkpoint_dir + Actor/TD3_actor_{}.pth.format(episode)) print(Saving actor network successfully!) self.target_actor.save_checkpoint(self.checkpoint_dir + Target_actor/TD3_target_actor_{}.pth.format(episode)) print(Saving target_actor network successfully!) self.critic1.save_checkpoint(self.checkpoint_dir + Critic1/TD3_critic1_{}.pth.format(episode)) print(Saving critic1 network successfully!) self.target_critic1.save_checkpoint(self.checkpoint_dir + Target_critic1/TD3_target_critic1_{}.pth.format(episode)) print(Saving target critic1 network successfully!) self.critic2.save_checkpoint(self.checkpoint_dir + Critic2/TD3_critic2_{}.pth.format(episode)) print(Saving critic2 network successfully!) self.target_critic2.save_checkpoint(self.checkpoint_dir + Target_critic2/TD3_target_critic2_{}.pth.format(episode)) print(Saving target critic2 network successfully!) def load_models(self, episode): self.actor.load_checkpoint(self.checkpoint_dir + Actor/TD3_actor_{}.pth.format(episode)) print(Loading actor network successfully!) self.target_actor.load_checkpoint(self.checkpoint_dir + Target_actor/TD3_target_actor_{}.pth.format(episode)) print(Loading target_actor network successfully!) self.critic1.load_checkpoint(self.checkpoint_dir + Critic1/TD3_critic1_{}.pth.format(episode)) print(Loading critic1 network successfully!) self.target_critic1.load_checkpoint(self.checkpoint_dir + Target_critic1/TD3_target_critic1_{}.pth.format(episode)) print(Loading target critic1 network successfully!) self.critic2.load_checkpoint(self.checkpoint_dir + Critic2/TD3_critic2_{}.pth.format(episode)) print(Loading critic2 network successfully!) self.target_critic2.load_checkpoint(self.checkpoint_dir + Target_critic2/TD3_target_critic2_{}.pth.format(episode)) print(Loading target critic2 network successfully!)算法仿真环境是gym库中的LunarLanderContinuous-v2环境 ,因此需要先配置好gym库。进入Aanconda中对应的Python环境中 ,执行下面的指令
pip install gym但是,这样安装的gym库只包括少量的内置环境 ,如算法环境 、简单文字游戏环境和经典控制环境 ,无法使用LunarLanderContinuous-v2 。因此还要安装一些其他依赖项,具体可以参照这篇blog: AttributeError: module ‘gym.envs.box2d‘ has no attribute ‘LunarLander‘解决办法 。如果已经配置好环境 ,那请忽略这一段。
训练脚本(train.py):
import gym import numpy as np import argparse from TD3 import TD3 from utils import create_directory, plot_learning_curve, scale_action parser = argparse.ArgumentParser() parser.add_argument(--max_episodes, type=int, default=1000) parser.add_argument(--ckpt_dir, type=str, default=./checkpoints/TD3/) parser.add_argument(--figure_file, type=str, default=./output_images/reward.png) args = parser.parse_args() def main(): env = gym.make(LunarLanderContinuous-v2) agent = TD3(alpha=0.0003, beta=0.0003, state_dim=env.observation_space.shape[0], action_dim=env.action_space.shape[0], actor_fc1_dim=400, actor_fc2_dim=300, critic_fc1_dim=400, critic_fc2_dim=300, ckpt_dir=args.ckpt_dir, gamma=0.99, tau=0.005, action_noise=0.1, policy_noise=0.2, policy_noise_clip=0.5, delay_time=2, max_size=1000000, batch_size=256) create_directory(path=args.ckpt_dir, sub_path_list=[Actor, Critic1, Critic2, Target_actor, Target_critic1, Target_critic2]) total_reward_history = [] avg_reward_history = [] for episode in range(args.max_episodes): total_reward = 0 done = False observation = env.reset() while not done: action = agent.choose_action(observation, train=True) action_ = scale_action(action, low=env.action_space.low, high=env.action_space.high) observation_, reward, done, info = env.step(action_) agent.remember(observation, action, reward, observation_, done) agent.learn() total_reward += reward observation = observation_ total_reward_history.append(total_reward) avg_reward = np.mean(total_reward_history[-100:]) avg_reward_history.append(avg_reward) print(Ep: {} Reward: {} AvgReward: {}.format(episode+1, total_reward, avg_reward)) if (episode + 1) % 200 == 0: agent.save_models(episode+1) episodes = [i+1 for i in range(args.max_episodes)] plot_learning_curve(episodes, avg_reward_history, title=AvgReward, ylabel=reward, figure_file=args.figure_file) if __name__ == __main__: main()训练脚本中有三个参数 ,max_episodes表示训练幕数 ,checkpoint_dir表示训练权重保存路径 ,figure_file表示训练结果的保存路径(其实是一张累积奖励曲线图) ,按照默认设置即可 。
训练时还会用到画图函数和创建文件夹函数 ,它们都被放在utils.py脚本中:
import os import numpy as np import matplotlib.pyplot as plt def create_directory(path: str, sub_path_list: list): for sub_path in sub_path_list: if not os.path.exists(path + sub_path): os.makedirs(path + sub_path, exist_ok=True) print(Path: {} create successfully!.format(path + sub_path)) else: print(Path: {} is already existence!.format(path + sub_path)) def plot_learning_curve(episodes, records, title, ylabel, figure_file): plt.figure() plt.plot(episodes, records, color=b, linestyle=-) plt.title(title) plt.xlabel(episode) plt.ylabel(ylabel) plt.show() plt.savefig(figure_file) def scale_action(action, low, high): action = np.clip(action, -1, 1) weight = (high - low) / 2 bias = (high + low) / 2 action_ = action * weight + bias return action_另外我们还提供了测试代码 ,主要用于测试训练效果以及观察环境的动态渲染 (test.py):
import gym import imageio import argparse from TD3 import TD3 from utils import scale_action parser = argparse.ArgumentParser() parser.add_argument(--ckpt_dir, type=str, default=./checkpoints/TD3/) parser.add_argument(--figure_file, type=str, default=./output_images/LunarLander.gif) parser.add_argument(--fps, type=int, default=30) parser.add_argument(--render, type=bool, default=True) parser.add_argument(--save_video, type=bool, default=True) args = parser.parse_args() def main(): env = gym.make(LunarLanderContinuous-v2) agent = TD3(alpha=0.0003, beta=0.0003, state_dim=env.observation_space.shape[0], action_dim=env.action_space.shape[0], actor_fc1_dim=400, actor_fc2_dim=300, critic_fc1_dim=400, critic_fc2_dim=300, ckpt_dir=args.ckpt_dir, gamma=0.99, tau=0.005, action_noise=0.1, policy_noise=0.2, policy_noise_clip=0.5, delay_time=2, max_size=1000000, batch_size=256) agent.load_models(1000) video = imageio.get_writer(args.figure_file, fps=args.fps) done = False observation = env.reset() while not done: if args.render: env.render() action = agent.choose_action(observation, train=True) action_ = scale_action(action, low=env.action_space.low, high=env.action_space.high) observation_, reward, done, info = env.step(action_) observation = observation_ if args.save_video: video.append_data(env.render(mode=rgb_array)) if __name__ == __main__: main()测试脚本中包括五个参数 ,filename表示环境动态图的保存路径 ,checkpoint_dir表示加载的权重路径,save_video表示是否要保存动态图 ,fps表示动态图的帧率 ,rander表示是否开启环境渲染 。大家只需要调整save_video和rander这两个参数,其余保持默认即可 。
6 实验结果
通过平均奖励曲线可以看出 ,大概迭代到400步左右时算法趋于收敛 。相较于DDPG算法 ,TD3算法的性能有了明显提升 。
这是测试效果图,智能体能够很好地完成降落任务 ,整个过程非常平稳!
7 结论
本文主要讲解了TD3算法中的相关技术细节 ,并进行了代码实现 。TD3算法是一种有效解决连续控制问题的深度强化学习算法 ,也是我非常喜欢用的算法之一 ,希望各位小伙伴能够理解并掌握这个算法 。
以上如果有出现错误的地方 ,欢迎各位怒斥!
创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!