Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
glfw>=1.4.0
numpy>=1.11
Cython>=0.27.2
imageio>=2.1.2
cffi>=1.10
imagehash>=3.4
ipdb
Pillow>=4.0.0
pycparser>=2.17.0
pytest>=3.0.5
pytest-instafail==0.3.0
scipy>=0.18.0
sphinx
sphinx_rtd_theme
numpydoc
cloudpickle==1.3.0
cached-property==1.3.1
gym==0.17.3
gitpython==2.1.7
gtimer==1.0.0b5
# awscli==1.11.179
boto3==1.4.8
# ray==0.2.2
path.py==10.3.1
torch==1.6.0
joblib==0.9.4
opencv-python==4.3.0.36
torchvision==0.2.0
sk-video==1.1.10
# git+https://github.com/vitchyr/multiworld.git
moviepy
# comet_ml
10 changes: 10 additions & 0 deletions rlkit/core/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,18 @@ def __init__(self):

self._log_tabular_only = False
self._header_printed = False
self._comet_log = None
self.table_printer = TerminalTablePrinter()

def reset(self):
self.__init__()

def set_comet_logger(self, log):
self._comet_log = log

def get_comet_logger(self):
return self._comet_log

def _add_output(self, file_name, arr, fds, mode='a'):
if file_name not in arr:
mkdir_p(os.path.dirname(file_name))
Expand Down Expand Up @@ -173,6 +180,9 @@ def log(self, s, with_prefix=True, with_timestamp=True):

def record_tabular(self, key, val):
self._tabular.append((self._tabular_prefix_str + str(key), str(val)))
if (self._comet_log is not None):
self._comet_log.log_metrics({str(self._tabular_prefix_str) + str(key): val})
# logger.set_step(step=settings["round"])

