pydantic-ai-slim 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -77,6 +77,7 @@ class InstrumentationSettings:
77
77
  tracer: Tracer = field(repr=False)
78
78
  event_logger: EventLogger = field(repr=False)
79
79
  event_mode: Literal['attributes', 'logs'] = 'attributes'
80
+ include_binary_content: bool = True
80
81
 
81
82
  def __init__(
82
83
  self,
@@ -84,6 +85,7 @@ class InstrumentationSettings:
84
85
  event_mode: Literal['attributes', 'logs'] = 'attributes',
85
86
  tracer_provider: TracerProvider | None = None,
86
87
  event_logger_provider: EventLoggerProvider | None = None,
88
+ include_binary_content: bool = True,
87
89
  ):
88
90
  """Create instrumentation options.
89
91
 
@@ -97,6 +99,7 @@ class InstrumentationSettings:
97
99
  If not provided, the global event logger provider is used.
98
100
  Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
99
101
  This is only used if `event_mode='logs'`.
102
+ include_binary_content: Whether to include binary content in the instrumentation events.
100
103
  """
101
104
  from pydantic_ai import __version__
102
105
 
@@ -105,6 +108,40 @@ class InstrumentationSettings:
105
108
  self.tracer = tracer_provider.get_tracer('pydantic-ai', __version__)
106
109
  self.event_logger = event_logger_provider.get_event_logger('pydantic-ai', __version__)
107
110
  self.event_mode = event_mode
111
+ self.include_binary_content = include_binary_content
112
+
113
+ def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]:
114
+ """Convert a list of model messages to OpenTelemetry events.
115
+
116
+ Args:
117
+ messages: The messages to convert.
118
+
119
+ Returns:
120
+ A list of OpenTelemetry events.
121
+ """
122
+ events: list[Event] = []
123
+ instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage]
124
+ if instructions is not None:
125
+ events.append(Event('gen_ai.system.message', body={'content': instructions, 'role': 'system'}))
126
+
127
+ for message_index, message in enumerate(messages):
128
+ message_events: list[Event] = []
129
+ if isinstance(message, ModelRequest):
130
+ for part in message.parts:
131
+ if hasattr(part, 'otel_event'):
132
+ message_events.append(part.otel_event(self))
133
+ elif isinstance(message, ModelResponse): # pragma: no branch
134
+ message_events = message.otel_events()
135
+ for event in message_events:
136
+ event.attributes = {
137
+ 'gen_ai.message.index': message_index,
138
+ **(event.attributes or {}),
139
+ }
140
+ events.extend(message_events)
141
+
142
+ for event in events:
143
+ event.body = InstrumentedModel.serialize_any(event.body)
144
+ return events
108
145
 
109
146
 
110
147
  GEN_AI_SYSTEM_ATTRIBUTE = 'gen_ai.system'
@@ -155,7 +192,7 @@ class InstrumentedModel(WrapperModel):
155
192
  ) as response_stream:
156
193
  yield response_stream
157
194
  finally:
158
- if response_stream:
195
+ if response_stream: # pragma: no branch
159
196
  finish(response_stream.get())
160
197
 
161
198
  @contextmanager
@@ -193,8 +230,8 @@ class InstrumentedModel(WrapperModel):
193
230
  if not span.is_recording():
194
231
  return
195
232
 
