1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
| class Qnet(torch.nn.Module): ''' 只有一层隐藏层的Q网络 ''' def __init__(self, state_dim, hidden_dim, action_dim): super(Qnet, self).__init__() self.fc1 = torch.nn.Linear(state_dim, hidden_dim) self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
def forward(self, x): x = F.relu(self.fc1(x)) return self.fc2(x) class DQN: ''' DQN算法 ''' def __init__(self,learning_rate=2e-3, gamma=0.98,epsilon=0.01,target_update_step=10,ReplayBuffer_size=10000,device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")): self.train_env = gym.make('CartPole-v1',new_step_api=True) self.eval_env = gym.make('CartPole-v1',new_step_api=True) self.env_seed = 0 self.replay_buffer = ReplayBuffer(ReplayBuffer_size) self.state_dim = self.train_env.observation_space.shape[0] self.action_dim = self.train_env.action_space.n self.q_net = Qnet(self.state_dim, 128,self.action_dim).to(device) self.target_q_net = Qnet(self.state_dim, 128,self.action_dim).to(device) self.target_update_step = target_update_step self.optimizer = torch.optim.Adam(self.q_net.parameters(),lr=learning_rate) self.gamma = gamma self.epsilon = epsilon self.device = device def set_env_seed(self,seed): self.env_seed = seed
def take_action(self, state): if np.random.random() < self.epsilon: action = np.random.randint(self.action_dim) else: state = torch.tensor([state], dtype=torch.float).to(self.device) action = self.q_net(state).argmax().item() return action def evaluate_model(self,num_episodes=10): total_reward = 0 for i_episode in range(num_episodes): state = self.eval_env.reset(seed=self.env_seed) done = False while not done: state_tensor = torch.tensor(np.array([state]), dtype=torch.float).to(self.device) q_values = self.q_net(state_tensor) action = np.argmax(q_values.detach().cpu().numpy()) next_state, reward, done, _,_ = self.eval_env.step(action) state = next_state total_reward += reward return total_reward/num_episodes def update_q_net(self, transition_dict): states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device) actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device) rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device) next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device) dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)
q_values = self.q_net(states).gather(1, actions) max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1) q_targets = rewards + self.gamma * max_next_q_values * (1 - dones) dqn_loss = torch.mean(F.mse_loss(q_values, q_targets)) self.optimizer.zero_grad() dqn_loss.backward() self.optimizer.step() def run(self,num_episodes=500,batch_size=64): max_return = 0 q_net_update_times = 0 episode_return_list = [] pbar = tqdm(total=num_episodes, desc="Training") for i_episode in range(num_episodes): episode_return = 0 state = self.train_env.reset(seed=self.env_seed) done = False while not done: action = self.take_action(state) next_state, reward, done, _,_ = self.train_env.step(action) episode_return += reward self.replay_buffer.add(state, action, reward, next_state, done) state = next_state if self.replay_buffer.size() > 500: b_s, b_a, b_r, b_ns, b_d = self.replay_buffer.sample(batch_size) transition_dict = {'states': b_s,'actions': b_a,'next_states': b_ns,'rewards': b_r,'dones': b_d} self.update_q_net(transition_dict) if q_net_update_times % self.target_update_step == 0: self.target_q_net.load_state_dict(self.q_net.state_dict()) q_net_update_times += 1 episode_return_list.append(episode_return) if (i_episode+1) % 10 == 0: pbar.set_postfix({"Episode": i_episode + 1, "Return": f"{np.mean(episode_return_list[-10:]):.3f}"}) current_return = self.evaluate_model(10) if current_return > max_return: max_return = current_return torch.save(self.q_net.state_dict(), 'q_net.pth') pbar.update(1) def show_result(self): env_name = 'CartPole-v1' env = gym.make(env_name,new_step_api=True) self.q_net.load_state_dict(torch.load('q_net.pth')) self.q_net.eval() frames = [] total_reward = 0 for i_episode in range(10): state = env.reset(seed=self.env_seed) done = False while not done: state_tensor = torch.tensor(np.array([state]), dtype=torch.float).to(self.device) q_values = self.q_net(state_tensor) action = np.argmax(q_values.detach().cpu().numpy()) next_state, reward, done, _,_ = env.step(action) state = next_state total_reward += reward frame = env.render(mode='rgb_array') frames.append(frame) print(f"Avarage reward: {total_reward/10:.3f}") imageio.mimsave('cartpole.gif', frames, duration=0.1) if __name__ == '__main__': random.seed(0) np.random.seed(0) torch.manual_seed(0) agent = DQN(learning_rate=2e-3,gamma=0.98,epsilon=0.01,target_update_step=10,ReplayBuffer_size=10000) agent.set_env_seed(0) agent.run(num_episodes=500,batch_size=64) agent.show_result()
|