@@ -165,11 +165,13 @@ async def websocket_reflection_backtracking(self, path, selected_node, websocket
165165 if websocket :
166166 await websocket .send_json ({
167167 "type" : "reflection_backtracking" ,
168- "path" : [node .action for node in path if node .action is not None ],
168+ "path" : [{
169+ "natural_language_description" : node .natural_language_description ,
170+ "action" : node .action } for node in path if node .action is not None ],
169171 "node_id" : id (selected_node ),
170- "node_parent_id " : id (selected_node .parent ),
171- "node_action " : selected_node .action ,
172- "node_description " : selected_node .natural_language_description ,
172+ "parent_id " : id (selected_node .parent ),
173+ "action " : selected_node .action ,
174+ "description " : selected_node .natural_language_description ,
173175 "trajectory" : selected_node .get_trajectory ()
174176 })
175177
@@ -304,80 +306,81 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
304306
305307 # Step 3: simulation using the current node, (generate a path using the current node, and score the path)
306308 # TODO: implement simulation using openai
307- print (f"{ GREEN } Step 3: Simulation{ RESET } " )
308- await self .websocket_step_start (step = 3 , step_name = "simulation" , websocket = websocket )
309- path = self .get_path_to_root (selected_node )
310- # here score is the reward
311- score = await self .evaluate_selected_path (path )
312- # change to reward later?
313- if score > best_score :
314- best_score = score
315- best_path = path
316- best_node = selected_node
317- print (f"\n New best path found!" )
318- print (f"best score: { best_score :.3f} " )
319- print (f"best node: { best_node .action } " )
320- print (f"best node: { best_node .natural_language_description } " )
321- print (f"best path: { best_path } " )
322-
323- # add websocket information, just use websocket here
324- if websocket :
325- await self .websocket_simulation_result (score , selected_node , websocket = websocket )
309+ if selected_node != self .root_node :
310+ print (f"{ GREEN } Step 3: Simulation{ RESET } " )
311+ await self .websocket_step_start (step = 3 , step_name = "simulation" , websocket = websocket )
312+ path = self .get_path_to_root (selected_node )
313+ # here score is the reward
314+ score = await self .evaluate_selected_path (path )
315+ # change to reward later?
316+ if score > best_score :
317+ best_score = score
318+ best_path = path
319+ best_node = selected_node
320+ print (f"\n New best path found!" )
321+ print (f"best score: { best_score :.3f} " )
322+ print (f"best node: { best_node .action } " )
323+ print (f"best node: { best_node .natural_language_description } " )
324+ print (f"best path: { best_path } " )
326325
326+ # add websocket information, just use websocket here
327+ if websocket :
328+ await self .websocket_simulation_result (score , selected_node , websocket = websocket )
327329
328- ## Step 4: reflection backtracking
329- print (f"{ GREEN } Step 4: Reflection Backtracking{ RESET } " )
330- await self .websocket_step_start (step = 4 , step_name = "reflection_backtracking" , websocket = websocket )
331- if score >= self .config .reflection_score :
332- # Convert path to serializable trajectory
333- # trajectory = [node.action for node in path if node.action is not None]
334- await self .websocket_search_complete ("success" , score , selected_node .get_trajectory (), websocket = websocket )
335- await self .playwright_manager .close ()
336- return selected_node
337330
338- print (f"path: { path } " )
339- path , current_node = await self .reflection_backtracking (path )
340- print (f"path: { path } " )
341- print (f"current_node: { current_node .action } " )
342- print (f"current_node: { current_node .natural_language_description } " )
331+ ## Step 4: reflection backtracking
332+ print (f"{ GREEN } Step 4: Reflection Backtracking{ RESET } " )
333+ await self .websocket_step_start (step = 4 , step_name = "reflection_backtracking" , websocket = websocket )
334+ if score >= self .config .reflection_score :
335+ # Convert path to serializable trajectory
336+ # trajectory = [node.action for node in path if node.action is not None]
337+ await self .websocket_search_complete ("success" , score , selected_node .get_trajectory (), websocket = websocket )
338+ await self .playwright_manager .close ()
339+ return selected_node
343340
344- # add websocket information, just use websocket here
345- if websocket :
346- await self .websocket_reflection_backtracking (path , current_node , websocket = websocket )
341+ print (f"path: { path } " )
342+ path , current_node = await self .reflection_backtracking (path )
343+ print (f"path: { path } " )
344+ print (f"current_node: { current_node .action } " )
345+ print (f"current_node: { current_node .natural_language_description } " )
347346
348- # Step 5: backpropagation
349- print (f"{ GREEN } Step 5: Backpropagation{ RESET } " )
350- await self .websocket_step_start (step = 5 , step_name = "backpropagation" , websocket = websocket )
351- for node in path :
352- if node != self .root_node :
353- old_value = node .value
354- node .visits += 1
355- node .value += (score - node .value ) / node .visits
356- # consiste with lats backpropagation
357- #node.value = (node.value * (node.visits - 1) + score) / node.visits
358- print (f"Node { node .action } :" )
359- print (f" Visits: { node .visits } " )
360- print (f" Value: { old_value :.3f} -> { node .value :.3f} " )
361347 # add websocket information, just use websocket here
362- # if websocket:
363- # await websocket.send_json({
364- # "type": "backpropagation",
365- # "node_id": id(node),
366- # "node_parent_id": id(node.parent),
367- # "node_action": node.action,
368- # "node_value": node.value,
369- # "node_visits": node.visits,
370- # "node_old_value": old_value,
371- # "node_description": node.natural_language_description,
372- # })
348+ if websocket :
349+ await self .websocket_reflection_backtracking (path , current_node , websocket = websocket )
373350
374- tree_data = self ._get_tree_data ()
375- print_entire_tree (self .root_node )
376- print (tree_data )
377- if websocket :
378- await self .websocket_tree_update (type = "tree_update_node_backpropagation" , websocket = websocket , tree_data = tree_data )
379- else :
351+ # Step 5: backpropagation
352+ print (f"{ GREEN } Step 5: Backpropagation{ RESET } " )
353+ await self .websocket_step_start (step = 5 , step_name = "backpropagation" , websocket = websocket )
354+ for node in path :
355+ if node != self .root_node :
356+ old_value = node .value
357+ node .visits += 1
358+ node .value += (score - node .value ) / node .visits
359+ # consiste with lats backpropagation
360+ #node.value = (node.value * (node.visits - 1) + score) / node.visits
361+ print (f"Node { node .action } :" )
362+ print (f" Visits: { node .visits } " )
363+ print (f" Value: { old_value :.3f} -> { node .value :.3f} " )
364+ # add websocket information, just use websocket here
365+ # if websocket:
366+ # await websocket.send_json({
367+ # "type": "backpropagation",
368+ # "node_id": id(node),
369+ # "node_parent_id": id(node.parent),
370+ # "node_action": node.action,
371+ # "node_value": node.value,
372+ # "node_visits": node.visits,
373+ # "node_old_value": old_value,
374+ # "node_description": node.natural_language_description,
375+ # })
376+
377+ tree_data = self ._get_tree_data ()
380378 print_entire_tree (self .root_node )
379+ print (tree_data )
380+ if websocket :
381+ await self .websocket_tree_update (type = "tree_update_node_backpropagation" , websocket = websocket , tree_data = tree_data )
382+ else :
383+ print_entire_tree (self .root_node )
381384 if best_node :
382385 # Convert node to serializable trajectory
383386 # trajectory = [n.action for n in self.get_path_to_root(best_node) if n.action is not None]
0 commit comments