/src/TD3.py

https://github.com/huangwl18/modular-rl · Python · 118 lines · 86 code · 21 blank · 11 comment · 9 complexity · a7e8286fb6adf018380d55fd3f2e489c MD5 · raw file

  1. # Implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3)
  2. from __future__ import print_function
  3. import torch
  4. import torch.nn.functional as F
  5. from ModularActor import ActorGraphPolicy
  6. from ModularCritic import CriticGraphPolicy
  7. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  8. class TD3(object):
  9. def __init__(self, args):
  10. self.args = args
  11. self.actor = ActorGraphPolicy(args.limb_obs_size, 1,
  12. args.msg_dim, args.batch_size,
  13. args.max_action, args.max_children,
  14. args.disable_fold, args.td, args.bu).to(device)
  15. self.actor_target = ActorGraphPolicy(args.limb_obs_size, 1,
  16. args.msg_dim, args.batch_size,
  17. args.max_action, args.max_children,
  18. args.disable_fold, args.td, args.bu).to(device)
  19. self.critic = CriticGraphPolicy(args.limb_obs_size, 1,
  20. args.msg_dim, args.batch_size,
  21. args.max_children, args.disable_fold,
  22. args.td, args.bu).to(device)
  23. self.critic_target = CriticGraphPolicy(args.limb_obs_size, 1,
  24. args.msg_dim, args.batch_size,
  25. args.max_children, args.disable_fold,
  26. args.td, args.bu).to(device)
  27. self.actor_target.load_state_dict(self.actor.state_dict())
  28. self.critic_target.load_state_dict(self.critic.state_dict())
  29. self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=args.lr)
  30. self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=args.lr)
  31. def change_morphology(self, graph):
  32. self.actor.change_morphology(graph)
  33. self.actor_target.change_morphology(graph)
  34. self.critic.change_morphology(graph)
  35. self.critic_target.change_morphology(graph)
  36. def select_action(self, state):
  37. state = torch.FloatTensor(state.reshape(1, -1)).to(device)
  38. with torch.no_grad():
  39. action = self.actor(state, 'inference').cpu().numpy().flatten()
  40. return action
  41. def train_single(self, replay_buffer, iterations, batch_size=100, discount=0.99,
  42. tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2):
  43. for it in range(iterations):
  44. # sample replay buffer
  45. x, y, u, r, d = replay_buffer.sample(batch_size)
  46. state = torch.FloatTensor(x).to(device)
  47. next_state = torch.FloatTensor(y).to(device)
  48. action = torch.FloatTensor(u).to(device)
  49. reward = torch.FloatTensor(r).to(device)
  50. done = torch.FloatTensor(1 - d).to(device)
  51. # select action according to policy and add clipped noise
  52. with torch.no_grad():
  53. noise = torch.FloatTensor(u).data.normal_(0, policy_noise).to(device)
  54. noise = noise.clamp(-noise_clip, noise_clip)
  55. next_action = self.actor_target(next_state) + noise
  56. next_action = next_action.clamp(-self.args.max_action, self.args.max_action)
  57. # Qtarget = reward + discount * min_i(Qi(next_state, pi(next_state)))
  58. target_Q1, target_Q2 = self.critic_target(next_state, next_action)
  59. target_Q = torch.min(target_Q1, target_Q2)
  60. target_Q = reward + (done * discount * target_Q)
  61. # get current Q estimates
  62. current_Q1, current_Q2 = self.critic(state, action)
  63. # compute critic loss
  64. critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
  65. # optimize the critic
  66. self.critic_optimizer.zero_grad()
  67. critic_loss.backward()
  68. self.critic_optimizer.step()
  69. # delayed policy updates
  70. if it % policy_freq == 0:
  71. # compute actor loss
  72. actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
  73. # optimize the actor
  74. self.actor_optimizer.zero_grad()
  75. actor_loss.backward()
  76. self.actor_optimizer.step()
  77. # update the frozen target models
  78. for param, target_param in zip(self.critic.parameters(),
  79. self.critic_target.parameters()):
  80. target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
  81. for param, target_param in zip(self.actor.parameters(),
  82. self.actor_target.parameters()):
  83. target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
  84. def train(self, replay_buffer_list, iterations_list, batch_size=100, discount=0.99,
  85. tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2, graphs=None, envs_train_names=None):
  86. per_morph_iter = sum(iterations_list) // len(envs_train_names)
  87. for env_name in envs_train_names:
  88. replay_buffer = replay_buffer_list[env_name]
  89. self.change_morphology(graphs[env_name])
  90. self.train_single(replay_buffer, per_morph_iter, batch_size=batch_size, discount=discount,
  91. tau=tau, policy_noise=policy_noise, noise_clip=noise_clip, policy_freq=policy_freq)
  92. def save(self, fname):
  93. torch.save(self.actor.state_dict(), '%s_actor.pth' % fname)
  94. torch.save(self.critic.state_dict(), '%s_critic.pth' % fname)
  95. def load(self, fname):
  96. self.actor.load_state_dict(torch.load('%s_actor.pth' % fname))
  97. self.critic.load_state_dict(torch.load('%s_critic.pth' % fname))