RL Lab 2: Value Iteration¶
基于价值迭代的网格迷宫求解。
简介¶
动态规划适合于解决状态转移过程完全已知的最优决策问题。价值迭代算法直接根据贝尔曼方程更新状态-价值函数,根据状态-价值函数可以导出动作-价值函数,从而得到策略。
$$ Q^*(s, a) = R(s, a) + \gamma \sum_{s'} P(s'|s, a)V^*(s') $$
或
$$ Q^*(s, a) = \sum_{s'} P(s'|s, a)(\gamma V^*(s') + R(s, a, s')) $$
目标¶
- 在
GridWorldEnvironment
的基础上实现价值迭代算法。 - 通过可视化方法显示算法的收敛过程、状态价值函数和策略。
扩展¶
- 贴现因子$\gamma$对模型收敛过程存在什么影响?
In [1]:
Copied!
import torch
import random
from matplotlib import pyplot as plt
from rl_env import GridEnvironment
%config InlineBackend.figure_format = 'svg'
# seed = random.randint(0, 10000)
seed = 7080
print(f"Seed: {seed}")
torch.manual_seed(seed)
random.seed(seed)
import torch
import random
from matplotlib import pyplot as plt
from rl_env import GridEnvironment
%config InlineBackend.figure_format = 'svg'
# seed = random.randint(0, 10000)
seed = 7080
print(f"Seed: {seed}")
torch.manual_seed(seed)
random.seed(seed)
Seed: 7080
生成迷宫
In [2]:
Copied!
env = GridEnvironment(25, 25, wall_ratio=0.07)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
grid_to_show = 2 * env.grid.clone().cpu().numpy()
grid_to_show[*env.starting_state] = 1
grid_to_show[*env.ending_state] = 1
ax.imshow(grid_to_show, cmap='Paired', interpolation='nearest')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect(env.ncols / env.nrows)
plt.show()
env = GridEnvironment(25, 25, wall_ratio=0.07)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
grid_to_show = 2 * env.grid.clone().cpu().numpy()
grid_to_show[*env.starting_state] = 1
grid_to_show[*env.ending_state] = 1
ax.imshow(grid_to_show, cmap='Paired', interpolation='nearest')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect(env.ncols / env.nrows)
plt.show()
实现智能体。智能体需要维护状态-价值函数。
In [3]:
Copied!
class ValueIterationAgent():
def __init__(self, env, discount_factor=0.9, eps=1e-6):
self.env = env
self.discount_factor = discount_factor
self.eps = eps
self.state_values = torch.zeros(env.num_states, dtype=torch.float32, device=env.device)
@property
def action_values(self):
# Q^*(s, a) &= R(s, a) + \gamma \sum_{s'} P(s'|s, a)V^*(s')
return (
self.discount_factor * self.env.state_transition_np * self.state_values + \
self.env.step_reward_np
) \
.sum(axis=1) \
.to_dense() \
.reshape(self.env.num_states, self.env.num_actions)
@property
def policy(self):
return self.action_values.argmax(axis=1)
def step(self) -> float:
# V^*(s) &= \max_a Q^*(s, a)
new_state_values = self.action_values.max(axis=1).values
# Check for convergence
delta = torch.abs(new_state_values - self.state_values).max().item()
self.state_values = new_state_values
return delta
def __iter__(self):
return self
def __next__(self):
delta = self.step()
if delta < self.eps:
raise StopIteration
return delta
agent = ValueIterationAgent(env, discount_factor=0.98, eps=1e-6)
deltas = [*agent]
class ValueIterationAgent():
def __init__(self, env, discount_factor=0.9, eps=1e-6):
self.env = env
self.discount_factor = discount_factor
self.eps = eps
self.state_values = torch.zeros(env.num_states, dtype=torch.float32, device=env.device)
@property
def action_values(self):
# Q^*(s, a) &= R(s, a) + \gamma \sum_{s'} P(s'|s, a)V^*(s')
return (
self.discount_factor * self.env.state_transition_np * self.state_values + \
self.env.step_reward_np
) \
.sum(axis=1) \
.to_dense() \
.reshape(self.env.num_states, self.env.num_actions)
@property
def policy(self):
return self.action_values.argmax(axis=1)
def step(self) -> float:
# V^*(s) &= \max_a Q^*(s, a)
new_state_values = self.action_values.max(axis=1).values
# Check for convergence
delta = torch.abs(new_state_values - self.state_values).max().item()
self.state_values = new_state_values
return delta
def __iter__(self):
return self
def __next__(self):
delta = self.step()
if delta < self.eps:
raise StopIteration
return delta
agent = ValueIterationAgent(env, discount_factor=0.98, eps=1e-6)
deltas = [*agent]
显示状态-价值函数的更新过程。
In [4]:
Copied!
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(deltas)
ax.set_xlabel('Iteration')
ax.set_ylabel('Max delta')
ax.set_title('Value Iteration Convergence')
ax.set_yscale('log')
ax.grid()
plt.show()
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(deltas)
ax.set_xlabel('Iteration')
ax.set_ylabel('Max delta')
ax.set_title('Value Iteration Convergence')
ax.set_yscale('log')
ax.grid()
plt.show()
显示状态值和策略
In [5]:
Copied!
fig, (ax_path, ax_state, ax_cmap) = plt.subplots(1, 3, figsize=(10, 4), width_ratios=[1, 1, 0.1])
ax_path.set_title('Main Path')
state = env.starting_state
path = [state]
while state != env.ending_state:
action = agent.policy[env.get_state(state)].item()
next_state = state + env.get_action(action)
path.append(next_state)
state = next_state
path_grid = grid_to_show.copy()
for index, (i, j) in enumerate(path):
path_grid[i, j] = 1
ax_path.imshow(path_grid, cmap='Paired')
ax_state.set_title('State Values')
state_values_to_show = torch.full((env.nrows, env.ncols), torch.nan)
for i, value in enumerate(agent.state_values):
state = env.get_state(i)
state_values_to_show[*state] = value.item()
ax_state.imshow(state_values_to_show.cpu().numpy(), cmap='viridis', interpolation='nearest')
fig.colorbar(ax_state.images[0], cax=ax_cmap, orientation='vertical')
for ax in (ax_path, ax_state):
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect(env.ncols / env.nrows)
plt.show()
fig, (ax_path, ax_state, ax_cmap) = plt.subplots(1, 3, figsize=(10, 4), width_ratios=[1, 1, 0.1])
ax_path.set_title('Main Path')
state = env.starting_state
path = [state]
while state != env.ending_state:
action = agent.policy[env.get_state(state)].item()
next_state = state + env.get_action(action)
path.append(next_state)
state = next_state
path_grid = grid_to_show.copy()
for index, (i, j) in enumerate(path):
path_grid[i, j] = 1
ax_path.imshow(path_grid, cmap='Paired')
ax_state.set_title('State Values')
state_values_to_show = torch.full((env.nrows, env.ncols), torch.nan)
for i, value in enumerate(agent.state_values):
state = env.get_state(i)
state_values_to_show[*state] = value.item()
ax_state.imshow(state_values_to_show.cpu().numpy(), cmap='viridis', interpolation='nearest')
fig.colorbar(ax_state.images[0], cax=ax_cmap, orientation='vertical')
for ax in (ax_path, ax_state):
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect(env.ncols / env.nrows)
plt.show()