From 957f052821b4b0e6fdf1b31e9dbd0bad79ee9f22 Mon Sep 17 00:00:00 2001 From: MihaiAnca13 Date: Wed, 14 Oct 2020 18:22:31 +0100 Subject: [PATCH 1/7] Added option for multiple episodes per batch Re-arranged imports --- squiRL/common/data_stream.py | 16 +++++++----- squiRL/vpg/vpg.py | 47 ++++++++++++++++++++---------------- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/squiRL/common/data_stream.py b/squiRL/common/data_stream.py index 8917e6a..93798c4 100644 --- a/squiRL/common/data_stream.py +++ b/squiRL/common/data_stream.py @@ -3,14 +3,16 @@ Attributes: Experience (namedtuple): An environment step experience """ -import numpy as np -from torch.utils.data.dataset import IterableDataset from collections import deque from collections import namedtuple -from squiRL.common.policies import MLP -import gym from typing import Tuple +import gym +import numpy as np +from torch.utils.data.dataset import IterableDataset + +from squiRL.common.policies import MLP + Experience = namedtuple('Experience', ('state', 'action', 'reward', 'done', 'last_state')) @@ -88,8 +90,9 @@ class RLDataset(IterableDataset): net (nn.Module): Policy network replay_buffer: Replay buffer """ + def __init__(self, replay_buffer: RolloutCollector, env: gym.Env, net: MLP, - agent) -> None: + agent, episodes_per_batch: int = 1) -> None: """Summary Args: @@ -102,6 +105,7 @@ def __init__(self, replay_buffer: RolloutCollector, env: gym.Env, net: MLP, self.env = env self.net = net self.agent = agent + self.episodes_per_batch = episodes_per_batch def populate(self) -> None: """ @@ -119,7 +123,7 @@ def __iter__(self): Yields: Tuple: Sampled experience """ - for i in range(1): + for i in range(self.episodes_per_batch): self.populate() states, actions, rewards, dones, new_states = self.replay_buffer.sample( ) diff --git a/squiRL/vpg/vpg.py b/squiRL/vpg/vpg.py index 91bd846..06663fb 100644 --- a/squiRL/vpg/vpg.py +++ b/squiRL/vpg/vpg.py @@ -2,8 +2,10 @@ """ import argparse from argparse import ArgumentParser +from collections import OrderedDict from copy import copy from typing import Tuple, List + import gym import numpy as np import pytorch_lightning as pl @@ -13,11 +15,10 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader from torch.utils.data._utils import collate -from collections import OrderedDict from squiRL.common import reg_policies -from squiRL.common.data_stream import RLDataset, RolloutCollector from squiRL.common.agents import Agent +from squiRL.common.data_stream import RLDataset, RolloutCollector class VPG(pl.LightningModule): @@ -44,6 +45,7 @@ def __init__(self, hparams: argparse.Namespace) -> None: self.env = gym.make(self.hparams.env) self.gamma = self.hparams.gamma self.eps = self.hparams.eps + self.episodes_per_batch = self.hparams.episodes_per_batch obs_size = self.env.observation_space.shape[0] n_actions = self.env.action_space.n @@ -84,6 +86,10 @@ def add_model_specific_args( type=int, default=20, help="num of dataloader cpu workers") + parser.add_argument("--episodes_per_batch", + type=int, + default=1, + help="number of episodes to be sampled per training step") return parser def reward_to_go(self, rewards: torch.Tensor) -> torch.tensor: @@ -104,8 +110,8 @@ def reward_to_go(self, rewards: torch.Tensor) -> torch.tensor: res.append(copy(sum_r)) return list(reversed(res)) - def vpg_loss(self, batch: Tuple[torch.Tensor, - torch.Tensor]) -> torch.Tensor: + def vpg_loss(self, + batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: """ Calculates the loss based on the REINFORCE objective, using the discounted @@ -128,14 +134,14 @@ def vpg_loss(self, batch: Tuple[torch.Tensor, discounted_rewards = self.reward_to_go(rewards) discounted_rewards = torch.tensor(discounted_rewards) advantage = (discounted_rewards - discounted_rewards.mean()) / ( - discounted_rewards.std() + self.eps) + discounted_rewards.std() + self.eps) advantage = advantage.type_as(log_probs) loss = -advantage * log_probs return loss.sum() - def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], - nb_batch) -> OrderedDict: + def training_step(self, batch: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], + nb_batch) -> pl.TrainResult: """ Carries out an entire episode in env and calculates loss @@ -143,14 +149,19 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], OrderedDict: Training step result Args: - batch (Tuple[torch.Tensor, torch.Tensor]): Current mini batch of - replay data + batch (List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]): Current + mini batch of replay data nb_batch (TYPE): Current index of mini batch of replay data """ - _, _, rewards, _, _ = batch - episode_reward = rewards.sum().detach() + loss = None + for episode in batch: + _, _, rewards, _, _ = episode + episode_reward = rewards.sum().detach() - loss = self.vpg_loss(batch) + if loss is None: + loss = self.vpg_loss(episode) + else: + loss += self.vpg_loss(episode) if self.trainer.use_dp or self.trainer.use_ddp2: loss = loss.unsqueeze(0) @@ -191,13 +202,7 @@ def collate_fn(self, batch): """ batch = collate.default_convert(batch) - states = torch.cat([s[0] for s in batch]) - actions = torch.cat([s[1] for s in batch]) - rewards = torch.cat([s[2] for s in batch]) - dones = torch.cat([s[3] for s in batch]) - next_states = torch.cat([s[4] for s in batch]) - - return states, actions, rewards, dones, next_states + return batch def __dataloader(self) -> DataLoader: """Initialize the RL dataset used for retrieving experiences @@ -205,11 +210,11 @@ def __dataloader(self) -> DataLoader: Returns: DataLoader: Handles loading the data for training """ - dataset = RLDataset(self.replay_buffer, self.env, self.net, self.agent) + dataset = RLDataset(self.replay_buffer, self.env, self.net, self.agent, self.episodes_per_batch) dataloader = DataLoader( dataset=dataset, collate_fn=self.collate_fn, - batch_size=1, + batch_size=self.episodes_per_batch, ) return dataloader From b68d168895201f23b09f738dce8310077e50632c Mon Sep 17 00:00:00 2001 From: MihaiAnca13 Date: Sat, 17 Oct 2020 17:23:27 +0100 Subject: [PATCH 2/7] Prepare code for shared buffer Add num_workers to VPG --- squiRL/common/data_stream.py | 61 +++++++++++++++++++++++++++++------- squiRL/vpg/vpg.py | 11 +++++-- 2 files changed, 57 insertions(+), 15 deletions(-) diff --git a/squiRL/common/data_stream.py b/squiRL/common/data_stream.py index 93798c4..c012ca8 100644 --- a/squiRL/common/data_stream.py +++ b/squiRL/common/data_stream.py @@ -10,6 +10,8 @@ import gym import numpy as np from torch.utils.data.dataset import IterableDataset +import torch.multiprocessing as mp +import torch from squiRL.common.policies import MLP @@ -29,14 +31,33 @@ class RolloutCollector: capacity (int): Size of the buffer replay_buffer (deque): Experience buffer """ - def __init__(self, capacity: int) -> None: + def __init__(self, capacity: int, state_shape: tuple, action_shape: tuple, should_share: bool = False) -> None: """Summary Args: capacity (int): Description """ + + state_shape = [capacity] + list(state_shape) + action_shape = [capacity] + list(action_shape) + self.capacity = capacity - self.replay_buffer = deque(maxlen=self.capacity) + self.count = torch.tensor([0], dtype=torch.int64) + self.states = torch.zeros(state_shape, dtype=torch.float32) + self.actions = torch.zeros(action_shape, dtype=torch.float32) + self.rewards = torch.zeros((capacity), dtype=torch.float32) + self.dones = torch.zeros((capacity), dtype=torch.bool) + self.next_states = torch.zeros(state_shape, dtype=torch.float32) + + if should_share: + self.count.share_memory_() + self.states.share_memory_() + self.actions.share_memory_() + self.next_states.share_memory_() + self.rewards.share_memory_() + self.dones.share_memory_() + + self.lock = mp.Lock() def __len__(self) -> int: """Calculates length of buffer @@ -44,7 +65,7 @@ def __len__(self) -> int: Returns: int: Length of buffer """ - return len(self.replay_buffer) + return self.count.detach().numpy().item() def append(self, experience: Experience) -> None: """ @@ -52,9 +73,25 @@ def append(self, experience: Experience) -> None: Args: experience (Experience): Tuple (state, action, reward, done, - new_state) + last_state) """ - self.replay_buffer.append(experience) + + with self.lock: + if self.count[0] < self.capacity: + self.count[0] += 1 + + # count keeps the exact length, but indexing starts from 0 so we decrease by 1 + nr = self.count[0] - 1 + + self.states[nr] = torch.tensor(experience.state, dtype=torch.float32) + self.actions[nr] = torch.tensor(experience.action, dtype=torch.float32) + self.rewards[nr] = torch.tensor(experience.reward, dtype=torch.float32) + self.dones[nr] = torch.tensor(experience.done, dtype=torch.bool) + self.next_states[nr] = torch.tensor(experience.last_state, dtype=torch.float32) + + else: + exit("RolloutCollector: Buffer is full but samples are being added to it") + def sample(self) -> Tuple: """Sample experience from buffer @@ -62,17 +99,17 @@ def sample(self) -> Tuple: Returns: Tuple: Sampled experience """ - states, actions, rewards, dones, next_states = zip( - *[self.replay_buffer[i] for i in range(len(self.replay_buffer))]) - return (np.array(states), np.array(actions), - np.array(rewards, dtype=np.float32), - np.array(dones, dtype=np.bool), np.array(next_states)) + # count keeps the exact length, but indexing starts from 0 so we decrease by 1 + nr = self.count[0] - 1 + return (self.states[:nr], self.actions[:nr], self.rewards[:nr], self.dones[:nr], self.next_states[:nr]) def empty_buffer(self) -> None: - """Empty replay buffer + """Empty replay buffer by resetting the count (so old data gets overwritten) """ - self.replay_buffer.clear() + with self.lock: + # the [0] is very important, otherwise we throw the tensor out and the int that replaces it won't get shared + self.count[0] = 0 class RLDataset(IterableDataset): diff --git a/squiRL/vpg/vpg.py b/squiRL/vpg/vpg.py index 06663fb..87f16d0 100644 --- a/squiRL/vpg/vpg.py +++ b/squiRL/vpg/vpg.py @@ -46,11 +46,12 @@ def __init__(self, hparams: argparse.Namespace) -> None: self.gamma = self.hparams.gamma self.eps = self.hparams.eps self.episodes_per_batch = self.hparams.episodes_per_batch + self.num_workers = hparams.num_workers obs_size = self.env.observation_space.shape[0] n_actions = self.env.action_space.n self.net = reg_policies[self.hparams.policy](obs_size, n_actions) - self.replay_buffer = RolloutCollector(self.hparams.episode_length) + self.replay_buffer = RolloutCollector(self.hparams.episode_length, self.env.observation_space.shape, self.env.action_space.shape) self.agent = Agent(self.env, self.replay_buffer) @@ -128,8 +129,10 @@ def vpg_loss(self, action_logit = self.net(states.float()) log_probs = F.log_softmax(action_logit, - dim=-1).squeeze(0)[range(len(actions)), - actions] + dim=-1)[range(len(actions)), + actions.long()] + + discounted_rewards = self.reward_to_go(rewards) discounted_rewards = torch.tensor(discounted_rewards) @@ -215,6 +218,8 @@ def __dataloader(self) -> DataLoader: dataset=dataset, collate_fn=self.collate_fn, batch_size=self.episodes_per_batch, + num_workers=self.num_workers, + pin_memory=True ) return dataloader From b5b4b9863925f97c018b01a68ad8bafb305949d9 Mon Sep 17 00:00:00 2001 From: MihaiAnca13 Date: Sat, 17 Oct 2020 18:33:19 +0100 Subject: [PATCH 3/7] Add num_envs Agent handles list of envs Loop through envs in RolloutCollector Remove env from RLDataset --- squiRL/common/agents.py | 33 ++++++++++++++++++++++----------- squiRL/common/data_stream.py | 17 +++++++---------- squiRL/vpg/vpg.py | 18 +++++++++++++----- train.py | 2 +- 4 files changed, 43 insertions(+), 27 deletions(-) diff --git a/squiRL/common/agents.py b/squiRL/common/agents.py index 66e815e..e891dd2 100644 --- a/squiRL/common/agents.py +++ b/squiRL/common/agents.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from typing import Tuple +from typing import Tuple, List from squiRL.common.data_stream import Experience @@ -17,25 +17,34 @@ class Agent: env: training environment Attributes: - env (gym.Env): OpenAI gym training environment + env (List[gym.Env]): List of OpenAI gym training environment obs (int): Array of env observation state replay_buffer (TYPE): Data collector for saving experience """ - def __init__(self, env: gym.Env, replay_buffer) -> None: + def __init__(self, env: List[gym.Env], replay_buffer) -> None: """Initializes agent class Args: - env (gym.Env): OpenAI gym training environment + env (List[gym.Env]): List of OpenAI gym training environment replay_buffer (TYPE): Data collector for saving experience """ - self.env = env + self.envs = env + self.obs = [None] * len(self.envs) self.replay_buffer = replay_buffer - self.reset() + + for i in range(len(self.envs)): + self.env_idx = i + self.reset() + + self.env_idx = 0 def reset(self) -> None: """Resets the environment and updates the obs """ - self.obs = self.env.reset() + self.obs[self.env_idx] = self.envs[self.env_idx].reset() + + def next_env(self) -> None: + self.env_idx = (self.env_idx + 1) % len(self.envs) def process_obs(self, obs: int) -> torch.Tensor: """Converts obs np.array to torch.Tensor for passing through NN @@ -61,7 +70,9 @@ def get_action( Returns: action (int): Action to be carried out """ - obs = self.process_obs(self.obs) + obs = self.obs[self.env_idx] + assert obs is not None + obs = self.process_obs(obs) action_logit = net(obs) probs = F.softmax(action_logit, dim=-1) @@ -89,11 +100,11 @@ def play_step( action = self.get_action(net) # do step in the environment - new_obs, reward, done, _ = self.env.step(action) - exp = Experience(self.obs, action, reward, done, new_obs) + new_obs, reward, done, _ = self.envs[self.env_idx].step(action) + exp = Experience(self.obs[self.env_idx], action, reward, done, new_obs) self.replay_buffer.append(exp) - self.obs = new_obs + self.obs[self.env_idx] = new_obs if done: self.reset() return reward, done diff --git a/squiRL/common/data_stream.py b/squiRL/common/data_stream.py index c012ca8..0192af2 100644 --- a/squiRL/common/data_stream.py +++ b/squiRL/common/data_stream.py @@ -7,8 +7,6 @@ from collections import namedtuple from typing import Tuple -import gym -import numpy as np from torch.utils.data.dataset import IterableDataset import torch.multiprocessing as mp import torch @@ -123,23 +121,20 @@ class RLDataset(IterableDataset): Attributes: agent (Agent): Agent that interacts with env - env (gym.Env): OpenAI gym environment net (nn.Module): Policy network replay_buffer: Replay buffer """ - def __init__(self, replay_buffer: RolloutCollector, env: gym.Env, net: MLP, + def __init__(self, replay_buffer: RolloutCollector, net: MLP, agent, episodes_per_batch: int = 1) -> None: """Summary Args: replay_buffer (RolloutCollector): Description - env (gym.Env): OpenAI gym environment net (nn.Module): Policy network agent (Agent): Agent that interacts with env """ self.replay_buffer = replay_buffer - self.env = env self.net = net self.agent = agent self.episodes_per_batch = episodes_per_batch @@ -161,7 +156,9 @@ def __iter__(self): Tuple: Sampled experience """ for i in range(self.episodes_per_batch): - self.populate() - states, actions, rewards, dones, new_states = self.replay_buffer.sample( - ) - yield (states, actions, rewards, dones, new_states) + for j in range(len(self.agent.envs)): + self.agent.env_idx = j + self.populate() + states, actions, rewards, dones, new_states = self.replay_buffer.sample( + ) + yield (states, actions, rewards, dones, new_states) diff --git a/squiRL/vpg/vpg.py b/squiRL/vpg/vpg.py index 87f16d0..f0fb016 100644 --- a/squiRL/vpg/vpg.py +++ b/squiRL/vpg/vpg.py @@ -42,16 +42,19 @@ def __init__(self, hparams: argparse.Namespace) -> None: super(VPG, self).__init__() self.hparams = hparams - self.env = gym.make(self.hparams.env) + self.num_envs = hparams.num_envs + self.env = [gym.make(self.hparams.env) for i in range(self.num_envs)] self.gamma = self.hparams.gamma self.eps = self.hparams.eps self.episodes_per_batch = self.hparams.episodes_per_batch self.num_workers = hparams.num_workers - obs_size = self.env.observation_space.shape[0] - n_actions = self.env.action_space.n + # Assuming all envs used have the same obs and action space, the first one is used to extract this info + obs_size = self.env[0].observation_space.shape[0] + action_size = self.env[0].action_space.shape + n_actions = self.env[0].action_space.n self.net = reg_policies[self.hparams.policy](obs_size, n_actions) - self.replay_buffer = RolloutCollector(self.hparams.episode_length, self.env.observation_space.shape, self.env.action_space.shape) + self.replay_buffer = RolloutCollector(self.hparams.episode_length, (obs_size,), action_size) self.agent = Agent(self.env, self.replay_buffer) @@ -91,6 +94,11 @@ def add_model_specific_args( type=int, default=1, help="number of episodes to be sampled per training step") + parser.add_argument("--num_envs", + type=int, + default=1, + help="number of environments to be sequentially sampled from") + return parser def reward_to_go(self, rewards: torch.Tensor) -> torch.tensor: @@ -213,7 +221,7 @@ def __dataloader(self) -> DataLoader: Returns: DataLoader: Handles loading the data for training """ - dataset = RLDataset(self.replay_buffer, self.env, self.net, self.agent, self.episodes_per_batch) + dataset = RLDataset(self.replay_buffer, self.net, self.agent, self.episodes_per_batch) dataloader = DataLoader( dataset=dataset, collate_fn=self.collate_fn, diff --git a/train.py b/train.py index a8bb7ed..24cb08f 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,7 @@ def main(hparams) -> None: if hparams.debug: hparams.logger = None hparams.profiler = True - hparams.num_workers = None + hparams.num_workers = 1 else: hparams.logger = WandbLogger(project=hparams.project) seed_everything(hparams.seed) From 18894cbd93580ed7bbee88369162d6a7fd820760 Mon Sep 17 00:00:00 2001 From: MihaiAnca13 Date: Sun, 18 Oct 2020 12:01:19 +0100 Subject: [PATCH 4/7] Num workers when debugging should be 0 not 1 --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 24cb08f..b7e9d3c 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,7 @@ def main(hparams) -> None: if hparams.debug: hparams.logger = None hparams.profiler = True - hparams.num_workers = 1 + hparams.num_workers = 0 else: hparams.logger = WandbLogger(project=hparams.project) seed_everything(hparams.seed) From 1d9c5dc04f1ce4c4ca2cb8fb671b0dc50835260b Mon Sep 17 00:00:00 2001 From: MihaiAnca13 Date: Sun, 18 Oct 2020 13:42:05 +0100 Subject: [PATCH 5/7] Add simple profiler --- squiRL/vpg/vpg.py | 3 +++ train.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/squiRL/vpg/vpg.py b/squiRL/vpg/vpg.py index f0fb016..70095cd 100644 --- a/squiRL/vpg/vpg.py +++ b/squiRL/vpg/vpg.py @@ -58,6 +58,9 @@ def __init__(self, hparams: argparse.Namespace) -> None: self.agent = Agent(self.env, self.replay_buffer) + if hparams.profiler is not None: + self.profiler = hparams.profiler + @staticmethod def add_model_specific_args( parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: diff --git a/train.py b/train.py index b7e9d3c..4605c7c 100644 --- a/train.py +++ b/train.py @@ -10,6 +10,7 @@ from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities.seed import seed_everything import pytorch_lightning as pl +from pytorch_lightning.profiler import SimpleProfiler import squiRL @@ -21,7 +22,7 @@ def main(hparams) -> None: """ if hparams.debug: hparams.logger = None - hparams.profiler = True + hparams.profiler = SimpleProfiler() hparams.num_workers = 0 else: hparams.logger = WandbLogger(project=hparams.project) From bf070c614b1b89f8fd52b8df1fb5c5749ff5cab1 Mon Sep 17 00:00:00 2001 From: AGKhalil Date: Sun, 18 Oct 2020 19:56:26 +0000 Subject: [PATCH 6/7] profiler pl removed cause pickle error --- train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/train.py b/train.py index 4605c7c..b7e9d3c 100644 --- a/train.py +++ b/train.py @@ -10,7 +10,6 @@ from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities.seed import seed_everything import pytorch_lightning as pl -from pytorch_lightning.profiler import SimpleProfiler import squiRL @@ -22,7 +21,7 @@ def main(hparams) -> None: """ if hparams.debug: hparams.logger = None - hparams.profiler = SimpleProfiler() + hparams.profiler = True hparams.num_workers = 0 else: hparams.logger = WandbLogger(project=hparams.project) From 705225da79d5f68e558fdc82c5e9f8efdb9388bc Mon Sep 17 00:00:00 2001 From: AGKhalil Date: Sun, 18 Oct 2020 20:12:54 +0000 Subject: [PATCH 7/7] logging mean episode loss --- squiRL/vpg/vpg.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/squiRL/vpg/vpg.py b/squiRL/vpg/vpg.py index 70095cd..30701d7 100644 --- a/squiRL/vpg/vpg.py +++ b/squiRL/vpg/vpg.py @@ -168,9 +168,10 @@ def training_step(self, batch: List[Tuple[torch.Tensor, torch.Tensor, torch.Tens nb_batch (TYPE): Current index of mini batch of replay data """ loss = None + episode_rewards = [] for episode in batch: _, _, rewards, _, _ = episode - episode_reward = rewards.sum().detach() + episode_rewards.append(rewards.sum().detach()) if loss is None: loss = self.vpg_loss(episode) @@ -180,6 +181,7 @@ def training_step(self, batch: List[Tuple[torch.Tensor, torch.Tensor, torch.Tens if self.trainer.use_dp or self.trainer.use_ddp2: loss = loss.unsqueeze(0) + mean_episode_reward = torch.tensor(np.mean(episode_rewards)) result = pl.TrainResult(loss) result.log('loss', loss, @@ -187,8 +189,8 @@ def training_step(self, batch: List[Tuple[torch.Tensor, torch.Tensor, torch.Tens on_epoch=True, prog_bar=False, logger=True) - result.log('episode_reward', - episode_reward, + result.log('mean_episode_reward', + mean_episode_reward, on_step=True, on_epoch=True, prog_bar=True,