Skip to content

Commit 917889d

Browse files
committed
Remove the reward attribute of the node and add proper close session logic to all agents
1 parent a6a814f commit 917889d

File tree

6 files changed

+23
-12
lines changed

6 files changed

+23
-12
lines changed

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/base_agent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _get_tree_data(self):
102102
"value": node.value,
103103
"visits": node.visits,
104104
"feedback": node.feedback,
105-
"reward": node.reward
105+
# "reward": node.reward
106106
}
107107
tree_data.append(node_data)
108108

@@ -129,7 +129,7 @@ async def remove_simulated_trajectory(self, starting_node, terminal_node: LATSNo
129129
"description": node.natural_language_description,
130130
"visits": node.visits,
131131
"value": float(f"{node.value:.3f}") if hasattr(node, 'value') else None,
132-
"reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None,
132+
# "reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None,
133133
"is_terminal": node.is_terminal,
134134
"feedback": node.feedback if hasattr(node, 'feedback') else None,
135135
"is_root": not hasattr(node, 'parent') or node.parent is None,
@@ -159,7 +159,7 @@ def _get_trajectory_data(self, terminal_node: LATSNode):
159159
"description": node.natural_language_description,
160160
"visits": node.visits,
161161
"value": float(f"{node.value:.3f}") if hasattr(node, 'value') else None,
162-
"reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None,
162+
# "reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None,
163163
"is_terminal": node.is_terminal,
164164
"feedback": node.feedback if hasattr(node, 'feedback') else None,
165165
"is_root": not hasattr(node, 'parent') or node.parent is None,
@@ -432,7 +432,7 @@ async def node_children_evaluation(self, node: LATSNode) -> None:
432432

433433
for child, score in zip(node.children, scores):
434434
child.value = score
435-
child.reward = score
435+
# child.reward = score
436436

437437
async def node_evaluation(self, node: LATSNode) -> None:
438438
"""Evaluate the current node and assign its score."""
@@ -469,7 +469,7 @@ async def node_evaluation(self, node: LATSNode) -> None:
469469

470470
# Assign the score to the node
471471
node.value = score
472-
node.reward = score
472+
# node.reward = score
473473

474474

475475
except Exception as e:

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ async def lats_search(self, websocket=None):
3737
await self.websocket_node_selection(node, websocket=websocket)
3838

3939
if node is None:
40-
print("All paths lead to terminal nodes with reward 0. Ending search.")
40+
print("All paths lead to terminal nodes with value 0. Ending search.")
4141
break
4242

4343
# Step 2: Node Expansion
@@ -76,8 +76,10 @@ async def lats_search(self, websocket=None):
7676
terminal_nodes.append(terminal_node)
7777
await self.websocket_simulation_result(reward, terminal_node, websocket=websocket)
7878

79-
if reward == 1:
79+
# simulation score threshold
80+
if reward >= self.config.simulation_score:
8081
await self.websocket_search_complete("success", reward, terminal_node.get_trajectory(), websocket=websocket)
82+
await self.playwright_manager.close()
8183
return terminal_node
8284

8385
# Step 5: Backpropagation
@@ -95,8 +97,8 @@ async def lats_search(self, websocket=None):
9597
all_nodes_list = collect_all_nodes(self.root_node)
9698
all_nodes_list.extend(terminal_nodes)
9799

98-
## temp change: if reward is the same, choose the deeper node
99-
best_child = max(all_nodes_list, key=lambda x: (x.reward, x.depth))
100+
## temp change: if value is the same, choose the deeper node
101+
best_child = max(all_nodes_list, key=lambda x: (x.value, x.depth))
100102

101103
if best_child.value >= 0.75:
102104
print("Successful trajectory found")

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
self.value = 0.0
8989
self.depth = 0 if parent is None else parent.depth + 1
9090
self.is_terminal = False
91-
self.reward = 0.0
91+
# self.reward = 0.0
9292
self.exhausted = False # If all children are terminal
9393
self.em = 0.0 # Exact match, evaluation metric
9494
self.observation: Optional[Observation] = None
@@ -177,7 +177,7 @@ def to_dict(self) -> dict:
177177
'value': self.value,
178178
'depth': self.depth,
179179
'is_terminal': self.is_terminal,
180-
'reward': self.reward,
180+
# 'reward': self.reward,
181181
'em': self.em,
182182
}
183183

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
302302
# Convert path to serializable trajectory
303303
# trajectory = [node.action for node in path if node.action is not None]
304304
await self.websocket_search_complete("success", score, selected_node.get_trajectory(), websocket=websocket)
305+
await self.playwright_manager.close()
305306
return selected_node
306307

307308
print(f"path: {path}")
@@ -328,4 +329,5 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
328329
# Convert node to serializable trajectory
329330
# trajectory = [n.action for n in self.get_path_to_root(best_node) if n.action is not None]
330331
await self.websocket_search_complete("partial_success", best_node.value, best_node.get_trajectory(), websocket=websocket)
332+
await self.playwright_manager.close()
331333
return best_node

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ async def bfs(self, websocket=None):
110110

111111
# Send completion update if websocket is provided
112112
await self.websocket_search_complete("success", score, current_node.get_trajectory(), websocket=websocket)
113+
await self.playwright_manager.close()
113114

114115
return current_node
115116

@@ -120,6 +121,7 @@ async def bfs(self, websocket=None):
120121

121122
# Send completion update if websocket is provided
122123
await self.websocket_search_complete("partial_success", best_score, best_node.get_trajectory(), websocket=websocket)
124+
await self.playwright_manager.close()
123125

124126
return best_node
125127

@@ -128,6 +130,7 @@ async def bfs(self, websocket=None):
128130

129131
# Send failure update if websocket is provided
130132
await self.websocket_search_complete("failure", 0, None, websocket=websocket)
133+
await self.playwright_manager.close()
131134

132135
return None
133136

@@ -209,7 +212,8 @@ async def dfs(self, websocket=None) -> List[Dict[str, Any]]:
209212
print(f"Found satisfactory solution with score {score}")
210213

211214
# Send completion update if websocket is provided
212-
await self.websocket_search_complete("success", score, current_node.get_trajectory(), websocket=websocket)
215+
await self.websocket_search_complete("success", score, current_node.get_trajectory(), websocket=websocket)
216+
await self.playwright_manager.close()
213217
return current_node
214218

215219
# Add non-terminal children to stack in reverse order
@@ -234,6 +238,7 @@ async def dfs(self, websocket=None) -> List[Dict[str, Any]]:
234238

235239
# Send completion update if websocket is provided
236240
await self.websocket_search_complete("partial_success", best_score, best_node.get_trajectory(), websocket=websocket)
241+
await self.playwright_manager.close()
237242

238243
return best_node
239244

@@ -242,6 +247,7 @@ async def dfs(self, websocket=None) -> List[Dict[str, Any]]:
242247

243248
# Send failure update if websocket is provided
244249
await self.websocket_search_complete("failure", 0, None, websocket=websocket)
250+
await self.playwright_manager.close()
245251

246252
return None
247253

visual-tree-search-backend/app/api/lwats/core_async/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class AgentConfig:
2525
num_simulations: int = 1
2626
account_reset: bool = True
2727

28+
simulation_score: float = 0.75
2829
reflection_score: float = 0.75
2930

3031
# Features

0 commit comments

Comments
 (0)