Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def model_name(self) -> str:
def provider_name(self) -> str:
return self.response.provider_name or '' # pragma: no cover

@property
def provider_url(self) -> str | None:
return self.response.provider_url # pragma: no cover

@property
def timestamp(self) -> datetime:
return self.response.timestamp # pragma: no cover
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def model_name(self) -> str:
def provider_name(self) -> str:
return self.response.provider_name or '' # pragma: no cover

@property
def provider_url(self) -> str | None:
return self.response.provider_url # pragma: no cover

@property
def timestamp(self) -> datetime:
return self.response.timestamp # pragma: no cover
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def model_name(self) -> str:
def provider_name(self) -> str:
return self.response.provider_name or '' # pragma: no cover

@property
def provider_url(self) -> str | None:
return self.response.provider_url # pragma: no cover

@property
def timestamp(self) -> datetime:
return self.response.timestamp # pragma: no cover
Expand Down
5 changes: 4 additions & 1 deletion pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,9 @@ class ModelResponse:
provider_name: str | None = None
"""The name of the LLM provider that generated the response."""

provider_url: str | None = None
"""The base URL of the LLM provider that generated the response."""

provider_details: Annotated[
dict[str, Any] | None,
# `vendor_details` is deprecated, but we still want to support deserializing model responses stored in a DB before the name was changed
Expand Down Expand Up @@ -1340,7 +1343,7 @@ def cost(self) -> genai_types.PriceCalculation:
return calc_price(
self.usage,
self.model_name,
provider_id=self.provider_name,
provider_api_url=self.provider_url,
genai_request_timestamp=self.timestamp,
)

Expand Down
7 changes: 7 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,7 @@ def get(self) -> ModelResponse:
timestamp=self.timestamp,
usage=self.usage(),
provider_name=self.provider_name,
provider_url=self.provider_url,
provider_response_id=self.provider_response_id,
provider_details=self.provider_details,
finish_reason=self.finish_reason,
Expand All @@ -898,6 +899,12 @@ def provider_name(self) -> str | None:
"""Get the provider name."""
raise NotImplementedError()

@property
@abstractmethod
def provider_url(self) -> str | None:
"""Get the provider base URL."""
raise NotImplementedError()

@property
@abstractmethod
def timestamp(self) -> datetime:
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
model_name=response.model,
provider_response_id=response.id,
provider_name=self._provider.name,
provider_url=self._provider.base_url,
finish_reason=finish_reason,
provider_details=provider_details,
)
Expand Down Expand Up @@ -1266,6 +1267,11 @@ def provider_name(self) -> str:
"""Get the provider name."""
return self._provider_name

@property
def provider_url(self) -> str:
"""Get the provider base URL."""
return self._provider_url

@property
def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ async def request_stream(
_model_name=self.model_name,
_event_stream=response['stream'],
_provider_name=self._provider.name,
_provider_url=self.base_url,
_provider_response_id=response.get('ResponseMetadata', {}).get('RequestId', None),
)

Expand Down Expand Up @@ -388,6 +389,7 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes
model_name=self.model_name,
provider_response_id=response_id,
provider_name=self._provider.name,
provider_url=self.base_url,
finish_reason=finish_reason,
provider_details=provider_details,
)
Expand Down Expand Up @@ -706,6 +708,7 @@ class BedrockStreamedResponse(StreamedResponse):
_model_name: BedrockModelName
_event_stream: EventStream[ConverseStreamOutputTypeDef]
_provider_name: str
_provider_url: str
_timestamp: datetime = field(default_factory=_utils.now_utc)
_provider_response_id: str | None = None

Expand Down Expand Up @@ -793,6 +796,11 @@ def provider_name(self) -> str:
"""Get the provider name."""
return self._provider_name

@property
def provider_url(self) -> str:
"""Get the provider base URL."""
return self._provider_url

@property
def timestamp(self) -> datetime:
return self._timestamp
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def _process_response(self, response: V2ChatResponse) -> ModelResponse:
usage=_map_usage(response),
model_name=self._model_name,
provider_name=self._provider.name,
provider_url=self.base_url,
finish_reason=finish_reason,
provider_details=provider_details,
)
Expand Down
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,11 @@ def provider_name(self) -> None:
"""Get the provider name."""
return None

@property
def provider_url(self) -> None:
"""Get the provider base URL."""
return None

@property
def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
Expand Down
20 changes: 17 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def __init__(

@property
def base_url(self) -> str:
assert self._url is not None, 'URL not initialized' # pragma: no cover
return self._url # pragma: no cover
assert self._url is not None, 'URL not initialized'
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[note]
This change now calls this property to retrieve provider_url, so it's now covered by tests.
Removing # pragma: no cover.

return self._url

@property
def model_name(self) -> GeminiModelName:
Expand Down Expand Up @@ -298,6 +298,7 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse:
usage,
vendor_id=vendor_id,
vendor_details=vendor_details,
provider_url=self.base_url,
)

