Skip to content

Commit ed61a2b

Browse files
committed
🔧 Refactor code interpreter API to use langchain_core and langchain_openai libraries
1 parent 0953050 commit ed61a2b

File tree

12 files changed

+62
-63
lines changed

12 files changed

+62
-63
lines changed

‎src/codeinterpreterapi/__init__.py‎

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1+
from . import _patch_parser # noqa
2+
13
from codeinterpreterapi.config import settings
24
from codeinterpreterapi.schema import File
35
from codeinterpreterapi.session import CodeInterpreterSession
46

5-
from ._patch_parser import patch
6-
7-
patch()
87

98
__all__ = [
109
"CodeInterpreterSession",

‎src/codeinterpreterapi/_patch_parser.py‎

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,13 @@
33
from json import JSONDecodeError
44
from typing import List, Union
55

6+
from langchain.agents.agent import AgentOutputParser
7+
from langchain.agents.openai_functions_agent import base
68
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
79
from langchain_core.exceptions import OutputParserException
8-
from langchain_core.messages import (
9-
AIMessage,
10-
BaseMessage,
11-
)
10+
from langchain_core.messages import AIMessage, BaseMessage
1211
from langchain_core.outputs import ChatGeneration, Generation
1312

14-
from langchain.agents.agent import AgentOutputParser
15-
1613

1714
class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
1815
"""Parses a message into agent action/finish.
@@ -102,8 +99,5 @@ def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
10299
raise ValueError("Can only parse messages")
103100

104101

105-
def patch() -> None:
106-
"""Patch the parser."""
107-
from langchain.agents import openai_functions_agent
108-
109-
openai_functions_agent.OpenAIFunctionsAgentOutputParser = OpenAIFunctionsAgentOutputParser # type: ignore
102+
# patch
103+
base.OpenAIFunctionsAgentOutputParser = OpenAIFunctionsAgentOutputParser # type: ignore

‎src/codeinterpreterapi/chains/extract_code.py‎

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from langchain.base_language import BaseLanguageModel
2-
from langchain.chat_models.anthropic import ChatAnthropic
1+
from langchain_core.language_models import BaseLanguageModel
32

43

54
def extract_python_code(
@@ -19,7 +18,9 @@ async def aextract_python_code(
1918

2019

2120
async def test() -> None:
22-
llm = ChatAnthropic(model="claude-1.3") # type: ignore
21+
from langchain_openai import ChatOpenAI
22+
23+
llm = ChatOpenAI()
2324

2425
code = """
2526
import matplotlib.pyplot as plt

‎src/codeinterpreterapi/chains/modifications_check.py‎

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
11
import json
22
from typing import List, Optional
33

4-
from langchain.base_language import BaseLanguageModel
5-
from langchain.chat_models.anthropic import ChatAnthropic
4+
from langchain_core.language_models import BaseLanguageModel
65

76
from codeinterpreterapi.prompts import determine_modifications_prompt
87

98

109
def get_file_modifications(
1110
code: str,
1211
llm: BaseLanguageModel,
13-
retry: int = 2,
12+
retry: int = 4,
1413
) -> Optional[List[str]]:
1514
if retry < 1:
1615
return None
1716

1817
prompt = determine_modifications_prompt.format(code=code)
1918

20-
result = llm.predict(prompt, stop="```")
19+
result = llm.invoke(prompt)
2120

2221
try:
23-
result = json.loads(result)
22+
result = json.loads(result.content)
2423
except json.JSONDecodeError:
2524
result = ""
2625
if not result or not isinstance(result, dict) or "modifications" not in result:
@@ -31,17 +30,17 @@ def get_file_modifications(
3130
async def aget_file_modifications(
3231
code: str,
3332
llm: BaseLanguageModel,
34-
retry: int = 2,
33+
retry: int = 4,
3534
) -> Optional[List[str]]:
3635
if retry < 1:
3736
return None
3837

3938
prompt = determine_modifications_prompt.format(code=code)
4039

41-
result = await llm.apredict(prompt, stop="```")
40+
result = await llm.ainvoke(prompt)
4241

4342
try:
44-
result = json.loads(result)
43+
result = json.loads(result.content)
4544
except json.JSONDecodeError:
4645
result = ""
4746
if not result or not isinstance(result, dict) or "modifications" not in result:
@@ -50,7 +49,9 @@ async def aget_file_modifications(
5049

5150

5251
async def test() -> None:
53-
llm = ChatAnthropic(model="claude-2") # type: ignore
52+
from langchain_openai import ChatOpenAI
53+
54+
llm = ChatOpenAI()
5455

5556
code = """
5657
import matplotlib.pyplot as plt

‎src/codeinterpreterapi/chains/rm_dl_link.py‎

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from langchain.base_language import BaseLanguageModel
2-
from langchain.chat_models.openai import ChatOpenAI
3-
from langchain.schema import AIMessage, OutputParserException
1+
from langchain_core.exceptions import OutputParserException
2+
from langchain_core.language_models import BaseLanguageModel
3+
from langchain_core.messages import AIMessage
4+
from langchain_openai import ChatOpenAI
45

56
from codeinterpreterapi.prompts import remove_dl_link_prompt
67

@@ -12,7 +13,7 @@ def remove_download_link(
1213
messages = remove_dl_link_prompt.format_prompt(
1314
input_response=input_response
1415
).to_messages()
15-
message = llm.predict_messages(messages)
16+
message = llm.invoke(messages)
1617

1718
if not isinstance(message, AIMessage):
1819
raise OutputParserException("Expected an AIMessage")
@@ -28,7 +29,7 @@ async def aremove_download_link(
2829
messages = remove_dl_link_prompt.format_prompt(
2930
input_response=input_response
3031
).to_messages()
31-
message = await llm.apredict_messages(messages)
32+
message = await llm.ainvoke(messages)
3233

3334
if not isinstance(message, AIMessage):
3435
raise OutputParserException("Expected an AIMessage")

‎src/codeinterpreterapi/chat_history.py‎

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
import json
33
from typing import List
44

5-
from codeboxapi import CodeBox # type: ignore
6-
from langchain.schema import BaseChatMessageHistory
7-
from langchain.schema.messages import BaseMessage, messages_from_dict, messages_to_dict
5+
from codeboxapi import CodeBox
6+
from langchain_core.chat_history import BaseChatMessageHistory
7+
from langchain_core.messages import BaseMessage, messages_from_dict, messages_to_dict
88

99

10-
# TODO: This is probably not efficient, but it works for now.
1110
class CodeBoxChatMessageHistory(BaseChatMessageHistory):
1211
"""
1312
Chat message history that stores history inside the codebox.

‎src/codeinterpreterapi/config.py‎

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
from typing import Optional
22

3-
from dotenv import load_dotenv
4-
from langchain.pydantic_v1 import BaseSettings, SecretStr
5-
from langchain.schema import SystemMessage
3+
from langchain_core.messages import SystemMessage
4+
from langchain_core.pydantic_v1 import BaseSettings, SecretStr
65

76
from codeinterpreterapi.prompts import code_interpreter_system_message
87

9-
# .env file
10-
load_dotenv(dotenv_path="./.env")
11-
128

139
class CodeInterpreterAPISettings(BaseSettings):
1410
"""
@@ -18,8 +14,8 @@ class CodeInterpreterAPISettings(BaseSettings):
1814
DEBUG: bool = False
1915

2016
# Models
21-
OPENAI_API_KEY: Optional[str] = None
22-
AZURE_API_KEY: Optional[str] = None
17+
OPENAI_API_KEY: Optional[SecretStr] = None
18+
AZURE_OPENAI_API_KEY: Optional[SecretStr] = None
2319
AZURE_API_BASE: Optional[str] = None
2420
AZURE_API_VERSION: Optional[str] = None
2521
AZURE_DEPLOYMENT_NAME: Optional[str] = None
@@ -46,5 +42,9 @@ class CodeInterpreterAPISettings(BaseSettings):
4642
# deprecated
4743
VERBOSE: bool = DEBUG
4844

45+
class Config:
46+
env_file = "./.env"
47+
extra = "ignore"
48+
4949

5050
settings = CodeInterpreterAPISettings()

‎src/codeinterpreterapi/prompts/modifications_check.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from langchain.prompts import PromptTemplate
1+
from langchain_core.prompts import PromptTemplate
22

33
determine_modifications_prompt = PromptTemplate(
44
input_variables=["code"],

‎src/codeinterpreterapi/prompts/remove_dl_link.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
2-
from langchain.schema import AIMessage, HumanMessage, SystemMessage
1+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
2+
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
33

44
remove_dl_link_prompt = ChatPromptTemplate(
55
input_variables=["input_response"],

‎src/codeinterpreterapi/prompts/system_message.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from langchain.schema import SystemMessage
1+
from langchain_core.messages import SystemMessage
22

33
system_message = SystemMessage(
44
content="""

0 commit comments

Comments
 (0)