1313from internnav .evaluator .utils .data_collector import DataCollector
1414from internnav .evaluator .utils .dataset import ResultLogger , split_data
1515from internnav .evaluator .utils .eval import generate_episode
16+ from internnav .evaluator .utils .visualize_util import VisualizeUtil
1617from internnav .projects .dataloader .resumable import ResumablePathKeyDataloader
1718from internnav .utils import common_log_util , progress_log_multi_util
1819from internnav .utils .common_log_util import common_logger as log
@@ -44,7 +45,7 @@ def __init__(self, config: EvalCfg):
4445 # generate episode
4546 episodes = generate_episode (self .dataloader , config )
4647 if len (episodes ) == 0 :
47- log .info ("No more episodes to evaluate" )
48+ log .info ("No more episodes to evaluate. Episodes are saved in data/sample_episodes/ " )
4849 sys .exit (0 )
4950 config .task .task_settings .update ({'episodes' : episodes })
5051 self .env_num = config .task .task_settings ['env_num' ]
@@ -72,6 +73,9 @@ def __init__(self, config: EvalCfg):
7273 set_seed_model (0 )
7374 self .data_collector = DataCollector (self .dataloader .lmdb_path )
7475 self .robot_flash = config .task .robot_flash
76+ self .save_to_json = config .eval_settings ['save_to_json' ]
77+ self .vis_output = config .eval_settings ['vis_output' ]
78+ self .visualize_util = VisualizeUtil (self .task_name , fps = 6 )
7579
7680 @property
7781 def ignore_obs_attr (self ):
@@ -202,9 +206,7 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
202206 if terminated and self .runner_status [env_id ] != runner_status_code .TERMINATED :
203207 obs = obs_ls [env_id ]
204208 reset_info = reset_infos [env_id ]
205- if not __debug__ :
206- pass
207- log .info (json .dumps (obs ['metrics' ]))
209+ log .info (f"{ self .now_path_key (reset_info )} : { json .dumps (obs ['metrics' ], indent = 4 )} " )
208210 self .data_collector .save_eval_result (
209211 key = self .now_path_key (reset_info ),
210212 result = obs ['metrics' ][list (obs ['metrics' ].keys ())[0 ]][0 ]['fail_reason' ],
@@ -216,19 +218,29 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
216218 step_count = obs ['metrics' ][list (obs ['metrics' ].keys ())[0 ]][0 ]['steps' ],
217219 result = obs ['metrics' ][list (obs ['metrics' ].keys ())[0 ]][0 ]['fail_reason' ],
218220 )
221+ # visualize
222+ if self .vis_output :
223+ self .visualize_util .trace_end (
224+ trajectory_id = self .now_path_key (reset_info ),
225+ result = obs ['metrics' ][list (obs ['metrics' ].keys ())[0 ]][0 ]['fail_reason' ],
226+ )
227+ # json format result
228+ if self .save_to_json :
229+ self .result_logger .write_now_result_json ()
219230 self .result_logger .write_now_result ()
220231 self .runner_status [env_id ] = runner_status_code .NOT_RESET
221232 log .info (f'env{ env_id } : states switch to NOT_RESET.' )
222- reset_env_ids = np .where (self .runner_status == runner_status_code .NOT_RESET )[ # need this status to reset
223- 0
224- ].tolist ()
233+ # need this status to reset
234+ reset_env_ids = np .where (self .runner_status == runner_status_code .NOT_RESET )[0 ].tolist ()
225235 if len (reset_env_ids ) > 0 :
226236 log .info (f'env{ reset_env_ids } : start new episode!' )
227237 obs , new_reset_infos = self .env .reset (reset_env_ids )
228238 self .runner_status [reset_env_ids ] = runner_status_code .WARM_UP
229239 log .info (f'env{ reset_env_ids } : states switch to WARM UP.' )
240+
230241 # modify original reset_info
231242 reset_infos = np .array (reset_infos )
243+ # If there is only one reset and no new_deset_infos, return an empty array
232244 reset_infos [reset_env_ids ] = new_reset_infos if len (new_reset_infos ) > 0 else None
233245 self .runner_status [
234246 np .vectorize (lambda x : x )(reset_infos ) == None # noqa: E711
@@ -242,9 +254,15 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
242254 for reset_info in new_reset_infos :
243255 if reset_info is None :
244256 continue
257+ # start new trace log
245258 progress_log_multi_util .trace_start (
246259 trajectory_id = self .now_path_key (reset_info ),
247260 )
261+ # start new visualize log
262+ if self .vis_output :
263+ self .visualize_util .trace_start (
264+ trajectory_id = self .now_path_key (reset_info ), reference_path = reset_info .data ['reference_path' ]
265+ )
248266 return False , reset_infos
249267
250268 def eval (self ):
@@ -257,6 +275,10 @@ def eval(self):
257275 progress_log_multi_util .trace_start (
258276 trajectory_id = self .now_path_key (info ),
259277 )
278+ if self .vis_output :
279+ self .visualize_util .trace_start (
280+ trajectory_id = self .now_path_key (info ), reference_path = info .data ['reference_path' ]
281+ )
260282 log .info ('start new episode!' )
261283
262284 obs = self .warm_up ()
@@ -277,6 +299,16 @@ def eval(self):
277299 env_term , reset_info = self .terminate_ops (obs , reset_info , terminated )
278300 if env_term :
279301 break
302+
303+ # save step obs
304+ if self .vis_output :
305+ for ob , info , act in zip (obs , reset_info , action ):
306+ if info is None or 'rgb' not in ob or ob ['fail_reason' ]:
307+ continue
308+ self .visualize_util .save_observation (
309+ trajectory_id = self .now_path_key (info ), obs = ob , action = act [self .robot_name ]
310+ )
311+
280312 self .env .close ()
281313 progress_log_multi_util .report ()
282314
0 commit comments