22
33import time
44from typing import Any , Optional , Tuple , List
5-
5+ import os
66from openai import OpenAI
7+ from datetime import datetime
8+ import aiohttp
9+ from dotenv import load_dotenv
10+ load_dotenv ()
711
812from .lats_node import LATSNode , Observation
913from ...core_async .config import AgentConfig
1216from ...webagent_utils_async .utils .playwright_manager import AsyncPlaywrightManager , setup_playwright
1317from .tree_vis import RED , better_print , print_trajectory , collect_all_nodes , GREEN , RESET , print_entire_tree
1418from .trajectory_score import create_llm_prompt , score_trajectory_with_openai
15- # from ...replay import locate_element_from_action, step_execution
1619from ...replay_async import generate_feedback , playwright_step_execution , locate_element_from_action
1720from ...webagent_utils_async .browser_env .observation import extract_page_info , observe_features
1821from ...webagent_utils_async .action .prompt_functions import generate_actions_with_observation
2225
2326from ...webagent_utils_async .utils .utils import parse_function_args , locate_element
2427from ...evaluation_async .evaluators import goal_finished_evaluator
25- # from ...replay import playwright_step_execution, generate_feedback
2628from ...webagent_utils_async .action .prompt_functions import extract_top_actions
2729from ...webagent_utils_async .browser_env .observation import extract_page_info
2830from .lats_node import LATSNode
@@ -91,22 +93,47 @@ def __init__(
9193 )
9294 self .goal_finished = False
9395 self .result_node = None
96+ self .reset_url = os .environ ["ACCOUNT_RESET_URL" ]
9497
95- async def run (self ) -> list [LATSNode ]:
98+ async def run (self , websocket = None ) -> list [LATSNode ]:
9699 """
97100 Run the LATS search and return the best path found.
98101
102+ Args:
103+ websocket: Optional WebSocket connection for sending updates
104+
99105 Returns:
100106 list[LATSNode]: Best path from root to terminal node
101107 """
102- best_node = await self .lats_search ()
108+ if websocket :
109+ await websocket .send_json ({
110+ "type" : "search_status" ,
111+ "status" : "started" ,
112+ "message" : "Starting LATS search" ,
113+ "timestamp" : datetime .utcnow ().isoformat ()
114+ })
115+
116+ best_node = await self .lats_search (websocket )
103117 print_trajectory (best_node )
118+
119+ if websocket :
120+ await websocket .send_json ({
121+ "type" : "search_complete" ,
122+ "status" : "success" if best_node .reward == 1 else "partial_success" ,
123+ "score" : best_node .reward ,
124+ "path" : [{"id" : id (node ), "action" : node .action } for node in best_node .get_trajectory ()],
125+ "timestamp" : datetime .utcnow ().isoformat ()
126+ })
127+
104128 return best_node .get_trajectory ()
105129
106- async def lats_search (self ) -> LATSNode :
130+ async def lats_search (self , websocket = None ) -> LATSNode :
107131 """
108132 Perform the main LATS search algorithm.
109133
134+ Args:
135+ websocket: Optional WebSocket connection for sending updates
136+
110137 Returns:
111138 LATSNode: Best terminal node found
112139 """
@@ -116,13 +143,26 @@ async def lats_search(self) -> LATSNode:
116143 terminal_nodes = []
117144
118145 for i in range (self .config .iterations ):
146+ if websocket :
147+ await websocket .send_json ({
148+ "type" : "iteration_start" ,
149+ "iteration" : i + 1 ,
150+ "timestamp" : datetime .utcnow ().isoformat ()
151+ })
152+
119153 print (f"" )
120154 print (f"" )
121155 print (f"Iteration { i + 1 } ..." )
122156
123- # Step 1: Selection
124- print (f"" )
125- print (f"{ GREEN } Step 1: selection{ RESET } " )
157+ # Step 1: Selection with websocket update
158+ if websocket :
159+ await websocket .send_json ({
160+ "type" : "step_start" ,
161+ "step" : "selection" ,
162+ "iteration" : i + 1 ,
163+ "timestamp" : datetime .utcnow ().isoformat ()
164+ })
165+
126166 node = self .select_node (self .root_node )
127167
128168 if node is None :
@@ -133,16 +173,22 @@ async def lats_search(self) -> LATSNode:
133173 better_print (node = self .root_node , selected_node = node )
134174 print (f"" )
135175
136- # Step 2: Expansion
137- print (f"" )
138- print (f"{ GREEN } Step 2: expansion{ RESET } " )
139- await self .expand_node (node )
176+ # Step 2: Expansion with websocket update
177+ if websocket :
178+ await websocket .send_json ({
179+ "type" : "step_start" ,
180+ "step" : "expansion" ,
181+ "iteration" : i + 1 ,
182+ "timestamp" : datetime .utcnow ().isoformat ()
183+ })
184+
185+ await self .expand_node (node , websocket )
140186
141187 while node is not None and node .is_terminal and not self .goal_finished :
142188 print (f"Depth limit node found at iteration { i + 1 } , reselecting..." )
143189 node = self .select_node (self .root_node )
144190 if node is not None :
145- await self .expand_node (node )
191+ await self .expand_node (node , websocket )
146192
147193 if node is None :
148194 # all the nodes are terminal, stop the search
@@ -183,6 +229,15 @@ async def lats_search(self) -> LATSNode:
183229 better_print (self .root_node )
184230 print (f"" )
185231
232+ # Send tree update after each iteration
233+ if websocket :
234+ tree_data = self ._get_tree_data ()
235+ await websocket .send_json ({
236+ "type" : "tree_update" ,
237+ "tree" : tree_data ,
238+ "timestamp" : datetime .utcnow ().isoformat ()
239+ })
240+
186241 # Find best node
187242 all_nodes_list = collect_all_nodes (self .root_node )
188243 all_nodes_list .extend (terminal_nodes )
@@ -212,20 +267,43 @@ def select_node(self, node: LATSNode) -> Optional[LATSNode]:
212267 return None
213268 return node .get_best_leaf ()
214269
215- async def expand_node (self , node : LATSNode ) -> None :
270+ async def expand_node (self , node : LATSNode , websocket = None ) -> None :
216271 """
217272 Expand a node by generating its children.
218273
219274 Args:
220275 node: Node to expand
276+ websocket: Optional WebSocket connection for sending updates
221277 """
222- children = await self .generate_children (node )
278+ if websocket :
279+ await websocket .send_json ({
280+ "type" : "node_expanding" ,
281+ "node_id" : id (node ),
282+ "timestamp" : datetime .utcnow ().isoformat ()
283+ })
284+
285+ children = await self .generate_children (node , websocket )
223286
224287 for child in children :
225288 node .add_child (child )
289+ if websocket :
290+ await websocket .send_json ({
291+ "type" : "node_created" ,
292+ "node_id" : id (child ),
293+ "parent_id" : id (node ),
294+ "action" : child .action ,
295+ "description" : child .natural_language_description ,
296+ "timestamp" : datetime .utcnow ().isoformat ()
297+ })
226298
227299 if children and children [0 ].goal_finish_feedback .is_done :
228300 self .set_goal_finished (children [0 ])
301+ if websocket :
302+ await websocket .send_json ({
303+ "type" : "goal_finished" ,
304+ "node_id" : id (children [0 ]),
305+ "timestamp" : datetime .utcnow ().isoformat ()
306+ })
229307 return
230308
231309 node .check_terminal ()
@@ -424,17 +502,104 @@ def backpropagate(self, node: LATSNode, value: float) -> None:
424502 node .value = (node .value * (node .visits - 1 ) + value ) / node .visits
425503 node = node .parent
426504
427- async def _reset_browser (self ) -> None :
428- """Reset the browser to initial state."""
505+ async def _reset_browser (self , websocket = None ) -> Optional [ str ] :
506+ """Reset the browser to initial state and return the live browser URL if available ."""
429507 await self .playwright_manager .close ()
430- self .playwright_manager = await setup_playwright (
431- headless = self .config .headless ,
432- mode = self .config .browser_mode ,
433- storage_state = self .config .storage_state ,
434- # log_folder=self.config.log_folder,
435- )
436- page = await self .playwright_manager .get_page ()
437- await page .goto (self .starting_url , wait_until = "networkidle" )
508+
509+ ## reset account using api-based account reset
510+ if self .config .account_reset :
511+ if websocket :
512+ await websocket .send_json ({
513+ "type" : "account_reset" ,
514+ "status" : "started" ,
515+ "timestamp" : datetime .utcnow ().isoformat ()
516+ })
517+
518+ try :
519+ # Use aiohttp instead of curl
520+ async with aiohttp .ClientSession () as session :
521+ headers = {'Connection' : 'close' } # Similar to curl -N
522+ async with session .get (self .reset_url , headers = headers ) as response :
523+ if response .status == 200 :
524+ data = await response .json ()
525+ print (f"Account reset successful: { data } " )
526+ if websocket :
527+ await websocket .send_json ({
528+ "type" : "account_reset" ,
529+ "status" : "success" ,
530+ "data" : data ,
531+ "timestamp" : datetime .utcnow ().isoformat ()
532+ })
533+ else :
534+ error_msg = f"Account reset failed with status { response .status } "
535+ print (error_msg )
536+ if websocket :
537+ await websocket .send_json ({
538+ "type" : "account_reset" ,
539+ "status" : "failed" ,
540+ "reason" : error_msg ,
541+ "timestamp" : datetime .utcnow ().isoformat ()
542+ })
543+
544+ except Exception as e :
545+ print (f"Error during account reset: { e } " )
546+ if websocket :
547+ await websocket .send_json ({
548+ "type" : "account_reset" ,
549+ "status" : "failed" ,
550+ "reason" : str (e ),
551+ "timestamp" : datetime .utcnow ().isoformat ()
552+ })
553+
554+ try :
555+ # Create new playwright manager
556+ self .playwright_manager = await setup_playwright (
557+ storage_state = self .config .storage_state ,
558+ headless = self .config .headless ,
559+ mode = self .config .browser_mode
560+ )
561+ page = await self .playwright_manager .get_page ()
562+ live_browser_url = None
563+ if self .config .browser_mode == "browserbase" :
564+ live_browser_url = await self .playwright_manager .get_live_browser_url ()
565+ session_id = await self .playwright_manager .get_session_id ()
566+ else :
567+ session_id = None
568+ live_browser_url = None
569+ await page .goto (self .starting_url , wait_until = "networkidle" )
570+
571+ # Send success message if websocket is provided
572+ if websocket :
573+ if self .config .storage_state :
574+ await websocket .send_json ({
575+ "type" : "browser_setup" ,
576+ "status" : "success" ,
577+ "message" : f"Browser successfully initialized with storage state file: { self .config .storage_state } " ,
578+ "live_browser_url" : live_browser_url ,
579+ "session_id" : session_id ,
580+ "timestamp" : datetime .utcnow ().isoformat ()
581+ })
582+ else :
583+ await websocket .send_json ({
584+ "type" : "browser_setup" ,
585+ "status" : "success" ,
586+ "message" : "Browser successfully initialized" ,
587+ "live_browser_url" : live_browser_url ,
588+ "session_id" : session_id ,
589+ "timestamp" : datetime .utcnow ().isoformat ()
590+ })
591+
592+ return live_browser_url , session_id
593+ except Exception as e :
594+ print (f"Error setting up browser: { e } " )
595+ if websocket :
596+ await websocket .send_json ({
597+ "type" : "browser_setup" ,
598+ "status" : "failed" ,
599+ "reason" : str (e ),
600+ "timestamp" : datetime .utcnow ().isoformat ()
601+ })
602+ return None , None
438603
439604 async def observe (self ) -> None :
440605 page = await self .playwright_manager .get_page ()
@@ -520,7 +685,7 @@ async def generate_candidate_actions(self, node: LATSNode) -> list[dict]:
520685 valid_actions .append (action_data )
521686 return valid_actions
522687
523- async def generate_children (self , node : LATSNode ) -> list [LATSNode ]:
688+ async def generate_children (self , node : LATSNode , websocket = None ) -> list [LATSNode ]:
524689 print (f"{ GREEN } -- generating candidate actions...{ RESET } " )
525690
526691 children = []
@@ -588,3 +753,24 @@ def get_path_to_root(self, node: LATSNode) -> List[LATSNode]:
588753 path .append (current )
589754 current = current .parent
590755 return list (reversed (path ))
756+
757+ def _get_tree_data (self ):
758+ """Get tree data in a format suitable for visualization"""
759+ nodes = collect_all_nodes (self .root_node )
760+ tree_data = []
761+
762+ for node in nodes :
763+ node_data = {
764+ "id" : id (node ),
765+ "parent_id" : id (node .parent ) if node .parent else None ,
766+ "action" : node .action if node .action else "ROOT" ,
767+ "description" : node .natural_language_description ,
768+ "depth" : node .depth ,
769+ "is_terminal" : node .is_terminal ,
770+ "value" : node .value ,
771+ "visits" : node .visits ,
772+ "reward" : node .reward
773+ }
774+ tree_data .append (node_data )
775+
776+ return tree_data
0 commit comments