def record_dict(self, d, prefix=None):
if prefix is not None:
Expand Down
154 changes: 152 additions & 2 deletions rlkit/data_management/env_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ def __init__(
self,
max_replay_buffer_size,
env,
env_info_sizes=None
env_info_sizes=None,
dtype="float32"
):
"""
:param max_replay_buffer_size:
Expand All @@ -30,7 +31,8 @@ def __init__(
max_replay_buffer_size=max_replay_buffer_size,
observation_dim=get_dim(self._ob_space),
action_dim=get_dim(self._action_space),
env_info_sizes=env_info_sizes
env_info_sizes=env_info_sizes,
dtype=dtype
)

def add_sample(self, observation, action, reward, terminal,
Expand All @@ -48,3 +50,151 @@ def add_sample(self, observation, action, reward, terminal,
terminal=terminal,
**kwargs
)


class PPOEnvReplayBuffer(EnvReplayBuffer):
def __init__(
self,
max_replay_buffer_size,
env,
discount_factor,
value_f,
use_gae=False,
gae_discount=0.95,
**kwargs
):
super().__init__(max_replay_buffer_size, env, **kwargs)
self._returns = np.zeros((max_replay_buffer_size, 1))
self.current_trajectory_rewards = np.zeros((max_replay_buffer_size, 1))
self._max_replay_buffer_size = max_replay_buffer_size
self.discount_factor = discount_factor
self.value_f = value_f
self.use_gae = use_gae
self.gae_discount = gae_discount
self._bottom = 0
self._values = np.zeros((max_replay_buffer_size, 1))
self._advs = np.zeros((max_replay_buffer_size, 1))

def discounted_rewards(self, rewards, discount_factor):
import scipy
from scipy import signal, misc
"""
computes discounted sums along 0th dimension of x.
inputs
------
rewards: ndarray
discount_factor: float
outputs
-------
y: ndarray with same shape as x, satisfying
y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k],
where k = len(x) - t - 1
"""
assert rewards.ndim >= 1
return scipy.signal.lfilter([1],[1,-discount_factor],rewards[::-1], axis=0)[::-1]

def terminate_episode(self):
returns = []
observations = self._observations[self._bottom:self._top]
self._values[self._bottom:self._top] = ptu.get_numpy(self.value_f(ptu.from_numpy(observations)))

# b1 = np.append(self._values[self._bottom:self._top], 0)
### THe proper way to terminate the episode
b1 = np.append(self._values[self._bottom:self._top], 0 if self._terminals[self._top-1] else self._values[self._top-1])
# b1 = np.append(self._values[self._bottom:self._top], self._values[self._top-1])
b1 = np.reshape(b1, (-1,1))
deltas = self._rewards[self._bottom:self._top] + self.discount_factor*b1[1:] - b1[:-1]
self._advs[self._bottom:self._top] = self.discounted_rewards(deltas, self.discount_factor * self.gae_discount)
self._returns[self._bottom:self._top] = self.discounted_rewards(self._rewards[self._bottom:self._top], self.discount_factor)

self._bottom = self._top

def add_sample(self, observation, action, reward, terminal,
next_observation, env_info=None, agent_info=None):
if self._top == self._max_replay_buffer_size:
raise EnvironmentError('Replay Buffer Overflow, please reduce the number of samples added!')

# This could catch onehot vs. integer representation differences.
assert(self._actions.shape[-1] == action.size)
self._observations[self._top] = observation
self._actions[self._top] = action
self._rewards[self._top] = reward
self._terminals[self._top] = terminal
self._next_obs[self._top] = next_observation

for key in self._env_info_keys:
self._env_infos[key][self._top] = env_info[key]

for key in self._agent_info_keys:
self._agent_infos[key][self._top] = agent_info[key]

self._advance()

def add_paths(self, paths):
log.trace("Adding {} new paths. First path length: {}".format(len(paths), paths[0]['actions'].shape[0]))
for path in paths:
self.add_path(path)
# process samples after adding paths
self.process_samples(self.value_f)

def process_samples(self, value_f):
# Compute value for all states
pass
# self._advs[:] = self._returns - self._values

# Center adv
# advs = self._advs[:self._top]
# self._advs[:self._top] = (advs - advs.mean()) / (advs.std() + 1e-5)

def random_batch(self, batch_size):
indices = np.random.randint(0, self._size, batch_size)
batch = dict(
observations=self._observations[indices],
actions=self._actions[indices],
rewards=self._rewards[indices],
terminals=self._terminals[indices],
next_observations=self._next_obs[indices],
returns=self._returns[indices],
advs=self._advs[indices],
vf_preds=self._values[indices]
)
for key in self._env_info_keys:
assert key not in batch.keys()
batch[key] = self._env_infos[key][indices]
for key in self._agent_info_keys:
assert key not in batch.keys()
batch[key] = self._agent_infos[key][indices]
return batch

def all_batch_windows(self, window_len, skip=1, return_env_info=False):
# Will return (bs, batch_len, dim)

start_indices = np.arange(0, self._size - window_len, skip)
terminal_sums = [self._terminals[i0:i0+window_len].sum() for i0 in start_indices]

# NB first mask should always be True for current start_indices.
valid_start_mask = np.logical_and(start_indices + window_len < self._size, np.equal(terminal_sums, 0))
valid_start_indices = start_indices[valid_start_mask]

batch = dict(
observations=np.stack([self._observations[i:i+window_len] for i in valid_start_indices]),
actions=np.stack([self._actions[i:i+window_len] for i in valid_start_indices]),
rewards=np.stack([self._rewards[i:i+window_len] for i in valid_start_indices]),
terminals=np.stack([self._terminals[i:i+window_len] for i in valid_start_indices]),
buffer_idx=valid_start_indices,
)
if return_env_info:
env_info_batch = {}
for k, v in self._env_infos.items():
env_info_batch[k] = np.stack([v[i:i+window_len] for i in valid_start_indices])
batch.update(env_info_batch)
return batch

def relabel_rewards(self, rewards):
# Ensure the updated rewards match the size of all of the data currently in the buffer.
assert(rewards.shape == (self._size,))
# Assumes the rewards correspond to the most-recently-added data to the buffer.
# I'm pretty sure this assumption is valid, because otherwise self._size would have to
# be larger than rewards.shape[0].
self._rewards[self._top - self._size:self._size] = rewards[:, None]

5 changes: 3 additions & 2 deletions rlkit/data_management/simple_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ def __init__(
observation_dim,
action_dim,
env_info_sizes,
dtype="float32"
):
self._observation_dim = observation_dim
self._action_dim = action_dim
self._max_replay_buffer_size = max_replay_buffer_size
self._observations = np.zeros((max_replay_buffer_size, observation_dim))
self._observations = np.zeros((max_replay_buffer_size, observation_dim), dtype=dtype)
# It's a bit memory inefficient to save the observations twice,
# but it makes the code *much* easier since you no longer have to
# worry about termination conditions.
self._next_obs = np.zeros((max_replay_buffer_size, observation_dim))
self._next_obs = np.zeros((max_replay_buffer_size, observation_dim), dtype=dtype)
self._actions = np.zeros((max_replay_buffer_size, action_dim))
# Make everything a 2D np array to make it easier for other code to
# reason about the shape of the data
Expand Down
Empty file added rlkit/envs/__init__.py
Empty file.
Empty file.
1 change: 1 addition & 0 deletions rlkit/samplers/data_collector/path_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def collect_new_paths(
self._policy,
max_path_length=max_path_length_this_loop,
)
# print ("path: ", path["observations"])
path_len = len(path['actions'])
if (
path_len != max_path_length
Expand Down
8 changes: 8 additions & 0 deletions rlkit/torch/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import abc
import numpy as np
import torch
from torch import nn as nn

from rlkit.torch import pytorch_util as ptu

class PyTorchModule(nn.Module, metaclass=abc.ABCMeta):
"""
Keeping wrapper around to be a bit more future-proof.
"""
pass


def eval_np(module, *args, **kwargs):
"""
Expand Down
71 changes: 71 additions & 0 deletions rlkit/torch/policy_gradient/categorical_mlp_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch
import numpy as np
import pdb

