Skip to content

Commit 35f85b2

Browse files
committed
recover trajectory from LASTNode to serializable action
1 parent c7afb23 commit 35f85b2

File tree

3 files changed

+35
-10
lines changed

3 files changed

+35
-10
lines changed

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,9 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
289289
print(f"{GREEN}Step 4: Reflection Backtracking{RESET}")
290290
await self.websocket_step_start(step=4, step_name="reflection_backtracking", websocket=websocket)
291291
if score >= self.config.reflection_score:
292-
await self.websocket_search_complete("success", score, path, websocket=websocket)
292+
# Convert path to serializable trajectory
293+
trajectory = [node.action for node in path if node.action is not None]
294+
await self.websocket_search_complete("success", score, trajectory, websocket=websocket)
293295
return node
294296

295297
print(f"path: {path}")

visual-tree-search-backend/test/test-tree-search-ws-lats.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ async def connect_and_test_search(
7777
starting_url: str,
7878
goal: str,
7979
search_algorithm: str = "bfs",
80-
max_depth: int = 3
80+
max_depth: int = 3,
81+
iterations: int = 5
8182
):
8283
"""
8384
Connect to the WebSocket endpoint and test the tree search functionality.
@@ -88,6 +89,7 @@ async def connect_and_test_search(
8889
goal: Goal to achieve
8990
search_algorithm: Search algorithm to use (bfs or dfs)
9091
max_depth: Maximum depth for the search tree
92+
iterations: Number of iterations for LATS algorithm
9193
"""
9294
logger.info(f"Connecting to WebSocket at {ws_url}")
9395

@@ -107,7 +109,8 @@ async def connect_and_test_search(
107109
"starting_url": starting_url,
108110
"goal": goal,
109111
"search_algorithm": search_algorithm,
110-
"max_depth": max_depth
112+
"max_depth": max_depth,
113+
"iterations": iterations
111114
}
112115

113116
logger.info(f"Sending search request: {request}")
@@ -156,6 +159,9 @@ def parse_arguments():
156159
parser.add_argument("--max-depth", type=int, default=3,
157160
help="Maximum depth for the search tree (default: 3)")
158161

162+
parser.add_argument("--iterations", type=int, default=5,
163+
help="Number of iterations for LATS algorithm (default: 5)")
164+
159165
# Add the new argument for log file
160166
parser.add_argument("--log-file", type=str,
161167
help="File to save the colored output to")
@@ -196,14 +202,16 @@ def flush(self):
196202
logger.info(f"Goal: {args.goal}")
197203
logger.info(f"Algorithm: {args.algorithm}")
198204
logger.info(f"Max depth: {args.max_depth}")
205+
logger.info(f"Iterations: {args.iterations}")
199206

200207
try:
201208
await connect_and_test_search(
202209
ws_url=args.ws_url,
203210
starting_url=args.starting_url,
204211
goal=args.goal,
205212
search_algorithm=args.algorithm,
206-
max_depth=args.max_depth
213+
max_depth=args.max_depth,
214+
iterations=args.iterations
207215
)
208216
finally:
209217
# Clean up if logging to file

visual-tree-search-backend/test/test-tree-search-ws-mcts.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ async def connect_and_test_search(
7777
starting_url: str,
7878
goal: str,
7979
search_algorithm: str = "bfs",
80-
max_depth: int = 3
80+
max_depth: int = 3,
81+
iterations: int = 5
8182
):
8283
"""
8384
Connect to the WebSocket endpoint and test the tree search functionality.
@@ -88,6 +89,7 @@ async def connect_and_test_search(
8889
goal: Goal to achieve
8990
search_algorithm: Search algorithm to use (bfs or dfs)
9091
max_depth: Maximum depth for the search tree
92+
iterations: Number of iterations for MCTS algorithm
9193
"""
9294
logger.info(f"Connecting to WebSocket at {ws_url}")
9395

@@ -99,16 +101,23 @@ async def connect_and_test_search(
99101
data = json.loads(response)
100102
if data.get("type") == "connection_established":
101103
logger.info(f"Connection established with ID: {data.get('connection_id')}")
102-
104+
if search_algorithm in ["bfs", "dfs"]:
105+
agent_type = "SimpleSearchAgent"
106+
elif search_algorithm in ["lats"]:
107+
agent_type = "LATSAgent"
108+
elif search_algorithm in ["mcts"]:
109+
agent_type = "MCTSAgent"
110+
else:
111+
raise ValueError(f"Invalid search algorithm: {search_algorithm}")
103112
# Send search request
104113
request = {
105114
"type": "start_search",
106-
"agent_type": "MCTSAgent",
115+
"agent_type": agent_type,
107116
"starting_url": starting_url,
108117
"goal": goal,
109118
"search_algorithm": search_algorithm,
110119
"max_depth": max_depth,
111-
"iterations": 10
120+
"iterations": iterations
112121
}
113122

114123
logger.info(f"Sending search request: {request}")
@@ -151,12 +160,15 @@ def parse_arguments():
151160
parser.add_argument("--goal", type=str, default=DEFAULT_GOAL,
152161
help=f"Goal to achieve (default: {DEFAULT_GOAL})")
153162

154-
parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats", "mcts"], default="lats",
155-
help="Search algorithm to use (default: lats)")
163+
parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats", "mcts"], default="mcts",
164+
help="Search algorithm to use (default: mcts)")
156165

157166
parser.add_argument("--max-depth", type=int, default=3,
158167
help="Maximum depth for the search tree (default: 3)")
159168

169+
parser.add_argument("--iterations", type=int, default=5,
170+
help="Number of iterations for LATS algorithm (default: 5)")
171+
160172
# Add the new argument for log file
161173
parser.add_argument("--log-file", type=str,
162174
help="File to save the colored output to")
@@ -197,6 +209,7 @@ def flush(self):
197209
logger.info(f"Goal: {args.goal}")
198210
logger.info(f"Algorithm: {args.algorithm}")
199211
logger.info(f"Max depth: {args.max_depth}")
212+
logger.info(f"Iterations: {args.iterations}")
200213

201214
try:
202215
await connect_and_test_search(
@@ -205,6 +218,8 @@ def flush(self):
205218
goal=args.goal,
206219
search_algorithm=args.algorithm,
207220
max_depth=args.max_depth
221+
,
222+
iterations=args.iterations
208223
)
209224
finally:
210225
# Clean up if logging to file

0 commit comments

Comments
 (0)