diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index ba1f550df..278530ac3 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -101,7 +101,7 @@ class FeedForward32Policy(policies.ActorCriticPolicy): def __init__(self, *args, **kwargs): """Builds FeedForward32Policy; arguments passed to `ActorCriticPolicy`.""" - super().__init__(*args, **kwargs, net_arch=[32, 32]) + super().__init__(*args, **kwargs) class SAC1024Policy(sac_policies.SACPolicy): @@ -117,7 +117,7 @@ class SAC1024Policy(sac_policies.SACPolicy): def __init__(self, *args, **kwargs): """Builds SAC1024Policy; arguments passed to `SACPolicy`.""" - super().__init__(*args, **kwargs, net_arch=[1024, 1024]) + super().__init__(*args, **kwargs) class NormalizeFeaturesExtractor(torch_layers.FlattenExtractor):