from rlkit.torch.core import PyTorchModule
from rlkit.policies.base import Policy
import rlkit.torch.pytorch_util as ptu
import torch.distributions as tdist

class CategoricalMlpPolicy(Policy, PyTorchModule):
def __init__(self, action_network, return_onehot=True):
super().__init__()
self.action_network = action_network
self.return_onehot = return_onehot
if not return_onehot:
raise NotImplementedError("not validated with cartpole")
self.action_dim = self.action_network._modules['last_fc'].out_features

def log_prob(self, one_hot_actions, probs, none):
"""

:param one_hot_actions: (B, A) one-hot actions
:param probs: (B, A) per-action probabilities
:returns:
:rtype:

"""
assert(probs.shape[-1] == self.action_dim)
assert(one_hot_actions.shape[-1] == self.action_dim)
# Replay buffer stores discrete actions as onehots
return torch.log(probs[torch.arange(one_hot_actions.shape[0]), one_hot_actions.argmax(1)])

def get_action(self, observation, argmax=False):
action_dist = self.forward(ptu.from_numpy(observation))
action_idx = self.rsample(*action_dist)
if argmax: action_idx[0, 0] = torch.argmax(action_dist[0])
action_onehot = ptu.zeros(action_dist[0].shape, dtype=torch.int64)
action_onehot[0, action_idx[0, 0]] = 1
action_log_prob = self.log_prob(action_onehot, *action_dist)
agent_info = dict(action_log_prob=ptu.get_numpy(action_log_prob), action_dist=ptu.get_numpy(action_dist[0]))
if self.return_onehot:
return ptu.get_numpy(action_onehot).flatten().tolist(), agent_info
else:
return ptu.get_numpy(action_idx).ravel().item(), agent_info

def entropy(self, probs, none):
return - (probs * torch.log(probs)).sum(-1)

def rsample(self, probs, none):
s = tdist.Categorical(probs, validate_args=True).sample((1,))
return s

def forward(self, input):
if len(input.shape) == 1:
action_probs = self.action_network(input.view(1, -1))
else:
action_probs = self.action_network(input)
return (action_probs, None)

def kl(self, source_probs, dest_probs):
source_log_probs = torch.log(source_probs)
dest_log_probs = torch.log(dest_probs)
assert(source_probs.shape[-1] == self.action_dim)
assert(dest_probs.shape[-1] == self.action_dim)

# These must be true for discrete action spaces.
assert(0 <= source_probs.min() <= source_probs.max() <= 1)
assert(0 <= dest_probs.min() <= dest_probs.max() <= 1)
kl = (source_probs * (source_log_probs - dest_log_probs)).sum(-1)
assert(ptu.get_numpy(kl.min()) >= -1e-5)
return kl
Loading