pydantic-ai-slim 0.0.31__py3-none-any.whl → 0.0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -2,13 +2,14 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import asyncio
4
4
  import dataclasses
5
+ import json
5
6
  from collections.abc import AsyncIterator, Iterator, Sequence
6
7
  from contextlib import asynccontextmanager, contextmanager
7
8
  from contextvars import ContextVar
8
9
  from dataclasses import field
9
10
  from typing import Any, Generic, Literal, Union, cast
10
11
 
11
- import logfire_api
12
+ from opentelemetry.trace import Span, Tracer
12
13
  from typing_extensions import TypeGuard, TypeVar, assert_never
13
14
 
14
15
  from pydantic_graph import BaseNode, Graph, GraphRunContext
@@ -42,17 +43,6 @@ __all__ = (
42
43
  'capture_run_messages',
43
44
  )
44
45
 
45
- _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
46
-
47
- # while waiting for https://github.com/pydantic/logfire/issues/745
48
- try:
49
- import logfire._internal.stack_info
50
- except ImportError:
51
- pass
52
- else:
53
- from pathlib import Path
54
-
55
- logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
56
46
 
57
47
  T = TypeVar('T')
58
48
  S = TypeVar('S')
@@ -105,7 +95,8 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
105
95
 
106
96
  function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
107
97
 
108
- run_span: logfire_api.LogfireSpan
98
+ run_span: Span
99
+ tracer: Tracer
109
100
 
110
101
 
111
102
  class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
@@ -330,7 +321,9 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
330
321
  ctx.state.run_step += 1
331
322
 
332
323
  model_settings = merge_model_settings(ctx.deps.model_settings, None)
333
- with _logfire.span('preparing model request params {run_step=}', run_step=ctx.state.run_step):
324
+ with ctx.deps.tracer.start_as_current_span(
325
+ 'preparing model request params', attributes=dict(run_step=ctx.state.run_step)
326
+ ):
334
327
  model_request_parameters = await _prepare_request_parameters(ctx)
335
328
  return model_settings, model_request_parameters
336
329
 
@@ -380,26 +373,12 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
380
373
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
381
374
  ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
382
375
  """Process the model response and yield events for the start and end of each function tool call."""
383
- with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
384
- stream = self._run_stream(ctx)
385
- yield stream
376
+ stream = self._run_stream(ctx)
377
+ yield stream
386
378
 
387
- # Run the stream to completion if it was not finished:
388
- async for _event in stream:
389
- pass
390
-
391
- # Set the next node based on the final state of the stream
392
- next_node = self._next_node
393
- if isinstance(next_node, End):
394
- handle_span.set_attribute('result', next_node.data)
395
- handle_span.message = 'handle model response -> final result'
396
- elif tool_responses := self._tool_responses:
397
- # TODO: We could drop `self._tool_responses` if we drop this set_attribute
398
- # I'm thinking it might be better to just create a span for the handling of each tool
399
- # than to set an attribute here.
400
- handle_span.set_attribute('tool_responses', tool_responses)
401
- tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
402
- handle_span.message = f'handle model response -> {tool_responses_str}'
379
+ # Run the stream to completion if it was not finished:
380
+ async for _event in stream:
381
+ pass
403
382
 
404
383
  async def _run_stream(
405
384
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -494,10 +473,29 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
494
473
  if tool_responses:
495
474
  messages.append(_messages.ModelRequest(parts=tool_responses))
496
475
 
497
- run_span.set_attribute('usage', usage)
498
- run_span.set_attribute(
499
- 'all_messages_events',
500
- [InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)],
476
+ run_span.set_attributes(
477
+ {
478
+ **usage.opentelemetry_attributes(),
479
+ 'all_messages_events': json.dumps(
480
+ [InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)]
481
+ ),
482
+ 'final_result': final_result.data
483
+ if isinstance(final_result.data, str)
484
+ else json.dumps(InstrumentedModel.serialize_any(final_result.data)),
485
+ }
486
+ )
487
+ run_span.set_attributes(
488
+ {
489
+ 'logfire.json_schema': json.dumps(
490
+ {
491
+ 'type': 'object',
492
+ 'properties': {
493
+ 'all_messages_events': {'type': 'array'},
494
+ 'final_result': {'type': 'object'},
495
+ },
496
+ }
497
+ ),
498
+ }
501
499
  )
