網站首頁 編程語言 正文
0. 前言
本節中,我們使用策略梯度算法解決?CartPole
?問題。雖然在這個簡單問題中,使用隨機搜索策略和爬山算法就足夠了。但是,我們可以使用這個簡單問題來更專注的學習策略梯度算法,并在之后的學習中使用此算法解決更加復雜的問題。
1. 策略梯度算法
策略梯度算法通過記錄回合中的所有時間步并基于回合結束時與這些時間步相關聯的獎勵來更新權重訓練智能體。使智能體遍歷整個回合然后基于獲得的獎勵更新策略的技術稱為蒙特卡洛策略梯度。
在策略梯度算法中,模型權重在每個回合結束時沿梯度方向移動。關于梯度的計算,我們將在下一節中詳細解釋。此外,在每一時間步中,基于當前狀態和權重計算的概率得到策略,并從中采樣一個動作。與隨機搜索和爬山算法(通過采取確定性動作以獲得更高的得分)相反,它不再確定地采取動作。因此,策略從確定性轉變為隨機性。例如,如果向左的動作和向右的動作的概率為?[0.8,0.2]
,則表示有?80%
?的概率選擇向左的動作,但這并不意味著一定會選擇向左的動作。
2. 使用策略梯度算法解決CartPole問題
在本節中,我們將學習使用?PyTorch
?實現策略梯度算法了。 導入所需的庫,創建?CartPole
?環境實例,并計算狀態空間和動作空間的尺寸:
import gym import torch import matplotlib.pyplot as plt env = gym.make('CartPole-v0') n_state = env.observation_space.shape[0] print(n_state) n_action = env.action_space.n print(n_action)
定義?run_episode
?函數,在此函數中,根據給定輸入權重的情況下模擬一回合?CartPole
?游戲,并返回獎勵和計算出的梯度。在每個時間步中執行以下操作:
- 根據當前狀態和輸入權重計算兩個動作的概率?
probs
- 根據結果概率采樣一個動作?
action
- 以概率作為輸入計算?
softmax
?函數的導數?d_softmax
,由于只需要計算與選定動作相關的導數,因此:
\frac {\partial p_i} {\partial z_j} = p_i(1-p_j), i=j?zj??pi??=pi?(1?pj?),i=j
- 將所得的導數?
d_softmax
?除以概率?probs
,以得與策略相關的對數導數?d_log
- 根據鏈式法則計算權重的梯度?
grad
:
\frac {dy}{dx}=\frac{dy}{du}\cdot\frac{du}{dx}dxdy?=dudy??dxdu?
- 記錄得到的梯度?
grad
- 執行動作,累積獎勵并更新狀態
def run_episode(env, weight): state = env.reset() grads = [] total_reward = 0 is_done = False while not is_done: state = torch.from_numpy(state).float() # 根據當前狀態和輸入權重計算兩個動作的概率 probs z = torch.matmul(state, weight) probs = torch.nn.Softmax(dim=0)(z) # 根據結果概率采樣一個動作 action action = int(torch.bernoulli(probs[1]).item()) # 以概率作為輸入計算 softmax 函數的導數 d_softmax d_softmax = torch.diag(probs) - probs.view(-1, 1) * probs # 計算與策略相關的對數導數d_log d_log = d_softmax[action] / probs[action] # 計算權重的梯度grad grad = state.view(-1, 1) * d_log grads.append(grad) state, reward, is_done, _ = env.step(action) total_reward += reward if is_done: break return total_reward, grads
回合完成后,返回在此回合中獲得的總獎勵以及在各個時間步中計算的梯度信息,用于之后更新權重。
接下來,定義要運行的回合數,在每個回合中調用?run_episode
?函數,并初始化權重以及用于記錄每個回合總獎勵的變量:
n_episode = 1000 weight = torch.rand(n_state, n_action) total_rewards = []
在每個回合結束后,使用計算出的梯度來更新權重。對于回合中的每個時間步,權重都根據學習率、計算出的梯度和智能體在剩余時間步中的獲得的總獎勵進行更新。
我們知道在回合終止之前,每一時間步的獎勵都是?1
。因此,我們用于計算每個時間步策略梯度的未來獎勵是剩余的時間步數。在每個回合之后,我們使用隨機梯度上升方法將梯度乘以未來獎勵來更新權重。這樣,一個回合中經歷的時間步越長,權重的更新幅度就越大,這將增加獲得更大總獎勵的機會。我們設定學習率為?0.001
:
learning_rate = 0.001 for e in range(n_episode): total_reward, gradients = run_episode(env, weight) print('Episode {}: {}'.format(e + 1, total_reward)) for i, gradient in enumerate(gradients): weight += learning_rate * gradient * (total_reward - i) total_rewards.append(total_reward)
然后,我們計算通過策略梯度算法獲得的平均總獎勵:
print('Average total reward over {} episode: {}'.format(n_episode, sum(total_rewards)/n_episode))
我們可以繪制每個回合的總獎勵變化情況,如下所示:
plt.plot(total_rewards) plt.xlabel('Episode') plt.ylabel('Reward') plt.show()
在上圖中,我們可以看到獎勵會隨著訓練回合的增加呈現出上升趨勢,然后能夠在最大值處穩定。我們還可以看到,即使在收斂之后,獎勵也會振蕩,這是由于策略梯度算法是一種隨機策略算法。
最后,我們查看學習到策略在?1000
?個新回合中的性能表現,并計算平均獎勵:
n_episode_eval = 1000 total_rewards_eval = [] for e in range(n_episode_eval): total_reward, _ = run_episode(env, weight) print('Episode {}: {}'.format(e+1, total_reward)) total_rewards_eval.append(total_reward) print('Average total reward over {} episode: {}'.format(n_episode_eval, sum(total_rewards_eval)/n_episode_eval)) # Average total reward over 1000 episode: 200
進行測試后,可以看到回合的平均獎勵接近最大值?200
。可以多次測試訓練后的模型,得到的平均獎勵較為穩定。正如我們一開始所說的那樣,對于諸如?CartPole
?之類的簡單環境,策略梯度算法可能大材小用,但它為我們解決更加復雜的問題奠定了基礎。
原文鏈接:https://juejin.cn/post/7118954918198640654
相關推薦
- 2023-08-16 uniapp中v-model數據無法讀取問題 failed for prop “value“
- 2022-10-27 kotlin?協程上下文異常處理詳解_Android
- 2023-01-29 React基于路由的代碼分割技術詳解_React
- 2022-12-27 python利用logging模塊實現根據日志級別打印不同顏色日志的代碼案例_python
- 2022-06-25 Python制作簡易計算器功能_python
- 2022-06-18 Redis官方可視化工具RedisInsight的安裝使用詳細教程(功能強大)_Redis
- 2022-10-14 Ubuntu18.04使用Xorg創建虛擬屏幕
- 2022-05-25 Entity?Framework?Core對Web項目生成數據庫表_實用技巧
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支