Skip to content

Commit 0eb3158

Browse files
authored
Merge pull request #43 from PathOnAI/add-agent-template
configure mcts and lats agents
2 parents fa2c8b9 + ac84097 commit 0eb3158

File tree

8 files changed

+207
-82
lines changed

8 files changed

+207
-82
lines changed

visual-tree-search-backend/README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,29 @@ to test the message passing from the backend to the frontend
5757
```
5858
curl -X POST http://localhost:3000/api/terminate-session/647f4021-2402-4733-84a3-255f0d20c151
5959
{"status":"success","message":"Session 647f4021-2402-4733-84a3-255f0d20c151 termination requested"}
60+
```
61+
62+
## 6. Add more search agent
63+
```
64+
python run_demo_treesearch_async.py \
65+
--browser-mode chromium \
66+
--storage-state shopping.json \
67+
--starting-url "http://128.105.145.205:7770/" \
68+
--agent-type "LATSAgent" \
69+
--action_generation_model "gpt-4o-mini" \
70+
--goal "search running shoes, click on the first result" \
71+
--iterations 3 \
72+
--max_depth 3
73+
```
74+
75+
```
76+
python run_demo_treesearch_async.py \
77+
--browser-mode chromium \
78+
--storage-state shopping.json \
79+
--starting-url "http://128.105.145.205:7770/" \
80+
--agent-type "MCTSAgent" \
81+
--action_generation_model "gpt-4o-mini" \
82+
--goal "search running shoes, click on the first result" \
83+
--iterations 3 \
84+
--max_depth 3
6085
```
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import logging
2+
import time
3+
from typing import Any, Dict, List, Optional
4+
from collections import deque
5+
from datetime import datetime
6+
import os
7+
import json
8+
import subprocess
9+
10+
from openai import OpenAI
11+
from dotenv import load_dotenv
12+
load_dotenv()
13+
import aiohttp
14+
15+
from ...core_async.config import AgentConfig
16+
17+
from ...webagent_utils_async.action.highlevel import HighLevelActionSet
18+
from ...webagent_utils_async.utils.playwright_manager import AsyncPlaywrightManager, setup_playwright
19+
from ...webagent_utils_async.utils.utils import parse_function_args, locate_element
20+
from ...evaluation_async.evaluators import goal_finished_evaluator
21+
from ...replay_async import generate_feedback, playwright_step_execution
22+
from ...webagent_utils_async.action.prompt_functions import extract_top_actions
23+
from ...webagent_utils_async.browser_env.observation import extract_page_info
24+
from .lats_node import LATSNode
25+
from .tree_vis import better_print, print_trajectory, collect_all_nodes, GREEN, RESET, print_entire_tree
26+
from .trajectory_score import create_llm_prompt, score_trajectory_with_openai
27+
from ...webagent_utils_async.utils.utils import urls_to_images
28+
29+
logger = logging.getLogger(__name__)
30+
openai_client = OpenAI()
31+
32+
class LATSAgent:
33+
def __init__(
34+
self,
35+
starting_url: str,
36+
messages: list[dict[str, Any]],
37+
goal: str,
38+
images: list,
39+
playwright_manager: AsyncPlaywrightManager,
40+
config: AgentConfig,
41+
):
42+
self.starting_url = starting_url
43+
self.goal = goal
44+
self.image_urls = images
45+
self.images = urls_to_images(self.image_urls)
46+
self.messages = messages
47+
self.messages.append({"role": "user", "content": f"The goal is: {self.goal}"})
48+
49+
self.playwright_manager = playwright_manager
50+
51+
self.config = config
52+
53+
self.agent_type = ["bid", "nav", "file", "select_option"]
54+
self.action_set = HighLevelActionSet(
55+
subsets=self.agent_type, strict=False, multiaction=True, demo_mode="default"
56+
)
57+
self.root_node = LATSNode(
58+
natural_language_description=None,
59+
action=None,
60+
prob=None,
61+
element=None,
62+
goal=self.goal,
63+
parent=None
64+
)
65+
self.reset_url = os.environ["ACCOUNT_RESET_URL"]
66+
67+
async def run(self, websocket=None) -> List[Dict[str, Any]]:
68+
pass
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import logging
2+
import time
3+
from typing import Any, Dict, List, Optional
4+
from collections import deque
5+
from datetime import datetime
6+
import os
7+
import json
8+
import subprocess
9+
10+
from openai import OpenAI
11+
from dotenv import load_dotenv
12+
load_dotenv()
13+
import aiohttp
14+
15+
from ...core_async.config import AgentConfig
16+
17+
from ...webagent_utils_async.action.highlevel import HighLevelActionSet
18+
from ...webagent_utils_async.utils.playwright_manager import AsyncPlaywrightManager, setup_playwright
19+
from ...webagent_utils_async.utils.utils import parse_function_args, locate_element
20+
from ...evaluation_async.evaluators import goal_finished_evaluator
21+
from ...replay_async import generate_feedback, playwright_step_execution
22+
from ...webagent_utils_async.action.prompt_functions import extract_top_actions
23+
from ...webagent_utils_async.browser_env.observation import extract_page_info
24+
from .lats_node import LATSNode
25+
from .tree_vis import better_print, print_trajectory, collect_all_nodes, GREEN, RESET, print_entire_tree
26+
from .trajectory_score import create_llm_prompt, score_trajectory_with_openai
27+
from ...webagent_utils_async.utils.utils import urls_to_images
28+
29+
logger = logging.getLogger(__name__)
30+
openai_client = OpenAI()
31+
32+
class MCTSAgent:
33+
def __init__(
34+
self,
35+
starting_url: str,
36+
messages: list[dict[str, Any]],
37+
goal: str,
38+
images: list,
39+
playwright_manager: AsyncPlaywrightManager,
40+
config: AgentConfig,
41+
):
42+
self.starting_url = starting_url
43+
self.goal = goal
44+
self.image_urls = images
45+
self.images = urls_to_images(self.image_urls)
46+
self.messages = messages
47+
self.messages.append({"role": "user", "content": f"The goal is: {self.goal}"})
48+
49+
self.playwright_manager = playwright_manager
50+
51+
self.config = config
52+
53+
self.agent_type = ["bid", "nav", "file", "select_option"]
54+
self.action_set = HighLevelActionSet(
55+
subsets=self.agent_type, strict=False, multiaction=True, demo_mode="default"
56+
)
57+
self.root_node = LATSNode(
58+
natural_language_description=None,
59+
action=None,
60+
prob=None,
61+
element=None,
62+
goal=self.goal,
63+
parent=None
64+
)
65+
self.reset_url = os.environ["ACCOUNT_RESET_URL"]
66+
67+
async def run(self, websocket=None) -> List[Dict[str, Any]]:
68+
pass

visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/simple_search_agent.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ async def _reset_browser(self, websocket=None) -> Optional[str]:
164164
if self.config.browser_mode == "browserbase":
165165
live_browser_url = await self.playwright_manager.get_live_browser_url()
166166
session_id = await self.playwright_manager.get_session_id()
167+
else:
168+
session_id = None
169+
live_browser_url = None
167170
await page.goto(self.starting_url, wait_until="networkidle")
168171

169172
# Send success message if websocket is provided

visual-tree-search-backend/app/api/lwats/core_async/agent_factory.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from .config import AgentConfig
88
from ..agents_async.SimpleSearchAgents.simple_search_agent import SimpleSearchAgent
9+
from ..agents_async.SimpleSearchAgents.lats_agent import LATSAgent
10+
from ..agents_async.SimpleSearchAgents.mcts_agent import MCTSAgent
911
from ..webagent_utils_async.utils.utils import setup_logger
1012
from ..webagent_utils_async.utils.playwright_manager import setup_playwright
1113

@@ -70,7 +72,8 @@ async def setup_search_agent(
7072
"content": SEARCH_AGENT_SYSTEM_PROMPT,
7173
}]
7274

73-
if agent_type == "SimpleSearchAgent":
75+
if agent_type == "SimpleSearchAgent":
76+
print("SimpleSearchAgent")
7477
agent = SimpleSearchAgent(
7578
starting_url=starting_url,
7679
messages=messages,
@@ -79,6 +82,26 @@ async def setup_search_agent(
7982
playwright_manager=playwright_manager,
8083
config=agent_config,
8184
)
85+
elif agent_type == "LATSAgent":
86+
print("LATSAgent")
87+
agent = LATSAgent(
88+
starting_url=starting_url,
89+
messages=messages,
90+
goal=goal,
91+
images = images,
92+
playwright_manager=playwright_manager,
93+
config=agent_config,
94+
)
95+
elif agent_type == "MCTSAgent":
96+
print("MCTSAgent")
97+
agent = MCTSAgent(
98+
starting_url=starting_url,
99+
messages=messages,
100+
goal=goal,
101+
images = images,
102+
playwright_manager=playwright_manager,
103+
config=agent_config,
104+
)
82105
else:
83106
error_message = f"Unsupported agent type: {agent_type}. Please use 'FunctionCallingAgent', 'HighLevelPlanningAgent', 'ContextAwarePlanningAgent', 'PromptAgent' or 'PromptSearchAgent' ."
84107
logger.error(error_message)

visual-tree-search-backend/app/api/run_demo_treesearch_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ async def main(args):
2222

2323
agent_config = AgentConfig(**filter_valid_config_args(args.__dict__))
2424
print(agent_config)
25+
2526
agent, playwright_manager = await setup_search_agent(
2627
agent_type=args.agent_type,
2728
starting_url=args.starting_url,
2829
goal=args.goal,
2930
images=args.images,
3031
agent_config=agent_config
3132
)
32-
print(agent_config)
3333

3434
# Run the search
3535
results = await agent.run()

visual-tree-search-backend/app/api/run_demo_treesearch_sync.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

0 commit comments

Comments
 (0)