Skip to content

Commit bb63603

Browse files
committed
step 1,2,3 done
1 parent 9e1e75a commit bb63603

File tree

2 files changed

+94
-85
lines changed

2 files changed

+94
-85
lines changed

visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/lats_agent.py

Lines changed: 93 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Language-based Action Tree Search (LATS) Agent implementation."""
22

3-
import logging
43
import time
54
from typing import Any, Optional, Tuple, List
65

@@ -34,7 +33,6 @@
3433
from ...webagent_utils_async.browser_env.observation import extract_page_info
3534
from ...webagent_utils_async.evaluation.feedback import capture_post_action_feedback
3635

37-
logger = logging.getLogger(__name__)
3836
openai_client = OpenAI()
3937

4038
class LATSAgent:
@@ -101,95 +99,85 @@ async def run(self) -> list[LATSNode]:
10199
Returns:
102100
list[LATSNode]: Best path from root to terminal node
103101
"""
104-
pass
105-
# best_node = self.lats_search()
106-
# print_trajectory(best_node)
107-
# return best_node.get_trajectory()
102+
best_node = await self.lats_search()
103+
print_trajectory(best_node)
104+
return best_node.get_trajectory()
108105

109-
def lats_search(self) -> LATSNode:
106+
async def lats_search(self) -> LATSNode:
110107
"""
111108
Perform the main LATS search algorithm.
112109
113110
Returns:
114111
LATSNode: Best terminal node found
115112
"""
116-
logger.info(f"")
117-
logger.info(f"{GREEN}START SEARCH{RESET}")
113+
print(f"")
114+
print(f"{GREEN}START SEARCH{RESET}")
118115

119116
terminal_nodes = []
120117

121118
for i in range(self.config.iterations):
122-
logger.info(f"")
123-
logger.info(f"")
124-
logger.info(f"Iteration {i + 1}...")
119+
print(f"")
120+
print(f"")
121+
print(f"Iteration {i + 1}...")
125122

126123
# Step 1: Selection
127-
logger.info(f"")
128-
logger.info(f"{GREEN}Step 1: selection{RESET}")
124+
print(f"")
125+
print(f"{GREEN}Step 1: selection{RESET}")
129126
node = self.select_node(self.root_node)
130127

131128
if node is None:
132-
logger.info("All paths lead to terminal nodes with reward 0. Ending search.")
129+
print("All paths lead to terminal nodes with reward 0. Ending search.")
133130
break
134131

135132
print(f"{GREEN}Tree:{RESET}")
136133
better_print(node=self.root_node, selected_node=node)
137134
print(f"")
138135

139136
# Step 2: Expansion
140-
logger.info(f"")
141-
logger.info(f"{GREEN}Step 2: expansion{RESET}")
142-
self.expand_node(node)
137+
print(f"")
138+
print(f"{GREEN}Step 2: expansion{RESET}")
139+
await self.expand_node(node)
143140

144141
while node is not None and node.is_terminal and not self.goal_finished:
145-
logger.info(f"Depth limit node found at iteration {i + 1}, reselecting...")
142+
print(f"Depth limit node found at iteration {i + 1}, reselecting...")
146143
node = self.select_node(self.root_node)
147144
if node is not None:
148-
self.expand_node(node)
145+
await self.expand_node(node)
149146

150147
if node is None:
151148
# all the nodes are terminal, stop the search
152-
logger.info(f"{RED}All nodes are terminal, stopping search{RESET}")
149+
print(f"{RED}All nodes are terminal, stopping search{RESET}")
153150
break
154151

155152
if self.goal_finished:
156-
logger.info(f"{RED}Goal finished, stopping search{RESET}")
153+
print(f"{RED}Goal finished, stopping search{RESET}")
157154
break
158155

159156
print(f"{GREEN}Tree:{RESET}")
160157
better_print(self.root_node)
161158
print(f"")
162159

163160
# Step 3: Evaluation
164-
logger.info(f"")
165-
logger.info(f"{GREEN}Step 3: evaluation{RESET}")
166-
self.evaluate_node(node)
161+
print(f"")
162+
print(f"{GREEN}Step 3: evaluation{RESET}")
163+
await self.evaluate_node(node)
167164

