diff --git a/squiRL/common/agents.py b/squiRL/common/agents.py index 66e815e..e891dd2 100644 --- a/squiRL/common/agents.py +++ b/squiRL/common/agents.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from typing import Tuple +from typing import Tuple, List from squiRL.common.data_stream import Experience @@ -17,25 +17,34 @@ class Agent: env: training environment Attributes: - env (gym.Env): OpenAI gym training environment + env (List[gym.Env]): List of OpenAI gym training environment obs (int): Array of env observation state replay_buffer (TYPE): Data collector for saving experience """ - def __init__(self, env: gym.Env, replay_buffer) -> None: + def __init__(self, env: List[gym.Env], replay_buffer) -> None: """Initializes agent class Args: - env (gym.Env): OpenAI gym training environment + env (List[gym.Env]): List of OpenAI gym training environment replay_buffer (TYPE): Data collector for saving experience """ - self.env = env + self.envs = env + self.obs = [None] * len(self.envs) self.replay_buffer = replay_buffer - self.reset() + + for i in range(len(self.envs)): + self.env_idx = i + self.reset() + + self.env_idx = 0 def reset(self) -> None: """Resets the environment and updates the obs """ - self.obs = self.env.reset() + self.obs[self.env_idx] = self.envs[self.env_idx].reset() + + def next_env(self) -> None: + self.env_idx = (self.env_idx + 1) % len(self.envs) def process_obs(self, obs: int) -> torch.Tensor: """Converts obs np.array to torch.Tensor for passing through NN @@ -61,7 +70,9 @@ def get_action( Returns: action (int): Action to be carried out """ - obs = self.process_obs(self.obs) + obs = self.obs[self.env_idx] + assert obs is not None + obs = self.process_obs(obs) action_logit = net(obs) probs = F.softmax(action_logit, dim=-1) @@ -89,11 +100,11 @@ def play_step( action = self.get_action(net) # do step in the environment - new_obs, reward, done, _ = self.env.step(action) - exp = Experience(self.obs, action, reward, done, new_obs) + new_obs, reward, done, _ = self.envs[self.env_idx].step(action) + exp = Experience(self.obs[self.env_idx], action, reward, done, new_obs) self.replay_buffer.append(exp) - self.obs = new_obs + self.obs[self.env_idx] = new_obs if done: self.reset() return reward, done diff --git a/squiRL/common/data_stream.py b/squiRL/common/data_stream.py index 8917e6a..0192af2 100644 --- a/squiRL/common/data_stream.py +++ b/squiRL/common/data_stream.py @@ -3,14 +3,16 @@ Attributes: Experience (namedtuple): An environment step experience """ -import numpy as np -from torch.utils.data.dataset import IterableDataset from collections import deque from collections import namedtuple -from squiRL.common.policies import MLP -import gym from typing import Tuple +from torch.utils.data.dataset import IterableDataset +import torch.multiprocessing as mp +import torch + +from squiRL.common.policies import MLP + Experience = namedtuple('Experience', ('state', 'action', 'reward', 'done', 'last_state')) @@ -27,14 +29,33 @@ class RolloutCollector: capacity (int): Size of the buffer replay_buffer (deque): Experience buffer """ - def __init__(self, capacity: int) -> None: + def __init__(self, capacity: int, state_shape: tuple, action_shape: tuple, should_share: bool = False) -> None: """Summary Args: capacity (int): Description """ + + state_shape = [capacity] + list(state_shape) + action_shape = [capacity] + list(action_shape) + self.capacity = capacity - self.replay_buffer = deque(maxlen=self.capacity) + self.count = torch.tensor([0], dtype=torch.int64) + self.states = torch.zeros(state_shape, dtype=torch.float32) + self.actions = torch.zeros(action_shape, dtype=torch.float32) + self.rewards = torch.zeros((capacity), dtype=torch.float32) + self.dones = torch.zeros((capacity), dtype=torch.bool) + self.next_states = torch.zeros(state_shape, dtype=torch.float32) + + if should_share: + self.count.share_memory_() + self.states.share_memory_() + self.actions.share_memory_() + self.next_states.share_memory_() + self.rewards.share_memory_() + self.dones.share_memory_() + + self.lock = mp.Lock() def __len__(self) -> int: """Calculates length of buffer @@ -42,7 +63,7 @@ def __len__(self) -> int: Returns: int: Length of buffer """ - return len(self.replay_buffer) + return self.count.detach().numpy().item() def append(self, experience: Experience) -> None: """ @@ -50,9 +71,25 @@ def append(self, experience: Experience) -> None: Args: experience (Experience): Tuple (state, action, reward, done, - new_state) + last_state) """ - self.replay_buffer.append(experience) + + with self.lock: + if self.count[0] < self.capacity: + self.count[0] += 1 + + # count keeps the exact length, but indexing starts from 0 so we decrease by 1 + nr = self.count[0] - 1 + + self.states[nr] = torch.tensor(experience.state, dtype=torch.float32) + self.actions[nr] = torch.tensor(experience.action, dtype=torch.float32) + self.rewards[nr] = torch.tensor(experience.reward, dtype=torch.float32) + self.dones[nr] = torch.tensor(experience.done, dtype=torch.bool) + self.next_states[nr] = torch.tensor(experience.last_state, dtype=torch.float32) + + else: + exit("RolloutCollector: Buffer is full but samples are being added to it") + def sample(self) -> Tuple: """Sample experience from buffer @@ -60,17 +97,17 @@ def sample(self) -> Tuple: Returns: Tuple: Sampled experience """ - states, actions, rewards, dones, next_states = zip( - *[self.replay_buffer[i] for i in range(len(self.replay_buffer))]) - return (np.array(states), np.array(actions), - np.array(rewards, dtype=np.float32), - np.array(dones, dtype=np.bool), np.array(next_states)) + # count keeps the exact length, but indexing starts from 0 so we decrease by 1 + nr = self.count[0] - 1 + return (self.states[:nr], self.actions[:nr], self.rewards[:nr], self.dones[:nr], self.next_states[:nr]) def empty_buffer(self) -> None: - """Empty replay buffer + """Empty replay buffer by resetting the count (so old data gets overwritten) """ - self.replay_buffer.clear() + with self.lock: + # the [0] is very important, otherwise we throw the tensor out and the int that replaces it won't get shared + self.count[0] = 0 class RLDataset(IterableDataset): @@ -84,24 +121,23 @@ class RLDataset(IterableDataset): Attributes: agent (Agent): Agent that interacts with env - env (gym.Env): OpenAI gym environment net (nn.Module): Policy network replay_buffer: Replay buffer """ - def __init__(self, replay_buffer: RolloutCollector, env: gym.Env, net: MLP, - agent) -> None: + + def __init__(self, replay_buffer: RolloutCollector, net: MLP, + agent, episodes_per_batch: int = 1) -> None: """Summary Args: replay_buffer (RolloutCollector): Description - env (gym.Env): OpenAI gym environment net (nn.Module): Policy network agent (Agent): Agent that interacts with env """ self.replay_buffer = replay_buffer - self.env = env self.net = net self.agent = agent + self.episodes_per_batch = episodes_per_batch def populate(self) -> None: """ @@ -119,8 +155,10 @@ def __iter__(self): Yields: Tuple: Sampled experience """ - for i in range(1): - self.populate() - states, actions, rewards, dones, new_states = self.replay_buffer.sample( - ) - yield (states, actions, rewards, dones, new_states) + for i in range(self.episodes_per_batch): + for j in range(len(self.agent.envs)): + self.agent.env_idx = j + self.populate() + states, actions, rewards, dones, new_states = self.replay_buffer.sample( + ) + yield (states, actions, rewards, dones, new_states) diff --git a/squiRL/vpg/vpg.py b/squiRL/vpg/vpg.py index 91bd846..30701d7 100644 --- a/squiRL/vpg/vpg.py +++ b/squiRL/vpg/vpg.py @@ -2,8 +2,10 @@ """ import argparse from argparse import ArgumentParser +from collections import OrderedDict from copy import copy from typing import Tuple, List + import gym import numpy as np import pytorch_lightning as pl @@ -13,11 +15,10 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader from torch.utils.data._utils import collate -from collections import OrderedDict from squiRL.common import reg_policies -from squiRL.common.data_stream import RLDataset, RolloutCollector from squiRL.common.agents import Agent +from squiRL.common.data_stream import RLDataset, RolloutCollector class VPG(pl.LightningModule): @@ -41,17 +42,25 @@ def __init__(self, hparams: argparse.Namespace) -> None: super(VPG, self).__init__() self.hparams = hparams - self.env = gym.make(self.hparams.env) + self.num_envs = hparams.num_envs + self.env = [gym.make(self.hparams.env) for i in range(self.num_envs)] self.gamma = self.hparams.gamma self.eps = self.hparams.eps - obs_size = self.env.observation_space.shape[0] - n_actions = self.env.action_space.n + self.episodes_per_batch = self.hparams.episodes_per_batch + self.num_workers = hparams.num_workers + # Assuming all envs used have the same obs and action space, the first one is used to extract this info + obs_size = self.env[0].observation_space.shape[0] + action_size = self.env[0].action_space.shape + n_actions = self.env[0].action_space.n self.net = reg_policies[self.hparams.policy](obs_size, n_actions) - self.replay_buffer = RolloutCollector(self.hparams.episode_length) + self.replay_buffer = RolloutCollector(self.hparams.episode_length, (obs_size,), action_size) self.agent = Agent(self.env, self.replay_buffer) + if hparams.profiler is not None: + self.profiler = hparams.profiler + @staticmethod def add_model_specific_args( parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: @@ -84,6 +93,15 @@ def add_model_specific_args( type=int, default=20, help="num of dataloader cpu workers") + parser.add_argument("--episodes_per_batch", + type=int, + default=1, + help="number of episodes to be sampled per training step") + parser.add_argument("--num_envs", + type=int, + default=1, + help="number of environments to be sequentially sampled from") + return parser def reward_to_go(self, rewards: torch.Tensor) -> torch.tensor: @@ -104,8 +122,8 @@ def reward_to_go(self, rewards: torch.Tensor) -> torch.tensor: res.append(copy(sum_r)) return list(reversed(res)) - def vpg_loss(self, batch: Tuple[torch.Tensor, - torch.Tensor]) -> torch.Tensor: + def vpg_loss(self, + batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: """ Calculates the loss based on the REINFORCE objective, using the discounted @@ -122,20 +140,22 @@ def vpg_loss(self, batch: Tuple[torch.Tensor, action_logit = self.net(states.float()) log_probs = F.log_softmax(action_logit, - dim=-1).squeeze(0)[range(len(actions)), - actions] + dim=-1)[range(len(actions)), + actions.long()] + + discounted_rewards = self.reward_to_go(rewards) discounted_rewards = torch.tensor(discounted_rewards) advantage = (discounted_rewards - discounted_rewards.mean()) / ( - discounted_rewards.std() + self.eps) + discounted_rewards.std() + self.eps) advantage = advantage.type_as(log_probs) loss = -advantage * log_probs return loss.sum() - def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], - nb_batch) -> OrderedDict: + def training_step(self, batch: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], + nb_batch) -> pl.TrainResult: """ Carries out an entire episode in env and calculates loss @@ -143,18 +163,25 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], OrderedDict: Training step result Args: - batch (Tuple[torch.Tensor, torch.Tensor]): Current mini batch of - replay data + batch (List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]): Current + mini batch of replay data nb_batch (TYPE): Current index of mini batch of replay data """ - _, _, rewards, _, _ = batch - episode_reward = rewards.sum().detach() + loss = None + episode_rewards = [] + for episode in batch: + _, _, rewards, _, _ = episode + episode_rewards.append(rewards.sum().detach()) - loss = self.vpg_loss(batch) + if loss is None: + loss = self.vpg_loss(episode) + else: + loss += self.vpg_loss(episode) if self.trainer.use_dp or self.trainer.use_ddp2: loss = loss.unsqueeze(0) + mean_episode_reward = torch.tensor(np.mean(episode_rewards)) result = pl.TrainResult(loss) result.log('loss', loss, @@ -162,8 +189,8 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], on_epoch=True, prog_bar=False, logger=True) - result.log('episode_reward', - episode_reward, + result.log('mean_episode_reward', + mean_episode_reward, on_step=True, on_epoch=True, prog_bar=True, @@ -191,13 +218,7 @@ def collate_fn(self, batch): """ batch = collate.default_convert(batch) - states = torch.cat([s[0] for s in batch]) - actions = torch.cat([s[1] for s in batch]) - rewards = torch.cat([s[2] for s in batch]) - dones = torch.cat([s[3] for s in batch]) - next_states = torch.cat([s[4] for s in batch]) - - return states, actions, rewards, dones, next_states + return batch def __dataloader(self) -> DataLoader: """Initialize the RL dataset used for retrieving experiences @@ -205,11 +226,13 @@ def __dataloader(self) -> DataLoader: Returns: DataLoader: Handles loading the data for training """ - dataset = RLDataset(self.replay_buffer, self.env, self.net, self.agent) + dataset = RLDataset(self.replay_buffer, self.net, self.agent, self.episodes_per_batch) dataloader = DataLoader( dataset=dataset, collate_fn=self.collate_fn, - batch_size=1, + batch_size=self.episodes_per_batch, + num_workers=self.num_workers, + pin_memory=True ) return dataloader diff --git a/train.py b/train.py index a8bb7ed..b7e9d3c 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,7 @@ def main(hparams) -> None: if hparams.debug: hparams.logger = None hparams.profiler = True - hparams.num_workers = None + hparams.num_workers = 0 else: hparams.logger = WandbLogger(project=hparams.project) seed_everything(hparams.seed)