Skip to content

Commit e0d265c

Browse files
authored
[Fix] Restore visualization feature and refactor RDP utilities into geometry utils (#145)
* add back visualize feature; refactor rdp utils to geometry utils * remove json file
1 parent c48cdc2 commit e0d265c

File tree

7 files changed

+76
-66
lines changed

7 files changed

+76
-66
lines changed

internnav/agent/rdp_agent.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
from internnav.configs.agent import AgentCfg
1010
from internnav.configs.model.base_encoders import ModelCfg
1111
from internnav.model import get_config, get_policy
12-
from internnav.model.basemodel.rdp.utils import (
12+
from internnav.model.utils.feature_extract import (
13+
extract_image_features,
14+
extract_instruction_tokens,
15+
)
16+
from internnav.utils.common_log_util import common_logger as log
17+
from internnav.utils.geometry_utils import (
1318
FixedLengthStack,
1419
compute_actions,
1520
get_delta,
@@ -18,11 +23,6 @@
1823
quat_to_euler_angles,
1924
to_local_coords_batch,
2025
)
21-
from internnav.model.utils.feature_extract import (
22-
extract_image_features,
23-
extract_instruction_tokens,
24-
)
25-
from internnav.utils.common_log_util import common_logger as log
2626

2727

2828
@Agent.register('rdp')

internnav/dataset/rdp_lerobot_dataset.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1+
import copy
12
import random
23
from collections import defaultdict
34

4-
import lmdb
5-
import msgpack_numpy
65
import numpy as np
76
import torch
8-
import copy
97
from PIL import Image
108
from torchvision.transforms import (
119
CenterCrop,
@@ -26,8 +24,8 @@
2624
from internnav.dataset.base import BaseDataset, ObservationsDict, _block_shuffle
2725
from internnav.evaluator.utils.common import norm_depth
2826
from internnav.model.basemodel.LongCLIP.model import longclip
29-
from internnav.model.basemodel.rdp.utils import get_delta, normalize_data, to_local_coords
3027
from internnav.model.utils.feature_extract import extract_instruction_tokens
28+
from internnav.utils.geometry_utils import get_delta, normalize_data, to_local_coords
3129
from internnav.utils.lerobot_as_lmdb import LerobotAsLmdb
3230

3331

@@ -174,14 +172,12 @@ def _load_next(self): # noqa: C901
174172
data['camera_info'][self.camera_name]['rgb'] = data['camera_info'][self.camera_name]['rgb'][
175173
:-drop_last_frame_nums
176174
]
177-
data['camera_info'][self.camera_name]['depth'] = data['camera_info'][self.camera_name][
178-
'depth'
179-
][:-drop_last_frame_nums]
180-
data['robot_info']['yaw'] = data['robot_info']['yaw'][:-drop_last_frame_nums]
181-
data['robot_info']['position'] = data['robot_info']['position'][:-drop_last_frame_nums]
182-
data['robot_info']['orientation'] = data['robot_info']['orientation'][
175+
data['camera_info'][self.camera_name]['depth'] = data['camera_info'][self.camera_name]['depth'][
183176
:-drop_last_frame_nums
184177
]
178+
data['robot_info']['yaw'] = data['robot_info']['yaw'][:-drop_last_frame_nums]
179+
data['robot_info']['position'] = data['robot_info']['position'][:-drop_last_frame_nums]
180+
data['robot_info']['orientation'] = data['robot_info']['orientation'][:-drop_last_frame_nums]
185181
data['progress'] = data['progress'][:-drop_last_frame_nums]
186182
data['step'] = data['step'][:-drop_last_frame_nums]
187183

@@ -192,7 +188,7 @@ def _load_next(self): # noqa: C901
192188
if yaw > np.pi:
193189
yaw -= 2 * np.pi
194190
yaws[yaw_i] = yaw
195-
191+
196192
episodes_in_json = data_to_load['episodes_in_json']
197193

198194
instructions = [
@@ -221,7 +217,6 @@ def _load_next(self): # noqa: C901
221217
new_preload, self.bert_tokenizer, is_clip_long=self.is_clip_long
222218
)
223219

224-
225220
# process the instruction
226221
# copy the instruction to each step
227222
if self.need_extract_instr_features:
@@ -447,12 +442,7 @@ def _pad_helper(t, max_len, fill_val=0, return_masks=False):
447442
observations_batch = ObservationsDict(observations_batch)
448443
# Expand B to match the flattened batch size
449444
B_expanded = B.repeat(observations_batch['prev_actions'].shape[0]).view(-1, 1)
450-
451-
return (
452-
observations_batch,
453-
observations_batch['prev_actions'],
454-
not_done_masks_batch.view(-1, 1),
455-
B_expanded
456-
)
457-
445+
446+
return (observations_batch, observations_batch['prev_actions'], not_done_masks_batch.view(-1, 1), B_expanded)
447+
458448
return _rdp_collate_fn

internnav/dataset/rdp_lmdb_dataset.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import copy
12
import random
23
from collections import defaultdict
34

45
import lmdb
56
import msgpack_numpy
67
import numpy as np
78
import torch
8-
import copy
99
from PIL import Image
1010
from torchvision.transforms import (
1111
CenterCrop,
@@ -26,8 +26,8 @@
2626
from internnav.dataset.base import BaseDataset, ObservationsDict, _block_shuffle
2727
from internnav.evaluator.utils.common import norm_depth
2828
from internnav.model.basemodel.LongCLIP.model import longclip
29-
from internnav.model.basemodel.rdp.utils import get_delta, normalize_data, to_local_coords
3029
from internnav.model.utils.feature_extract import extract_instruction_tokens
30+
from internnav.utils.geometry_utils import get_delta, normalize_data, to_local_coords
3131

3232

3333
def _convert_image_to_rgb(image):
@@ -466,12 +466,7 @@ def _pad_helper(t, max_len, fill_val=0, return_masks=False):
466466
observations_batch = ObservationsDict(observations_batch)
467467
# Expand B to match the flattened batch size
468468
B_expanded = B.repeat(observations_batch['prev_actions'].shape[0]).view(-1, 1)
469-
470-
return (
471-
observations_batch,
472-
observations_batch['prev_actions'],
473-
not_done_masks_batch.view(-1, 1),
474-
B_expanded
475-
)
476-
469+
470+
return (observations_batch, observations_batch['prev_actions'], not_done_masks_batch.view(-1, 1), B_expanded)
471+
477472
return _rdp_collate_fn

internnav/evaluator/utils/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from scipy.ndimage import binary_dilation
1111

1212
from internnav.utils.common_log_util import common_logger as log
13+
from internnav.utils.geometry_utils import quat_to_euler_angles
1314

1415

1516
def create_robot_mask(topdown_global_map_camera, mask_size=20):
@@ -343,7 +344,6 @@ def draw_trajectory(array, obs_lst, reference_path):
343344
import matplotlib.pyplot as plt
344345
import numpy as np
345346
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
346-
from omni.isaac.core.utils.rotations import quat_to_euler_angles
347347

348348
from internnav.evaluator.utils.path_plan import world_to_pixel
349349

internnav/utils/visualize_util.py renamed to internnav/evaluator/utils/visualize_util.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
1+
import logging
12
import os
23
import time
3-
import logging
44
from dataclasses import dataclass
5-
from typing import Callable, Dict, Optional, Union, Any, List
6-
7-
import numpy as np
5+
from typing import Dict, List, Optional
86

97
from internnav import PROJECT_ROOT_PATH
10-
from .common_log_util import get_task_name
11-
from internnav.evaluator.utils.common import obs_to_image, images_to_video
8+
from internnav.evaluator.utils.common import images_to_video, obs_to_image
129

1310
try:
14-
from PIL import Image
1511
_PIL_AVAILABLE = True
1612
except Exception:
1713
_PIL_AVAILABLE = False
1814

1915
try:
20-
import imageio.v2 as imageio
2116
_IMAGEIO_AVAILABLE = True
2217
except Exception:
2318
_IMAGEIO_AVAILABLE = False
@@ -49,15 +44,16 @@ class VisualizeUtil:
4944
save_frame_fn(image, out_path) and save_video_fn(frames_dir, out_path, fps)
5045
Otherwise, built-ins (PIL + imageio) are used.
5146
"""
47+
5248
def __init__(
5349
self,
5450
dataset_name: str,
5551
fps: int = 10,
5652
img_ext: str = "png",
5753
video_ext: str = "mp4",
5854
root_subdir: str = "video",
59-
save_frame_fn = obs_to_image,
60-
save_video_fn = images_to_video,
55+
save_frame_fn=obs_to_image,
56+
save_video_fn=images_to_video,
6157
):
6258
self.dataset_name = dataset_name
6359
self.fps = fps
@@ -73,14 +69,16 @@ def __init__(
7369
file_handler.setLevel(logging.INFO)
7470
file_handler.setFormatter(logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s"))
7571
# Avoid adding duplicate handlers in repeated inits
76-
if not any(isinstance(h, logging.FileHandler) and h.baseFilename == file_handler.baseFilename
77-
for h in viz_logger.handlers):
72+
if not any(
73+
isinstance(h, logging.FileHandler) and h.baseFilename == file_handler.baseFilename
74+
for h in viz_logger.handlers
75+
):
7876
viz_logger.addHandler(file_handler)
7977

8078
self.base_dir = base_dir
8179

8280
# Pluggable savers
83-
self._save_frame_fn = save_frame_fn
81+
self._save_frame_fn = save_frame_fn
8482
self._save_video_fn = save_video_fn
8583

8684
# Metrics
@@ -103,7 +101,7 @@ def trace_start(self, trajectory_id: str, reference_path):
103101
fps=self.fps,
104102
start_time=time.time(),
105103
saved_frames=[],
106-
reference_path=reference_path
104+
reference_path=reference_path,
107105
)
108106
viz_logger.info(f"[start] trajectory_id={trajectory_id}")
109107

@@ -123,7 +121,7 @@ def save_observation(
123121

124122
if step_index is None:
125123
step_index = ti.frame_count
126-
124+
127125
ti.frame_count += 1
128126
if ti.saved_frames is not None:
129127
ti.saved_frames.append(obs)
@@ -133,7 +131,6 @@ def save_observation(
133131
out_path = os.path.join(ti.frames_dir, fname)
134132
self._save_frame_fn(ti.saved_frames, action, out_path, ti.reference_path)
135133

136-
137134
def trace_end(self, trajectory_id: str, result: Optional[str] = None, assemble_video: bool = True):
138135
"""
139136
Mark trajectory finished and (optionally) assemble video.
@@ -153,7 +150,7 @@ def trace_end(self, trajectory_id: str, result: Optional[str] = None, assemble_v
153150
if assemble_video:
154151
self._save_video_fn(ti.frames_dir, ti.video_path, ti.fps)
155152
viz_logger.info(f"[video] saved {ti.video_path}")
156-
153+
157154
self._del_traj(trajectory_id)
158155

159156
def report(self):
@@ -179,15 +176,12 @@ def report(self):
179176
f"[duration:{duration}s] [frames:{total_frames}] [avg_fps:{fps}] results:{result_map}"
180177
)
181178

182-
183179
def _require_traj(self, trajectory_id: str) -> TrajectoryVizInfo:
184180
if trajectory_id not in self.trajectories:
185181
raise KeyError(f"trajectory_id not started: {trajectory_id}")
186182
return self.trajectories[trajectory_id]
187-
183+
188184
def _del_traj(self, trajectory_id: str) -> None:
189185
if trajectory_id not in self.trajectories:
190186
raise KeyError(f"trajectory_id not started: {trajectory_id}")
191187
del self.trajectories[trajectory_id]
192-
193-

internnav/evaluator/vln_multi_evaluator.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from internnav.evaluator.utils.data_collector import DataCollector
1414
from internnav.evaluator.utils.dataset import ResultLogger, split_data
1515
from internnav.evaluator.utils.eval import generate_episode
16+
from internnav.evaluator.utils.visualize_util import VisualizeUtil
1617
from internnav.projects.dataloader.resumable import ResumablePathKeyDataloader
1718
from internnav.utils import common_log_util, progress_log_multi_util
1819
from internnav.utils.common_log_util import common_logger as log
@@ -44,7 +45,7 @@ def __init__(self, config: EvalCfg):
4445
# generate episode
4546
episodes = generate_episode(self.dataloader, config)
4647
if len(episodes) == 0:
47-
log.info("No more episodes to evaluate")
48+
log.info("No more episodes to evaluate. Episodes are saved in data/sample_episodes/")
4849
sys.exit(0)
4950
config.task.task_settings.update({'episodes': episodes})
5051
self.env_num = config.task.task_settings['env_num']
@@ -72,6 +73,9 @@ def __init__(self, config: EvalCfg):
7273
set_seed_model(0)
7374
self.data_collector = DataCollector(self.dataloader.lmdb_path)
7475
self.robot_flash = config.task.robot_flash
76+
self.save_to_json = config.eval_settings['save_to_json']
77+
self.vis_output = config.eval_settings['vis_output']
78+
self.visualize_util = VisualizeUtil(self.task_name, fps=6)
7579

7680
@property
7781
def ignore_obs_attr(self):
@@ -202,9 +206,7 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
202206
if terminated and self.runner_status[env_id] != runner_status_code.TERMINATED:
203207
obs = obs_ls[env_id]
204208
reset_info = reset_infos[env_id]
205-
if not __debug__:
206-
pass
207-
log.info(json.dumps(obs['metrics']))
209+
log.info(f"{self.now_path_key(reset_info)}: {json.dumps(obs['metrics'], indent=4)}")
208210
self.data_collector.save_eval_result(
209211
key=self.now_path_key(reset_info),
210212
result=obs['metrics'][list(obs['metrics'].keys())[0]][0]['fail_reason'],
@@ -216,19 +218,29 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
216218
step_count=obs['metrics'][list(obs['metrics'].keys())[0]][0]['steps'],
217219
result=obs['metrics'][list(obs['metrics'].keys())[0]][0]['fail_reason'],
218220
)
221+
# visualize
222+
if self.vis_output:
223+
self.visualize_util.trace_end(
224+
trajectory_id=self.now_path_key(reset_info),
225+
result=obs['metrics'][list(obs['metrics'].keys())[0]][0]['fail_reason'],
226+
)
227+
# json format result
228+
if self.save_to_json:
229+
self.result_logger.write_now_result_json()
219230
self.result_logger.write_now_result()
220231
self.runner_status[env_id] = runner_status_code.NOT_RESET
221232
log.info(f'env{env_id}: states switch to NOT_RESET.')
222-
reset_env_ids = np.where(self.runner_status == runner_status_code.NOT_RESET)[ # need this status to reset
223-
0
224-
].tolist()
233+
# need this status to reset
234+
reset_env_ids = np.where(self.runner_status == runner_status_code.NOT_RESET)[0].tolist()
225235
if len(reset_env_ids) > 0:
226236
log.info(f'env{reset_env_ids}: start new episode!')
227237
obs, new_reset_infos = self.env.reset(reset_env_ids)
228238
self.runner_status[reset_env_ids] = runner_status_code.WARM_UP
229239
log.info(f'env{reset_env_ids}: states switch to WARM UP.')
240+
230241
# modify original reset_info
231242
reset_infos = np.array(reset_infos)
243+
# If there is only one reset and no new_deset_infos, return an empty array
232244
reset_infos[reset_env_ids] = new_reset_infos if len(new_reset_infos) > 0 else None
233245
self.runner_status[
234246
np.vectorize(lambda x: x)(reset_infos) == None # noqa: E711
@@ -242,9 +254,15 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
242254
for reset_info in new_reset_infos:
243255
if reset_info is None:
244256
continue
257+
# start new trace log
245258
progress_log_multi_util.trace_start(
246259
trajectory_id=self.now_path_key(reset_info),
247260
)
261+
# start new visualize log
262+
if self.vis_output:
263+
self.visualize_util.trace_start(
264+
trajectory_id=self.now_path_key(reset_info), reference_path=reset_info.data['reference_path']
265+
)
248266
return False, reset_infos
249267

250268
def eval(self):
@@ -257,6 +275,10 @@ def eval(self):
257275
progress_log_multi_util.trace_start(
258276
trajectory_id=self.now_path_key(info),
259277
)
278+
if self.vis_output:
279+
self.visualize_util.trace_start(
280+
trajectory_id=self.now_path_key(info), reference_path=info.data['reference_path']
281+
)
260282
log.info('start new episode!')
261283

262284
obs = self.warm_up()
@@ -277,6 +299,16 @@ def eval(self):
277299
env_term, reset_info = self.terminate_ops(obs, reset_info, terminated)
278300
if env_term:
279301
break
302+
303+
# save step obs
304+
if self.vis_output:
305+
for ob, info, act in zip(obs, reset_info, action):
306+
if info is None or 'rgb' not in ob or ob['fail_reason']:
307+
continue
308+
self.visualize_util.save_observation(
309+
trajectory_id=self.now_path_key(info), obs=ob, action=act[self.robot_name]
310+
)
311+
280312
self.env.close()
281313
progress_log_multi_util.report()
282314

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import base64
22
import math
3-
import os
43
import pickle
54

65
import numpy as np

0 commit comments

Comments
 (0)