168165
print(f"{GREEN}Tree:{RESET}")
169166
better_print(self.root_node)
170167
print(f"")
171168

172169
# Step 4: Simulation
173-
logger.info(f"{GREEN}Step 4: simulation{RESET}")
170+
print(f"{GREEN}Step 4: simulation{RESET}")
174171
# # Find the child with the highest value
175172
## always = 1
176-
reward, terminal_node = self.simulate(max(node.children, key=lambda child: child.value), max_depth=self.config.max_depth, num_simulations=1)
173+
reward, terminal_node = await self.simulate(max(node.children, key=lambda child: child.value), max_depth=self.config.max_depth, num_simulations=1)
177174
terminal_nodes.append(terminal_node)
178175

179176
if reward == 1:
180177
return terminal_node
181178

182-
183-
# print(f"{GREEN}Tree:{RESET}")
184-
# better_print(self.root_node, selected_node=terminal_node)
185-
# print(f"")
186-
187-
# if self.goal_finished:
188-
# logger.info(f"{RED}Goal finished, stopping search{RESET}")
189-
# break
190-
191179
# Step 5: Backpropagation
192-
logger.info(f"{GREEN}Step 5: backpropagation{RESET}")
180+
print(f"{GREEN}Step 5: backpropagation{RESET}")
193181
self.backpropagate(terminal_node, reward)
194182
print(f"{GREEN}Tree:{RESET}")
195183
better_print(self.root_node)
@@ -203,10 +191,10 @@ def lats_search(self) -> LATSNode:
203191
best_child = max(all_nodes_list, key=lambda x: (x.reward, x.depth))
204192

205193
if best_child.reward == 1:
206-
logger.info("Successful trajectory found")
194+
print("Successful trajectory found")
207195
else:
208-
logger.info("Unsuccessful trajectory found")
209-
self.playwright_manager.close()
196+
print("Unsuccessful trajectory found")
197+
await self.playwright_manager.close()
210198

211199
return best_child if best_child is not None else self.root_node
212200

@@ -224,14 +212,14 @@ def select_node(self, node: LATSNode) -> Optional[LATSNode]:
224212
return None
225213
return node.get_best_leaf()
226214

227-
def expand_node(self, node: LATSNode) -> None:
215+
async def expand_node(self, node: LATSNode) -> None:
228216
"""
229217
Expand a node by generating its children.
230218
231219
Args:
232220
node: Node to expand
233221
"""
234-
children = self.generate_children(node)
222+
children = await self.generate_children(node)
235223

236224
for child in children:
237225
node.add_child(child)
@@ -242,7 +230,7 @@ def expand_node(self, node: LATSNode) -> None:
242230

243231
node.check_terminal()
244232

245-
def evaluate_node(self, node: LATSNode) -> None:
233+
async def evaluate_node(self, node: LATSNode) -> None:
246234
"""
247235
Evaluate a node using LLM scoring.
248236
@@ -253,23 +241,23 @@ def evaluate_node(self, node: LATSNode) -> None:
253241
float: Evaluation score
254242
"""
255243
scores = []
256-
logger.info(f"{GREEN}-- total {len(node.children)} children to evaluate:{RESET}")
244+
print(f"{GREEN}-- total {len(node.children)} children to evaluate:{RESET}")
257245
for i, child in enumerate(node.children):
258-
logger.info(f"{GREEN}--- evaluating child {i+1}...{RESET}")
246+
print(f"{GREEN}--- evaluating child {i+1}...{RESET}")
259247
if child.is_terminal:
260248
score = 0
261249
else:
262250
trajectory = child.get_trajectory()
263251
prompt = create_llm_prompt(trajectory, self.goal)
264252
result = score_trajectory_with_openai(prompt, openai_client, self.config.evaluation_model, child.observation.image)
265-
score = result["score"]/10
253+
score = result["overall_score"]
266254
scores.append(score)
267255

268256
for child, score in zip(node.children, scores):
269257
child.value = score
270258
child.reward = score
271259