196
- events = self.messages_to_otel_events(messages)
197
- for event in self.messages_to_otel_events([response]):
233
+ events = self.settings.messages_to_otel_events(messages)
234
+ for event in self.settings.messages_to_otel_events([response]):
198
235
  events.append(
199
236
  Event(
200
237
  'gen_ai.choice',
@@ -253,9 +290,9 @@ class InstrumentedModel(WrapperModel):
253
290
  except Exception: # pragma: no cover
254
291
  pass
255
292
  else:
256
- if parsed.hostname:
293
+ if parsed.hostname: # pragma: no branch
257
294
  attributes['server.address'] = parsed.hostname
258
- if parsed.port:
295
+ if parsed.port: # pragma: no branch
259
296
  attributes['server.port'] = parsed.port
260
297
 
261
298
  return attributes
@@ -263,40 +300,13 @@ class InstrumentedModel(WrapperModel):
263
300
  @staticmethod
264
301
  def event_to_dict(event: Event) -> dict[str, Any]:
265
302
  if not event.body:
266
- body = {}
303
+ body = {} # pragma: no cover
267
304
  elif isinstance(event.body, Mapping):
268
305
  body = event.body # type: ignore
269
306
  else:
270
307
  body = {'body': event.body}
271
308
  return {**body, **(event.attributes or {})}
272
309
 
273
- @staticmethod
274
- def messages_to_otel_events(messages: list[ModelMessage]) -> list[Event]:
275
- events: list[Event] = []
276
- last_model_request: ModelRequest | None = None
277
- for message_index, message in enumerate(messages):
278
- message_events: list[Event] = []
279
- if isinstance(message, ModelRequest):
280
- last_model_request = message
281
- for part in message.parts:
282
- if hasattr(part, 'otel_event'):
283
- message_events.append(part.otel_event())
284
- elif isinstance(message, ModelResponse):
285
- message_events = message.otel_events()
286
- for event in message_events:
287
- event.attributes = {
288
- 'gen_ai.message.index': message_index,
289
- **(event.attributes or {}),
290
- }
291
- events.extend(message_events)
292
- if last_model_request and last_model_request.instructions:
293
- events.insert(
294
- 0, Event('gen_ai.system.message', body={'content': last_model_request.instructions, 'role': 'system'})
295
- )
296
- for event in events:
297
- event.body = InstrumentedModel.serialize_any(event.body)
298
- return events
299
-
300
310
  @staticmethod
301
311
  def serialize_any(value: Any) -> str:
302
312
  try:
@@ -71,7 +71,7 @@ try:
71
71
  from mistralai.models.usermessage import UserMessage as MistralUserMessage
72
72
  from mistralai.types.basemodel import Unset as MistralUnset
73
73
  from mistralai.utils.eventstreaming import EventStreamAsync as MistralEventStreamAsync
74
- except ImportError as e:
74
+ except ImportError as e: # pragma: lax no cover
75
75
  raise ImportError(
76
76
  'Please install `mistral` to use the Mistral model, '
77
77
  'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`'
@@ -208,7 +208,7 @@ class MistralModel(Model):
208
208
  except SDKError as e:
209
209
  if (status_code := e.status_code) >= 400:
210
210
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
211
- raise
211
+ raise # pragma: lax no cover
212
212
 
213
213
  assert response, 'A unexpected empty response from Mistral.'
214
214
  return response
@@ -325,7 +325,9 @@ class MistralModel(Model):
325
325
  tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call)
326
326
  parts.append(tool)
327
327
 
328
- return ModelResponse(parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
328
+ return ModelResponse(
329
+ parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
330
+ )
329
331
 
330
332
  async def _process_streamed_response(
331
333
  self,
@@ -336,7 +338,9 @@ class MistralModel(Model):
336
338
  peekable_response = _utils.PeekableAsyncStream(response)
337
339
  first_chunk = await peekable_response.peek()
338
340
  if isinstance(first_chunk, _utils.Unset):
339
- raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
341
+ raise UnexpectedModelBehavior( # pragma: no cover
342
+ 'Streamed response ended without content or tool calls'
343
+ )
340
344
 
341
345
  if first_chunk.data.created:
342
346
  timestamp = datetime.fromtimestamp(first_chunk.data.created, tz=timezone.utc)
@@ -437,7 +441,7 @@ class MistralModel(Model):
437
441
  """Convert a timeout to milliseconds."""
438
442
  if timeout is None:
439
443
  return None
440
- if isinstance(timeout, float):
444
+ if isinstance(timeout, float): # pragma: no cover
441
445
  return int(1000 * timeout)
442
446
  raise NotImplementedError('Timeout object is not yet supported for MistralModel.')
443
447
 
@@ -454,7 +458,7 @@ class MistralModel(Model):
454
458
  )
455
459
  elif isinstance(part, RetryPromptPart):
456
460
  if part.tool_name is None:
457
- yield MistralUserMessage(content=part.model_response())
461
+ yield MistralUserMessage(content=part.model_response()) # pragma: no cover
458
462
  else:
459
463
  yield MistralToolMessage(
460
464
  tool_call_id=part.tool_call_id,
@@ -519,7 +523,7 @@ class MistralModel(Model):
519
523
  else:
520
524
  raise RuntimeError('Only image binary content is supported for Mistral.')
521
525
  elif isinstance(item, DocumentUrl):
522
- raise RuntimeError('DocumentUrl is not supported in Mistral.')
526
+ raise RuntimeError('DocumentUrl is not supported in Mistral.') # pragma: no cover
523
527
  elif isinstance(item, VideoUrl):
524
528
  raise RuntimeError('VideoUrl is not supported in Mistral.')
525
529
  else: # pragma: no cover
@@ -663,7 +667,7 @@ def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk)
663
667
  details=None,
664
668
  )
665
669
  else:
666
- return Usage()
670
+ return Usage() # pragma: no cover
667
671
 
668
672
 
669
673
  def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
@@ -677,7 +681,9 @@ def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None
677
681
  if isinstance(chunk, MistralTextChunk):
678
682
  output = output or '' + chunk.text
679
683
  else:
680
- assert False, f'Other data types like (Image, Reference) are not yet supported, got {type(chunk)}'
684
+ assert False, ( # pragma: no cover
685
+ f'Other data types like (Image, Reference) are not yet supported, got {type(chunk)}'
686
+ )
681
687
  elif isinstance(content, str):
682
688
  output = content
683
689
 
@@ -104,6 +104,12 @@ class OpenAIModelSettings(ModelSettings, total=False):
104
104
  result in faster responses and fewer tokens used on reasoning in a response.
105
105
  """
106
106
 
107
+ openai_logprobs: bool
108
+ """Include log probabilities in the response."""
109
+
110
+ openai_top_logprobs: int
111
+ """Include log probabilities of the top n tokens in the response."""
112
+
107
113
  openai_user: str
108
114
  """A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse.
109
115
 
@@ -164,7 +170,7 @@ class OpenAIModel(Model):
164
170
  self,
165
171
  model_name: OpenAIModelName,
166
172
  *,
167
- provider: Literal['openai', 'deepseek', 'azure'] | Provider[AsyncOpenAI] = 'openai',
173
+ provider: Literal['openai', 'deepseek', 'azure', 'openrouter'] | Provider[AsyncOpenAI] = 'openai',
168
174
  system_prompt_role: OpenAISystemPromptRole | None = None,
169
175
  ):
170
176
  """Initialize an OpenAI model.
@@ -287,6 +293,8 @@ class OpenAIModel(Model):
287
293
  frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
288
294
  logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
289
295
  reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
296
+ logprobs=model_settings.get('openai_logprobs', NOT_GIVEN),
297
+ top_logprobs=model_settings.get('openai_top_logprobs', NOT_GIVEN),
290
298
  user=model_settings.get('openai_user', NOT_GIVEN),
291
299
  extra_headers=extra_headers,
292
300
  extra_body=model_settings.get('extra_body'),
@@ -294,26 +302,54 @@ class OpenAIModel(Model):
294
302
  except APIStatusError as e:
295
303
  if (status_code := e.status_code) >= 400:
296
304
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
297
- raise
305
+ raise # pragma: lax no cover
298
306
 
299
307
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
300
308
  """Process a non-streamed response, and prepare a message to return."""
301
309
  timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
302
310
  choice = response.choices[0]
303
311
  items: list[ModelResponsePart] = []
312
+ vendor_details: dict[str, Any] | None = None
313
+
314
+ # Add logprobs to vendor_details if available
315
+ if choice.logprobs is not None and choice.logprobs.content:
316
+ # Convert logprobs to a serializable format
317
+ vendor_details = {
318
+ 'logprobs': [
319
+ {
320
+ 'token': lp.token,
321
+ 'bytes': lp.bytes,
322
+ 'logprob': lp.logprob,
323
+ 'top_logprobs': [
324
+ {'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
325
+ ],
326
+ }
327
+ for lp in choice.logprobs.content
328
+ ],
329
+ }
330
+
304
331
  if choice.message.content is not None:
305
332
  items.append(TextPart(choice.message.content))
306
333
  if choice.message.tool_calls is not None:
307
334
  for c in choice.message.tool_calls:
308
335
  items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
309
- return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
336
+ return ModelResponse(
337
+ items,
338
+ usage=_map_usage(response),
339
+ model_name=response.model,
340
+ timestamp=timestamp,
341
+ vendor_details=vendor_details,
342
+ vendor_id=response.id,
343
+ )
310
344
 
311
345
  async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
312
346
  """Process a streamed response, and prepare a streaming response to return."""
313
347
  peekable_response = _utils.PeekableAsyncStream(response)
314
348
  first_chunk = await peekable_response.peek()
315
349
  if isinstance(first_chunk, _utils.Unset):
316
- raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
350
+ raise UnexpectedModelBehavior( # pragma: no cover
351
+ 'Streamed response ended without content or tool calls'
352
+ )
317
353
 
318
354
  return OpenAIStreamedResponse(
319
355
  _model_name=self._model_name,
@@ -399,7 +435,9 @@ class OpenAIModel(Model):
399
435
  )
400
436
  elif isinstance(part, RetryPromptPart):
401
437
  if part.tool_name is None:
402
- yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
438
+ yield chat.ChatCompletionUserMessageParam( # pragma: no cover
439
+ role='user', content=part.model_response()
440
+ )
403
441
  else:
404
442
  yield chat.ChatCompletionToolMessageParam(
405
443
  role='tool',
@@ -637,7 +675,7 @@ class OpenAIResponsesModel(Model):
637
675
  except APIStatusError as e:
638
676
  if (status_code := e.status_code) >= 400:
639
677
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
640
- raise
678
+ raise # pragma: lax no cover
641
679
 
642
680
  def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reasoning | NotGiven:
643
681
  reasoning_effort = model_settings.get('openai_reasoning_effort', None)
@@ -867,7 +905,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
867
905
  args=chunk.delta,
868
906
  tool_call_id=chunk.item_id,
869
907
  )
870
- if maybe_event is not None:
908
+ if maybe_event is not None: # pragma: no branch
871
909
  yield maybe_event
872
910
 
873
911
  elif isinstance(chunk, responses.ResponseFunctionCallArgumentsDoneEvent):
@@ -1031,7 +1069,7 @@ class _OpenAIJsonSchema(WalkJsonSchema):
1031
1069
  notes.append(f'{key}={value}')
1032
1070
  notes_string = ', '.join(notes)
1033
1071
  schema['description'] = notes_string if not description else f'{description} ({notes_string})'
1034
- elif self.strict is None:
1072
+ elif self.strict is None: # pragma: no branch
1035
1073
  self.is_strict_compatible = False
1036
1074
 
1037
1075
  schema_type = schema.get('type')
@@ -169,7 +169,7 @@ class TestModel(Model):
169
169
  model_name=self._model_name,
170
170
  )
171
171
 
172
- if messages:
172
+ if messages: # pragma: no branch
173
173
  last_message = messages[-1]
174
174
  assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.'
175
175
 
@@ -48,4 +48,4 @@ class WrapperModel(Model):
48
48
  return self.wrapped.system
49
49
 
50
50
  def __getattr__(self, item: str):
51
- return getattr(self.wrapped, item)
51
+ return getattr(self.wrapped, item) # pragma: no cover
@@ -52,6 +52,10 @@ def infer_provider(provider: str) -> Provider[Any]:
52
52
  from .deepseek import DeepSeekProvider
53
53
 
54
54
  return DeepSeekProvider()
55
+ elif provider == 'openrouter':
56
+ from .openrouter import OpenRouterProvider
57
+
58
+ return OpenRouterProvider()
55
59
  elif provider == 'azure':
56
60
  from .azure import AzureProvider
57
61
 
@@ -0,0 +1,143 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import Literal, overload
5
+
6
+ from pydantic_ai.exceptions import UserError
7
+ from pydantic_ai.models import get_user_agent
8
+ from pydantic_ai.providers import Provider
9
+
10
+ try:
11
+ from google import genai
12
+ from google.auth.credentials import Credentials
13
+ except ImportError as _import_error:
14
+ raise ImportError(
15
+ 'Please install the `google-genai` package to use the Google provider, '
16
+ 'you can use the `google` optional group — `pip install "pydantic-ai-slim[google]"`'
17
+ ) from _import_error
18
+
19
+
20
+ class GoogleProvider(Provider[genai.Client]):
21
+ """Provider for Google."""
22
+
23
+ @property
24
+ def name(self) -> str:
25
+ return 'google-vertex' if self._client._api_client.vertexai else 'google-gla' # type: ignore[reportPrivateUsage]
26
+
27
+ @property
28
+ def base_url(self) -> str:
29
+ return str(self._client._api_client._http_options.base_url) # type: ignore[reportPrivateUsage]
30
+
31
+ @property
32
+ def client(self) -> genai.Client:
33
+ return self._client
34
+
35
+ @overload
36
+ def __init__(self, *, api_key: str) -> None: ...
37
+
38
+ @overload
39
+ def __init__(
40
+ self,
41
+ *,
42
+ credentials: Credentials | None = None,
43
+ project: str | None = None,
44
+ location: VertexAILocation | Literal['global'] | None = None,
45
+ ) -> None: ...
46
+
47
+ @overload
48
+ def __init__(self, *, client: genai.Client) -> None: ...
49
+
50
+ @overload
51
+ def __init__(self, *, vertexai: bool = False) -> None: ...
52
+
53
+ def __init__(
54
+ self,
55
+ *,
56
+ api_key: str | None = None,
57
+ credentials: Credentials | None = None,
58
+ project: str | None = None,
59
+ location: VertexAILocation | Literal['global'] | None = None,
60
+ client: genai.Client | None = None,
61
+ vertexai: bool | None = None,
62
+ ) -> None:
63
+ """Create a new Google provider.
64
+
65
+ Args:
66
+ api_key: The `API key <https://ai.google.dev/gemini-api/docs/api-key>`_ to
67
+ use for authentication. It can also be set via the `GOOGLE_API_KEY` environment variable.
68
+ Applies to the Gemini Developer API only.
69
+ credentials: The credentials to use for authentication when calling the Vertex AI APIs. Credentials can be
70
+ obtained from environment variables and default credentials. For more information, see Set up
71
+ Application Default Credentials. Applies to the Vertex AI API only.
72
+ project: The Google Cloud project ID to use for quota. Can be obtained from environment variables
73
+ (for example, GOOGLE_CLOUD_PROJECT). Applies to the Vertex AI API only.
74
+ location: The location to send API requests to (for example, us-central1). Can be obtained from environment variables.
75
+ Applies to the Vertex AI API only.
76
+ client: A pre-initialized client to use.
77
+ vertexai: Force the use of the Vertex AI API. If `False`, the Google Generative Language API will be used.
78
+ Defaults to `False`.
79
+ """
80
+ if client is None:
81
+ # NOTE: We are keeping GEMINI_API_KEY for backwards compatibility.
82
+ api_key = api_key or os.environ.get('GOOGLE_API_KEY')
83
+
84
+ if vertexai is None: # pragma: lax no cover
85
+ vertexai = bool(location or project or credentials)
86
+
87
+ if not vertexai:
88
+ if api_key is None:
89
+ raise UserError( # pragma: no cover
90
+ 'Set the `GOOGLE_API_KEY` environment variable or pass it via `GoogleProvider(api_key=...)`'
91
+ 'to use the Google Generative Language API.'
92
+ )
93
+ self._client = genai.Client(
94
+ vertexai=vertexai,
95
+ api_key=api_key,
96
+ http_options={'headers': {'User-Agent': get_user_agent()}},
97
+ )
98
+ else:
99
+ self._client = genai.Client(
100
+ vertexai=vertexai,
101
+ project=project or os.environ.get('GOOGLE_CLOUD_PROJECT'),
102
+ location=location or os.environ.get('GOOGLE_CLOUD_LOCATION'),
103
+ credentials=credentials,
104
+ http_options={'headers': {'User-Agent': get_user_agent()}},
105
+ )
106
+ else:
107
+ self._client = client # pragma: lax no cover
108
+
109
+
110
+ VertexAILocation = Literal[
111
+ 'asia-east1',
112
+ 'asia-east2',
113
+ 'asia-northeast1',
114
+ 'asia-northeast3',
115
+ 'asia-south1',
116
+ 'asia-southeast1',
117
+ 'australia-southeast1',
118
+ 'europe-central2',
119
+ 'europe-north1',
120
+ 'europe-southwest1',
121
+ 'europe-west1',
122
+ 'europe-west2',
123
+ 'europe-west3',
124
+ 'europe-west4',
125
+ 'europe-west6',
126
+ 'europe-west8',
127
+ 'europe-west9',
128
+ 'me-central1',
129
+ 'me-central2',
130
+ 'me-west1',
131
+ 'northamerica-northeast1',
132
+ 'southamerica-east1',
133
+ 'us-central1',
134
+ 'us-east1',
135
+ 'us-east4',
136
+ 'us-east5',
137
+ 'us-south1',
138
+ 'us-west1',
139
+ 'us-west4',
140
+ ]
141
+ """Regions available for Vertex AI.
142
+ More details [here](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#genai-locations).
143
+ """
@@ -128,7 +128,7 @@ class _VertexAIAuth(httpx.Auth):
128
128
  self.credentials = None
129
129
 
130
130
  async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
131
- if self.credentials is None:
131
+ if self.credentials is None: # pragma: no branch
132
132
  self.credentials = await self._get_credentials()
133
133
  if self.credentials.token is None: # type: ignore[reportUnknownMemberType]
134
134
  await self._refresh_token()
@@ -157,9 +157,9 @@ class _VertexAIAuth(httpx.Auth):
157
157
  creds, creds_project_id = await _async_google_auth()
158
158
  creds_source = '`google.auth.default()`'
159
159
 
160
- if self.project_id is None:
160
+ if self.project_id is None: # pragma: no branch
161
161
  if creds_project_id is None:
162
- raise UserError(f'No project_id provided and none found in {creds_source}')
162
+ raise UserError(f'No project_id provided and none found in {creds_source}') # pragma: no cover
163
163
  self.project_id = creds_project_id
164
164
  return creds
165
165
 
@@ -0,0 +1,69 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import overload
5
+
6
+ from httpx import AsyncClient as AsyncHTTPClient
7
+ from openai import AsyncOpenAI
8
+
9
+ from pydantic_ai.exceptions import UserError
10
+ from pydantic_ai.models import cached_async_http_client
11
+ from pydantic_ai.providers import Provider
12
+
13
+ try:
14
+ from openai import AsyncOpenAI
15
+ except ImportError as _import_error: # pragma: no cover
16
+ raise ImportError(
17
+ 'Please install the `openai` package to use the OpenRouter provider, '
18
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
19
+ ) from _import_error
20
+
21
+
22
+ class OpenRouterProvider(Provider[AsyncOpenAI]):
23
+ """Provider for OpenRouter API."""
24
+
25
+ @property
26
+ def name(self) -> str:
27
+ return 'openrouter'
28
+
29
+ @property
30
+ def base_url(self) -> str:
31
+ return 'https://openrouter.ai/api/v1'
32
+
33
+ @property
34
+ def client(self) -> AsyncOpenAI:
35
+ return self._client
36
+
37
+ @overload
38
+ def __init__(self) -> None: ...
39
+
40
+ @overload
41
+ def __init__(self, *, api_key: str) -> None: ...
42
+
43
+ @overload
44
+ def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
45
+
46
+ @overload
47
+ def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
48
+
49
+ def __init__(
50
+ self,
51
+ *,
52
+ api_key: str | None = None,
53
+ openai_client: AsyncOpenAI | None = None,
54
+ http_client: AsyncHTTPClient | None = None,
55
+ ) -> None:
56
+ api_key = api_key or os.getenv('OPENROUTER_API_KEY')
57
+ if not api_key and openai_client is None:
58
+ raise UserError(
59
+ 'Set the `OPENROUTER_API_KEY` environment variable or pass it via `OpenRouterProvider(api_key=...)`'
60
+ 'to use the OpenRouter provider.'
61
+ )
62
+
63
+ if openai_client is not None:
64
+ self._client = openai_client
65
+ elif http_client is not None:
66
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
67
+ else:
68
+ http_client = cached_async_http_client(provider='openrouter')
69
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)