diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 28395f56bd..21ac99da40 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -49,7 +49,7 @@ 'max_tokens': 'length', 'stop_sequence': 'stop', 'tool_use': 'tool_call', - 'pause_turn': 'stop', + 'pause_turn': 'stop', # TODO: should this be a different finish reason? 'refusal': 'content_filter', } @@ -385,33 +385,51 @@ async def _messages_create( output_format = self._native_output_format(model_request_parameters) betas, extra_headers = self._get_betas_and_extra_headers(tools, model_request_parameters, model_settings) betas.update(builtin_tool_betas) - try: - return await self.client.beta.messages.create( - max_tokens=model_settings.get('max_tokens', 4096), - system=system_prompt or OMIT, - messages=anthropic_messages, - model=self._model_name, - tools=tools or OMIT, - tool_choice=tool_choice or OMIT, - mcp_servers=mcp_servers or OMIT, - output_format=output_format or OMIT, - betas=sorted(betas) or OMIT, - stream=stream, - thinking=model_settings.get('anthropic_thinking', OMIT), - stop_sequences=model_settings.get('stop_sequences', OMIT), - temperature=model_settings.get('temperature', OMIT), - top_p=model_settings.get('top_p', OMIT), - timeout=model_settings.get('timeout', NOT_GIVEN), - metadata=model_settings.get('anthropic_metadata', OMIT), - extra_headers=extra_headers, - extra_body=model_settings.get('extra_body'), - ) - except APIStatusError as e: - if (status_code := e.status_code) >= 400: - raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e - raise ModelAPIError(model_name=self.model_name, message=e.message) from e # pragma: lax no cover - except APIConnectionError as e: - raise ModelAPIError(model_name=self.model_name, message=e.message) from e + + # Handle pause_turn retry loop + while True: + try: + response = await self.client.beta.messages.create( + max_tokens=model_settings.get('max_tokens', 4096), + system=system_prompt or OMIT, + messages=anthropic_messages, + model=self._model_name, + tools=tools or OMIT, + tool_choice=tool_choice or OMIT, + mcp_servers=mcp_servers or OMIT, + output_format=output_format or OMIT, + betas=sorted(betas) or OMIT, + stream=stream, + thinking=model_settings.get('anthropic_thinking', OMIT), + stop_sequences=model_settings.get('stop_sequences', OMIT), + temperature=model_settings.get('temperature', OMIT), + top_p=model_settings.get('top_p', OMIT), + timeout=model_settings.get('timeout', NOT_GIVEN), + metadata=model_settings.get('anthropic_metadata', OMIT), + extra_headers=extra_headers, + extra_body=model_settings.get('extra_body'), + ) + + # Handle pause_turn for non-streaming + assert isinstance(response, BetaMessage) + if response.stop_reason == 'pause_turn': + # Append assistant message to history and continue + anthropic_messages.append( + { + 'role': 'assistant', + 'content': response.content, + } + ) + continue + + return response + + except APIStatusError as e: + if (status_code := e.status_code) >= 400: + raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e + raise ModelAPIError(model_name=self.model_name, message=e.message) from e # pragma: lax no cover + except APIConnectionError as e: + raise ModelAPIError(model_name=self.model_name, message=e.message) from e def _get_betas_and_extra_headers( self, @@ -512,6 +530,16 @@ def _process_response(self, response: BetaMessage) -> ModelResponse: call_part = builtin_tool_calls.get(item.tool_use_id) items.append(_map_mcp_server_result_block(item, call_part, self.system)) else: + # Fallback for new block types like `bash_code_execution_tool_result` if they aren't explicitly typed yet + # or if we want to handle them generically. + # For now, we'll try to handle `bash_code_execution_tool_result` if it appears as a dict or unknown type, + # but since `response.content` is typed as a union of specific blocks, we might need to rely on `model_dump` or similar if the SDK doesn't support it yet. + # However, the user request says "Handle the bash_code_execution_tool_result event type". + # If `anthropic` SDK doesn't have it, we might not see it here unless we upgrade or it's in `BetaContentBlock`. + # Assuming `BetaCodeExecutionToolResultBlock` covers it or we need to add a check. + # Let's assume for now `BetaCodeExecutionToolResultBlock` is sufficient or we'll see. + # But wait, `bash_code_execution_tool_result` implies a specific type. + # Let's check if we can import it. assert isinstance(item, BetaToolUseBlock), f'unexpected item type {type(item)}' items.append( ToolCallPart( @@ -1175,6 +1203,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: vendor_part_id=event.index, part=_map_code_execution_tool_result_block(current_block, self.provider_name), ) + elif isinstance(current_block, BetaWebFetchToolResultBlock): # pragma: lax no cover yield self._parts_manager.handle_part( vendor_part_id=event.index, @@ -1287,6 +1316,14 @@ def _map_server_tool_use_block(item: BetaServerToolUseBlock, provider_name: str) args=cast(dict[str, Any], item.input) or None, tool_call_id=item.id, ) + elif item.name == 'bash_code_execution': + return BuiltinToolCallPart( + provider_name=provider_name, + tool_name=CodeExecutionTool.kind, + args=cast(dict[str, Any], item.input) or None, + tool_call_id=item.id, + ) + elif item.name == 'web_fetch': return BuiltinToolCallPart( provider_name=provider_name, diff --git a/tests/models/test_anthropic_pause_turn.py b/tests/models/test_anthropic_pause_turn.py new file mode 100644 index 0000000000..9a62ac4269 --- /dev/null +++ b/tests/models/test_anthropic_pause_turn.py @@ -0,0 +1,62 @@ +from __future__ import annotations as _annotations + +import pytest +from inline_snapshot import snapshot + +from pydantic_ai import Agent +from pydantic_ai.models.anthropic import AnthropicModel +from pydantic_ai.providers.anthropic import AnthropicProvider + +from ..conftest import try_import +from .test_anthropic import MockAnthropic, completion_message + +with try_import() as imports_successful: + from anthropic.types.beta import ( + BetaTextBlock, + BetaUsage, + BetaMessage, + ) + +pytestmark = [ + pytest.mark.skipif(not imports_successful(), reason='anthropic not installed'), + pytest.mark.anyio, +] + +async def test_pause_turn_retry_loop(allow_model_requests: None): + # Mock a sequence of responses: + # 1. pause_turn response + # 2. final response + + c1 = completion_message( + [BetaTextBlock(text='paused', type='text')], + usage=BetaUsage(input_tokens=10, output_tokens=5), + ) + c1.stop_reason = 'pause_turn' # type: ignore + + c2 = completion_message( + [BetaTextBlock(text='final', type='text')], + usage=BetaUsage(input_tokens=10, output_tokens=5), + ) + + mock_client = MockAnthropic.create_mock([c1, c2]) + m = AnthropicModel('claude-3-5-sonnet-20241022', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + result = await agent.run('test prompt') + + # Verify the agent received the final response + assert result.output == 'final' + + # Verify the loop happened (2 requests) + assert len(mock_client.chat_completion_kwargs) == 2 + + # Verify history in second request includes the paused message + messages_2 = mock_client.chat_completion_kwargs[1]['messages'] + # Should be: User -> Assistant(paused) + assert len(messages_2) == 2 + assert messages_2[1]['role'] == 'assistant' + # Content is a list of BetaContentBlock objects, get the text from first block + content_blocks = messages_2[1]['content'] + assert len(content_blocks) > 0 + first_block = content_blocks[0] + assert hasattr(first_block, 'text') and first_block.text == 'paused'