@@ -275,7 +275,7 @@ async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1)
275275 print_entire_tree (self .root_node )
276276 return await self .rollout (node , max_depth = max_depth )
277277
278- def send_completion_request (self , plan , depth , node , trajectory = []):
278+ async def send_completion_request (self , plan , depth , node , trajectory = []):
279279 print ("print the trajectory" )
280280 print_trajectory (node )
281281 print ("print the entire tree" )
@@ -284,20 +284,20 @@ def send_completion_request(self, plan, depth, node, trajectory=[]):
284284 if depth >= self .config .max_depth :
285285 return trajectory , node
286286
287- context = self .playwright_manager .get_context ()
288- page = self .playwright_manager .get_page ()
287+ context = await self .playwright_manager .get_context ()
288+ page = await self .playwright_manager .get_page ()
289289 # Extract page information
290290 time .sleep (3 )
291- page_info = extract_page_info (page , fullpage = True , log_folder = self .config .log_folder )
292- updated_actions = extract_top_actions (
291+ page_info = await extract_page_info (page , fullpage = True , log_folder = self .config .log_folder )
292+ updated_actions = await extract_top_actions (
293293 trajectory , self .goal , self .images , page_info , self .action_set , openai_client ,
294294 features = ["axtree" ], elements_filter = "som" , branching_factor = self .config .branching_factor ,
295295 log_folder = self .config .log_folder , fullpage = True ,
296296 action_generation_model = self .config .action_generation_model ,
297297 action_grounding_model = self .config .action_grounding_model
298298 )
299299 next_action = updated_actions [0 ]
300- retry_count = self .config .retry_count if hasattr (self .config , 'retry_count' ) else 3 # Default retries if not set
300+ retry_count = self .config .retry_count if hasattr (self .config , 'retry_count' ) else 1 # Default retries if not set
301301
302302 for attempt in range (retry_count ):
303303 try :
@@ -308,13 +308,13 @@ def send_completion_request(self, plan, depth, node, trajectory=[]):
308308 if len (function_calls ) == 1 :
309309 for function_name , function_args in function_calls :
310310 extracted_number = parse_function_args (function_args )
311- element = locate_element (page , extracted_number )
311+ element = await locate_element (page , extracted_number )
312312 next_action ["element" ] = element
313313
314314 # Execute action
315- execute_action (next_action , self .action_set , page , context , self .goal , page_info ['interactive_elements' ],
315+ await execute_action (next_action , self .action_set , page , context , self .goal , page_info ['interactive_elements' ],
316316 self .config .log_folder )
317- feedback = capture_post_action_feedback (page , next_action , self .goal , self .config .log_folder )
317+ feedback = await capture_post_action_feedback (page , next_action , self .goal , self .config .log_folder )
318318 trajectory .append ({'action' : next_action ['action' ], 'feedback' : feedback })
319319 action_str = next_action ["action" ]
320320
@@ -328,7 +328,7 @@ def send_completion_request(self, plan, depth, node, trajectory=[]):
328328 messages .append ({"role" : "user" , "content" : 'action is: {}' .format (action )})
329329 messages .append ({"role" : "user" , "content" : 'action feedback is: {}' .format (feedback )})
330330
331- goal_finished = is_goal_finished (messages , openai_client )
331+ goal_finished = await is_goal_finished (messages , openai_client )
332332
333333 new_node = LATSNode (
334334 natural_language_description = next_action ["natural_language_description" ],
@@ -342,22 +342,22 @@ def send_completion_request(self, plan, depth, node, trajectory=[]):
342342 if goal_finished :
343343 return trajectory , new_node
344344
345- return self .send_completion_request (plan , depth + 1 , new_node , trajectory )
345+ return await self .send_completion_request (plan , depth + 1 , new_node , trajectory )
346346
347347 except Exception as e :
348348 print (f"Attempt { attempt + 1 } failed with error: { e } " )
349349 if attempt + 1 == retry_count :
350350 print ("Max retries reached. Skipping this step and retrying the whole request." )
351351 # Retry the entire request from the same state
352- return self .send_completion_request (plan , depth , node , trajectory )
352+ return await self .send_completion_request (plan , depth , node , trajectory )
353353
354354 # If all retries and retries of retries fail, return the current trajectory and node
355355 return trajectory , node
356356
357357
358- def rollout (self , node : LATSNode , max_depth : int = 2 )-> tuple [float , LATSNode ]:
358+ async def rollout (self , node : LATSNode , max_depth : int = 2 )-> tuple [float , LATSNode ]:
359359 # Reset browser state
360- self ._reset_browser ()
360+ await self ._reset_browser ()
361361 path = self .get_path_to_root (node )
362362
363363 print ("execute path" )
@@ -367,7 +367,7 @@ def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSNode]:
367367 trajectory = []
368368
369369 for n in path [1 :]: # Skip root node
370- success = playwright_step_execution (
370+ success = await playwright_step_execution (
371371 n ,
372372 self .goal ,
373373 self .playwright_manager ,
@@ -377,7 +377,7 @@ def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSNode]:
377377 if not success :
378378 return 0 , n
379379 if not n .feedback :
380- n .feedback = generate_feedback (
380+ n .feedback = await generate_feedback (
381381 self .goal ,
382382 n .natural_language_description ,
383383 self .playwright_manager ,
@@ -389,14 +389,14 @@ def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSNode]:
389389 ## call the prompt agent
390390 print ("current depth: " , len (path ) - 1 )
391391 print ("max depth: " , self .config .max_depth )
392- trajectory , node = self .send_completion_request (self .goal , len (path ) - 1 , node = n , trajectory = trajectory )
392+ trajectory , node = await self .send_completion_request (self .goal , len (path ) - 1 , node = n , trajectory = trajectory )
393393 print ("print the trajectory" )
394394 print_trajectory (node )
395395 print ("print the entire tree" )
396396 print_entire_tree (self .root_node )
397397
398- page = self .playwright_manager .get_page ()
399- page_info = extract_page_info (page , self .config .fullpage , self .config .log_folder )
398+ page = await self .playwright_manager .get_page ()
399+ page_info = await extract_page_info (page , self .config .fullpage , self .config .log_folder )
400400
401401 messages = [{"role" : "user" , "content" : f"Action is: { n .action } " } for n in path [1 :]]
402402 goal_finished , confidence_score = goal_finished_evaluator (
@@ -467,7 +467,7 @@ async def execute_action_trajectory(self, action_trajectory: list[dict]) -> None
467467 temp_node = LATSNode (
468468 natural_language_description = action_data ["natural_language_description" ],
469469 action = action_data ["action" ],
470- prob = action_data [ "prob" ] ,
470+ prob = 0 ,
471471 element = action_data ["element" ],
472472 goal = self .goal ,
473473 parent = None # No parent needed for temporary node
0 commit comments