Skip to content

Commit 8b38805

Browse files
committed
step 4 and step 5 works
1 parent bb63603 commit 8b38805

File tree

1 file changed

+20
-20
lines changed
  • visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents

1 file changed

+20
-20
lines changed

visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/lats_agent.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1)
275275
print_entire_tree(self.root_node)
276276
return await self.rollout(node, max_depth=max_depth)
277277

278-
def send_completion_request(self, plan, depth, node, trajectory=[]):
278+
async def send_completion_request(self, plan, depth, node, trajectory=[]):
279279
print("print the trajectory")
280280
print_trajectory(node)
281281
print("print the entire tree")
@@ -284,20 +284,20 @@ def send_completion_request(self, plan, depth, node, trajectory=[]):
284284
if depth >= self.config.max_depth:
285285
return trajectory, node
286286

287-
context = self.playwright_manager.get_context()
288-
page = self.playwright_manager.get_page()
287+
context = await self.playwright_manager.get_context()
288+
page = await self.playwright_manager.get_page()
289289
# Extract page information
290290
time.sleep(3)
291-
page_info = extract_page_info(page, fullpage=True, log_folder=self.config.log_folder)
292-
updated_actions = extract_top_actions(
291+
page_info = await extract_page_info(page, fullpage=True, log_folder=self.config.log_folder)
292+
updated_actions = await extract_top_actions(
293293
trajectory, self.goal, self.images, page_info, self.action_set, openai_client,
294294
features=["axtree"], elements_filter="som", branching_factor=self.config.branching_factor,
295295
log_folder=self.config.log_folder, fullpage=True,
296296
action_generation_model=self.config.action_generation_model,
297297
action_grounding_model=self.config.action_grounding_model
298298
)
299299
next_action = updated_actions[0]
300-
retry_count = self.config.retry_count if hasattr(self.config, 'retry_count') else 3 # Default retries if not set
300+
retry_count = self.config.retry_count if hasattr(self.config, 'retry_count') else 1 # Default retries if not set
301301

302302
for attempt in range(retry_count):
303303
try:
@@ -308,13 +308,13 @@ def send_completion_request(self, plan, depth, node, trajectory=[]):
308308
if len(function_calls) == 1:
309309
for function_name, function_args in function_calls:
310310
extracted_number = parse_function_args(function_args)
311-
element = locate_element(page, extracted_number)
311+
element = await locate_element(page, extracted_number)
312312
next_action["element"] = element
313313

314314
# Execute action
315-
execute_action(next_action, self.action_set, page, context, self.goal, page_info['interactive_elements'],
315+
await execute_action(next_action, self.action_set, page, context, self.goal, page_info['interactive_elements'],
316316
self.config.log_folder)
317-
feedback = capture_post_action_feedback(page, next_action, self.goal, self.config.log_folder)
317+
feedback = await capture_post_action_feedback(page, next_action, self.goal, self.config.log_folder)
318318
trajectory.append({'action': next_action['action'], 'feedback': feedback})
319319
action_str = next_action["action"]
320320

@@ -328,7 +328,7 @@ def send_completion_request(self, plan, depth, node, trajectory=[]):
328328
messages.append({"role": "user", "content": 'action is: {}'.format(action)})
329329
messages.append({"role": "user", "content": 'action feedback is: {}'.format(feedback)})
330330

331-
goal_finished = is_goal_finished(messages, openai_client)
331+
goal_finished = await is_goal_finished(messages, openai_client)
332332

333333
new_node = LATSNode(
334334
natural_language_description=next_action["natural_language_description"],
@@ -342,22 +342,22 @@ def send_completion_request(self, plan, depth, node, trajectory=[]):
342342
if goal_finished:
343343
return trajectory, new_node
344344

345-
return self.send_completion_request(plan, depth + 1, new_node, trajectory)
345+
return await self.send_completion_request(plan, depth + 1, new_node, trajectory)
346346

347347
except Exception as e:
348348
print(f"Attempt {attempt + 1} failed with error: {e}")
349349
if attempt + 1 == retry_count:
350350
print("Max retries reached. Skipping this step and retrying the whole request.")
351351
# Retry the entire request from the same state
352-
return self.send_completion_request(plan, depth, node, trajectory)
352+
return await self.send_completion_request(plan, depth, node, trajectory)
353353

354354
# If all retries and retries of retries fail, return the current trajectory and node
355355
return trajectory, node
356356

357357

358-
def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSNode]:
358+
async def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSNode]:
359359
# Reset browser state
360-
self._reset_browser()
360+
await self._reset_browser()
361361
path = self.get_path_to_root(node)
362362

363363
print("execute path")
@@ -367,7 +367,7 @@ def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSNode]:
367367
trajectory = []
368368

369369
for n in path[1:]: # Skip root node
370-
success = playwright_step_execution(
370+
success = await playwright_step_execution(
371371
n,
372372
self.goal,
373373
self.playwright_manager,
@@ -377,7 +377,7 @@ def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSNode]:
377377
if not success:
378378
return 0, n
379379
if not n.feedback:
380-
n.feedback = generate_feedback(
380+
n.feedback = await generate_feedback(
381381
self.goal,
382382
n.natural_language_description,
383383
self.playwright_manager,
@@ -389,14 +389,14 @@ def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSNode]:
389389
## call the prompt agent
390390
print("current depth: ", len(path) - 1)
391391
print("max depth: ", self.config.max_depth)
392-
trajectory, node = self.send_completion_request(self.goal, len(path) - 1, node=n, trajectory=trajectory)
392+
trajectory, node = await self.send_completion_request(self.goal, len(path) - 1, node=n, trajectory=trajectory)
393393
print("print the trajectory")
394394
print_trajectory(node)
395395
print("print the entire tree")
396396
print_entire_tree(self.root_node)
397397

398-
page = self.playwright_manager.get_page()
399-
page_info = extract_page_info(page, self.config.fullpage, self.config.log_folder)
398+
page = await self.playwright_manager.get_page()
399+
page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder)
400400

401401
messages = [{"role": "user", "content": f"Action is: {n.action}"} for n in path[1:]]
402402
goal_finished, confidence_score = goal_finished_evaluator(
@@ -467,7 +467,7 @@ async def execute_action_trajectory(self, action_trajectory: list[dict]) -> None
467467
temp_node = LATSNode(
468468
natural_language_description=action_data["natural_language_description"],
469469
action=action_data["action"],
470-
prob=action_data["prob"],
470+
prob=0,
471471
element=action_data["element"],
472472
goal=self.goal,
473473
parent=None # No parent needed for temporary node

0 commit comments

Comments
 (0)