async def _process_streamed_response(
Expand Down Expand Up @@ -329,6 +330,7 @@ async def _process_streamed_response(
_content=content,
_stream=aiter_bytes,
_provider_name=self._provider.name,
_provider_url=self.base_url,
)

async def _message_to_gemini_content(
Expand Down Expand Up @@ -453,6 +455,7 @@ class GeminiStreamedResponse(StreamedResponse):
_content: bytearray
_stream: AsyncIterator[bytes]
_provider_name: str
_provider_url: str
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
Expand Down Expand Up @@ -527,6 +530,11 @@ def provider_name(self) -> str:
"""Get the provider name."""
return self._provider_name

@property
def provider_url(self) -> str:
"""Get the provider base URL."""
return self._provider_url

@property
def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
Expand Down Expand Up @@ -713,6 +721,7 @@ def _process_response_from_parts(
model_name: GeminiModelName,
usage: usage.RequestUsage,
vendor_id: str | None,
provider_url: str,
vendor_details: dict[str, Any] | None = None,
) -> ModelResponse:
items: list[ModelResponsePart] = []
Expand All @@ -731,7 +740,12 @@ def _process_response_from_parts(
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
)
return ModelResponse(
parts=items, usage=usage, model_name=model_name, provider_response_id=vendor_id, provider_details=vendor_details
parts=items,
usage=usage,
model_name=model_name,
provider_response_id=vendor_id,
provider_details=vendor_details,
provider_url=provider_url,
)


Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
candidate.grounding_metadata,
response.model_version or self._model_name,
self._provider.name,
self._provider.base_url,
usage,
vendor_id=vendor_id,
vendor_details=vendor_details,
Expand Down Expand Up @@ -780,6 +781,11 @@ def provider_name(self) -> str:
"""Get the provider name."""
return self._provider_name

@property
def provider_url(self) -> str:
"""Get the provider base URL."""
return self._provider_url

@property
def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
Expand Down Expand Up @@ -858,6 +864,7 @@ def _process_response_from_parts(
grounding_metadata: GroundingMetadata | None,
model_name: GoogleModelName,
provider_name: str,
provider_url: str,
usage: usage.RequestUsage,
vendor_id: str | None,
vendor_details: dict[str, Any] | None = None,
Expand Down Expand Up @@ -927,6 +934,7 @@ def _process_response_from_parts(
provider_response_id=vendor_id,
provider_details=vendor_details,
provider_name=provider_name,
provider_url=provider_url,
finish_reason=finish_reason,
)

Expand Down
9 changes: 9 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ async def request(
model_name=e.model_name,
timestamp=_utils.now_utc(),
provider_name=self._provider.name,
provider_url=self.base_url,
finish_reason='error',
)
except ValidationError:
Expand Down Expand Up @@ -349,6 +350,7 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
timestamp=timestamp,
provider_response_id=response.id,
provider_name=self._provider.name,
provider_url=self.base_url,
finish_reason=finish_reason,
provider_details=provider_details,
)
Expand All @@ -371,6 +373,7 @@ async def _process_streamed_response(
_model_profile=self.profile,
_timestamp=number_to_datetime(first_chunk.created),
_provider_name=self._provider.name,
_provider_url=self.base_url,
)

def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
Expand Down Expand Up @@ -524,6 +527,7 @@ class GroqStreamedResponse(StreamedResponse):
_response: AsyncIterable[chat.ChatCompletionChunk]
_timestamp: datetime
_provider_name: str
_provider_url: str

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
try:
Expand Down Expand Up @@ -621,6 +625,11 @@ def provider_name(self) -> str:
"""Get the provider name."""
return self._provider_name

@property
def provider_url(self) -> str:
"""Get the provider base URL."""
return self._provider_url

@property
def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
Expand Down
13 changes: 13 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ def __init__(

super().__init__(settings=settings, profile=profile or provider.model_profile)

@property
def base_url(self) -> str:
"""The base URL of the provider."""
return self._provider.base_url

@property
def model_name(self) -> HuggingFaceModelName:
"""The model name."""
Expand Down Expand Up @@ -295,6 +300,7 @@ def _process_response(self, response: ChatCompletionOutput) -> ModelResponse:
timestamp=timestamp,
provider_response_id=response.id,
provider_name=self._provider.name,
provider_url=self.base_url,
finish_reason=finish_reason,
provider_details=provider_details,
)
Expand All @@ -317,6 +323,7 @@ async def _process_streamed_response(
_response=peekable_response,
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
_provider_name=self._provider.name,
_provider_url=self.base_url,
)

def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]:
Expand Down Expand Up @@ -465,6 +472,7 @@ class HuggingFaceStreamedResponse(StreamedResponse):
_response: AsyncIterable[ChatCompletionStreamOutput]
_timestamp: datetime
_provider_name: str
_provider_url: str

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
async for chunk in self._response:
Expand Down Expand Up @@ -515,6 +523,11 @@ def provider_name(self) -> str:
"""Get the provider name."""
return self._provider_name

@property
def provider_url(self) -> str:
"""Get the provider base URL."""
return self._provider_url

@property
def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def _process_response(self, response: MistralChatCompletionResponse) -> ModelRes
timestamp=timestamp,
provider_response_id=response.id,
provider_name=self._provider.name,
provider_url=self._provider.base_url,
finish_reason=finish_reason,
provider_details=provider_details,
)
Expand Down Expand Up @@ -408,6 +409,7 @@ async def _process_streamed_response(
_model_name=first_chunk.data.model,
_timestamp=timestamp,
_provider_name=self._provider.name,
_provider_url=self._provider.base_url,
)

@staticmethod
Expand Down Expand Up @@ -615,6 +617,7 @@ class MistralStreamedResponse(StreamedResponse):
_response: AsyncIterable[MistralCompletionEvent]
_timestamp: datetime
_provider_name: str
_provider_url: str

_delta_content: str = field(default='', init=False)

Expand Down Expand Up @@ -676,6 +679,11 @@ def provider_name(self) -> str:
"""Get the provider name."""
return self._provider_name

@property
def provider_url(self) -> str:
"""Get the provider base URL."""
return self._provider_url

@property
def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
Expand Down
Loading