Skip to content

Commit 6be89b8

Browse files
committed
add websocket
1 parent 8b38805 commit 6be89b8

File tree

4 files changed

+383
-27
lines changed

4 files changed

+383
-27
lines changed

visual-tree-search-backend/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,12 @@ python run_demo_treesearch_async.py \
8282
--goal "search running shoes, click on the first result" \
8383
--iterations 3 \
8484
--max_depth 3
85+
```
86+
87+
## 7. Add LATS agent
88+
* test run_demo_treesearch_async.py
89+
* test web socket
90+
```
91+
uvicorn app.main:app --host 0.0.0.0 --port 3000
92+
python test/test-tree-search-ws-lats.py
8593
```

visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/lats_agent.py

Lines changed: 213 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22

33
import time
44
from typing import Any, Optional, Tuple, List
5-
5+
import os
66
from openai import OpenAI
7+
from datetime import datetime
8+
import aiohttp
9+
from dotenv import load_dotenv
10+
load_dotenv()
711

812
from .lats_node import LATSNode, Observation
913
from ...core_async.config import AgentConfig
@@ -12,7 +16,6 @@
1216
from ...webagent_utils_async.utils.playwright_manager import AsyncPlaywrightManager, setup_playwright
1317
from .tree_vis import RED, better_print, print_trajectory, collect_all_nodes, GREEN, RESET, print_entire_tree
1418
from .trajectory_score import create_llm_prompt, score_trajectory_with_openai
15-
# from ...replay import locate_element_from_action, step_execution
1619
from ...replay_async import generate_feedback, playwright_step_execution, locate_element_from_action
1720
from ...webagent_utils_async.browser_env.observation import extract_page_info, observe_features
1821
from ...webagent_utils_async.action.prompt_functions import generate_actions_with_observation
@@ -22,7 +25,6 @@
2225

2326
from ...webagent_utils_async.utils.utils import parse_function_args, locate_element
2427
from ...evaluation_async.evaluators import goal_finished_evaluator
25-
# from ...replay import playwright_step_execution, generate_feedback
2628
from ...webagent_utils_async.action.prompt_functions import extract_top_actions
2729
from ...webagent_utils_async.browser_env.observation import extract_page_info
2830
from .lats_node import LATSNode
@@ -91,22 +93,47 @@ def __init__(
9193
)
9294
self.goal_finished = False
9395
self.result_node = None
96+
self.reset_url = os.environ["ACCOUNT_RESET_URL"]
9497

95-
async def run(self) -> list[LATSNode]:
98+
async def run(self, websocket=None) -> list[LATSNode]:
9699
"""
97100
Run the LATS search and return the best path found.
98101
102+
Args:
103+
websocket: Optional WebSocket connection for sending updates
104+
99105
Returns:
100106
list[LATSNode]: Best path from root to terminal node
101107
"""
102-
best_node = await self.lats_search()
108+
if websocket:
109+
await websocket.send_json({
110+
"type": "search_status",
111+
"status": "started",
112+
"message": "Starting LATS search",
113+
"timestamp": datetime.utcnow().isoformat()
114+
})
115+
116+
best_node = await self.lats_search(websocket)
103117
print_trajectory(best_node)
118+
119+
if websocket:
120+
await websocket.send_json({
121+
"type": "search_complete",
122+
"status": "success" if best_node.reward == 1 else "partial_success",
123+
"score": best_node.reward,
124+
"path": [{"id": id(node), "action": node.action} for node in best_node.get_trajectory()],
125+
"timestamp": datetime.utcnow().isoformat()
126+
})
127+
104128
return best_node.get_trajectory()
105129

106-
async def lats_search(self) -> LATSNode:
130+
async def lats_search(self, websocket=None) -> LATSNode:
107131
"""
108132
Perform the main LATS search algorithm.
109133
134+
Args:
135+
websocket: Optional WebSocket connection for sending updates
136+
110137
Returns:
111138
LATSNode: Best terminal node found
112139
"""
@@ -116,13 +143,26 @@ async def lats_search(self) -> LATSNode:
116143
terminal_nodes = []
117144

118145
for i in range(self.config.iterations):
146+
if websocket:
147+
await websocket.send_json({
148+
"type": "iteration_start",
149+
"iteration": i + 1,
150+
"timestamp": datetime.utcnow().isoformat()
151+
})
152+
119153
print(f"")
120154
print(f"")
121155
print(f"Iteration {i + 1}...")
122156

123-
# Step 1: Selection
124-
print(f"")
125-
print(f"{GREEN}Step 1: selection{RESET}")
157+
# Step 1: Selection with websocket update
158+
if websocket:
159+
await websocket.send_json({
160+
"type": "step_start",
161+
"step": "selection",
162+
"iteration": i + 1,
163+
"timestamp": datetime.utcnow().isoformat()
164+
})
165+
126166
node = self.select_node(self.root_node)
127167

128168
if node is None:
@@ -133,16 +173,22 @@ async def lats_search(self) -> LATSNode:
133173
better_print(node=self.root_node, selected_node=node)
134174
print(f"")
135175

136-
# Step 2: Expansion
137-
print(f"")
138-
print(f"{GREEN}Step 2: expansion{RESET}")
139-
await self.expand_node(node)
176+
# Step 2: Expansion with websocket update
177+
if websocket:
178+
await websocket.send_json({
179+
"type": "step_start",
180+
"step": "expansion",
181+
"iteration": i + 1,
182+
"timestamp": datetime.utcnow().isoformat()
183+
})
184+
185+
await self.expand_node(node, websocket)
140186

141187
while node is not None and node.is_terminal and not self.goal_finished:
142188
print(f"Depth limit node found at iteration {i + 1}, reselecting...")
143189
node = self.select_node(self.root_node)
144190
if node is not None:
145-
await self.expand_node(node)
191+
await self.expand_node(node, websocket)
146192

147193
if node is None:
148194
# all the nodes are terminal, stop the search
@@ -183,6 +229,15 @@ async def lats_search(self) -> LATSNode:
183229
better_print(self.root_node)
184230
print(f"")
185231

232+
# Send tree update after each iteration
233+
if websocket:
234+
tree_data = self._get_tree_data()
235+
await websocket.send_json({
236+
"type": "tree_update",
237+
"tree": tree_data,
238+
"timestamp": datetime.utcnow().isoformat()
239+
})
240+
186241
# Find best node
187242
all_nodes_list = collect_all_nodes(self.root_node)
188243
all_nodes_list.extend(terminal_nodes)
@@ -212,20 +267,43 @@ def select_node(self, node: LATSNode) -> Optional[LATSNode]:
212267
return None
213268
return node.get_best_leaf()
214269

215-
async def expand_node(self, node: LATSNode) -> None:
270+
async def expand_node(self, node: LATSNode, websocket=None) -> None:
216271
"""
217272
Expand a node by generating its children.
218273
219274
Args:
220275
node: Node to expand
276+
websocket: Optional WebSocket connection for sending updates
221277
"""
222-
children = await self.generate_children(node)
278+
if websocket:
279+
await websocket.send_json({
280+
"type": "node_expanding",
281+
"node_id": id(node),
282+
"timestamp": datetime.utcnow().isoformat()
283+
})
284+
285+
children = await self.generate_children(node, websocket)
223286

224287
for child in children:
225288
node.add_child(child)
289+
if websocket:
290+
await websocket.send_json({
291+
"type": "node_created",
292+
"node_id": id(child),
293+
"parent_id": id(node),
294+
"action": child.action,
295+
"description": child.natural_language_description,
296+
"timestamp": datetime.utcnow().isoformat()
297+
})
226298

227299
if children and children[0].goal_finish_feedback.is_done:
228300
self.set_goal_finished(children[0])
301+
if websocket:
302+
await websocket.send_json({
303+
"type": "goal_finished",
304+
"node_id": id(children[0]),
305+
"timestamp": datetime.utcnow().isoformat()
306+
})
229307
return
230308

231309
node.check_terminal()
@@ -424,17 +502,104 @@ def backpropagate(self, node: LATSNode, value: float) -> None:
424502
node.value = (node.value * (node.visits - 1) + value) / node.visits
425503
node = node.parent
426504

427-
async def _reset_browser(self) -> None:
428-
"""Reset the browser to initial state."""
505+
async def _reset_browser(self, websocket=None) -> Optional[str]:
506+
"""Reset the browser to initial state and return the live browser URL if available."""
429507
await self.playwright_manager.close()
430-
self.playwright_manager = await setup_playwright(
431-
headless=self.config.headless,
432-
mode=self.config.browser_mode,
433-
storage_state=self.config.storage_state,
434-
# log_folder=self.config.log_folder,
435-
)
436-
page = await self.playwright_manager.get_page()
437-
await page.goto(self.starting_url, wait_until="networkidle")
508+
509+
## reset account using api-based account reset
510+
if self.config.account_reset:
511+
if websocket:
512+
await websocket.send_json({
513+
"type": "account_reset",
514+
"status": "started",
515+
"timestamp": datetime.utcnow().isoformat()
516+
})
517+
518+
try:
519+
# Use aiohttp instead of curl
520+
async with aiohttp.ClientSession() as session:
521+
headers = {'Connection': 'close'} # Similar to curl -N
522+
async with session.get(self.reset_url, headers=headers) as response:
523+
if response.status == 200:
524+
data = await response.json()
525+
print(f"Account reset successful: {data}")
526+
if websocket:
527+
await websocket.send_json({
528+
"type": "account_reset",
529+
"status": "success",
530+
"data": data,
531+
"timestamp": datetime.utcnow().isoformat()
532+
})
533+
else:
534+
error_msg = f"Account reset failed with status {response.status}"
535+
print(error_msg)
536+
if websocket:
537+
await websocket.send_json({
538+
"type": "account_reset",
539+
"status": "failed",
540+
"reason": error_msg,
541+
"timestamp": datetime.utcnow().isoformat()
542+
})
543+
544+
except Exception as e:
545+
print(f"Error during account reset: {e}")
546+
if websocket:
547+
await websocket.send_json({
548+
"type": "account_reset",
549+
"status": "failed",
550+
"reason": str(e),
551+
"timestamp": datetime.utcnow().isoformat()
552+
})
553+
554+
try:
555+
# Create new playwright manager
556+
self.playwright_manager = await setup_playwright(
557+
storage_state=self.config.storage_state,
558+
headless=self.config.headless,
559+
mode=self.config.browser_mode
560+
)
561+
page = await self.playwright_manager.get_page()
562+
live_browser_url = None
563+
if self.config.browser_mode == "browserbase":
564+
live_browser_url = await self.playwright_manager.get_live_browser_url()
565+
session_id = await self.playwright_manager.get_session_id()
566+
else:
567+
session_id = None
568+
live_browser_url = None
569+
await page.goto(self.starting_url, wait_until="networkidle")
570+
571+
# Send success message if websocket is provided
572+
if websocket:
573+
if self.config.storage_state:
574+
await websocket.send_json({
575+
"type": "browser_setup",
576+
"status": "success",
577+
"message": f"Browser successfully initialized with storage state file: {self.config.storage_state}",
578+
"live_browser_url": live_browser_url,
579+
"session_id": session_id,
580+
"timestamp": datetime.utcnow().isoformat()
581+
})
582+
else:
583+
await websocket.send_json({
584+
"type": "browser_setup",
585+
"status": "success",
586+
"message": "Browser successfully initialized",
587+
"live_browser_url": live_browser_url,
588+
"session_id": session_id,
589+
"timestamp": datetime.utcnow().isoformat()
590+
})
591+
592+
return live_browser_url, session_id
593+
except Exception as e:
594+
print(f"Error setting up browser: {e}")
595+
if websocket:
596+
await websocket.send_json({
597+
"type": "browser_setup",
598+
"status": "failed",
599+
"reason": str(e),
600+
"timestamp": datetime.utcnow().isoformat()
601+
})
602+
return None, None
438603

439604
async def observe(self) -> None:
440605
page = await self.playwright_manager.get_page()
@@ -520,7 +685,7 @@ async def generate_candidate_actions(self, node: LATSNode) -> list[dict]:
520685
valid_actions.append(action_data)
521686
return valid_actions
522687

523-
async def generate_children(self, node: LATSNode) -> list[LATSNode]:
688+
async def generate_children(self, node: LATSNode, websocket=None) -> list[LATSNode]:
524689
print(f"{GREEN}-- generating candidate actions...{RESET}")
525690

526691
children = []
@@ -588,3 +753,24 @@ def get_path_to_root(self, node: LATSNode) -> List[LATSNode]:
588753
path.append(current)
589754
current = current.parent
590755
return list(reversed(path))
756+
757+
def _get_tree_data(self):
758+
"""Get tree data in a format suitable for visualization"""
759+
nodes = collect_all_nodes(self.root_node)
760+
tree_data = []
761+
762+
for node in nodes:
763+
node_data = {
764+
"id": id(node),
765+
"parent_id": id(node.parent) if node.parent else None,
766+
"action": node.action if node.action else "ROOT",
767+
"description": node.natural_language_description,
768+
"depth": node.depth,
769+
"is_terminal": node.is_terminal,
770+
"value": node.value,
771+
"visits": node.visits,
772+
"reward": node.reward
773+
}
774+
tree_data.append(node_data)
775+
776+
return tree_data

visual-tree-search-backend/app/api/routes/tree_search_websocket.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]):
126126
elif search_algorithm.lower() == "dfs":
127127
# Use the agent's built-in WebSocket-enabled DFS method
128128
await agent.dfs_with_websocket(websocket)
129+
elif search_algorithm.lower() == "lats":
130+
await agent.lats_search(websocket)
129131
else:
130132
await websocket.send_json({
131133
"type": "error",

0 commit comments

Comments
 (0)