272-
def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1) -> tuple[float, LATSNode]:
260+
async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1) -> tuple[float, LATSNode]:
273261
"""
274262
Perform a rollout simulation from a node.
275263
@@ -285,7 +273,7 @@ def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1) -> tup
285273
print_trajectory(node)
286274
print("print the entire tree")
287275
print_entire_tree(self.root_node)
288-
return self.rollout(node, max_depth=max_depth)
276+
return await self.rollout(node, max_depth=max_depth)
289277

290278
def send_completion_request(self, plan, depth, node, trajectory=[]):
291279
print("print the trajectory")
@@ -436,22 +424,22 @@ def backpropagate(self, node: LATSNode, value: float) -> None:
436424
node.value = (node.value * (node.visits - 1) + value) / node.visits
437425
node = node.parent
438426

439-
def _reset_browser(self) -> None:
427+
async def _reset_browser(self) -> None:
440428
"""Reset the browser to initial state."""
441-
self.playwright_manager.close()
442-
self.playwright_manager = setup_playwright(
429+
await self.playwright_manager.close()
430+
self.playwright_manager = await setup_playwright(
443431
headless=self.config.headless,
444432
mode=self.config.browser_mode,
445433
storage_state=self.config.storage_state,
446-
log_folder=self.config.log_folder,
434+
# log_folder=self.config.log_folder,
447435
)
448-
page = self.playwright_manager.get_page()
449-
page.goto(self.starting_url, wait_until="networkidle")
436+
page = await self.playwright_manager.get_page()
437+
await page.goto(self.starting_url, wait_until="networkidle")
450438

451-
def observe(self) -> None:
452-
page = self.playwright_manager.get_page()
453-
page_info = extract_page_info(page, self.config.fullpage, self.config.log_folder)
454-
feature_text = observe_features(
439+
async def observe(self) -> None:
440+
page = await self.playwright_manager.get_page()
441+
page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder)
442+
feature_text = await observe_features(
455443
page_info,
456444
features=self.config.features,
457445
elements_filter=self.config.elements_filter,
@@ -465,26 +453,47 @@ def observe(self) -> None:
465453
)
466454
return observation
467455

468-
def execute_action_trajectory(self, action_trajectory: list[dict]) -> None:
456+
async def execute_action_trajectory(self, action_trajectory: list[dict]) -> None:
469457
if not action_trajectory:
470458
return True
471459

472-
self._reset_browser()
460+
await self._reset_browser()
461+
print("taking action trajectory")
473462
for action_data in action_trajectory:
474-
success = step_execution(action_data, self.playwright_manager, self.config.log_folder)
463+
print("action_data")
464+
print(action_data)
465+
466+
# Convert action_data dict to LATSNode
467+
temp_node = LATSNode(
468+
natural_language_description=action_data["natural_language_description"],
469+
action=action_data["action"],
470+
prob=action_data["prob"],
471+
element=action_data["element"],
472+
goal=self.goal,
473+
parent=None # No parent needed for temporary node
474+
)
475+
476+
success = await playwright_step_execution(
477+
temp_node, # Pass the node instead of raw action_data
478+
self.goal,
479+
self.playwright_manager,
480+
is_replay=False,
481+
log_folder=self.config.log_folder
482+
)
483+
475484
if not success:
476485
return False
477486
return True
478487

479-
def generate_candidate_actions(self, node: LATSNode) -> list[dict]:
488+
async def generate_candidate_actions(self, node: LATSNode) -> list[dict]:
480489
trajectory = node.get_trajectory()
481490
action_trajectory = node.get_action_trajectory()
482-
self.execute_action_trajectory(action_trajectory)
483-
observation = self.observe()
491+
await self.execute_action_trajectory(action_trajectory)
492+
observation = await self.observe()
484493
# only root node has no observation at this point
485494
if node.observation is None:
486495
node.observation = observation
487-
actions = generate_actions_with_observation(
496+
actions = await generate_actions_with_observation(
488497
trajectory,
489498
self.goal,
490499
self.images,
@@ -497,56 +506,56 @@ def generate_candidate_actions(self, node: LATSNode) -> list[dict]:
497506
action_generation_model=self.config.action_generation_model,
498507
)
499508

500-
page = self.playwright_manager.get_page()
509+
page = await self.playwright_manager.get_page()
501510
valid_actions = []
502511
for action_data in actions:
503512
if action_data["action"] == "FINISH":
504513
continue
505514

506-
is_bid_action, element_data = locate_element_from_action(page, action_data["action"])
515+
is_bid_action, element_data = await locate_element_from_action(page, action_data["action"])
507516
if is_bid_action and not element_data:
508517
continue
509518

510519
action_data['element'] = element_data
511520
valid_actions.append(action_data)
512521
return valid_actions
513522

514-
def generate_children(self, node: LATSNode) -> list[LATSNode]:
515-
logger.info(f"{GREEN}-- generating candidate actions...{RESET}")
523+
async def generate_children(self, node: LATSNode) -> list[LATSNode]:
524+
print(f"{GREEN}-- generating candidate actions...{RESET}")
516525

517526
children = []
518527

519528
action_trajectory = node.get_action_trajectory()
520-
candidate_actions = self.generate_candidate_actions(node)
521-
logger.info(f"{GREEN}-- generated {len(candidate_actions)} actions{RESET}")
529+
candidate_actions = await self.generate_candidate_actions(node)
530+
print(f"{GREEN}-- generated {len(candidate_actions)} actions{RESET}")
522531
for action_data in candidate_actions:
523-
logger.info(f"{GREEN}--- {action_data['action']}{RESET}")
524-
logger.info(f"{GREEN}--- {action_data['natural_language_description']}{RESET}")
532+
print(f"{GREEN}--- {action_data['action']}{RESET}")
533+
print(f"{GREEN}--- {action_data['natural_language_description']}{RESET}")
525534

526-
logger.info(f"")
527-
logger.info(f"{GREEN}-- executing candidate trajectories{RESET}")
535+
print(f"")
536+
print(f"{GREEN}-- executing candidate trajectories{RESET}")
528537
for i, action_data in enumerate(candidate_actions):
529538

530539
candidate_action_trajectory = action_trajectory + [action_data]
531-
logger.info(f"{GREEN}--- trajectory {i+1}:{RESET}")
540+
print(f"{GREEN}--- trajectory {i+1}:{RESET}")
532541
for action in candidate_action_trajectory:
533-
logger.info(f"{GREEN}---- {action['action']}{RESET}")
534-
logger.info(f"{GREEN}---- {action['natural_language_description']}{RESET}")
535-
executed_successfully = self.execute_action_trajectory(candidate_action_trajectory)
542+
print(f"{GREEN}---- {action['action']}{RESET}")
543+
print(f"{GREEN}---- {action['natural_language_description']}{RESET}")
544+
executed_successfully = await self.execute_action_trajectory(candidate_action_trajectory)
536545
if not executed_successfully:
537546
# not executed successfully, give up this candidate
538-
logger.info(f"{RED}--- failed to execute action trajectory{RESET}")
547+
print(f"{RED}--- failed to execute action trajectory{RESET}")
539548
continue
540549

541-
observation = self.observe()
542-
logger.info(f"{GREEN}--- generate feedback...{RESET}")
543-
feedback = generate_feedback_with_screenshot(
550+
observation = await self.observe()
551+
print(f"{GREEN}--- generate feedback...{RESET}")
552+
feedback = await generate_feedback_with_screenshot(
544553
self.goal,
545554
action_data["natural_language_description"],
546555
observation.image,
547556
model=self.config.feedback_model,
548557
)
549-
logger.info(f"feedback: is_done: {feedback.is_done}, explanation: {feedback.explanation}")
558+
print(f"feedback: is_done: {feedback.is_done}, explanation: {feedback.explanation}")
550559

551560
child = LATSNode(
552561
natural_language_description=action_data["natural_language_description"],

visual-tree-search-backend/app/api/lwats/webagent_utils_async/browser_env/observation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ async def observe_features(page_info, features, elements_filter, log_folder, ful
506506

507507
feature_texts = []
508508
if "axtree" in features:
509-
axtree_str = await flatten_axtree_to_str(page_info.get('axtree', ''), extra_properties=page_info['extra_properties'], filter_som_only=filter_som_only, filter_visible_only=filter_visible_only)
509+
axtree_str = flatten_axtree_to_str(page_info.get('axtree', ''), extra_properties=page_info['extra_properties'], filter_som_only=filter_som_only, filter_visible_only=filter_visible_only)
510510
feature_texts.append(ACCESSIBILITY_FEATURE_TEMPLATE.format(axtree_str=axtree_str))
511511

512512
if "interactive_elements" in features:

0 commit comments

Comments
 (0)