11import json
22from typing import List , Optional
33
4- from langchain .base_language import BaseLanguageModel
5- from langchain .chat_models .anthropic import ChatAnthropic
4+ from langchain_core .language_models import BaseLanguageModel
65
76from codeinterpreterapi .prompts import determine_modifications_prompt
87
98
109def get_file_modifications (
1110 code : str ,
1211 llm : BaseLanguageModel ,
13- retry : int = 2 ,
12+ retry : int = 4 ,
1413) -> Optional [List [str ]]:
1514 if retry < 1 :
1615 return None
1716
1817 prompt = determine_modifications_prompt .format (code = code )
1918
20- result = llm .predict (prompt , stop = "```" )
19+ result = llm .invoke (prompt )
2120
2221 try :
23- result = json .loads (result )
22+ result = json .loads (result . content )
2423 except json .JSONDecodeError :
2524 result = ""
2625 if not result or not isinstance (result , dict ) or "modifications" not in result :
@@ -31,17 +30,17 @@ def get_file_modifications(
3130async def aget_file_modifications (
3231 code : str ,
3332 llm : BaseLanguageModel ,
34- retry : int = 2 ,
33+ retry : int = 4 ,
3534) -> Optional [List [str ]]:
3635 if retry < 1 :
3736 return None
3837
3938 prompt = determine_modifications_prompt .format (code = code )
4039
41- result = await llm .apredict (prompt , stop = "```" )
40+ result = await llm .ainvoke (prompt )
4241
4342 try :
44- result = json .loads (result )
43+ result = json .loads (result . content )
4544 except json .JSONDecodeError :
4645 result = ""
4746 if not result or not isinstance (result , dict ) or "modifications" not in result :
@@ -50,7 +49,9 @@ async def aget_file_modifications(
5049
5150
5251async def test () -> None :
53- llm = ChatAnthropic (model = "claude-2" ) # type: ignore
52+ from langchain_openai import ChatOpenAI
53+
54+ llm = ChatOpenAI ()
5455
5556 code = """
5657 import matplotlib.pyplot as plt
0 commit comments