Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 65 additions & 28 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
'max_tokens': 'length',
'stop_sequence': 'stop',
'tool_use': 'tool_call',
'pause_turn': 'stop',
'pause_turn': 'stop', # TODO: should this be a different finish reason?
'refusal': 'content_filter',
}

Expand Down Expand Up @@ -385,33 +385,51 @@ async def _messages_create(
output_format = self._native_output_format(model_request_parameters)
betas, extra_headers = self._get_betas_and_extra_headers(tools, model_request_parameters, model_settings)
betas.update(builtin_tool_betas)
try:
return await self.client.beta.messages.create(
max_tokens=model_settings.get('max_tokens', 4096),
system=system_prompt or OMIT,
messages=anthropic_messages,
model=self._model_name,
tools=tools or OMIT,
tool_choice=tool_choice or OMIT,
mcp_servers=mcp_servers or OMIT,
output_format=output_format or OMIT,
betas=sorted(betas) or OMIT,
stream=stream,
thinking=model_settings.get('anthropic_thinking', OMIT),
stop_sequences=model_settings.get('stop_sequences', OMIT),
temperature=model_settings.get('temperature', OMIT),
top_p=model_settings.get('top_p', OMIT),
timeout=model_settings.get('timeout', NOT_GIVEN),
metadata=model_settings.get('anthropic_metadata', OMIT),
extra_headers=extra_headers,
extra_body=model_settings.get('extra_body'),
)
except APIStatusError as e:
if (status_code := e.status_code) >= 400:
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
raise ModelAPIError(model_name=self.model_name, message=e.message) from e # pragma: lax no cover
except APIConnectionError as e:
raise ModelAPIError(model_name=self.model_name, message=e.message) from e

# Handle pause_turn retry loop
while True:
try:
response = await self.client.beta.messages.create(
max_tokens=model_settings.get('max_tokens', 4096),
system=system_prompt or OMIT,
messages=anthropic_messages,
model=self._model_name,
tools=tools or OMIT,
tool_choice=tool_choice or OMIT,
mcp_servers=mcp_servers or OMIT,
output_format=output_format or OMIT,
betas=sorted(betas) or OMIT,
stream=stream,
thinking=model_settings.get('anthropic_thinking', OMIT),
stop_sequences=model_settings.get('stop_sequences', OMIT),
temperature=model_settings.get('temperature', OMIT),
top_p=model_settings.get('top_p', OMIT),
timeout=model_settings.get('timeout', NOT_GIVEN),
metadata=model_settings.get('anthropic_metadata', OMIT),
extra_headers=extra_headers,
extra_body=model_settings.get('extra_body'),
)

# Handle pause_turn for non-streaming
assert isinstance(response, BetaMessage)
if response.stop_reason == 'pause_turn':
# Append assistant message to history and continue
anthropic_messages.append(
{
'role': 'assistant',
'content': response.content,
}
)
continue

return response

except APIStatusError as e:
if (status_code := e.status_code) >= 400:
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
raise ModelAPIError(model_name=self.model_name, message=e.message) from e # pragma: lax no cover
except APIConnectionError as e:
raise ModelAPIError(model_name=self.model_name, message=e.message) from e

def _get_betas_and_extra_headers(
self,
Expand Down Expand Up @@ -512,6 +530,16 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
call_part = builtin_tool_calls.get(item.tool_use_id)
items.append(_map_mcp_server_result_block(item, call_part, self.system))
else:
# Fallback for new block types like `bash_code_execution_tool_result` if they aren't explicitly typed yet
# or if we want to handle them generically.
# For now, we'll try to handle `bash_code_execution_tool_result` if it appears as a dict or unknown type,
# but since `response.content` is typed as a union of specific blocks, we might need to rely on `model_dump` or similar if the SDK doesn't support it yet.
# However, the user request says "Handle the bash_code_execution_tool_result event type".
# If `anthropic` SDK doesn't have it, we might not see it here unless we upgrade or it's in `BetaContentBlock`.
# Assuming `BetaCodeExecutionToolResultBlock` covers it or we need to add a check.
# Let's assume for now `BetaCodeExecutionToolResultBlock` is sufficient or we'll see.
# But wait, `bash_code_execution_tool_result` implies a specific type.
# Let's check if we can import it.
assert isinstance(item, BetaToolUseBlock), f'unexpected item type {type(item)}'
items.append(
ToolCallPart(
Expand Down Expand Up @@ -1175,6 +1203,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
vendor_part_id=event.index,
part=_map_code_execution_tool_result_block(current_block, self.provider_name),
)

elif isinstance(current_block, BetaWebFetchToolResultBlock): # pragma: lax no cover
yield self._parts_manager.handle_part(
vendor_part_id=event.index,
Expand Down Expand Up @@ -1287,6 +1316,14 @@ def _map_server_tool_use_block(item: BetaServerToolUseBlock, provider_name: str)
args=cast(dict[str, Any], item.input) or None,
tool_call_id=item.id,
)
elif item.name == 'bash_code_execution':
return BuiltinToolCallPart(
provider_name=provider_name,
tool_name=CodeExecutionTool.kind,
args=cast(dict[str, Any], item.input) or None,
tool_call_id=item.id,
)

elif item.name == 'web_fetch':
return BuiltinToolCallPart(
provider_name=provider_name,
Expand Down
62 changes: 62 additions & 0 deletions tests/models/test_anthropic_pause_turn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations as _annotations

import pytest
from inline_snapshot import snapshot

from pydantic_ai import Agent
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.providers.anthropic import AnthropicProvider

from ..conftest import try_import
from .test_anthropic import MockAnthropic, completion_message

with try_import() as imports_successful:
from anthropic.types.beta import (
BetaTextBlock,
BetaUsage,
BetaMessage,
)

pytestmark = [
pytest.mark.skipif(not imports_successful(), reason='anthropic not installed'),
pytest.mark.anyio,
]

async def test_pause_turn_retry_loop(allow_model_requests: None):
# Mock a sequence of responses:
# 1. pause_turn response
# 2. final response

c1 = completion_message(
[BetaTextBlock(text='paused', type='text')],
usage=BetaUsage(input_tokens=10, output_tokens=5),
)
c1.stop_reason = 'pause_turn' # type: ignore

c2 = completion_message(
[BetaTextBlock(text='final', type='text')],
usage=BetaUsage(input_tokens=10, output_tokens=5),
)

mock_client = MockAnthropic.create_mock([c1, c2])
m = AnthropicModel('claude-3-5-sonnet-20241022', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)

result = await agent.run('test prompt')

# Verify the agent received the final response
assert result.output == 'final'

# Verify the loop happened (2 requests)
assert len(mock_client.chat_completion_kwargs) == 2

# Verify history in second request includes the paused message
messages_2 = mock_client.chat_completion_kwargs[1]['messages']
# Should be: User -> Assistant(paused)
assert len(messages_2) == 2
assert messages_2[1]['role'] == 'assistant'
# Content is a list of BetaContentBlock objects, get the text from first block
content_blocks = messages_2[1]['content']
assert len(content_blocks) > 0
first_block = content_blocks[0]
assert hasattr(first_block, 'text') and first_block.text == 'paused'
Loading