11from typing import Any , Optional , Tuple , List
22from datetime import datetime
3+ import logging
4+ import json
5+ import time
6+ from openai import OpenAI
37from dotenv import load_dotenv
48load_dotenv ()
59
6- from .tree_vis import RED , better_print , print_trajectory , collect_all_nodes , GREEN , RESET , print_entire_tree
10+ from .tree_vis import RED , GREEN , RESET , better_print , print_trajectory , collect_all_nodes , print_entire_tree
711from .lats_node import LATSNode
812from .base_agent import BaseAgent
13+ from .trajectory_score import create_llm_prompt , score_trajectory_with_openai
14+ from ...replay_async import generate_feedback , playwright_step_execution
15+ from ...webagent_utils_async .browser_env .observation import extract_page_info
16+ from ...webagent_utils_async .action .prompt_functions import extract_top_actions
17+ from ...webagent_utils_async .utils .utils import parse_function_args , locate_element
18+ from ...evaluation_async .evaluators import goal_finished_evaluator
19+
20+ openai_client = OpenAI ()
21+
22+ logger = logging .getLogger (__name__ )
23+ logger .setLevel (logging .INFO )
924
1025class MCTSAgent (BaseAgent ):
11- async def run (self , websocket = None ) -> list [LATSNode ]:
26+ """
27+ Monte Carlo Tree Search Agent for web navigation tasks.
28+ This implementation uses reflection-based search to improve performance.
29+ """
30+
31+ async def run (self , websocket = None ) -> List [dict [str , Any ]]:
32+ """
33+ Run the MCTS algorithm based on configuration.
34+
35+ Args:
36+ websocket: Optional WebSocket connection to send updates to
37+
38+ Returns:
39+ List[Dict[str, Any]]: List of actions in the best path found
40+ """
1241 if websocket :
1342 await websocket .send_json ({
1443 "type" : "search_status" ,
@@ -17,4 +46,270 @@ async def run(self, websocket=None) -> list[LATSNode]:
1746 "timestamp" : datetime .utcnow ().isoformat ()
1847 })
1948
20- pass
49+ # Reset browser to initial state
50+ live_browser_url , session_id = await self ._reset_browser (websocket )
51+
52+ best_node = await self .mcts_search (websocket )
53+ print_trajectory (best_node )
54+
55+ return best_node
56+
57+ async def node_selection (self , node : LATSNode , websocket = None ) -> Optional [LATSNode ]:
58+ if node .is_terminal :
59+ return None
60+
61+ current_node = node
62+ path = [current_node ]
63+ selection_depth = 0
64+
65+ while current_node .children and not current_node .is_terminal :
66+ logger .info (f"\n Selection Step { selection_depth + 1 } :" )
67+ logger .info (f"Current node action: { current_node .action } " )
68+ logger .info (f"Number of children: { len (current_node .children )} " )
69+
70+ # Get trajectory for GPT-4 to evaluate
71+ trajectory = []
72+ for n in path [1 :]: # Skip root node
73+ trajectory .append ({
74+ "natural_language_description" : n .natural_language_description ,
75+ "action" : n .action ,
76+ "feedback" : n .feedback if hasattr (n , 'feedback' ) else None
77+ })
78+
79+ # Create prompt for GPT-4
80+ prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next.
81+ Consider the overall progress, efficiency, and likelihood of success.
82+
83+ Goal: { self .goal }
84+
85+ Current Trajectory:
86+ { json .dumps (trajectory , indent = 2 )}
87+
88+ Available Children:
89+ { json .dumps ([{
90+ 'action' : child .action ,
91+ 'description' : child .natural_language_description ,
92+ 'visits' : child .visits ,
93+ 'value' : child .value if hasattr (child , 'value' ) else 0
94+ } for child in current_node .children ], indent = 2 )}
95+
96+ Return a JSON response with:
97+ {{
98+ "selected_child_index": int, # Index of the selected child
99+ "explanation": str # Brief explanation of the selection
100+ }}"""
101+
102+ response = openai_client .chat .completions .create (
103+ model = self .config .evaluation_model ,
104+ messages = [
105+ {"role" : "system" , "content" : "You are an expert at selecting promising paths in a search tree." },
106+ {"role" : "user" , "content" : prompt }
107+ ],
108+ response_format = {"type" : "json_object" }
109+ )
110+
111+ selection = json .loads (response .choices [0 ].message .content )
112+ selected_index = selection ["selected_child_index" ]
113+
114+ if 0 <= selected_index < len (current_node .children ):
115+ current_node = current_node .children [selected_index ]
116+ path .append (current_node )
117+ logger .info (f"Selected child { selected_index + 1 } : { current_node .action } " )
118+ logger .info (f"Selection explanation: { selection ['explanation' ]} " )
119+ else :
120+ logger .warning (f"Invalid child index { selected_index } , breaking selection" )
121+ break
122+
123+
124+ selection_depth += 1
125+
126+ # Send final node selection update
127+ await self .websocket_node_selection (current_node , websocket = websocket )
128+ return current_node
129+
130+ async def evaluate_selected_path (self , path ) -> None :
131+ """Evaluate the current node and assign its score."""
132+ # Get the path from root to this node
133+ # path = self.get_path_to_root(node)
134+
135+ # Create trajectory for scoring (skip root node)
136+ trajectory = []
137+ for n in path [1 :]: # Skip root node
138+ trajectory .append ({
139+ "natural_language_description" : n .natural_language_description ,
140+ "action" : n .action ,
141+ "feedback" : n .feedback
142+ })
143+
144+ # Score the trajectory
145+ # TODO: if node is terminal, score is 0?
146+ # if node.is_terminal:
147+ # score = 0
148+ prompt = create_llm_prompt (trajectory , self .goal )
149+ print (f"prompt: { prompt } " )
150+ result = score_trajectory_with_openai (
151+ prompt ,
152+ openai_client ,
153+ model = self .config .evaluation_model
154+ )
155+ print (f"result: { result } " )
156+ score = result ["overall_score" ]
157+ print (f"Simulation Results, evaluate selected path:" )
158+ print (f"Overall Score: { score :.3f} " )
159+ print (f"Efficiency Score: { result ['efficiency_score' ]:.3f} " )
160+ print (f"Accuracy Score: { result ['accuracy_score' ]:.3f} " )
161+ print (f"Robustness Score: { result ['robustness_score' ]:.3f} " )
162+ return score
163+
164+ async def reflection_backtracking (self , path ) -> List [LATSNode ]:
165+ """
166+ Implement reflection-based backtracking to improve search trajectory.
167+
168+ Args:
169+ node: Current node
170+ path: Current path from root to node
171+
172+ Returns:
173+ List[LATSNode]: Modified path after backtracking
174+ """
175+ # Create trajectory for reflection
176+ trajectory = []
177+ for n in path [1 :]: # Skip root node
178+ trajectory .append ({
179+ "natural_language_description" : n .natural_language_description ,
180+ "action" : n .action ,
181+ "feedback" : n .feedback if hasattr (n , 'feedback' ) else None
182+ })
183+
184+ score = await self .evaluate_selected_path (path )
185+ print (f"\n Reflection Step (Score { score :.3f} < { self .config .reflection_score } ):" )
186+
187+ # Generate reflection prompt
188+ reflection_prompt = f"""Analyze the current trajectory and suggest improvements for the current website.
189+
190+ Goal: { self .goal }
191+
192+ Current Trajectory:
193+ { json .dumps (trajectory , indent = 2 )}
194+
195+ Score: { score }
196+
197+ Return a JSON response with:
198+ {{
199+ "backtrack_to_step": int, # Which step to backtrack to (0-based index)
200+ "reason": str, # Why backtrack to this step
201+ "suggested_improvements": [str] # List of suggested improvements specific to current websites
202+ }}"""
203+
204+ reflection = openai_client .chat .completions .create (
205+ model = self .config .evaluation_model ,
206+ messages = [
207+ {"role" : "system" , "content" : "You are an expert at analyzing and improving search trajectories." },
208+ {"role" : "user" , "content" : reflection_prompt }
209+ ],
210+ response_format = {"type" : "json_object" }
211+ )
212+
213+ reflection_result = json .loads (reflection .choices [0 ].message .content )
214+ backtrack_step = reflection_result ["backtrack_to_step" ]
215+
216+ # Backtrack to the suggested step
217+ if 0 <= backtrack_step < len (path ):
218+ # Prevent backtracking to root when we have actions
219+ if backtrack_step == 0 and len (path ) > 1 :
220+ backtrack_step = 1
221+ print ("Adjusted backtracking to maintain at least one action" )
222+
223+ current_node = path [backtrack_step ]
224+ # Remove nodes after the backtrack point
225+ while len (path ) > backtrack_step + 1 :
226+ path .pop ()
227+
228+ print (f"Backtracking to step { backtrack_step } " )
229+ print (f"Reason: { reflection_result ['reason' ]} " )
230+ print ("Suggested improvements:" )
231+ for improvement in reflection_result ["suggested_improvements" ]:
232+ print (f"- { improvement } " )
233+
234+ return path
235+
236+ async def mcts_search (self , websocket = None ) -> Optional [LATSNode ]:
237+ best_score = float ('-inf' )
238+ best_node = None
239+ print (f"iterations: { self .config .iterations } " )
240+
241+ for i in range (self .config .iterations ):
242+ await self .websocket_iteration_start (i , websocket = websocket )
243+
244+ print (f"\n { '=' * 50 } " )
245+ print (f"MCTS Iteration { i + 1 } /{ self .config .iterations } " )
246+ print (f"{ '=' * 50 } \n " )
247+
248+ # Step 1: Node Selection (contain simulation)
249+ # "node selection" combines selection and partial simulation
250+ print (f"{ GREEN } Step 1: Node Selection{ RESET } " )
251+ await self .websocket_step_start (step = 1 , step_name = "node_selection" , websocket = websocket )
252+ node = await self .node_selection (self .root_node , websocket )
253+
254+ if node is None :
255+ logger .warning ("All paths lead to terminal nodes. Ending search." )
256+ break
257+
258+ # Step 2: Node Expansion
259+ print (f"{ GREEN } Step 2: Node Expansion{ RESET } " )
260+ await self .websocket_step_start (step = 2 , step_name = "node_expansion" , websocket = websocket )
261+ await self .node_expansion (node , websocket )
262+ if node is None :
263+ # all the nodes are terminal, stop the search
264+ print (f"{ RED } All nodes are terminal, stopping search{ RESET } " )
265+ break
266+ tree_data = self ._get_tree_data ()
267+ if websocket :
268+ await self .websocket_tree_update (type = "tree_update_node_expansion" , websocket = websocket , tree_data = tree_data )
269+ else :
270+ print_entire_tree (self .root_node )
271+
272+
273+ # Step 3: simulation using the current node, (generate a path using the current node, and score the path)
274+ # TODO: implement simulation using openai
275+ print (f"{ GREEN } Step 3: Simulation{ RESET } " )
276+ await self .websocket_step_start (step = 3 , step_name = "simulation" , websocket = websocket )
277+ path = self .get_path_to_root (node )
278+ score = await self .evaluate_selected_path (path )
279+ # change to reward later?
280+ if score > best_score :
281+ best_score = score
282+ best_path = path
283+ print (f"\n New best path found!" )
284+ print (f"Previous best score: { best_score :.3f} " )
285+ print (f"New best score: { score :.3f} " )
286+
287+
288+ ## Step 4: reflection backtracking
289+ print (f"{ GREEN } Step 4: Reflection Backtracking{ RESET } " )
290+ await self .websocket_step_start (step = 4 , step_name = "reflection_backtracking" , websocket = websocket )
291+ if score >= self .config .reflection_score :
292+ # Convert path to serializable trajectory
293+ trajectory = [node .action for node in path if node .action is not None ]
294+ await self .websocket_search_complete ("success" , score , trajectory , websocket = websocket )
295+ return node
296+
297+ print (f"path: { path } " )
298+ path = await self .reflection_backtracking (path )
299+ print (f"path: { path } " )
300+
301+ # Step 5: backpropagation
302+ print (f"{ GREEN } Step 5: Backpropagation{ RESET } " )
303+ await self .websocket_step_start (step = 5 , step_name = "backpropagation" , websocket = websocket )
304+ for node in path :
305+ old_value = node .value
306+ node .visits += 1
307+ node .value = (node .value * (node .visits - 1 ) + score ) / node .visits
308+ print (f"Node { node .action } :" )
309+ print (f" Visits: { node .visits } " )
310+ print (f" Value: { old_value :.3f} -> { node .value :.3f} " )
311+ if best_node :
312+ # Convert node to serializable trajectory
313+ trajectory = [n .action for n in self .get_path_to_root (best_node ) if n .action is not None ]
314+ await self .websocket_search_complete ("partial_success" , best_node .value , best_node .get_trajectory (), websocket = websocket )
315+ return best_node
0 commit comments