11from typing import Any , Optional , Tuple , List
22from datetime import datetime
33from dotenv import load_dotenv
4- load_dotenv ()
4+ import json
55
66from .tree_vis import RED , better_print , print_trajectory , collect_all_nodes , GREEN , RESET , print_entire_tree
77from .lats_node import LATSNode
88from .base_agent import BaseAgent
9+ import openai
910
1011class MCTSAgent (BaseAgent ):
1112 async def run (self , websocket = None ) -> list [LATSNode ]:
@@ -24,6 +25,102 @@ async def run(self, websocket=None) -> list[LATSNode]:
2425
2526 # Performs Monte Carlo Tree Search starting from the root node with WebSocket updates.
2627 # Uses GPT-4 for node selection and reflection-based backpropagation.
28+ # TODO: if we select non-leaf node, do we expand the node again?
29+ # TODO: modify node selection logic to choose between the node and the children of the node,
30+ async def node_selection (self , node : LATSNode , websocket = None ) -> Optional [LATSNode ]:
31+ # start from the root node
32+ if node .is_terminal :
33+ return None
34+
35+ current_node = node
36+ path = [current_node ]
37+ selection_depth = 0
38+
39+ while current_node .children and not current_node .is_terminal :
40+ print (f"\n Selection Step { selection_depth + 1 } :" )
41+ print (f"Current node action: { current_node .action } " )
42+ print (f"Number of children: { len (current_node .children )} " )
43+
44+ # Get trajectory for GPT-4 to evaluate
45+ trajectory = []
46+ for node in path [1 :]: # Skip root node
47+ trajectory .append ({
48+ "natural_language_description" : node .natural_language_description ,
49+ "action" : node .action ,
50+ "feedback" : node .feedback if hasattr (node , 'feedback' ) else None
51+ })
52+
53+ # Create prompt for GPT-4
54+ prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next.
55+ Consider the overall progress, efficiency, and likelihood of success.
56+
57+ Goal: { self .goal }
58+
59+ Current Trajectory:
60+ { json .dumps (trajectory , indent = 2 )}
61+
62+ Available Children:
63+ { json .dumps ([{
64+ 'action' : child .action ,
65+ 'description' : child .natural_language_description ,
66+ 'visits' : child .visits ,
67+ 'value' : child .value if hasattr (child , 'value' ) else 0
68+ } for child in current_node .children ], indent = 2 )}
69+
70+ Return a JSON response with:
71+ {{
72+ "selected_child_index": int, # Index of the selected child
73+ "explanation": str # Brief explanation of the selection
74+ }}"""
75+
76+ try :
77+ response = openai .ChatCompletion .create (
78+ model = self .config .evaluation_model ,
79+ messages = [
80+ {"role" : "system" , "content" : "You are an expert at selecting promising paths in a search tree." },
81+ {"role" : "user" , "content" : prompt }
82+ ],
83+ response_format = {"type" : "json_object" }
84+ )
85+
86+ selection = json .loads (response .choices [0 ].message .content )
87+ selected_index = selection ["selected_child_index" ]
88+
89+ if 0 <= selected_index < len (current_node .children ):
90+ current_node = current_node .children [selected_index ]
91+ path .append (current_node )
92+ print (f"Selected child { selected_index + 1 } : { current_node .action } " )
93+ print (f"Selection explanation: { selection ['explanation' ]} " )
94+
95+ if websocket :
96+ await websocket .send_json ({
97+ "type" : "node_selected" ,
98+ "node_id" : id (current_node ),
99+ "parent_id" : id (current_node .parent ),
100+ "action" : current_node .action ,
101+ "description" : current_node .natural_language_description ,
102+ "explanation" : selection ["explanation" ],
103+ "depth" : selection_depth + 1 ,
104+ "timestamp" : datetime .utcnow ().isoformat ()
105+ })
106+ else :
107+ print (f"Invalid child index { selected_index } , breaking selection" )
108+ break
109+
110+ except Exception as e :
111+ print (f"Error in node selection: { str (e )} " )
112+ if websocket :
113+ await websocket .send_json ({
114+ "type" : "selection_error" ,
115+ "error" : str (e ),
116+ "timestamp" : datetime .utcnow ().isoformat ()
117+ })
118+ break
119+
120+ selection_depth += 1
121+
122+ return current_node
123+
27124 async def mcts_search (self , websocket = None ):
28125 for i in range (self .config .iterations ):
29126 await self .websocket_iteration_start (i , websocket = websocket )
@@ -32,6 +129,15 @@ async def mcts_search(self, websocket=None):
32129
33130 # Step 1: Node Selection
34131 # Selection: Use GPT-4 to select a promising path
132+ print (f"{ GREEN } Step 1: node selection{ RESET } " )
133+ await self .websocket_step_start (step = 1 , step_name = "node_selection" , websocket = websocket )
134+ node = await self .node_selection (self .root_node )
135+ await self .websocket_node_selection (node , websocket = websocket )
136+
137+ if node is None :
138+ print ("All paths lead to terminal nodes with reward 0. Ending search." )
139+ break
140+
35141
36142
37143 # Step 2: Node Expansion
0 commit comments