502
500
 
503
501
  # End the run with self.data
@@ -619,7 +617,10 @@ async def process_function_tools(
619
617
 
620
618
  # Run all tool tasks in parallel
621
619
  results_by_index: dict[int, _messages.ModelRequestPart] = {}
622
- with _logfire.span('running {tools=}', tools=[call.tool_name for _, call in calls_to_run]):
620
+ tool_names = [call.tool_name for _, call in calls_to_run]
621
+ with ctx.deps.tracer.start_as_current_span(
622
+ 'running tools', attributes={'tools': tool_names, 'logfire.msg': f'running tools: {", ".join(tool_names)}'}
623
+ ):
623
624
  # TODO: Should we wrap each individual tool call in a dedicated span?
624
625
  tasks = [asyncio.create_task(tool.run(call, run_context), name=call.tool_name) for tool, call in calls_to_run]
625
626
  pending = tasks
pydantic_ai/_pydantic.py CHANGED
@@ -6,7 +6,7 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
6
6
  from __future__ import annotations as _annotations
7
7
 
8
8
  from inspect import Parameter, signature
9
- from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
9
+ from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast
10
10
 
11
11
  from pydantic import ConfigDict
12
12
  from pydantic._internal import _decorators, _generate_schema, _typing_extra
@@ -15,6 +15,7 @@ from pydantic.fields import FieldInfo
15
15
  from pydantic.json_schema import GenerateJsonSchema
16
16
  from pydantic.plugin._schema_validator import create_schema_validator
17
17
  from pydantic_core import SchemaValidator, core_schema
18
+ from typing_extensions import get_origin
18
19
 
19
20
  from ._griffe import doc_descriptions
20
21
  from ._utils import check_object_json_schema, is_model_like
@@ -223,8 +224,7 @@ def _build_schema(
223
224
 
224
225
 
225
226
  def _is_call_ctx(annotation: Any) -> bool:
227
+ """Return whether the annotation is the `RunContext` class, parameterized or not."""
226
228
  from .tools import RunContext
227
229
 
228
- return annotation is RunContext or (
229
- _typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
230
- )
230
+ return annotation is RunContext or get_origin(annotation) is RunContext
pydantic_ai/_result.py CHANGED
@@ -1,14 +1,14 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import inspect
4
- import sys
5
- import types
6
4
  from collections.abc import Awaitable, Iterable, Iterator
7
5
  from dataclasses import dataclass, field
8
- from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
6
+ from typing import Any, Callable, Generic, Literal, Union, cast
9
7
 
10
8
  from pydantic import TypeAdapter, ValidationError
11
- from typing_extensions import TypeAliasType, TypedDict, TypeVar
9
+ from typing_extensions import TypedDict, TypeVar, get_args, get_origin
10
+ from typing_inspection import typing_objects
11
+ from typing_inspection.introspection import is_union_origin
12
12
 
13
13
  from . import _utils, messages as _messages
14
14
  from .exceptions import ModelRetry
@@ -248,23 +248,12 @@ def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
248
248
 
249
249
 
250
250
  def get_union_args(tp: Any) -> tuple[Any, ...]:
251
- """Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty union."""
252
- if isinstance(tp, TypeAliasType):
251
+ """Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty tuple."""
252
+ if typing_objects.is_typealiastype(tp):
253
253
  tp = tp.__value__
254
254
 
255
255
  origin = get_origin(tp)
256
- if origin_is_union(origin):
256
+ if is_union_origin(origin):
257
257
  return get_args(tp)
258
258
  else:
259
259
  return ()
260
-
261
-
262
- if sys.version_info < (3, 10):
263
-
264
- def origin_is_union(tp: type[Any] | None) -> bool:
265
- return tp is Union
266
-
267
- else:
268
-
269
- def origin_is_union(tp: type[Any] | None) -> bool:
270
- return tp is Union or tp is types.UnionType
pydantic_ai/agent.py CHANGED
@@ -8,7 +8,7 @@ from copy import deepcopy
8
8
  from types import FrameType
9
9
  from typing import Any, Callable, Generic, cast, final, overload
10
10
 
11
- import logfire_api
11
+ from opentelemetry.trace import NoOpTracer, use_span
12
12
  from typing_extensions import TypeGuard, TypeVar, deprecated
13
13
 
14
14
  from pydantic_graph import End, Graph, GraphRun, GraphRunContext
@@ -58,17 +58,6 @@ __all__ = (
58
58
  'UserPromptNode',
59
59
  )
60
60
 
61
- _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
62
-
63
- # while waiting for https://github.com/pydantic/logfire/issues/745
64
- try:
65
- import logfire._internal.stack_info
66
- except ImportError:
67
- pass
68
- else:
69
- from pathlib import Path
70
-
71
- logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
72
61
 
73
62
  T = TypeVar('T')
74
63
  S = TypeVar('S')
@@ -123,6 +112,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
123
112
  The type of the result data, used to validate the result data, defaults to `str`.
124
113
  """
125
114
 
115
+ instrument: bool
116
+ """Automatically instrument with OpenTelemetry. Will use Logfire if it's configured."""
117
+
126
118
  _deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
127
119
  _result_tool_name: str = dataclasses.field(repr=False)
128
120
  _result_tool_description: str | None = dataclasses.field(repr=False)
@@ -155,6 +147,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
155
147
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
156
148
  defer_model_check: bool = False,
157
149
  end_strategy: EndStrategy = 'early',
150
+ instrument: bool = False,
158
151
  ):
159
152
  """Create an agent.
160
153
 
@@ -184,6 +177,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
184
177
  [override the model][pydantic_ai.Agent.override] for testing.
185
178
  end_strategy: Strategy for handling tool calls that are requested alongside a final result.
186
179
  See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
180
+ instrument: Automatically instrument with OpenTelemetry. Will use Logfire if it's configured.
187
181
  """
188
182
  if model is None or defer_model_check:
189
183
  self.model = model
@@ -194,6 +188,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
194
188
  self.name = name
195
189
  self.model_settings = model_settings
196
190
  self.result_type = result_type
191
+ self.instrument = instrument
197
192
 
198
193
  self._deps_type = deps_type
199
194
 
@@ -396,6 +391,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
396
391
  if infer_name and self.name is None:
397
392
  self._infer_name(inspect.currentframe())
398
393
  model_used = self._get_model(model)
394
+ del model
399
395
 
400
396
  deps = self._get_deps(deps)
401
397
  new_message_index = len(message_history) if message_history else 0
@@ -425,14 +421,20 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
425
421
  model_settings = merge_model_settings(self.model_settings, model_settings)
426
422
  usage_limits = usage_limits or _usage.UsageLimits()
427
423
 
428
- # Build the deps object for the graph
429
- run_span = _logfire.span(
430
- '{agent_name} run {prompt=}',
431
- prompt=user_prompt,
432
- agent=self,
433
- model_name=model_used.model_name if model_used else 'no-model',
434
- agent_name=self.name or 'agent',
424
+ if isinstance(model_used, InstrumentedModel):
425
+ tracer = model_used.tracer
426
+ else:
427
+ tracer = NoOpTracer()
428
+ agent_name = self.name or 'agent'
429
+ run_span = tracer.start_span(
430
+ 'agent run',
431
+ attributes={
432
+ 'model_name': model_used.model_name if model_used else 'no-model',
433
+ 'agent_name': agent_name,
434
+ 'logfire.msg': f'{agent_name} run',
435
+ },
435
436
  )
437
+
436
438
  graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
437
439
  user_deps=deps,
438
440
  prompt=user_prompt,
@@ -447,6 +449,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
447
449
  result_validators=result_validators,
448
450
  function_tools=self._function_tools,
449
451
  run_span=run_span,
452
+ tracer=tracer,
450
453
  )
451
454
  start_node = _agent_graph.UserPromptNode[AgentDepsT](
452
455
  user_prompt=user_prompt,
@@ -460,7 +463,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
460
463
  state=state,
461
464
  deps=graph_deps,
462
465
  infer_name=False,
463
- span=run_span,
466
+ span=use_span(run_span, end_on_exit=True),
464
467
  ) as graph_run:
465
468
  yield AgentRun(graph_run)
466
469
 
@@ -1116,7 +1119,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1116
1119
  else:
1117
1120
  raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
1118
1121
 
1119
- if not isinstance(model_, InstrumentedModel):
1122
+ if self.instrument and not isinstance(model_, InstrumentedModel):
1120
1123
  model_ = InstrumentedModel(model_)
1121
1124
 
1122
1125
  return model_
@@ -28,9 +28,11 @@ if TYPE_CHECKING:
28
28
 
29
29
 
30
30
  KnownModelName = Literal[
31
+ 'anthropic:claude-3-7-sonnet-latest',
31
32
  'anthropic:claude-3-5-haiku-latest',
32
33
  'anthropic:claude-3-5-sonnet-latest',
33
34
  'anthropic:claude-3-opus-latest',
35
+ 'claude-3-7-sonnet-latest',
34
36
  'claude-3-5-haiku-latest',
35
37
  'claude-3-5-sonnet-latest',
36
38
  'claude-3-opus-latest',
@@ -47,6 +49,8 @@ KnownModelName = Literal[
47
49
  'cohere:command-r-plus-04-2024',
48
50
  'cohere:command-r-plus-08-2024',
49
51
  'cohere:command-r7b-12-2024',
52
+ 'deepseek:deepseek-chat',
53
+ 'deepseek:deepseek-reasoner',
50
54
  'google-gla:gemini-1.0-pro',
51
55
  'google-gla:gemini-1.5-flash',
52
56
  'google-gla:gemini-1.5-flash-8b',
@@ -56,6 +60,7 @@ KnownModelName = Literal[
56
60
  'google-gla:gemini-exp-1206',
57
61
  'google-gla:gemini-2.0-flash',
58
62
  'google-gla:gemini-2.0-flash-lite-preview-02-05',
63
+ 'google-gla:gemini-2.0-pro-exp-02-05',
59
64
  'google-vertex:gemini-1.0-pro',
60
65
  'google-vertex:gemini-1.5-flash',
61
66
  'google-vertex:gemini-1.5-flash-8b',
@@ -65,6 +70,7 @@ KnownModelName = Literal[
65
70
  'google-vertex:gemini-exp-1206',
66
71
  'google-vertex:gemini-2.0-flash',
67
72
  'google-vertex:gemini-2.0-flash-lite-preview-02-05',
73
+ 'google-vertex:gemini-2.0-pro-exp-02-05',
68
74
  'gpt-3.5-turbo',
69
75
  'gpt-3.5-turbo-0125',
70
76
  'gpt-3.5-turbo-0301',
@@ -316,54 +322,52 @@ def infer_model(model: Model | KnownModelName) -> Model:
316
322
  from .test import TestModel
317
323
 
318
324
  return TestModel()
319
- elif model.startswith('cohere:'):
320
- from .cohere import CohereModel
321
325
 
322
- return CohereModel(model[7:])
323
- elif model.startswith('openai:'):
324
- from .openai import OpenAIModel
326
+ try:
327
+ provider, model_name = model.split(':')
328
+ except ValueError:
329
+ model_name = model
330
+ # TODO(Marcelo): We should deprecate this way.
331
+ if model_name.startswith(('gpt', 'o1', 'o3')):
332
+ provider = 'openai'
333
+ elif model_name.startswith('claude'):
334
+ provider = 'anthropic'
335
+ elif model_name.startswith('gemini'):
336
+ provider = 'google-gla'
337
+ else:
338
+ raise UserError(f'Unknown model: {model}')
339
+
340
+ if provider == 'vertexai':
341
+ provider = 'google-vertex'
342
+
343
+ if provider == 'cohere':
344
+ from .cohere import CohereModel
325
345
 
326
- return OpenAIModel(model[7:])
327
- elif model.startswith(('gpt', 'o1', 'o3')):
346
+ # TODO(Marcelo): Missing provider API.
347
+ return CohereModel(model_name)
348
+ elif provider in ('deepseek', 'openai'):
328
349
  from .openai import OpenAIModel
329
350
 
330
- return OpenAIModel(model)
331
- elif model.startswith('google-gla'):
332
- from .gemini import GeminiModel
333
-
334
- return GeminiModel(model[11:])
335
- # backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
336
- elif model.startswith('gemini'):
351
+ return OpenAIModel(model_name, provider=provider)
352
+ elif provider in ('google-gla', 'google-vertex'):
337
353
  from .gemini import GeminiModel
338
354
 
339
- # noinspection PyTypeChecker
340
- return GeminiModel(model)
341
- elif model.startswith('groq:'):
355
+ return GeminiModel(model_name, provider=provider)
356
+ elif provider == 'groq':
342
357
  from .groq import GroqModel
343
358
 
344
- return GroqModel(model[5:])
345
- elif model.startswith('google-vertex'):
346
- from .vertexai import VertexAIModel
347
-
348
- return VertexAIModel(model[14:])
349
- # backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
350
- elif model.startswith('vertexai:'):
351
- from .vertexai import VertexAIModel
352
-
353
- return VertexAIModel(model[9:])
354
- elif model.startswith('mistral:'):
359
+ # TODO(Marcelo): Missing provider API.
360
+ return GroqModel(model_name)
361
+ elif provider == 'mistral':
355
362
  from .mistral import MistralModel
356
363
 
357
- return MistralModel(model[8:])
358
- elif model.startswith('anthropic'):
359
- from .anthropic import AnthropicModel
360
-
361
- return AnthropicModel(model[10:])
362
- # backwards compatibility with old model names (ex, claude-3-5-sonnet-latest -> anthropic:claude-3-5-sonnet-latest)
363
- elif model.startswith('claude'):
364
+ # TODO(Marcelo): Missing provider API.
365
+ return MistralModel(model_name)
366
+ elif provider == 'anthropic':
364
367
  from .anthropic import AnthropicModel
365
368
 
366
- return AnthropicModel(model)
369
+ # TODO(Marcelo): Missing provider API.
370
+ return AnthropicModel(model_name)
367
371
  else:
368
372
  raise UserError(f'Unknown model: {model}')
369
373
 
@@ -42,6 +42,7 @@ from . import (
42
42
  try:
43
43
  from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
44
44
  from anthropic.types import (
45
+ ContentBlock,
45
46
  ImageBlockParam,
46
47
  Message as AnthropicMessage,
47
48
  MessageParam,
@@ -69,6 +70,7 @@ except ImportError as _import_error:
69
70
  ) from _import_error
70
71
 
71
72
  LatestAnthropicModelNames = Literal[
73
+ 'claude-3-7-sonnet-latest',
72
74
  'claude-3-5-haiku-latest',
73
75
  'claude-3-5-sonnet-latest',
74
76
  'claude-3-opus-latest',
@@ -423,7 +425,7 @@ class AnthropicStreamedResponse(StreamedResponse):
423
425
  _timestamp: datetime
424
426
 
425
427
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
426
- current_block: TextBlock | ToolUseBlock | None = None
428
+ current_block: ContentBlock | None = None
427
429
  current_json: str = ''
428
430
 
429
431
  async for event in self._response:
@@ -8,12 +8,14 @@ from contextlib import asynccontextmanager
8
8
  from copy import deepcopy
9
9
  from dataclasses import dataclass, field
10
10
  from datetime import datetime
11
- from typing import Annotated, Any, Literal, Protocol, Union, cast
11
+ from typing import Annotated, Any, Literal, Protocol, Union, cast, overload
12
12
  from uuid import uuid4
13
13
 
14
14
  import pydantic
15
15
  from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
16
- from typing_extensions import NotRequired, TypedDict, assert_never
16
+ from typing_extensions import NotRequired, TypedDict, assert_never, deprecated
17
+
18
+ from pydantic_ai.providers import Provider, infer_provider
17
19
 
18
20
  from .. import ModelHTTPError, UnexpectedModelBehavior, UserError, _utils, usage
19
21
  from ..messages import (
@@ -53,6 +55,7 @@ LatestGeminiModelNames = Literal[
53
55
  'gemini-exp-1206',
54
56
  'gemini-2.0-flash',
55
57
  'gemini-2.0-flash-lite-preview-02-05',
58
+ 'gemini-2.0-pro-exp-02-05',
56
59
  ]
57
60
  """Latest Gemini models."""
58
61
 
@@ -81,17 +84,39 @@ class GeminiModel(Model):
81
84
  Apart from `__init__`, all methods are private or match those of the base class.
82
85
  """
83
86
 
84
- http_client: AsyncHTTPClient = field(repr=False)
87
+ client: AsyncHTTPClient = field(repr=False)
85
88
 
86
89
  _model_name: GeminiModelName = field(repr=False)
90
+ _provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = field(repr=False)
87
91
  _auth: AuthProtocol | None = field(repr=False)
88
92
  _url: str | None = field(repr=False)
89
93
  _system: str | None = field(default='google-gla', repr=False)
90
94
 
95
+ @overload
96
+ def __init__(
97
+ self,
98
+ model_name: GeminiModelName,
99
+ *,
100
+ provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] = 'google-gla',
101
+ ) -> None: ...
102
+
103
+ @deprecated('Use the `provider` argument instead of the `api_key`, `http_client`, and `url_template` arguments.')
104
+ @overload
91
105
  def __init__(
92
106
  self,
93
107
  model_name: GeminiModelName,
94
108
  *,
109
+ provider: None = None,
110
+ api_key: str | None = None,
111
+ http_client: AsyncHTTPClient | None = None,
112
+ url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
113
+ ) -> None: ...
114
+
115
+ def __init__(
116
+ self,
117
+ model_name: GeminiModelName,
118
+ *,
119
+ provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = None,
95
120
  api_key: str | None = None,
96
121
  http_client: AsyncHTTPClient | None = None,
97
122
  url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
@@ -100,6 +125,7 @@ class GeminiModel(Model):
100
125
 
101
126
  Args:
102
127
  model_name: The name of the model to use.
128
+ provider: The provider to use for the model.
103
129
  api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable
104
130
  will be used if available.
105
131
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
@@ -108,14 +134,24 @@ class GeminiModel(Model):
108
134
  `model` is substituted with the model name, and `function` is added to the end of the URL.
109
135
  """
110
136
  self._model_name = model_name
111
- if api_key is None:
112
- if env_api_key := os.getenv('GEMINI_API_KEY'):
113
- api_key = env_api_key
137
+ self._provider = provider
138
+
139
+ if provider is not None:
140
+ if isinstance(provider, str):
141
+ self._system = provider
142
+ self.client = infer_provider(provider).client
114
143
  else:
115
- raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
116
- self.http_client = http_client or cached_async_http_client()
117
- self._auth = ApiKeyAuth(api_key)
118
- self._url = url_template.format(model=model_name)
144
+ self._system = provider.name
145
+ self.client = provider.client
146
+ else:
147
+ if api_key is None:
148
+ if env_api_key := os.getenv('GEMINI_API_KEY'):
149
+ api_key = env_api_key
150
+ else:
151
+ raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
152
+ self.client = http_client or cached_async_http_client()
153
+ self._auth = ApiKeyAuth(api_key)
154
+ self._url = url_template.format(model=model_name)
119
155
 
120
156
  @property
121
157
  def auth(self) -> AuthProtocol:
@@ -216,17 +252,19 @@ class GeminiModel(Model):
216
252
  if generation_config:
217
253
  request_data['generation_config'] = generation_config
218
254
 
219
- url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
220
-
221
255
  headers = {
222
256
  'Content-Type': 'application/json',
223
257
  'User-Agent': get_user_agent(),
224
- **await self.auth.headers(),
225
258
  }
259
+ if self._provider is None: # pragma: no cover
260
+ url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
261
+ headers.update(await self.auth.headers())
262
+ else:
263
+ url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
226
264
 
227
265
  request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
228
266
 
229
- async with self.http_client.stream(
267
+ async with self.client.stream(
230
268
  'POST',
231
269
  url,
232
270
  content=request_json,