@@ -77,7 +77,8 @@ async def connect_and_test_search(
7777 starting_url : str ,
7878 goal : str ,
7979 search_algorithm : str = "bfs" ,
80- max_depth : int = 3
80+ max_depth : int = 3 ,
81+ iterations : int = 5
8182):
8283 """
8384 Connect to the WebSocket endpoint and test the tree search functionality.
@@ -88,6 +89,7 @@ async def connect_and_test_search(
8889 goal: Goal to achieve
8990 search_algorithm: Search algorithm to use (bfs or dfs)
9091 max_depth: Maximum depth for the search tree
92+ iterations: Number of iterations for MCTS algorithm
9193 """
9294 logger .info (f"Connecting to WebSocket at { ws_url } " )
9395
@@ -99,16 +101,23 @@ async def connect_and_test_search(
99101 data = json .loads (response )
100102 if data .get ("type" ) == "connection_established" :
101103 logger .info (f"Connection established with ID: { data .get ('connection_id' )} " )
102-
104+ if search_algorithm in ["bfs" , "dfs" ]:
105+ agent_type = "SimpleSearchAgent"
106+ elif search_algorithm in ["lats" ]:
107+ agent_type = "LATSAgent"
108+ elif search_algorithm in ["mcts" ]:
109+ agent_type = "MCTSAgent"
110+ else :
111+ raise ValueError (f"Invalid search algorithm: { search_algorithm } " )
103112 # Send search request
104113 request = {
105114 "type" : "start_search" ,
106- "agent_type" : "MCTSAgent" ,
115+ "agent_type" : agent_type ,
107116 "starting_url" : starting_url ,
108117 "goal" : goal ,
109118 "search_algorithm" : search_algorithm ,
110119 "max_depth" : max_depth ,
111- "iterations" : 10
120+ "iterations" : iterations
112121 }
113122
114123 logger .info (f"Sending search request: { request } " )
@@ -151,12 +160,15 @@ def parse_arguments():
151160 parser .add_argument ("--goal" , type = str , default = DEFAULT_GOAL ,
152161 help = f"Goal to achieve (default: { DEFAULT_GOAL } )" )
153162
154- parser .add_argument ("--algorithm" , type = str , choices = ["bfs" , "dfs" , "lats" , "mcts" ], default = "lats " ,
155- help = "Search algorithm to use (default: lats )" )
163+ parser .add_argument ("--algorithm" , type = str , choices = ["bfs" , "dfs" , "lats" , "mcts" ], default = "mcts " ,
164+ help = "Search algorithm to use (default: mcts )" )
156165
157166 parser .add_argument ("--max-depth" , type = int , default = 3 ,
158167 help = "Maximum depth for the search tree (default: 3)" )
159168
169+ parser .add_argument ("--iterations" , type = int , default = 5 ,
170+ help = "Number of iterations for LATS algorithm (default: 5)" )
171+
160172 # Add the new argument for log file
161173 parser .add_argument ("--log-file" , type = str ,
162174 help = "File to save the colored output to" )
@@ -197,6 +209,7 @@ def flush(self):
197209 logger .info (f"Goal: { args .goal } " )
198210 logger .info (f"Algorithm: { args .algorithm } " )
199211 logger .info (f"Max depth: { args .max_depth } " )
212+ logger .info (f"Iterations: { args .iterations } " )
200213
201214 try :
202215 await connect_and_test_search (
@@ -205,6 +218,8 @@ def flush(self):
205218 goal = args .goal ,
206219 search_algorithm = args .algorithm ,
207220 max_depth = args .max_depth
221+ ,
222+ iterations = args .iterations
208223 )
209224 finally :
210225 # Clean up if logging to file
0 commit comments