Skip to content

Commit 96517b1

Browse files
authored
[Fix] refactor evaluator, env, and extensions (#121)
* update evaluator and remove vlnpe part eval and env * add env * update env utils; resolve import evaluator issue in server * update env utils; resolve import evaluator issue in server * isolate dependency in env * fix default; fix test
1 parent 865fd98 commit 96517b1

File tree

76 files changed

+1191
-1000
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+1191
-1000
lines changed

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ data
132132
kujiale_data
133133
interiornav_data
134134
images/
135-
internutopia
136-
internutopia_extension
135+
# internutopia
136+
# internutopia_extension
137137
*.pyc
138138

139139
pre-commit*

internnav/agent/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def register(cls, agent_type: str):
2222
"""
2323

2424
def decorator(agent_class):
25+
if agent_type in cls.agents:
26+
raise ValueError(f"Agent {agent_type} already registered.")
2527
cls.agents[agent_type] = agent_class
2628

2729
return decorator

internnav/agent/cma_agent.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
from gym import spaces
66

77
from internnav.agent.base import Agent
8+
from internnav.agent.utils.common import batch_obs, set_seed_model
89
from internnav.configs.agent import AgentCfg
910
from internnav.configs.model.base_encoders import ModelCfg
10-
from internnav.evaluator.utils.common import set_seed_model
11-
from internnav.evaluator.utils.models import batch_obs
1211
from internnav.model import get_config, get_policy
1312

1413

@@ -22,6 +21,7 @@ class CmaAgent(Agent):
2221
)
2322

2423
def __init__(self, agent_config: AgentCfg):
24+
2525
super().__init__(agent_config)
2626
self._model_settings = ModelCfg(**agent_config.model_settings)
2727
model_settings = self._model_settings
@@ -119,13 +119,20 @@ def inference(self, obs):
119119
dtype=torch.bool,
120120
)
121121
end = time.time()
122-
print(f'CmaAgent step time: {round(end-start,4)}s')
122+
print(f'CmaAgent step time: {round(end-start, 4)}s')
123123
return actions.cpu().numpy().tolist()
124124

125125
def step(self, obs):
126126
print('CmaPolicyAgent step')
127127
start = time.time()
128128
action = self.inference(obs)
129129
end = time.time()
130-
print(f'Time: {round(end-start,4)}s')
131-
return action
130+
print(f'Time: {round(end-start, 4)}s')
131+
132+
# convert from [[x],[y]] to [{'action': [x],'ideal_flag':True}, {'action': [y],'ideal_flag':True}]
133+
actions = []
134+
for a in action:
135+
if not isinstance(a, list):
136+
a = [a]
137+
actions.append({'action': a, 'ideal_flag': True})
138+
return actions

internnav/agent/rdp_agent.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
import random
21
import time
32

43
import numpy as np
54
import torch
65
from gym import spaces
76

87
from internnav.agent.base import Agent
8+
from internnav.agent.utils.common import batch_obs, set_seed_model
99
from internnav.configs.agent import AgentCfg
1010
from internnav.configs.model.base_encoders import ModelCfg
11-
from internnav.evaluator.utils.models import batch_obs
1211
from internnav.model import get_config, get_policy
1312
from internnav.model.basemodel.LongCLIP.model import longclip
1413
from internnav.model.basemodel.rdp.utils import (
@@ -25,17 +24,9 @@
2524
extract_image_features,
2625
extract_instruction_tokens,
2726
)
28-
from internnav.utils import common_log_util
2927
from internnav.utils.common_log_util import common_logger as log
3028

3129

32-
def set_random_seed(seed):
33-
random.seed(seed)
34-
np.random.seed(seed)
35-
torch.manual_seed(seed)
36-
torch.cuda.manual_seed_all(seed)
37-
38-
3930
@Agent.register('rdp')
4031
class RdpAgent(Agent):
4132
observation_space = spaces.Box(
@@ -47,7 +38,7 @@ class RdpAgent(Agent):
4738

4839
def __init__(self, config: AgentCfg):
4940
super().__init__(config)
50-
set_random_seed(0)
41+
set_seed_model(0)
5142
self._model_settings = self.config.model_settings
5243
self._model_settings = ModelCfg(**self._model_settings)
5344
env_num = getattr(self._model_settings, 'env_num', 1)
@@ -348,5 +339,12 @@ def step(self, obs):
348339
start = time.time()
349340
action = self.inference(obs)
350341
end = time.time()
351-
print(f'总时间: {round(end-start,4)}s')
352-
return action
342+
print(f'总时间: {round(end-start, 4)}s')
343+
344+
# convert from [[a1],[a2]] to [{'action': [a1],'ideal_flag':True}, {'action': [a2],'ideal_flag':True}]
345+
actions = []
346+
for a in action:
347+
if not isinstance(a, list):
348+
a = [a]
349+
actions.append({'action': a, 'ideal_flag': True})
350+
return actions
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,25 @@
11
from collections import defaultdict
22
from typing import DefaultDict, List, Optional
33

4+
import numpy as np
45
import torch
56

67
from .tensor_dict import TensorDict
78

89

10+
def set_seed_model(seed):
11+
import random
12+
13+
import torch
14+
15+
random.seed(seed)
16+
np.random.seed(seed)
17+
torch.manual_seed(seed)
18+
torch.cuda.manual_seed(seed)
19+
torch.backends.cudnn.benchmark = False
20+
torch.backends.cudnn.deterministic = False
21+
22+
923
def batch_obs(
1024
observations,
1125
device: Optional[torch.device] = None,

internnav/configs/evaluator/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,12 @@ class EvalDatasetCfg(BaseModel):
5656

5757

5858
class EvalCfg(BaseModel):
59+
eval_type: Optional[str] = None
60+
eval_settings: Optional[Dict[str, Any]] = {}
5961
agent: Optional[AgentCfg] = None
6062
env: EnvCfg
6163
task: TaskCfg
6264
dataset: EvalDatasetCfg
63-
eval_settings: Optional[Dict[str, Any]] = {}
6465

6566

6667
__all__ = [

internnav/dist.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

internnav/env/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from internnav.env.base import Env
2-
from internnav.env.vln_pe_env import VlnPeEnv
3-
from internnav.env.vln_multi_env import VlnMultiEnv
2+
from internnav.env.internutopia_env import InternutopiaEnv
43

5-
__all__ = ['Env', 'VlnPeEnv', 'VlnMultiEnv']
4+
__all__ = ['Env', 'InternutopiaEnv']

internnav/env/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def register(cls, env_type: str):
3939
"""
4040

4141
def decorator(env_class):
42+
if env_type in cls.envs:
43+
raise ValueError(f"Env {env_type} already registered.")
4244
cls.envs[env_type] = env_class
4345

4446
return decorator

0 commit comments

Comments
 (0)