Skip to content

Commit f2f35fe

Browse files
authored
Merge pull request #88 from PathOnAI/add-mcts-frontend
Add mcts frontend
2 parents f5ea5eb + 82f3582 commit f2f35fe

File tree

5 files changed

+35
-13
lines changed

5 files changed

+35
-13
lines changed

visual-tree-search-app/components/LATSVisual.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ const LATSVisual: React.FC<SimpleSearchVisualProps> = ({ messages }) => {
9595
}
9696

9797
// Handle tree structure updates
98-
if ((data.type === 'tree_update_node_expansion' || data.type === 'tree_update_node_children_evaluation' || data.typ === 'tree_update_node_backpropagation')
98+
if ((data.type === 'tree_update_node_expansion' || data.type === 'tree_update_node_children_evaluation' || data.type === 'tree_update_node_backpropagation')
9999
&& Array.isArray(data.tree)) {
100100
// Preserve simulation flags when updating from tree
101101
if (updatedTreeNodes.some(node => node.isSimulated)) {

visual-tree-search-app/components/MCTSVisual.tsx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ const MCTSVisual: React.FC<SimpleSearchVisualProps> = ({ messages }) => {
9595
}
9696

9797
// Handle tree structure updates
98-
if ((data.type === 'tree_update_node_expansion' || data.type === 'tree_update_node_children_evaluation' || data.typ === 'tree_update_node_backpropagation')
98+
if ((data.type === 'tree_update_node_expansion' || data.type === 'tree_update_node_children_evaluation' || data.type === 'tree_update_node_backpropagation')
9999
&& Array.isArray(data.tree)) {
100100
// Preserve simulation flags when updating from tree
101101
if (updatedTreeNodes.some(node => node.isSimulated)) {
@@ -483,14 +483,14 @@ const MCTSVisual: React.FC<SimpleSearchVisualProps> = ({ messages }) => {
483483
<span className="w-3 h-3 rounded-full inline-block mr-1 bg-blue-500 dark:bg-blue-600"></span>
484484
<span className="text-gray-700 dark:text-gray-300">Selected</span>
485485
</div>
486-
<div className="flex items-center">
486+
{/* <div className="flex items-center">
487487
<span className="w-3 h-3 rounded-full inline-block mr-1 bg-green-500 dark:bg-green-600"></span>
488488
<span className="text-gray-700 dark:text-gray-300">Sim Start</span>
489489
</div>
490490
<div className="flex items-center">
491491
<span className="w-3 h-3 rounded-full inline-block mr-1 bg-orange-500 dark:bg-orange-600"></span>
492492
<span className="text-gray-700 dark:text-gray-300">Simulated</span>
493-
</div>
493+
</div> */}
494494
</div>
495495
</div>
496496
<div

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/base_agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,8 @@ async def node_evaluation(self, node: LATSNode) -> None:
487487
def backpropagate(self, node: LATSNode, value: float) -> None:
488488
while node:
489489
node.visits += 1
490-
node.value = (node.value * (node.visits - 1) + value) / node.visits
490+
# Calculate running average: newAvg = oldAvg + (value - oldAvg) / newCount
491+
node.value += (value - node.value) / node.visits
491492
node = node.parent
492493

493494
# shared
@@ -558,7 +559,7 @@ async def rollout(self, node: LATSNode, websocket=None)-> tuple[float, LATSNode]
558559
score = confidence_score if goal_finished else 0
559560
await self.remove_simulated_trajectory(starting_node=node, terminal_node=terminal_node, websocket=websocket)
560561

561-
return score, node
562+
return score, terminal_node
562563

563564

564565
# TODO: decide whether to keep the tree update

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,10 @@ async def lats_search(self, websocket=None):
8585
# Step 5: Backpropagation
8686
print(f"{GREEN}Step 5: backpropagation{RESET}")
8787
await self.websocket_step_start(step=5, step_name="backpropagation", websocket=websocket)
88-
self.backpropagate(terminal_node, reward)
88+
self.backpropagate(selected_node, reward)
8989
tree_data = self._get_tree_data()
90+
print_entire_tree(self.root_node)
91+
print(tree_data)
9092
if websocket:
9193
await self.websocket_tree_update(type="tree_update_node_backpropagation", websocket=websocket, tree_data=tree_data)
9294
else:

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,12 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
264264
print(f"{GREEN}Step 1: Node Selection{RESET}")
265265
await self.websocket_step_start(step=1, step_name="node_selection", websocket=websocket)
266266
selected_node = await self.node_selection(self.root_node, websocket)
267-
tree_data = self._get_tree_data()
268-
if websocket:
269-
await self.websocket_tree_update(type="tree_update_node_selection", websocket=websocket, tree_data=tree_data)
270-
else:
271-
print_entire_tree(self.root_node)
267+
# await self.websocket_node_selection(selected_node, websocket=websocket)
268+
# tree_data = self._get_tree_data()
269+
# if websocket:
270+
# await self.websocket_tree_update(type="tree_update_node_selection", websocket=websocket, tree_data=tree_data)
271+
# else:
272+
# print_entire_tree(self.root_node)
272273

273274
if selected_node is None:
274275
logger.warning("All paths lead to terminal nodes. Ending search.")
@@ -338,10 +339,28 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
338339
for node in path:
339340
old_value = node.value
340341
node.visits += 1
341-
node.value = (node.value * (node.visits - 1) + score) / node.visits
342+
node.value += (score - node.value) / node.visits
343+
# consiste with lats backpropagation
344+
#node.value = (node.value * (node.visits - 1) + score) / node.visits
342345
print(f"Node {node.action}:")
343346
print(f" Visits: {node.visits}")
344347
print(f" Value: {old_value:.3f} -> {node.value:.3f}")
348+
# add websocket information, just use websocket here
349+
# if websocket:
350+
# await websocket.send_json({
351+
# "type": "backpropagation",
352+
# "node_id": id(node),
353+
# "node_parent_id": id(node.parent),
354+
# "node_action": node.action,
355+
# "node_value": node.value,
356+
# "node_visits": node.visits,
357+
# "node_old_value": old_value,
358+
# "node_description": node.natural_language_description,
359+
# })
360+
361+
tree_data = self._get_tree_data()
362+
print_entire_tree(self.root_node)
363+
print(tree_data)
345364
if websocket:
346365
await self.websocket_tree_update(type="tree_update_node_backpropagation", websocket=websocket, tree_data=tree_data)
347366
else:

0 commit comments

Comments
 (0)