11"""Language-based Action Tree Search (LATS) Agent implementation."""
22
3- import logging
43import time
54from typing import Any , Optional , Tuple , List
65
3433from ...webagent_utils_async .browser_env .observation import extract_page_info
3534from ...webagent_utils_async .evaluation .feedback import capture_post_action_feedback
3635
37- logger = logging .getLogger (__name__ )
3836openai_client = OpenAI ()
3937
4038class LATSAgent :
@@ -101,95 +99,85 @@ async def run(self) -> list[LATSNode]:
10199 Returns:
102100 list[LATSNode]: Best path from root to terminal node
103101 """
104- pass
105- # best_node = self.lats_search()
106- # print_trajectory(best_node)
107- # return best_node.get_trajectory()
102+ best_node = await self .lats_search ()
103+ print_trajectory (best_node )
104+ return best_node .get_trajectory ()
108105
109- def lats_search (self ) -> LATSNode :
106+ async def lats_search (self ) -> LATSNode :
110107 """
111108 Perform the main LATS search algorithm.
112109
113110 Returns:
114111 LATSNode: Best terminal node found
115112 """
116- logger . info (f"" )
117- logger . info (f"{ GREEN } START SEARCH{ RESET } " )
113+ print (f"" )
114+ print (f"{ GREEN } START SEARCH{ RESET } " )
118115
119116 terminal_nodes = []
120117
121118 for i in range (self .config .iterations ):
122- logger . info (f"" )
123- logger . info (f"" )
124- logger . info (f"Iteration { i + 1 } ..." )
119+ print (f"" )
120+ print (f"" )
121+ print (f"Iteration { i + 1 } ..." )
125122
126123 # Step 1: Selection
127- logger . info (f"" )
128- logger . info (f"{ GREEN } Step 1: selection{ RESET } " )
124+ print (f"" )
125+ print (f"{ GREEN } Step 1: selection{ RESET } " )
129126 node = self .select_node (self .root_node )
130127
131128 if node is None :
132- logger . info ("All paths lead to terminal nodes with reward 0. Ending search." )
129+ print ("All paths lead to terminal nodes with reward 0. Ending search." )
133130 break
134131
135132 print (f"{ GREEN } Tree:{ RESET } " )
136133 better_print (node = self .root_node , selected_node = node )
137134 print (f"" )
138135
139136 # Step 2: Expansion
140- logger . info (f"" )
141- logger . info (f"{ GREEN } Step 2: expansion{ RESET } " )
142- self .expand_node (node )
137+ print (f"" )
138+ print (f"{ GREEN } Step 2: expansion{ RESET } " )
139+ await self .expand_node (node )
143140
144141 while node is not None and node .is_terminal and not self .goal_finished :
145- logger . info (f"Depth limit node found at iteration { i + 1 } , reselecting..." )
142+ print (f"Depth limit node found at iteration { i + 1 } , reselecting..." )
146143 node = self .select_node (self .root_node )
147144 if node is not None :
148- self .expand_node (node )
145+ await self .expand_node (node )
149146
150147 if node is None :
151148 # all the nodes are terminal, stop the search
152- logger . info (f"{ RED } All nodes are terminal, stopping search{ RESET } " )
149+ print (f"{ RED } All nodes are terminal, stopping search{ RESET } " )
153150 break
154151
155152 if self .goal_finished :
156- logger . info (f"{ RED } Goal finished, stopping search{ RESET } " )
153+ print (f"{ RED } Goal finished, stopping search{ RESET } " )
157154 break
158155
159156 print (f"{ GREEN } Tree:{ RESET } " )
160157 better_print (self .root_node )
161158 print (f"" )
162159
163160 # Step 3: Evaluation
164- logger . info (f"" )
165- logger . info (f"{ GREEN } Step 3: evaluation{ RESET } " )
166- self .evaluate_node (node )
161+ print (f"" )
162+ print (f"{ GREEN } Step 3: evaluation{ RESET } " )
163+ await self .evaluate_node (node )
167164
168165 print (f"{ GREEN } Tree:{ RESET } " )
169166 better_print (self .root_node )
170167 print (f"" )
171168
172169 # Step 4: Simulation
173- logger . info (f"{ GREEN } Step 4: simulation{ RESET } " )
170+ print (f"{ GREEN } Step 4: simulation{ RESET } " )
174171 # # Find the child with the highest value
175172 ## always = 1
176- reward , terminal_node = self .simulate (max (node .children , key = lambda child : child .value ), max_depth = self .config .max_depth , num_simulations = 1 )
173+ reward , terminal_node = await self .simulate (max (node .children , key = lambda child : child .value ), max_depth = self .config .max_depth , num_simulations = 1 )
177174 terminal_nodes .append (terminal_node )
178175
179176 if reward == 1 :
180177 return terminal_node
181178
182-
183- # print(f"{GREEN}Tree:{RESET}")
184- # better_print(self.root_node, selected_node=terminal_node)
185- # print(f"")
186-
187- # if self.goal_finished:
188- # logger.info(f"{RED}Goal finished, stopping search{RESET}")
189- # break
190-
191179 # Step 5: Backpropagation
192- logger . info (f"{ GREEN } Step 5: backpropagation{ RESET } " )
180+ print (f"{ GREEN } Step 5: backpropagation{ RESET } " )
193181 self .backpropagate (terminal_node , reward )
194182 print (f"{ GREEN } Tree:{ RESET } " )
195183 better_print (self .root_node )
@@ -203,10 +191,10 @@ def lats_search(self) -> LATSNode:
203191 best_child = max (all_nodes_list , key = lambda x : (x .reward , x .depth ))
204192
205193 if best_child .reward == 1 :
206- logger . info ("Successful trajectory found" )
194+ print ("Successful trajectory found" )
207195 else :
208- logger . info ("Unsuccessful trajectory found" )
209- self .playwright_manager .close ()
196+ print ("Unsuccessful trajectory found" )
197+ await self .playwright_manager .close ()
210198
211199 return best_child if best_child is not None else self .root_node
212200
@@ -224,14 +212,14 @@ def select_node(self, node: LATSNode) -> Optional[LATSNode]:
224212 return None
225213 return node .get_best_leaf ()
226214
227- def expand_node (self , node : LATSNode ) -> None :
215+ async def expand_node (self , node : LATSNode ) -> None :
228216 """
229217 Expand a node by generating its children.
230218
231219 Args:
232220 node: Node to expand
233221 """
234- children = self .generate_children (node )
222+ children = await self .generate_children (node )
235223
236224 for child in children :
237225 node .add_child (child )
@@ -242,7 +230,7 @@ def expand_node(self, node: LATSNode) -> None:
242230
243231 node .check_terminal ()
244232
245- def evaluate_node (self , node : LATSNode ) -> None :
233+ async def evaluate_node (self , node : LATSNode ) -> None :
246234 """
247235 Evaluate a node using LLM scoring.
248236
@@ -253,23 +241,23 @@ def evaluate_node(self, node: LATSNode) -> None:
253241 float: Evaluation score
254242 """
255243 scores = []
256- logger . info (f"{ GREEN } -- total { len (node .children )} children to evaluate:{ RESET } " )
244+ print (f"{ GREEN } -- total { len (node .children )} children to evaluate:{ RESET } " )
257245 for i , child in enumerate (node .children ):
258- logger . info (f"{ GREEN } --- evaluating child { i + 1 } ...{ RESET } " )
246+ print (f"{ GREEN } --- evaluating child { i + 1 } ...{ RESET } " )
259247 if child .is_terminal :
260248 score = 0
261249 else :
262250 trajectory = child .get_trajectory ()
263251 prompt = create_llm_prompt (trajectory , self .goal )
264252 result = score_trajectory_with_openai (prompt , openai_client , self .config .evaluation_model , child .observation .image )
265- score = result ["score" ] / 10
253+ score = result ["overall_score" ]
266254 scores .append (score )
267255
268256 for child , score in zip (node .children , scores ):
269257 child .value = score
270258 child .reward = score
271259
272- def simulate (self , node : LATSNode , max_depth : int = 2 , num_simulations = 1 ) -> tuple [float , LATSNode ]:
260+ async def simulate (self , node : LATSNode , max_depth : int = 2 , num_simulations = 1 ) -> tuple [float , LATSNode ]:
273261 """
274262 Perform a rollout simulation from a node.
275263
@@ -285,7 +273,7 @@ def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1) -> tup
285273 print_trajectory (node )
286274 print ("print the entire tree" )
287275 print_entire_tree (self .root_node )
288- return self .rollout (node , max_depth = max_depth )
276+ return await self .rollout (node , max_depth = max_depth )
289277
290278 def send_completion_request (self , plan , depth , node , trajectory = []):
291279 print ("print the trajectory" )
@@ -436,22 +424,22 @@ def backpropagate(self, node: LATSNode, value: float) -> None:
436424 node .value = (node .value * (node .visits - 1 ) + value ) / node .visits
437425 node = node .parent
438426
439- def _reset_browser (self ) -> None :
427+ async def _reset_browser (self ) -> None :
440428 """Reset the browser to initial state."""
441- self .playwright_manager .close ()
442- self .playwright_manager = setup_playwright (
429+ await self .playwright_manager .close ()
430+ self .playwright_manager = await setup_playwright (
443431 headless = self .config .headless ,
444432 mode = self .config .browser_mode ,
445433 storage_state = self .config .storage_state ,
446- log_folder = self .config .log_folder ,
434+ # log_folder=self.config.log_folder,
447435 )
448- page = self .playwright_manager .get_page ()
449- page .goto (self .starting_url , wait_until = "networkidle" )
436+ page = await self .playwright_manager .get_page ()
437+ await page .goto (self .starting_url , wait_until = "networkidle" )
450438
451- def observe (self ) -> None :
452- page = self .playwright_manager .get_page ()
453- page_info = extract_page_info (page , self .config .fullpage , self .config .log_folder )
454- feature_text = observe_features (
439+ async def observe (self ) -> None :
440+ page = await self .playwright_manager .get_page ()
441+ page_info = await extract_page_info (page , self .config .fullpage , self .config .log_folder )
442+ feature_text = await observe_features (
455443 page_info ,
456444 features = self .config .features ,
457445 elements_filter = self .config .elements_filter ,
@@ -465,26 +453,47 @@ def observe(self) -> None:
465453 )
466454 return observation
467455
468- def execute_action_trajectory (self , action_trajectory : list [dict ]) -> None :
456+ async def execute_action_trajectory (self , action_trajectory : list [dict ]) -> None :
469457 if not action_trajectory :
470458 return True
471459
472- self ._reset_browser ()
460+ await self ._reset_browser ()
461+ print ("taking action trajectory" )
473462 for action_data in action_trajectory :
474- success = step_execution (action_data , self .playwright_manager , self .config .log_folder )
463+ print ("action_data" )
464+ print (action_data )
465+
466+ # Convert action_data dict to LATSNode
467+ temp_node = LATSNode (
468+ natural_language_description = action_data ["natural_language_description" ],
469+ action = action_data ["action" ],
470+ prob = action_data ["prob" ],
471+ element = action_data ["element" ],
472+ goal = self .goal ,
473+ parent = None # No parent needed for temporary node
474+ )
475+
476+ success = await playwright_step_execution (
477+ temp_node , # Pass the node instead of raw action_data
478+ self .goal ,
479+ self .playwright_manager ,
480+ is_replay = False ,
481+ log_folder = self .config .log_folder
482+ )
483+
475484 if not success :
476485 return False
477486 return True
478487
479- def generate_candidate_actions (self , node : LATSNode ) -> list [dict ]:
488+ async def generate_candidate_actions (self , node : LATSNode ) -> list [dict ]:
480489 trajectory = node .get_trajectory ()
481490 action_trajectory = node .get_action_trajectory ()
482- self .execute_action_trajectory (action_trajectory )
483- observation = self .observe ()
491+ await self .execute_action_trajectory (action_trajectory )
492+ observation = await self .observe ()
484493 # only root node has no observation at this point
485494 if node .observation is None :
486495 node .observation = observation
487- actions = generate_actions_with_observation (
496+ actions = await generate_actions_with_observation (
488497 trajectory ,
489498 self .goal ,
490499 self .images ,
@@ -497,56 +506,56 @@ def generate_candidate_actions(self, node: LATSNode) -> list[dict]:
497506 action_generation_model = self .config .action_generation_model ,
498507 )
499508
500- page = self .playwright_manager .get_page ()
509+ page = await self .playwright_manager .get_page ()
501510 valid_actions = []
502511 for action_data in actions :
503512 if action_data ["action" ] == "FINISH" :
504513 continue
505514
506- is_bid_action , element_data = locate_element_from_action (page , action_data ["action" ])
515+ is_bid_action , element_data = await locate_element_from_action (page , action_data ["action" ])
507516 if is_bid_action and not element_data :
508517 continue
509518
510519 action_data ['element' ] = element_data
511520 valid_actions .append (action_data )
512521 return valid_actions
513522
514- def generate_children (self , node : LATSNode ) -> list [LATSNode ]:
515- logger . info (f"{ GREEN } -- generating candidate actions...{ RESET } " )
523+ async def generate_children (self , node : LATSNode ) -> list [LATSNode ]:
524+ print (f"{ GREEN } -- generating candidate actions...{ RESET } " )
516525
517526 children = []
518527
519528 action_trajectory = node .get_action_trajectory ()
520- candidate_actions = self .generate_candidate_actions (node )
521- logger . info (f"{ GREEN } -- generated { len (candidate_actions )} actions{ RESET } " )
529+ candidate_actions = await self .generate_candidate_actions (node )
530+ print (f"{ GREEN } -- generated { len (candidate_actions )} actions{ RESET } " )
522531 for action_data in candidate_actions :
523- logger . info (f"{ GREEN } --- { action_data ['action' ]} { RESET } " )
524- logger . info (f"{ GREEN } --- { action_data ['natural_language_description' ]} { RESET } " )
532+ print (f"{ GREEN } --- { action_data ['action' ]} { RESET } " )
533+ print (f"{ GREEN } --- { action_data ['natural_language_description' ]} { RESET } " )
525534
526- logger . info (f"" )
527- logger . info (f"{ GREEN } -- executing candidate trajectories{ RESET } " )
535+ print (f"" )
536+ print (f"{ GREEN } -- executing candidate trajectories{ RESET } " )
528537 for i , action_data in enumerate (candidate_actions ):
529538
530539 candidate_action_trajectory = action_trajectory + [action_data ]
531- logger . info (f"{ GREEN } --- trajectory { i + 1 } :{ RESET } " )
540+ print (f"{ GREEN } --- trajectory { i + 1 } :{ RESET } " )
532541 for action in candidate_action_trajectory :
533- logger . info (f"{ GREEN } ---- { action ['action' ]} { RESET } " )
534- logger . info (f"{ GREEN } ---- { action ['natural_language_description' ]} { RESET } " )
535- executed_successfully = self .execute_action_trajectory (candidate_action_trajectory )
542+ print (f"{ GREEN } ---- { action ['action' ]} { RESET } " )
543+ print (f"{ GREEN } ---- { action ['natural_language_description' ]} { RESET } " )
544+ executed_successfully = await self .execute_action_trajectory (candidate_action_trajectory )
536545 if not executed_successfully :
537546 # not executed successfully, give up this candidate
538- logger . info (f"{ RED } --- failed to execute action trajectory{ RESET } " )
547+ print (f"{ RED } --- failed to execute action trajectory{ RESET } " )
539548 continue
540549
541- observation = self .observe ()
542- logger . info (f"{ GREEN } --- generate feedback...{ RESET } " )
543- feedback = generate_feedback_with_screenshot (
550+ observation = await self .observe ()
551+ print (f"{ GREEN } --- generate feedback...{ RESET } " )
552+ feedback = await generate_feedback_with_screenshot (
544553 self .goal ,
545554 action_data ["natural_language_description" ],
546555 observation .image ,
547556 model = self .config .feedback_model ,
548557 )
549- logger . info (f"feedback: is_done: { feedback .is_done } , explanation: { feedback .explanation } " )
558+ print (f"feedback: is_done: { feedback .is_done } , explanation: { feedback .explanation } " )
550559
551560 child = LATSNode (
552561 natural_language_description = action_data ["natural_language_description" ],
0 commit comments