Skip to content

Commit 283e176

Browse files
committed
add mcts agent node selection
1 parent 0534af4 commit 283e176

File tree

1 file changed

+107
-1
lines changed
  • visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents

1 file changed

+107
-1
lines changed

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from typing import Any, Optional, Tuple, List
22
from datetime import datetime
33
from dotenv import load_dotenv
4-
load_dotenv()
4+
import json
55

66
from .tree_vis import RED, better_print, print_trajectory, collect_all_nodes, GREEN, RESET, print_entire_tree
77
from .lats_node import LATSNode
88
from .base_agent import BaseAgent
9+
import openai
910

1011
class MCTSAgent(BaseAgent):
1112
async def run(self, websocket=None) -> list[LATSNode]:
@@ -24,6 +25,102 @@ async def run(self, websocket=None) -> list[LATSNode]:
2425

2526
# Performs Monte Carlo Tree Search starting from the root node with WebSocket updates.
2627
# Uses GPT-4 for node selection and reflection-based backpropagation.
28+
# TODO: if we select non-leaf node, do we expand the node again?
29+
# TODO: modify node selection logic to choose between the node and the children of the node,
30+
async def node_selection(self, node: LATSNode, websocket=None) -> Optional[LATSNode]:
31+
# start from the root node
32+
if node.is_terminal:
33+
return None
34+
35+
current_node = node
36+
path = [current_node]
37+
selection_depth = 0
38+
39+
while current_node.children and not current_node.is_terminal:
40+
print(f"\nSelection Step {selection_depth + 1}:")
41+
print(f"Current node action: {current_node.action}")
42+
print(f"Number of children: {len(current_node.children)}")
43+
44+
# Get trajectory for GPT-4 to evaluate
45+
trajectory = []
46+
for node in path[1:]: # Skip root node
47+
trajectory.append({
48+
"natural_language_description": node.natural_language_description,
49+
"action": node.action,
50+
"feedback": node.feedback if hasattr(node, 'feedback') else None
51+
})
52+
53+
# Create prompt for GPT-4
54+
prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next.
55+
Consider the overall progress, efficiency, and likelihood of success.
56+
57+
Goal: {self.goal}
58+
59+
Current Trajectory:
60+
{json.dumps(trajectory, indent=2)}
61+
62+
Available Children:
63+
{json.dumps([{
64+
'action': child.action,
65+
'description': child.natural_language_description,
66+
'visits': child.visits,
67+
'value': child.value if hasattr(child, 'value') else 0
68+
} for child in current_node.children], indent=2)}
69+
70+
Return a JSON response with:
71+
{{
72+
"selected_child_index": int, # Index of the selected child
73+
"explanation": str # Brief explanation of the selection
74+
}}"""
75+
76+
try:
77+
response = openai.ChatCompletion.create(
78+
model=self.config.evaluation_model,
79+
messages=[
80+
{"role": "system", "content": "You are an expert at selecting promising paths in a search tree."},
81+
{"role": "user", "content": prompt}
82+
],
83+
response_format={"type": "json_object"}
84+
)
85+
86+
selection = json.loads(response.choices[0].message.content)
87+
selected_index = selection["selected_child_index"]
88+
89+
if 0 <= selected_index < len(current_node.children):
90+
current_node = current_node.children[selected_index]
91+
path.append(current_node)
92+
print(f"Selected child {selected_index + 1}: {current_node.action}")
93+
print(f"Selection explanation: {selection['explanation']}")
94+
95+
if websocket:
96+
await websocket.send_json({
97+
"type": "node_selected",
98+
"node_id": id(current_node),
99+
"parent_id": id(current_node.parent),
100+
"action": current_node.action,
101+
"description": current_node.natural_language_description,
102+
"explanation": selection["explanation"],
103+
"depth": selection_depth + 1,
104+
"timestamp": datetime.utcnow().isoformat()
105+
})
106+
else:
107+
print(f"Invalid child index {selected_index}, breaking selection")
108+
break
109+
110+
except Exception as e:
111+
print(f"Error in node selection: {str(e)}")
112+
if websocket:
113+
await websocket.send_json({
114+
"type": "selection_error",
115+
"error": str(e),
116+
"timestamp": datetime.utcnow().isoformat()
117+
})
118+
break
119+
120+
selection_depth += 1
121+
122+
return current_node
123+
27124
async def mcts_search(self, websocket=None):
28125
for i in range(self.config.iterations):
29126
await self.websocket_iteration_start(i, websocket=websocket)
@@ -32,6 +129,15 @@ async def mcts_search(self, websocket=None):
32129

33130
# Step 1: Node Selection
34131
# Selection: Use GPT-4 to select a promising path
132+
print(f"{GREEN}Step 1: node selection{RESET}")
133+
await self.websocket_step_start(step=1, step_name="node_selection", websocket=websocket)
134+
node = await self.node_selection(self.root_node)
135+
await self.websocket_node_selection(node, websocket=websocket)
136+
137+
if node is None:
138+
print("All paths lead to terminal nodes with reward 0. Ending search.")
139+
break
140+
35141

36142

37143
# Step 2: Node Expansion

0 commit comments

Comments
 (0)