Reinforcement Learning with Ignite
In this tutorial we will implement a policy gradient based algorithm called Reinforce and use it to solve OpenAI’s Cartpole problem using PyTorch-Ignite.
Prerequisite
The reader should be familiar with the basic concepts of Reinforcement Learning like state, action, environment, etc.
The Cartpole Problem
We have to balance a Cartpole which is a pole-like structure attached to a cart. The cart is free to move across the frictionless surface. We can balance the cartpole by moving the cart left or right in 1D. Let’s start by defining a few terms.
State
There are 4 variables on which the environment depends: cart position and velocity, pole position and velocity.
Action space
There are 2 possible actions that the agent can perform: left or right direction.
Reward
For each instance of the cartpole not toppling down or going out of range, we have a reward of 1.
When is it solved?
The problem is considered solved when the average reward is greater than reward_threshold
defined for the environment.
Required Dependencies
!pip install gymnasium pytorch-ignite
On Colab
We need additional dependencies to render the environment on Google Colab.
!apt-get install -y xvfb python-opengl
!pip install pyvirtualdisplay
!pip install --upgrade pygame moviepy
Imports
from collections import deque
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from ignite.engine import Engine, Events
import gymnasium as gym
from gymnasium.wrappers import RecordVideo
import glob
import io
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display
Configurable Parameters
We will use these values later in the tutorial at appropriate places.
seed_val = 543
gamma = 0.99
log_interval = 100
max_episodes = 1000000
render = True
Setting up the Environment
Let’s load our environment first.
env = gym.make("CartPole-v0", render_mode="rgb_array")
On Colab
If on Google Colab, we need to follow a list of steps to render the output. First we initialize our screen size.
display = Display(visible=0, size=(1400, 900))
display.start()
<pyvirtualdisplay.display.Display at 0x7f76f00bf810>
Below we have a utility function to enable video recording of the gym environment. To enable video, we have to wrap our environment in this function.
def wrap_env(env):
env = RecordVideo(env, './video', disable_logger=True)
return env
env = wrap_env(env)
Model
We are going to utilize the reinforce algorithm in which our agent will use episode samples from starting state to goal state directly from the environment. Our model has two linear layers with 4 in features and 2 out features for 4 state variables and 2 actions respectively. We also define an action buffer as saved_log_probs
and rewards
. We also have an intermediate ReLU layer through which the outputs of the 1st layer are passed to receive the score for each action taken. Finally, we return a list of probabilities for each of these actions.
class Policy(nn.Module):
def __init__(self):
super(Policy, self).__init__()
self.affine1 = nn.Linear(4, 128)
self.dropout = nn.Dropout(p=0.6)
self.affine2 = nn.Linear(128, 2)
self.saved_log_probs = []
self.rewards = []
def forward(self, x):
x = self.affine1(x)
x = self.dropout(x)
x = F.relu(x)
action_scores = self.affine2(x)
return F.softmax(action_scores, dim=1)
And then we initialize our model, optimizer, epsilon and timesteps.
TimeStep is the object which contains information about a state like current observation, type of the step, reward, and discount. Given that some action is performed on some state, it gives the new state, type of the new step (or state), discount, and reward achieved.
policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
eps = np.finfo(np.float32).eps.item()
timesteps = range(10000)
Create Trainer
Ignite’s
Engine
allows users to define a process_function
to run one episode. We select an action from the policy, then take the action through step()
and finally increment our reward. If the problem is solved, we terminate training and save the timestep
.
An episode is an instance of a game (or life of a game). If the game ends or life decreases, the episode ends. Step, on the other hand, is the time or some discrete value which increases monotonically in an episode. With each change in the state of the game, the value of step increases until the game ends.
def run_single_timestep(engine, timestep):
observation = engine.state.observation
action = select_action(policy, observation)
engine.state.observation, reward, done, _, _ = env.step(action)
if render:
env.render()
policy.rewards.append(reward)
engine.state.ep_reward += reward
if done:
engine.terminate_epoch()
engine.state.timestep = timestep
trainer = Engine(run_single_timestep)
Next we need to select an action to take. After we get a list of probabilities, we create a categorical distribution over them and sample an action from that. This is then saved to the action buffer and the action to take is returned (left or right).
def select_action(policy, observation):
state = torch.from_numpy(observation).float().unsqueeze(0)
probs = policy(state)
m = Categorical(probs)
action = m.sample()
policy.saved_log_probs.append(m.log_prob(action))
return action.item()
We initialize a list to save policy loss and true returns of the rewards returned from the environment. Then we calculate the policy losses from the advantage (-log_prob * reward
). Finally, we reset the gradients, perform backprop on the policy loss and reset the rewards and actions buffer.
def finish_episode(policy, optimizer, gamma):
R = 0
policy_loss = []
returns = deque()
for r in policy.rewards[::-1]:
R = r + gamma * R
returns.appendleft(R)
returns = torch.tensor(returns)
returns = (returns - returns.mean()) / (returns.std() + eps)
for log_prob, R in zip(policy.saved_log_probs, returns):
policy_loss.append(-log_prob * R)
optimizer.zero_grad()
policy_loss = torch.cat(policy_loss).sum()
policy_loss.backward()
optimizer.step()
del policy.rewards[:]
del policy.saved_log_probs[:]
Attach handlers to run on specific events
We rename the start and end epoch events for easy understanding.
EPISODE_STARTED = Events.EPOCH_STARTED
EPISODE_COMPLETED = Events.EPOCH_COMPLETED
Before training begins, we initialize the reward in trainer
’s state.
trainer.state.running_reward = 10
When an episode begins, we have to reset the environment’s state.
@trainer.on(EPISODE_STARTED)
def reset_environment_state():
torch.manual_seed(seed_val + trainer.state.epoch)
trainer.state.observation, _ = env.reset(seed=seed_val + trainer.state.epoch)
trainer.state.ep_reward = 0
When an episode finishes, we update the running reward and perform backpropagation by calling finish_episode()
.
@trainer.on(EPISODE_COMPLETED)
def update_model():
trainer.state.running_reward = 0.05 * trainer.state.ep_reward + (1 - 0.05) * trainer.state.running_reward
finish_episode(policy, optimizer, gamma)
After that, every 100 (log_interval
) episodes, we log the results.
@trainer.on(EPISODE_COMPLETED(every=log_interval))
def log_episode():
i_episode = trainer.state.epoch
print(
f"Episode {i_episode}\tLast reward: {trainer.state.ep_reward:.2f}"
f"\tAverage length: {trainer.state.running_reward:.2f}"
)
And finally, we check if our running reward has crossed the threshold so that we can stop training.
@trainer.on(EPISODE_COMPLETED)
def should_finish_training():
running_reward = trainer.state.running_reward
if running_reward > env.spec.reward_threshold:
print(
f"Solved! Running reward is now {running_reward} and "
f"the last episode runs to {trainer.state.timestep} time steps!"
)
trainer.should_terminate = True
Run Trainer
trainer.run(timesteps, max_epochs=max_episodes)
Episode 100 Last length: 66 Average length: 37.90
Episode 200 Last length: 21 Average length: 115.82
Episode 300 Last length: 199 Average length: 133.13
Episode 400 Last length: 98 Average length: 134.97
Episode 500 Last length: 77 Average length: 77.39
Episode 600 Last length: 199 Average length: 132.99
Episode 700 Last length: 122 Average length: 137.40
Episode 800 Last length: 39 Average length: 159.51
Episode 900 Last length: 86 Average length: 113.31
Episode 1000 Last length: 76 Average length: 114.67
Episode 1100 Last length: 96 Average length: 98.65
Episode 1200 Last length: 90 Average length: 84.50
Episode 1300 Last length: 102 Average length: 89.10
Episode 1400 Last length: 64 Average length: 86.45
Episode 1500 Last length: 60 Average length: 76.35
Episode 1600 Last length: 75 Average length: 71.38
Episode 1700 Last length: 176 Average length: 117.25
Episode 1800 Last length: 139 Average length: 140.96
Episode 1900 Last length: 63 Average length: 141.79
Episode 2000 Last length: 66 Average length: 94.01
Episode 2100 Last length: 199 Average length: 115.46
Episode 2200 Last length: 113 Average length: 137.11
Episode 2300 Last length: 174 Average length: 135.36
Episode 2400 Last length: 80 Average length: 116.46
Episode 2500 Last length: 96 Average length: 101.47
Episode 2600 Last length: 199 Average length: 141.13
Episode 2700 Last length: 13 Average length: 134.91
Episode 2800 Last length: 90 Average length: 71.22
Episode 2900 Last length: 61 Average length: 70.14
Episode 3000 Last length: 199 Average length: 129.67
Episode 3100 Last length: 199 Average length: 173.62
Episode 3200 Last length: 199 Average length: 189.30
Solved! Running reward is now 195.03268327777783 and the last episode runs to 199 time steps!
State:
iteration: 396569
epoch: 3289
epoch_length: 10000
max_epochs: 1000000
output: <class 'NoneType'>
batch: 199
metrics: <class 'dict'>
dataloader: <class 'list'>
seed: <class 'NoneType'>
times: <class 'dict'>
running_reward: 195.03268327777783
observation: <class 'numpy.ndarray'>
timestep: 199
env.close()
On Colab
Finally, we can view our saved video.
mp4list = glob.glob('video/*.mp4')
if len(mp4list) > 0:
mp4 = mp4list[-1] # pick the last video
video = io.open(mp4, 'r+b').read()
encoded = base64.b64encode(video)
ipythondisplay.display(HTML(data='''<video alt="test" autoplay
loop controls style="height: 400px;">
<source src="data:video/mp4;base64,{0}" type="video/mp4" />
</video>'''.format(encoded.decode('ascii'))))
else:
print("Could not find video")
That’s it! We have successfully solved the Cartpole problem!