Skip to content

Commit a0f3c82

Browse files
authored
[FIX] solve import issue, isolate dependency and requirements (#135)
* update import issue; isolate env and evaluator; isolate agent and model * fix habitat import; seperate requirements * fix req * fix req; add simple agent * add depth image * fix requires * model READY, test agent runable * update requires * key fix to resolve parse issue, remove run parse when import server * update habitat scripts * requires ready
1 parent efa856b commit a0f3c82

File tree

34 files changed

+361
-138
lines changed

34 files changed

+361
-138
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,5 @@ logs/
150150
/results/
151151
checkpoints
152152
internnav/model/basemodel/LongCLIP/
153+
.gradio/
154+
result/

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ repos:
3737
- id: trailing-whitespace
3838
- id: check-yaml
3939
- id: end-of-file-fixer
40-
- id: requirements-txt-fixer
40+
# - id: requirements-txt-fixer
4141
- id: check-merge-conflict
4242
- id: fix-encoding-pragma
4343
args: ["--remove"]

internnav/agent/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ def __init__(self, config: AgentCfg):
1010
self.config = config
1111

1212
def step(self, obs: Dict[str, Any]):
13-
pass
13+
raise NotImplementedError("This function is not implemented yet.")
1414

1515
def reset(self):
16-
pass
16+
raise NotImplementedError("This function is not implemented yet.")
1717

1818
@classmethod
1919
def register(cls, agent_type: str):

internnav/agent/rdp_agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from internnav.configs.agent import AgentCfg
1010
from internnav.configs.model.base_encoders import ModelCfg
1111
from internnav.model import get_config, get_policy
12-
from internnav.model.basemodel.LongCLIP.model import longclip
1312
from internnav.model.basemodel.rdp.utils import (
1413
FixedLengthStack,
1514
compute_actions,
@@ -19,7 +18,6 @@
1918
quat_to_euler_angles,
2019
to_local_coords_batch,
2120
)
22-
from internnav.model.utils.bert_token import BertTokenizer
2321
from internnav.model.utils.feature_extract import (
2422
extract_image_features,
2523
extract_instruction_tokens,
@@ -67,13 +65,17 @@ def __init__(self, config: AgentCfg):
6765

6866
if self.use_clip_encoders:
6967
if self._model_settings.text_encoder.type == 'roberta':
68+
from internnav.model.utils.bert_token import BertTokenizer
69+
7070
self.bert_tokenizer = BertTokenizer(
7171
max_length=self._model_settings.instruction_encoder.max_length,
7272
load_model=self._model_settings.instruction_encoder.load_model,
7373
device=self.device,
7474
)
7575
self.use_bert = True
7676
elif self._model_settings.text_encoder.type == 'clip-long':
77+
from internnav.model.basemodel.LongCLIP.model import longclip
78+
7779
self.bert_tokenizer = longclip.tokenize
7880
self.use_bert = True
7981
self.is_clip_long = True

internnav/agent/simple_agent.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import time
2+
from typing import Any, Dict
3+
4+
import torch
5+
6+
from internnav.agent import Agent
7+
from internnav.configs.agent import AgentCfg
8+
from internnav.model import get_config, get_policy
9+
10+
11+
class SimpleAgent(Agent):
12+
"""
13+
agent template, override the functions for custom policy
14+
"""
15+
16+
def __init__(self, agent_config: AgentCfg):
17+
self.agent_config = agent_config
18+
self.device = torch.device('cuda', 0)
19+
20+
# get policy by name
21+
policy = get_policy(agent_config.model_settings.policy_name)
22+
23+
# load policy checkpoints
24+
self.policy = policy.from_pretrained(
25+
agent_config.ckpt_path,
26+
config=get_config(agent_config.model_settings.policy_name)(
27+
model_cfg={'model': agent_config.model_settings.model_dump()}
28+
),
29+
).to(self.device)
30+
31+
def convert_input(self, obs):
32+
return obs
33+
34+
def convert_output(self, action):
35+
return action
36+
37+
def inference(self, input):
38+
return self.policy(input)
39+
40+
def step(self, obs: Dict[str, Any]):
41+
print(f'{self.config.model_name} Agent step')
42+
start = time.time()
43+
44+
# convert obs to model input
45+
obs = self.convert_input(obs)
46+
action = self.inference(obs)
47+
action = self.convert_output(action)
48+
49+
end = time.time()
50+
print(f'time: {round(end-start, 4)}s')
51+
return action
52+
53+
def reset(self):
54+
pass
File renamed without changes.

internnav/configs/trainer/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .vln_camera import VLNCameraCfg
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .vln_eval_task import VLNEvalTaskCfg

internnav/evaluator/utils/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from collections import defaultdict
77

88
import numpy as np
9-
from internutopia.core.util import is_in_container
109
from PIL import Image, ImageDraw
1110
from scipy.ndimage import binary_dilation
1211

@@ -243,6 +242,8 @@ def load_data(dataset_root_dir, split, filter_same_trajectory=True, filter_stair
243242

244243
def load_scene_usd(mp3d_data_dir, scan):
245244
"""Load scene USD based on the scan"""
245+
from internutopia.core.util import is_in_container
246+
246247
find_flag = False
247248
for root, dirs, files in os.walk(os.path.join(mp3d_data_dir, scan)):
248249
target_file_name = 'fixed_docker.usd' if is_in_container() else 'fixed.usd'

0 commit comments

Comments
 (0)