diff --git a/src/imitation/algorithms/pebble/__init__.py b/src/imitation/algorithms/pebble/__init__.py new file mode 100644 index 000000000..dca061476 --- /dev/null +++ b/src/imitation/algorithms/pebble/__init__.py @@ -0,0 +1 @@ +"""PEBBLE specific algorithms.""" diff --git a/src/imitation/algorithms/pebble/entropy_reward.py b/src/imitation/algorithms/pebble/entropy_reward.py new file mode 100644 index 000000000..eba53405b --- /dev/null +++ b/src/imitation/algorithms/pebble/entropy_reward.py @@ -0,0 +1,199 @@ +"""Reward function for the PEBBLE training algorithm.""" + +import enum +from typing import Any, Callable, Optional, Tuple + +import gym +import numpy as np +import torch as th + +from imitation.policies.replay_buffer_wrapper import ( + ReplayBufferAwareRewardFn, + ReplayBufferRewardWrapper, + ReplayBufferView, +) +from imitation.rewards.reward_function import RewardFn +from imitation.rewards.reward_nets import RewardNet +from imitation.util import util + + +class InsufficientObservations(RuntimeError): + """Error signifying not enough observations for entropy calculation.""" + + pass + + +class EntropyRewardNet(RewardNet, ReplayBufferAwareRewardFn): + """RewardNet wrapping entropy reward function.""" + + __call__: Callable[..., Any] # Needed to appease pytype + + def __init__( + self, + nearest_neighbor_k: int, + observation_space: gym.Space, + action_space: gym.Space, + normalize_images: bool = True, + replay_buffer_view: Optional[ReplayBufferView] = None, + ): + """Initialize the RewardNet. + + Args: + nearest_neighbor_k: Parameter for entropy computation (see + compute_state_entropy()) + observation_space: the observation space of the environment + action_space: the action space of the environment + normalize_images: whether to automatically normalize + image observations to [0, 1] (from 0 to 255). Defaults to True. + replay_buffer_view: Replay buffer view with observations to compare + against when computing entropy. If None is given, the buffer needs to + be set with on_replay_buffer_initialized() before EntropyRewardNet can + be used + """ + super().__init__(observation_space, action_space, normalize_images) + self.nearest_neighbor_k = nearest_neighbor_k + self._replay_buffer_view = replay_buffer_view + + def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper): + """Sets replay buffer. + + This method needs to be called, e.g., after unpickling. + See also __getstate__() / __setstate__(). + + Args: + replay_buffer: replay buffer with history of observations + """ + assert self.observation_space == replay_buffer.observation_space + assert self.action_space == replay_buffer.action_space + self._replay_buffer_view = replay_buffer.buffer_view + + def forward( + self, + state: th.Tensor, + action: th.Tensor, + next_state: th.Tensor, + done: th.Tensor, + ) -> th.Tensor: + assert ( + self._replay_buffer_view is not None + ), "Missing replay buffer (possibly after unpickle)" + + all_observations = self._replay_buffer_view.observations + # ReplayBuffer sampling flattens the venv dimension, let's adapt to that + all_observations = all_observations.reshape( + (-1,) + self.observation_space.shape, + ) + + if all_observations.shape[0] < self.nearest_neighbor_k: + raise InsufficientObservations( + "Insufficient observations for entropy calculation", + ) + + return util.compute_state_entropy( + state, + all_observations, + self.nearest_neighbor_k, + ) + + def preprocess( + self, + state: np.ndarray, + action: np.ndarray, + next_state: np.ndarray, + done: np.ndarray, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor]: + """Override default preprocessing to avoid the default one-hot encoding. + + We also know forward() only works with state, so no need to convert + other tensors. + + Args: + state: The observation input. + action: The action input. + next_state: The observation input. + done: Whether the episode has terminated. + + Returns: + Observations preprocessed by converting them to Tensor. + """ + state_th = util.safe_to_tensor(state).to(self.device) + action_th = next_state_th = done_th = th.empty(0) + return state_th, action_th, next_state_th, done_th + + def __getstate__(self): + state = self.__dict__.copy() + del state["_replay_buffer_view"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._replay_buffer_view = None + + +class PebbleRewardPhase(enum.Enum): + """States representing different behaviors for PebbleStateEntropyReward.""" + + UNSUPERVISED_EXPLORATION = enum.auto() # Entropy based reward + POLICY_AND_REWARD_LEARNING = enum.auto() # Learned reward + + +class PebbleStateEntropyReward(ReplayBufferAwareRewardFn): + """Reward function for implementation of the PEBBLE learning algorithm. + + See https://arxiv.org/abs/2106.05091 . + + The rewards returned by this function go through the three phases: + 1. Before enough samples are collected for entropy calculation, the + underlying function is returned. This shouldn't matter because + OffPolicyAlgorithms have an initialization period for `learning_starts` + timesteps. + 2. During the unsupervised exploration phase, entropy based reward is returned + 3. After unsupervised exploration phase is finished, the underlying learned + reward is returned. + + The second phase requires that a buffer with observations to compare against is + supplied with on_replay_buffer_initialized(). To transition to the last phase, + unsupervised_exploration_finish() needs to be called. + """ + + def __init__( + self, + entropy_reward_fn: RewardFn, + learned_reward_fn: RewardFn, + ): + """Builds this class. + + Args: + entropy_reward_fn: The entropy-based reward function used during + unsupervised exploration + learned_reward_fn: The learned reward function used after unsupervised + exploration is finished + """ + self.entropy_reward_fn = entropy_reward_fn + self.learned_reward_fn = learned_reward_fn + self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION + + def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper): + if isinstance(self.entropy_reward_fn, ReplayBufferAwareRewardFn): + self.entropy_reward_fn.on_replay_buffer_initialized(replay_buffer) + + def unsupervised_exploration_finish(self): + assert self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION + self.state = PebbleRewardPhase.POLICY_AND_REWARD_LEARNING + + def __call__( + self, + state: np.ndarray, + action: np.ndarray, + next_state: np.ndarray, + done: np.ndarray, + ) -> np.ndarray: + if self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION: + try: + return self.entropy_reward_fn(state, action, next_state, done) + except InsufficientObservations: + # not enough observations to compare to, fall back to the learned + # function; (falling back to a constant may also be ok) + return self.learned_reward_fn(state, action, next_state, done) + else: + return self.learned_reward_fn(state, action, next_state, done) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 413cd979a..fccd7958d 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -33,6 +33,7 @@ from tqdm.auto import tqdm from imitation.algorithms import base +from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward from imitation.data import rollout, types, wrappers from imitation.data.types import ( AnyPath, @@ -44,6 +45,7 @@ from imitation.policies import exploration_wrapper from imitation.regularization import regularizers from imitation.rewards import reward_function, reward_nets, reward_wrapper +from imitation.rewards.reward_function import RewardFn from imitation.util import logger as imit_logger from imitation.util import networks, util @@ -75,6 +77,45 @@ def sample(self, steps: int) -> Sequence[TrajectoryWithRew]: be the environment rewards, not ones from a reward model). """ # noqa: DAR202 + @property + def has_pretraining(self) -> bool: + """Indicates whether this generator has a pre-training phase. + + The value can be used, e.g., when allocating time-steps for pre-training. + + By default, True is returned if the unsupervised_pretrain() method is not + overridden, bud subclasses may choose to override this behavior. + + Returns: + True if this generator has a pre-training phase, False otherwise + """ + orig_impl = TrajectoryGenerator.unsupervised_pretrain + return type(self).unsupervised_pretrain != orig_impl + + def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None: + """Pre-train an agent before collecting comparisons. + + Override this behavior in subclasses that implement pre-training. + If not overridden, this method raises ValueError when non-zero steps are + allocated for pre-training. + + Args: + steps: number of environment steps to train for. + **kwargs: additional keyword arguments to pass on to + the training procedure. + + Raises: + ValueError: Unsupervised pre-training not implemented but non-zero + steps are allocated for pre-training. + """ + if steps > 0: + raise ValueError( + f"{steps} timesteps allocated for unsupervised pre-training:" + " Trajectory generators without pre-training implementation should" + " not consume any timesteps (otherwise the total number of" + " timesteps executed may be misleading)", + ) + def train(self, steps: int, **kwargs: Any) -> None: """Train an agent if the trajectory generator uses one. @@ -165,7 +206,7 @@ def __init__( reward_fn.action_space, ) reward_fn = reward_fn.predict_processed - self.reward_fn = reward_fn + self.reward_fn: RewardFn = reward_fn self.exploration_frac = exploration_frac self.rng = rng @@ -316,6 +357,43 @@ def logger(self, value: imit_logger.HierarchicalLogger) -> None: self.algorithm.set_logger(self.logger) +class PebbleAgentTrainer(AgentTrainer): + """Specialization of AgentTrainer for PEBBLE training. + + Includes unsupervised pretraining with an entropy based reward function. + """ + + reward_fn: PebbleStateEntropyReward + + def __init__( + self, + *, + reward_fn: PebbleStateEntropyReward, + **kwargs, + ) -> None: + """Builds PebbleAgentTrainer. + + Args: + reward_fn: Pebble reward function + **kwargs: additional keyword arguments to pass on to the parent class + + Raises: + ValueError: Unexpected type of reward_fn given. + """ + if not isinstance(reward_fn, PebbleStateEntropyReward): + raise ValueError( + f"{self.__class__.__name__} expects " + f"{PebbleStateEntropyReward.__name__} reward function", + ) + super().__init__(reward_fn=reward_fn, **kwargs) + + def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None: + self.train(steps, **kwargs) + fn = self.reward_fn + assert isinstance(fn, PebbleStateEntropyReward) + fn.unsupervised_exploration_finish() + + def _get_trajectories( trajectories: Sequence[TrajectoryWithRew], steps: int, @@ -1495,6 +1573,7 @@ def __init__( transition_oversampling: float = 1, initial_comparison_frac: float = 0.1, initial_epoch_multiplier: float = 200.0, + unsupervised_agent_pretrain_frac: float = 0.05, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, allow_variable_horizon: bool = False, rng: Optional[np.random.Generator] = None, @@ -1544,6 +1623,9 @@ def __init__( initial_epoch_multiplier: before agent training begins, train the reward model for this many more epochs than usual (on fragments sampled from a random agent). + unsupervised_agent_pretrain_frac: fraction of total_timesteps for which the + agent will be trained without preference gathering (and reward model + training) custom_logger: Where to log to; if None (default), creates a new logger. allow_variable_horizon: If False (default), algorithm will raise an exception if it detects trajectories of different length during @@ -1642,6 +1724,7 @@ def __init__( self.fragment_length = fragment_length self.initial_comparison_frac = initial_comparison_frac self.initial_epoch_multiplier = initial_epoch_multiplier + self.unsupervised_agent_pretrain_frac = unsupervised_agent_pretrain_frac self.num_iterations = num_iterations self.transition_oversampling = transition_oversampling if callable(query_schedule): @@ -1670,25 +1753,31 @@ def train( A dictionary with final metrics such as loss and accuracy of the reward model. """ - initial_comparisons = int(total_comparisons * self.initial_comparison_frac) - total_comparisons -= initial_comparisons - # Compute the number of comparisons to request at each iteration in advance. - vec_schedule = np.vectorize(self.query_schedule) - unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations)) - probs = unnormalized_probs / np.sum(unnormalized_probs) - shares = util.oric(probs * total_comparisons) - schedule = [initial_comparisons] + shares.tolist() - print(f"Query schedule: {schedule}") - - timesteps_per_iteration, extra_timesteps = divmod( - total_timesteps, - self.num_iterations, - ) + preference_query_schedule = self._preference_gather_schedule(total_comparisons) + self.logger.log(f"Query schedule: {preference_query_schedule}") + + ( + unsup_pretrain_timesteps, + timesteps_per_iteration, + extra_timesteps, + ) = self._compute_timesteps(total_timesteps) reward_loss = None reward_accuracy = None - for i, num_pairs in enumerate(schedule): + ################################################### + # Pre-training agent before gathering preferences # + ################################################### + if unsup_pretrain_timesteps: + with self.logger.accumulate_means("agent"): + self.logger.log( + f"Pre-training agent for {unsup_pretrain_timesteps} timesteps", + ) + self.trajectory_generator.unsupervised_pretrain( + unsup_pretrain_timesteps, + ) + + for i, num_pairs in enumerate(preference_query_schedule): ########################## # Gather new preferences # ########################## @@ -1751,3 +1840,26 @@ def train( self._iteration += 1 return {"reward_loss": reward_loss, "reward_accuracy": reward_accuracy} + + def _preference_gather_schedule(self, total_comparisons): + initial_comparisons = int(total_comparisons * self.initial_comparison_frac) + total_comparisons -= initial_comparisons + vec_schedule = np.vectorize(self.query_schedule) + unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations)) + probs = unnormalized_probs / np.sum(unnormalized_probs) + shares = util.oric(probs * total_comparisons) + schedule = [initial_comparisons] + shares.tolist() + return schedule + + def _compute_timesteps(self, total_timesteps: int) -> Tuple[int, int, int]: + if self.trajectory_generator.has_pretraining: + unsupervised_pretrain_timesteps = int( + total_timesteps * self.unsupervised_agent_pretrain_frac, + ) + else: + unsupervised_pretrain_timesteps = 0 + timesteps_per_iteration, extra_timesteps = divmod( + total_timesteps - unsupervised_pretrain_timesteps, + self.num_iterations, + ) + return unsupervised_pretrain_timesteps, timesteps_per_iteration, extra_timesteps diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index 3101cf2c7..9d455ff15 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -76,7 +76,7 @@ class SAC1024Policy(sac_policies.SACPolicy): """Actor and value networks with two hidden layers of 1024 units respectively. This matches the implementation of SAC policies in the PEBBLE paper. See: - https://arxiv.org/pdf/2106.05091.pdf + https://arxiv.org/abs/2106.05091 https://github.com/denisyarats/pytorch_sac/blob/master/config/agent/sac.yaml Note: This differs from stable_baselines3 SACPolicy by having 1024 hidden units diff --git a/src/imitation/policies/replay_buffer_wrapper.py b/src/imitation/policies/replay_buffer_wrapper.py index 6d0d70449..a309917c2 100644 --- a/src/imitation/policies/replay_buffer_wrapper.py +++ b/src/imitation/policies/replay_buffer_wrapper.py @@ -1,6 +1,6 @@ """Wrapper for reward labeling for transitions sampled from a replay buffer.""" - -from typing import Mapping, Type +import abc +from typing import Callable, Mapping, Type import numpy as np from gym import spaces @@ -23,6 +23,30 @@ def _samples_to_reward_fn_input( ) +class ReplayBufferView: + """A read-only view over valid records in a ReplayBuffer.""" + + def __init__( + self, + observations_buffer: np.ndarray, + buffer_slice_provider: Callable[[], slice], + ): + """Builds ReplayBufferView. + + Args: + observations_buffer: Array buffer holding observations + buffer_slice_provider: Function returning slice of buffer + with valid observations + """ + self._observations_buffer_view = observations_buffer.view() + self._observations_buffer_view.flags.writeable = False + self._buffer_slice_provider = buffer_slice_provider + + @property + def observations(self): + return self._observations_buffer_view[self._buffer_slice_provider()] + + class ReplayBufferRewardWrapper(ReplayBuffer): """Relabel the rewards in transitions sampled from a ReplayBuffer.""" @@ -61,6 +85,8 @@ def __init__( self.reward_fn = reward_fn _base_kwargs = {k: v for k, v in kwargs.items() if k in ["device", "n_envs"]} super().__init__(buffer_size, observation_space, action_space, **_base_kwargs) + if isinstance(reward_fn, ReplayBufferAwareRewardFn): + reward_fn.on_replay_buffer_initialized(self) @property def pos(self) -> int: @@ -78,6 +104,13 @@ def full(self) -> bool: def full(self, full: bool): self.replay_buffer.full = full + @property + def buffer_view(self) -> ReplayBufferView: + def valid_buffer_slice(): + return slice(None) if self.full else slice(self.pos) + + return ReplayBufferView(self.replay_buffer.observations, valid_buffer_slice) + def sample(self, *args, **kwargs): samples = self.replay_buffer.sample(*args, **kwargs) rewards = self.reward_fn(**_samples_to_reward_fn_input(samples)) @@ -101,3 +134,21 @@ def _get_samples(self): "_get_samples() is intentionally not implemented." "This method should not be called.", ) + + +class ReplayBufferAwareRewardFn(RewardFn, abc.ABC): + """Abstract class for a reward function that needs access to a replay buffer.""" + + @abc.abstractmethod + def on_replay_buffer_initialized( + self, + replay_buffer: ReplayBufferRewardWrapper, + ) -> None: + """Hook method to be called when ReplayBuffer is initialized. + + Needed to propagate the ReplayBuffer to a reward function because the buffer + is created indirectly in ReplayBufferRewardWrapper. + + Args: + replay_buffer: the created ReplayBuffer + """ # noqa: DAR202 diff --git a/src/imitation/scripts/common/rl.py b/src/imitation/scripts/common/rl.py index 2bd3759a2..d71e35211 100644 --- a/src/imitation/scripts/common/rl.py +++ b/src/imitation/scripts/common/rl.py @@ -86,10 +86,12 @@ def _maybe_add_relabel_buffer( """Use ReplayBufferRewardWrapper in rl_kwargs if relabel_reward_fn is not None.""" rl_kwargs = dict(rl_kwargs) if relabel_reward_fn: - _buffer_kwargs = dict(reward_fn=relabel_reward_fn) - _buffer_kwargs["replay_buffer_class"] = rl_kwargs.get( - "replay_buffer_class", - buffers.ReplayBuffer, + _buffer_kwargs = dict( + reward_fn=relabel_reward_fn, + replay_buffer_class=rl_kwargs.get( + "replay_buffer_class", + buffers.ReplayBuffer, + ), ) rl_kwargs["replay_buffer_class"] = ReplayBufferRewardWrapper diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index ba4e9483c..3a66349c5 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -1,8 +1,10 @@ """Configuration for imitation.scripts.train_preference_comparisons.""" import sacred +import stable_baselines3 as sb3 from imitation.algorithms import preference_comparisons +from imitation.policies import base from imitation.scripts.common import common, reward, rl, train train_preference_comparisons_ex = sacred.Experiment( @@ -15,7 +17,6 @@ ], ) - MUJOCO_SHARED_LOCALS = dict(rl=dict(rl_kwargs=dict(ent_coef=0.1))) ANT_SHARED_LOCALS = dict( total_timesteps=int(3e7), @@ -60,6 +61,29 @@ def train_defaults(): checkpoint_interval = 0 # Num epochs between saving (<0 disables, =0 final only) query_schedule = "hyperbolic" + # Whether to use the PEBBLE algorithm (https://arxiv.org/abs/2106.05091) + pebble_enabled = False + unsupervised_agent_pretrain_frac = 0.0 + + +@train_preference_comparisons_ex.named_config +def pebble(): + # fraction of total_timesteps for training before preference gathering + pebble_enabled = True + unsupervised_agent_pretrain_frac = 0.05 + pebble_nearest_neighbor_k = 5 + + rl = { + "rl_cls": sb3.SAC, + "batch_size": 256, # batch size for RL algorithm + "rl_kwargs": {"batch_size": None}, # make sure to set batch size to None + } + train = { + "policy_cls": base.SAC1024Policy, # noqa: F841 + } + + locals() # quieten flake8 + @train_preference_comparisons_ex.named_config def cartpole(): @@ -115,14 +139,23 @@ def seals_mountain_car(): common = dict(env_name="seals/MountainCar-v0") +@train_preference_comparisons_ex.named_config +def mountain_car_continuous(): + common = {"env_name": "MountainCarContinuous-v0"} + allow_variable_horizon = True + locals() # quieten flake8 + + @train_preference_comparisons_ex.named_config def fast(): # Minimize the amount of computation. Useful for test cases. total_timesteps = 50 total_comparisons = 5 initial_comparison_frac = 0.2 + unsupervised_agent_pretrain_frac = 0.2 num_iterations = 1 fragment_length = 2 reward_trainer_kwargs = { "epochs": 1, } + locals() # quieten flake8 diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 331a4797a..5e07b094c 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -3,24 +3,38 @@ Can be used as a CLI script, or the `train_preference_comparisons` function can be called directly. """ - import functools import pathlib from typing import Any, Mapping, Optional, Type, Union +import gym +import numpy as np import torch as th from sacred.observers import FileStorageObserver -from stable_baselines3.common import type_aliases +from stable_baselines3.common import base_class, type_aliases, vec_env from imitation.algorithms import preference_comparisons +from imitation.algorithms.pebble.entropy_reward import ( + EntropyRewardNet, + PebbleStateEntropyReward, +) from imitation.data import types from imitation.policies import serialize +from imitation.policies.replay_buffer_wrapper import ( + ReplayBufferAwareRewardFn, + ReplayBufferRewardWrapper, +) +from imitation.rewards import reward_function, reward_nets +from imitation.rewards.reward_function import RewardFn +from imitation.rewards.reward_nets import NormalizedRewardNet from imitation.scripts.common import common, reward from imitation.scripts.common import rl as rl_common from imitation.scripts.common import train from imitation.scripts.config.train_preference_comparisons import ( train_preference_comparisons_ex, ) +from imitation.util import logger as imit_logger +from imitation.util.networks import RunningNorm def save_model( @@ -57,6 +71,96 @@ def save_checkpoint( ) +@train_preference_comparisons_ex.capture +def make_reward_function( + reward_net: reward_nets.RewardNet, + *, + pebble_enabled: bool = False, + pebble_nearest_neighbor_k: int = 5, +): + relabel_reward_fn = functools.partial( + reward_net.predict_processed, + update_stats=False, + ) + if pebble_enabled: + relabel_reward_fn = create_pebble_reward_fn( + relabel_reward_fn, # type: ignore[assignment] + pebble_nearest_neighbor_k, + reward_net.action_space, + reward_net.observation_space, + ) + return relabel_reward_fn + + +def create_pebble_reward_fn( + relabel_reward_fn: RewardFn, + pebble_nearest_neighbor_k: int, + action_space: gym.Space, + observation_space: gym.Space, +) -> PebbleStateEntropyReward: + entropy_reward_net = EntropyRewardNet( + nearest_neighbor_k=pebble_nearest_neighbor_k, + observation_space=observation_space, + action_space=action_space, + normalize_images=False, + ) + normalized_entropy_reward_net = NormalizedRewardNet(entropy_reward_net, RunningNorm) + + class EntropyRewardFn(ReplayBufferAwareRewardFn): + """Adapter for entropy reward adding on_replay_buffer_initialized() hook.""" + + def __call__(self, *args, **kwargs) -> np.ndarray: + kwargs["update_stats"] = True + return normalized_entropy_reward_net.predict_processed(*args, **kwargs) + + def on_replay_buffer_initialized( + self, + replay_buffer: ReplayBufferRewardWrapper, + ): + entropy_reward_net.on_replay_buffer_initialized(replay_buffer) + + return PebbleStateEntropyReward( + EntropyRewardFn(), + relabel_reward_fn, + ) + + +@train_preference_comparisons_ex.capture +def make_agent_trajectory_generator( + venv: vec_env.VecEnv, + agent: base_class.BaseAlgorithm, + reward_net: reward_nets.RewardNet, + relabel_reward_fn: reward_function.RewardFn, + rng: np.random.Generator, + custom_logger: Optional[imit_logger.HierarchicalLogger], + *, + exploration_frac: float, + pebble_enabled: bool, + trajectory_generator_kwargs: Mapping[str, Any], +) -> preference_comparisons.AgentTrainer: + if pebble_enabled: + assert isinstance(relabel_reward_fn, PebbleStateEntropyReward) + return preference_comparisons.PebbleAgentTrainer( + algorithm=agent, + reward_fn=relabel_reward_fn, + venv=venv, + exploration_frac=exploration_frac, + rng=rng, + custom_logger=custom_logger, + **trajectory_generator_kwargs, + ) + else: + return preference_comparisons.AgentTrainer( + algorithm=agent, + reward_fn=reward_net, + venv=venv, + exploration_frac=exploration_frac, + rng=rng, + custom_logger=custom_logger, + **trajectory_generator_kwargs, + ) + + @train_preference_comparisons_ex.main def train_preference_comparisons( total_timesteps: int, @@ -82,6 +186,7 @@ def train_preference_comparisons( allow_variable_horizon: bool, checkpoint_interval: int, query_schedule: Union[str, type_aliases.Schedule], + unsupervised_agent_pretrain_frac: float, ) -> Mapping[str, Any]: """Train a reward model using preference comparisons. @@ -141,6 +246,9 @@ def train_preference_comparisons( be allocated to each iteration. "hyperbolic" and "inverse_quadratic" apportion fewer queries to later iterations when the policy is assumed to be better and more stable. + unsupervised_agent_pretrain_frac: fraction of total_timesteps for which the + agent will be trained without preference gathering (and reward model + training) Returns: Rollout statistics from trained policy. @@ -153,10 +261,8 @@ def train_preference_comparisons( with common.make_venv() as venv: reward_net = reward.make_reward_net(venv) - relabel_reward_fn = functools.partial( - reward_net.predict_processed, - update_stats=False, - ) + relabel_reward_fn = make_reward_function(reward_net) + if agent_path is None: agent = rl_common.make_rl_algo(venv, relabel_reward_fn=relabel_reward_fn) else: @@ -169,21 +275,17 @@ def train_preference_comparisons( if trajectory_path is None: # Setting the logger here is not necessary (PreferenceComparisons takes care # of it automatically) but it avoids creating unnecessary loggers. - agent_trainer = preference_comparisons.AgentTrainer( - algorithm=agent, - reward_fn=reward_net, + trajectory_generator = make_agent_trajectory_generator( venv=venv, - exploration_frac=exploration_frac, + agent=agent, + reward_net=reward_net, + relabel_reward_fn=relabel_reward_fn, rng=rng, custom_logger=custom_logger, - **trajectory_generator_kwargs, ) # Stable Baselines will automatically occupy GPU 0 if it is available. # Let's use the same device as the SB3 agent for the reward model. - reward_net = reward_net.to(agent_trainer.algorithm.device) - trajectory_generator: preference_comparisons.TrajectoryGenerator = ( - agent_trainer - ) + reward_net = reward_net.to(trajectory_generator.algorithm.device) else: if exploration_frac > 0: raise ValueError( @@ -244,6 +346,7 @@ def train_preference_comparisons( custom_logger=custom_logger, allow_variable_horizon=allow_variable_horizon, query_schedule=query_schedule, + unsupervised_agent_pretrain_frac=unsupervised_agent_pretrain_frac, ) def save_callback(iteration_num): diff --git a/src/imitation/util/networks.py b/src/imitation/util/networks.py index c27aea2cd..e9564ca44 100644 --- a/src/imitation/util/networks.py +++ b/src/imitation/util/networks.py @@ -86,6 +86,9 @@ def forward(self, x: th.Tensor) -> th.Tensor: with th.no_grad(): self.update_stats(x) + return self.normalize(x) + + def normalize(self, x: th.Tensor) -> th.Tensor: # Note: this is different from the behavior in stable-baselines, see # https://github.com/HumanCompatibleAI/imitation/issues/442 return (x - self.running_mean) / th.sqrt(self.running_var + self.eps) @@ -126,12 +129,12 @@ def update_stats(self, batch: th.Tensor) -> None: tot_count = self.count + batch_count self.running_mean += delta * batch_count / tot_count - self.running_var *= self.count - self.running_var += batch_var * batch_count - self.running_var += th.square(delta) * self.count * batch_count / tot_count - self.running_var /= tot_count + m_a = self.running_var * self.count + m_b = batch_var * batch_count + M2 = m_a + m_b + th.square(delta) * self.count * batch_count / tot_count + self.running_var = M2 / tot_count - self.count += batch_count + self.count = tot_count class EMANorm(BaseNorm): diff --git a/src/imitation/util/util.py b/src/imitation/util/util.py index bbb7b2c37..cf38cee5a 100644 --- a/src/imitation/util/util.py +++ b/src/imitation/util/util.py @@ -359,3 +359,39 @@ def get_first_iter_element(iterable: Iterable[T]) -> Tuple[T, Iterable[T]]: return_iterable = iterable return first_element, return_iterable + + +def compute_state_entropy( + obs: th.Tensor, + all_obs: th.Tensor, + k: int, +) -> th.Tensor: + """Compute the state entropy given by KNN distance. + + Args: + obs: A batch of observations. + all_obs: The tensor of all states to compare to. + k: the number of neighbors to consider + + Returns: + A tensor containing the state entropy for `obs`. + """ + assert obs.shape[1:] == all_obs.shape[1:] + batch_size = 500 + with th.no_grad(): + non_batch_dimensions = tuple(range(2, len(obs.shape) + 1)) + dists: List[th.Tensor] = [] + for idx in range(len(all_obs) // batch_size + 1): + start = idx * batch_size + end = (idx + 1) * batch_size + all_obs_batch = all_obs[start:end] + distances_tensor = th.linalg.vector_norm( + obs[:, None] - all_obs_batch[None, :], + dim=non_batch_dimensions, + ord=2, + ) + assert distances_tensor.shape == (obs.shape[0], all_obs_batch.shape[0]) + dists.append(distances_tensor) + all_dists = th.cat(dists, dim=1) + knn_dists = th.kthvalue(all_dists, k=k + 1, dim=1).values + return knn_dists diff --git a/tests/algorithms/pebble/test_entropy_reward.py b/tests/algorithms/pebble/test_entropy_reward.py new file mode 100644 index 000000000..461a7dd5a --- /dev/null +++ b/tests/algorithms/pebble/test_entropy_reward.py @@ -0,0 +1,164 @@ +"""Tests for `imitation.algorithms.entropy_reward`.""" +import pickle +from unittest.mock import Mock + +import numpy as np +import pytest +import torch as th +from gym.spaces import Box +from gym.spaces.space import Space + +from imitation.algorithms.pebble.entropy_reward import ( + EntropyRewardNet, + InsufficientObservations, + PebbleStateEntropyReward, +) +from imitation.policies.replay_buffer_wrapper import ( + ReplayBufferAwareRewardFn, + ReplayBufferView, +) +from imitation.util import util + +SPACE = Box(-1, 1, shape=(1,)) +PLACEHOLDER = np.empty(SPACE.shape) + +BUFFER_SIZE = 20 +K = 4 +BATCH_SIZE = 8 +VENVS = 2 + + +def test_pebble_entropy_reward_returns_entropy_for_pretraining(): + expected_result = th.rand(BATCH_SIZE) + observations = th.rand((BATCH_SIZE,) + SPACE.shape) + entropy_fn = Mock() + entropy_fn.return_value = expected_result + learned_fn = Mock() + + reward_fn = PebbleStateEntropyReward(entropy_fn, learned_fn) + reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER) + + np.testing.assert_allclose(reward, expected_result) + entropy_fn.assert_called_once_with( + observations, + PLACEHOLDER, + PLACEHOLDER, + PLACEHOLDER, + ) + + +def test_pebble_entropy_reward_returns_learned_rew_on_insufficient_observations(rng): + expected_result = th.rand(BATCH_SIZE) + observations = th.rand((BATCH_SIZE,) + SPACE.shape) + entropy_fn = Mock() + entropy_fn.side_effect = InsufficientObservations("test error") + learned_fn = Mock() + learned_fn.return_value = expected_result + + reward_fn = PebbleStateEntropyReward(entropy_fn, learned_fn) + reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER) + + np.testing.assert_allclose(reward, expected_result) + learned_fn.assert_called_once_with( + observations, + PLACEHOLDER, + PLACEHOLDER, + PLACEHOLDER, + ) + + +def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_training(): + expected_result = th.rand(BATCH_SIZE) + observations = th.rand((BATCH_SIZE,) + SPACE.shape) + entropy_fn = Mock() + learned_fn = Mock() + learned_fn.return_value = expected_result + + reward_fn = PebbleStateEntropyReward(entropy_fn, learned_fn) + reward_fn.unsupervised_exploration_finish() + reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER) + + np.testing.assert_allclose(reward, expected_result) + learned_fn.assert_called_once_with( + observations, + PLACEHOLDER, + PLACEHOLDER, + PLACEHOLDER, + ) + + +def test_pebble_entropy_reward_propagates_on_replay_buffer_initialized(): + replay_buffer = replay_buffer_mock(np.empty((BUFFER_SIZE, VENVS) + SPACE.shape)) + entropy_fn = Mock(spec=ReplayBufferAwareRewardFn) + learned_fn = Mock() + + reward_fn = PebbleStateEntropyReward(entropy_fn, learned_fn) + reward_fn.on_replay_buffer_initialized(replay_buffer) + + entropy_fn.on_replay_buffer_initialized.assert_called_once_with(replay_buffer) + + +def test_entropy_reward_net_returns_entropy_for_pretraining(rng): + observations = th.rand((BATCH_SIZE, *SPACE.shape)) + all_observations = rng.random((BUFFER_SIZE, VENVS) + SPACE.shape) + reward_net = EntropyRewardNet(K, SPACE, SPACE) + reward_net.on_replay_buffer_initialized(replay_buffer_mock(all_observations)) + + # Act + reward = reward_net.predict_processed( + observations, + PLACEHOLDER, + PLACEHOLDER, + PLACEHOLDER, + ) + + # Assert + expected = util.compute_state_entropy( + observations, + all_observations.reshape(-1, *SPACE.shape), + K, + ) + np.testing.assert_allclose(reward, expected, rtol=0.005, atol=0.005) + + +def test_entropy_reward_net_raises_on_insufficient_observations(rng): + observations = th.rand((BATCH_SIZE, *SPACE.shape)) + all_observations = rng.random((K - 1, 1) + SPACE.shape) + reward_net = EntropyRewardNet(K, SPACE, SPACE) + reward_net.on_replay_buffer_initialized(replay_buffer_mock(all_observations)) + + # Act + with pytest.raises(InsufficientObservations): + reward_net.predict_processed( + observations, + PLACEHOLDER, + PLACEHOLDER, + PLACEHOLDER, + ) + + +def test_entropy_reward_net_can_pickle(rng): + all_observations = np.empty((BUFFER_SIZE, VENVS, *SPACE.shape)) + replay_buffer = replay_buffer_mock(all_observations) + reward_net = EntropyRewardNet(K, SPACE, SPACE) + reward_net.on_replay_buffer_initialized(replay_buffer) + + # Act + pickled = pickle.dumps(reward_net) + reward_fn_deserialized = pickle.loads(pickled) + reward_fn_deserialized.on_replay_buffer_initialized(replay_buffer) + + # Assert + obs = th.rand(VENVS, *SPACE.shape) + expected_result = reward_net(obs, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER) + actual_result = reward_fn_deserialized(obs, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER) + np.testing.assert_allclose(actual_result, expected_result) + + +def replay_buffer_mock(all_observations: np.ndarray, obs_space: Space = SPACE) -> Mock: + buffer_view = ReplayBufferView(all_observations, lambda: slice(None)) + mock = Mock() + mock.buffer_view = buffer_view + mock.observation_space = obs_space + mock.action_space = SPACE + return mock diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 12727c1c9..c66dcc157 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -3,6 +3,7 @@ import math import re from typing import Any, Sequence +from unittest.mock import Mock import gym import numpy as np @@ -17,10 +18,17 @@ import imitation.testing.reward_nets as testing_reward_nets from imitation.algorithms import preference_comparisons +from imitation.algorithms.preference_comparisons import ( + PebbleAgentTrainer, + TrajectoryGenerator, +) from imitation.data import types from imitation.data.types import TrajectoryWithRew +from imitation.policies.replay_buffer_wrapper import ReplayBufferView from imitation.regularization import regularizers, updaters from imitation.rewards import reward_nets +from imitation.rewards.reward_function import RewardFn +from imitation.scripts.train_preference_comparisons import create_pebble_reward_fn from imitation.util import networks, util UNCERTAINTY_ON = ["logit", "probability", "label"] @@ -72,6 +80,32 @@ def agent_trainer(agent, reward_net, venv, rng): return preference_comparisons.AgentTrainer(agent, reward_net, venv, rng) +@pytest.fixture +def replay_buffer(rng): + return ReplayBufferView(rng.random((10, 8, 4)), lambda: slice(None)) + + +@pytest.fixture +def pebble_agent_trainer(agent, reward_net, venv, rng, replay_buffer): + replay_buffer_mock = Mock() + replay_buffer_mock.buffer_view = replay_buffer + replay_buffer_mock.observation_space = venv.observation_space + replay_buffer_mock.action_space = venv.action_space + reward_fn = create_pebble_reward_fn( + reward_net.predict_processed, + 5, + venv.action_space, + venv.observation_space, + ) + reward_fn.on_replay_buffer_initialized(replay_buffer_mock) + return preference_comparisons.PebbleAgentTrainer( + algorithm=agent, + reward_fn=reward_fn, + venv=venv, + rng=rng, + ) + + def assert_info_arrs_equal(arr1, arr2): # pragma: no cover def check_possibly_nested_dicts_equal(dict1, dict2): for key, val1 in dict1.items(): @@ -293,14 +327,17 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng): "schedule", ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t**3)], ) +@pytest.mark.parametrize("agent_fixture", ["agent_trainer", "pebble_agent_trainer"]) def test_trainer_no_crash( - agent_trainer, + request, + agent_fixture, reward_net, random_fragmenter, custom_logger, schedule, rng, ): + agent_trainer = request.getfixturevalue(agent_fixture) main_trainer = preference_comparisons.PreferenceComparisons( agent_trainer, reward_net, @@ -1088,3 +1125,28 @@ def test_that_trainer_improves( ) assert np.mean(trained_agent_rewards) > np.mean(novice_agent_rewards) + + +def test_trajectory_generator_raises_on_pretrain_if_not_implemented(): + class TrajectoryGeneratorTestImpl(TrajectoryGenerator): + def sample(self, steps: int) -> Sequence[TrajectoryWithRew]: + return [] + + generator = TrajectoryGeneratorTestImpl() + assert generator.has_pretraining is False + with pytest.raises(ValueError, match="should not consume any timesteps"): + generator.unsupervised_pretrain(1) + + generator.sample(1) # just to make coverage happy + + +def test_pebble_agent_trainer_expects_pebble_reward(agent, venv, rng): + reward_fn: RewardFn = lambda state, action, next, done: state + + with pytest.raises(ValueError, match="PebbleStateEntropyReward"): + PebbleAgentTrainer( + algorithm=agent, + reward_fn=reward_fn, # type: ignore[call-arg] + venv=venv, + rng=rng, + ) diff --git a/tests/policies/test_replay_buffer_wrapper.py b/tests/policies/test_replay_buffer_wrapper.py index 40fc6eac5..7b92e64ba 100644 --- a/tests/policies/test_replay_buffer_wrapper.py +++ b/tests/policies/test_replay_buffer_wrapper.py @@ -2,16 +2,23 @@ import os.path as osp from typing import Type +from unittest.mock import Mock import numpy as np import pytest import stable_baselines3 as sb3 import torch as th +from gym import spaces from stable_baselines3.common import buffers, off_policy_algorithm, policies +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape from stable_baselines3.common.save_util import load_from_pkl -from imitation.policies.replay_buffer_wrapper import ReplayBufferRewardWrapper +from imitation.policies.replay_buffer_wrapper import ( + ReplayBufferAwareRewardFn, + ReplayBufferRewardWrapper, +) from imitation.util import util @@ -112,3 +119,53 @@ def test_wrapper_class(tmpdir, rng): # raise error for _get_samples() with pytest.raises(NotImplementedError, match=r".*_get_samples.*"): replay_buffer_wrapper._get_samples() + + +def test_replay_buffer_view_provides_buffered_observations(): + space = spaces.Box(np.array([0]), np.array([5])) + n_envs = 2 + buffer_size = 10 + action = np.empty((n_envs, get_action_dim(space))) + + obs_shape = get_obs_shape(space) + wrapper = ReplayBufferRewardWrapper( + buffer_size, + space, + space, + replay_buffer_class=ReplayBuffer, + reward_fn=Mock(), + n_envs=n_envs, + handle_timeout_termination=False, + ) + view = wrapper.buffer_view + + # initially empty + assert len(view.observations) == 0 + + # after adding observation + obs1 = np.random.random((n_envs, *obs_shape)) + wrapper.add(obs1, obs1, action, np.empty(n_envs), np.empty(n_envs), []) + np.testing.assert_allclose(view.observations, np.array([obs1])) + + # after filling buffer + observations = np.random.random((buffer_size // n_envs, n_envs, *obs_shape)) + for obs in observations: + wrapper.add(obs, obs, action, np.empty(n_envs), np.empty(n_envs), []) + + # ReplayBuffer internally uses a circular buffer + expected = np.roll(observations, 1, axis=0) + np.testing.assert_allclose(view.observations, expected) + + +def test_replay_buffer_reward_wrapper_calls_reward_initialization_callback(): + reward_fn = Mock(spec=ReplayBufferAwareRewardFn) + buffer = ReplayBufferRewardWrapper( + 10, + spaces.Discrete(2), + spaces.Discrete(2), + replay_buffer_class=ReplayBuffer, + reward_fn=reward_fn, + n_envs=2, + handle_timeout_termination=False, + ) + assert reward_fn.on_replay_buffer_initialized.call_args.args[0] is buffer diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 226b6b3c2..1f8a0d23d 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -254,6 +254,20 @@ def test_train_preference_comparisons_reward_named_config(tmpdir, named_configs) assert isinstance(run.result, dict) +def test_train_preference_comparisons_pebble_config(tmpdir): + config_updates = dict(common=dict(log_root=tmpdir)) + run = train_preference_comparisons.train_preference_comparisons_ex.run( + # make sure rl.sac named_config is called after rl.fast to overwrite + # rl_kwargs.batch_size to None + named_configs=ALGO_FAST_CONFIGS["preference_comparison"] + + ["pebble", "mountain_car_continuous"], + config_updates=config_updates, + ) + assert run.config["rl"]["rl_cls"] is stable_baselines3.SAC + assert run.status == "COMPLETED" + assert isinstance(run.result, dict) + + def test_train_dagger_main(tmpdir): with pytest.warns(None) as record: run = train_imitation.train_imitation_ex.run( diff --git a/tests/scripts/test_train_preference_comparisons.py b/tests/scripts/test_train_preference_comparisons.py new file mode 100644 index 000000000..cf794fecf --- /dev/null +++ b/tests/scripts/test_train_preference_comparisons.py @@ -0,0 +1,69 @@ +"""Tests train_preferences_comparisons helper methods.""" + +from unittest.mock import Mock, patch + +import numpy as np +import torch as th +from gym import Space +from gym.spaces import Box + +from imitation.policies.replay_buffer_wrapper import ReplayBufferView +from imitation.scripts.train_preference_comparisons import create_pebble_reward_fn + +K = 4 +SPACE = Box(-1, 1, shape=(1,)) +BUFFER_SIZE = 20 +VENVS = 2 +PLACEHOLDER = np.empty(SPACE.shape) + + +def test_creates_normalized_entropy_pebble_reward(): + with patch("imitation.util.util.compute_state_entropy") as m: + # mock entropy computation so that we can test + # only stats collection in this test + m.side_effect = lambda obs, all_obs, k: obs + + reward_fn = create_pebble_reward_fn(reward_fn_stub, K, SPACE, SPACE) + + all_observations = np.empty((BUFFER_SIZE, VENVS, *SPACE.shape)) + reward_fn.on_replay_buffer_initialized(replay_buffer_mock(all_observations)) + + dim = 8 + shift = 3 + scale = 2 + + # Act + for _ in range(1000): + state = th.randn(dim) * scale + shift + reward_fn(state, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER) + + normalized_reward = reward_fn( + np.zeros(dim), + PLACEHOLDER, + PLACEHOLDER, + PLACEHOLDER, + ) + + # Assert + np.testing.assert_allclose( + normalized_reward, + np.repeat(-shift / scale, dim), + rtol=0.05, + atol=0.05, + ) + + # Just to make coverage happy: + reward_fn_stub(state, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER) + + +def reward_fn_stub(state, action, next_state, done): + return state + + +def replay_buffer_mock(all_observations: np.ndarray, obs_space: Space = SPACE) -> Mock: + buffer_view = ReplayBufferView(all_observations, lambda: slice(None)) + mock = Mock() + mock.buffer_view = buffer_view + mock.observation_space = obs_space + mock.action_space = SPACE + return mock diff --git a/tests/util/test_util.py b/tests/util/test_util.py index ce663d8e0..28678dc8b 100644 --- a/tests/util/test_util.py +++ b/tests/util/test_util.py @@ -118,3 +118,29 @@ def test_tensor_iter_norm(): assert np.allclose(norm_1, 14.0) with pytest.raises(ValueError): util.tensor_iter_norm(tensor_list, ord=0.0) + + +def test_compute_state_entropy_1d(): + all_obs = th.arange(10, dtype=th.float).unsqueeze(1) + obs = all_obs[4:6] + np.testing.assert_allclose(util.compute_state_entropy(obs, all_obs, k=1), 1) + np.testing.assert_allclose(util.compute_state_entropy(obs, all_obs, k=2), 1) + np.testing.assert_allclose(util.compute_state_entropy(obs, all_obs, k=3), 2) + np.testing.assert_allclose(util.compute_state_entropy(obs, all_obs, k=4), 2) + np.testing.assert_allclose(util.compute_state_entropy(obs, all_obs, k=5), 3) + + +def test_compute_state_entropy_2d(): + all_obs_x = th.arange(10, dtype=th.float) + all_obs_y = th.arange(0, 100, step=10, dtype=th.float) + all_obs = th.stack((all_obs_x, all_obs_y), dim=1) + + obs = all_obs[4:6] + np.testing.assert_allclose( + util.compute_state_entropy(obs, all_obs, k=1), + np.sqrt(10**2 + 1**2), + ) + np.testing.assert_allclose( + util.compute_state_entropy(obs, all_obs, k=3), + np.sqrt(20**2 + 2**2), + )