pydantic-ai-slim 0.1.11__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/PKG-INFO +3 -3
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_agent_graph.py +6 -8
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_parts_manager.py +3 -1
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/agent.py +21 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/messages.py +7 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/__init__.py +6 -7
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/_json_schema.py +8 -2
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/anthropic.py +23 -26
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/bedrock.py +36 -12
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/cohere.py +5 -3
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/fallback.py +3 -4
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/function.py +9 -4
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/gemini.py +13 -5
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/groq.py +5 -3
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/instrumented.py +8 -9
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/mistral.py +5 -3
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/openai.py +9 -6
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/test.py +4 -3
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/wrapper.py +1 -2
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/usage.py +5 -3
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/.gitignore +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/README.md +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/__main__.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_cli.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_output.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/format_prompt.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/mcp.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/__init__.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/anthropic.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/azure.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/bedrock.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/cohere.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/deepseek.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/google_gla.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/google_vertex.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/groq.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/mistral.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/openai.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/tools.py +0 -0
- {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pyproject.toml +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.2.0
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Requires-Dist: typing-inspection>=0.4.0
|
|
35
35
|
Provides-Extra: anthropic
|
|
@@ -45,7 +45,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
|
|
|
45
45
|
Provides-Extra: duckduckgo
|
|
46
46
|
Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
|
|
47
47
|
Provides-Extra: evals
|
|
48
|
-
Requires-Dist: pydantic-evals==0.
|
|
48
|
+
Requires-Dist: pydantic-evals==0.2.0; extra == 'evals'
|
|
49
49
|
Provides-Extra: groq
|
|
50
50
|
Requires-Dist: groq>=0.15.0; extra == 'groq'
|
|
51
51
|
Provides-Extra: logfire
|
|
@@ -301,16 +301,15 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
301
301
|
ctx.state.message_history, model_settings, model_request_parameters
|
|
302
302
|
) as streamed_response:
|
|
303
303
|
self._did_stream = True
|
|
304
|
-
ctx.state.usage.
|
|
304
|
+
ctx.state.usage.requests += 1
|
|
305
305
|
yield streamed_response
|
|
306
306
|
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
|
|
307
307
|
# otherwise usage won't be properly counted:
|
|
308
308
|
async for _ in streamed_response:
|
|
309
309
|
pass
|
|
310
310
|
model_response = streamed_response.get()
|
|
311
|
-
request_usage = streamed_response.usage()
|
|
312
311
|
|
|
313
|
-
self._finish_handling(ctx, model_response
|
|
312
|
+
self._finish_handling(ctx, model_response)
|
|
314
313
|
assert self._result is not None # this should be set by the previous line
|
|
315
314
|
|
|
316
315
|
async def _make_request(
|
|
@@ -321,12 +320,12 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
321
320
|
|
|
322
321
|
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
323
322
|
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
|
|
324
|
-
model_response
|
|
323
|
+
model_response = await ctx.deps.model.request(
|
|
325
324
|
ctx.state.message_history, model_settings, model_request_parameters
|
|
326
325
|
)
|
|
327
|
-
ctx.state.usage.incr(_usage.Usage()
|
|
326
|
+
ctx.state.usage.incr(_usage.Usage())
|
|
328
327
|
|
|
329
|
-
return self._finish_handling(ctx, model_response
|
|
328
|
+
return self._finish_handling(ctx, model_response)
|
|
330
329
|
|
|
331
330
|
async def _prepare_request(
|
|
332
331
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
@@ -348,10 +347,9 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
348
347
|
self,
|
|
349
348
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
350
349
|
response: _messages.ModelResponse,
|
|
351
|
-
usage: _usage.Usage,
|
|
352
350
|
) -> CallToolsNode[DepsT, NodeRunEndT]:
|
|
353
351
|
# Update usage
|
|
354
|
-
ctx.state.usage.incr(usage
|
|
352
|
+
ctx.state.usage.incr(response.usage)
|
|
355
353
|
if ctx.deps.usage_limits:
|
|
356
354
|
ctx.deps.usage_limits.check_tokens(ctx.state.usage)
|
|
357
355
|
|
|
@@ -14,7 +14,7 @@ event-emitting logic.
|
|
|
14
14
|
from __future__ import annotations as _annotations
|
|
15
15
|
|
|
16
16
|
from collections.abc import Hashable
|
|
17
|
-
from dataclasses import dataclass, field
|
|
17
|
+
from dataclasses import dataclass, field, replace
|
|
18
18
|
from typing import Any, Union
|
|
19
19
|
|
|
20
20
|
from pydantic_ai.exceptions import UnexpectedModelBehavior
|
|
@@ -198,6 +198,8 @@ class ModelResponsePartsManager:
|
|
|
198
198
|
return PartStartEvent(index=part_index, part=updated_part)
|
|
199
199
|
else:
|
|
200
200
|
# We updated an existing part, so emit a PartDeltaEvent
|
|
201
|
+
if updated_part.tool_call_id and not delta.tool_call_id:
|
|
202
|
+
delta = replace(delta, tool_call_id=updated_part.tool_call_id)
|
|
201
203
|
return PartDeltaEvent(index=part_index, delta=delta)
|
|
202
204
|
|
|
203
205
|
def handle_tool_call_part(
|
|
@@ -551,6 +551,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
551
551
|
CallToolsNode(
|
|
552
552
|
model_response=ModelResponse(
|
|
553
553
|
parts=[TextPart(content='Paris', part_kind='text')],
|
|
554
|
+
usage=Usage(
|
|
555
|
+
requests=1,
|
|
556
|
+
request_tokens=56,
|
|
557
|
+
response_tokens=1,
|
|
558
|
+
total_tokens=57,
|
|
559
|
+
details=None,
|
|
560
|
+
),
|
|
554
561
|
model_name='gpt-4o',
|
|
555
562
|
timestamp=datetime.datetime(...),
|
|
556
563
|
kind='response',
|
|
@@ -1715,6 +1722,13 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
|
1715
1722
|
CallToolsNode(
|
|
1716
1723
|
model_response=ModelResponse(
|
|
1717
1724
|
parts=[TextPart(content='Paris', part_kind='text')],
|
|
1725
|
+
usage=Usage(
|
|
1726
|
+
requests=1,
|
|
1727
|
+
request_tokens=56,
|
|
1728
|
+
response_tokens=1,
|
|
1729
|
+
total_tokens=57,
|
|
1730
|
+
details=None,
|
|
1731
|
+
),
|
|
1718
1732
|
model_name='gpt-4o',
|
|
1719
1733
|
timestamp=datetime.datetime(...),
|
|
1720
1734
|
kind='response',
|
|
@@ -1853,6 +1867,13 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
|
1853
1867
|
CallToolsNode(
|
|
1854
1868
|
model_response=ModelResponse(
|
|
1855
1869
|
parts=[TextPart(content='Paris', part_kind='text')],
|
|
1870
|
+
usage=Usage(
|
|
1871
|
+
requests=1,
|
|
1872
|
+
request_tokens=56,
|
|
1873
|
+
response_tokens=1,
|
|
1874
|
+
total_tokens=57,
|
|
1875
|
+
details=None,
|
|
1876
|
+
),
|
|
1856
1877
|
model_name='gpt-4o',
|
|
1857
1878
|
timestamp=datetime.datetime(...),
|
|
1858
1879
|
kind='response',
|
|
@@ -14,6 +14,7 @@ from typing_extensions import TypeAlias
|
|
|
14
14
|
|
|
15
15
|
from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
|
|
16
16
|
from .exceptions import UnexpectedModelBehavior
|
|
17
|
+
from .usage import Usage
|
|
17
18
|
|
|
18
19
|
AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg']
|
|
19
20
|
ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
|
|
@@ -554,6 +555,12 @@ class ModelResponse:
|
|
|
554
555
|
parts: list[ModelResponsePart]
|
|
555
556
|
"""The parts of the model message."""
|
|
556
557
|
|
|
558
|
+
usage: Usage = field(default_factory=Usage)
|
|
559
|
+
"""Usage information for the request.
|
|
560
|
+
|
|
561
|
+
This has a default to make tests easier, and to support loading old messages where usage will be missing.
|
|
562
|
+
"""
|
|
563
|
+
|
|
557
564
|
model_name: str | None = None
|
|
558
565
|
"""The name of the model that generated the response."""
|
|
559
566
|
|
|
@@ -12,7 +12,6 @@ from contextlib import asynccontextmanager, contextmanager
|
|
|
12
12
|
from dataclasses import dataclass, field
|
|
13
13
|
from datetime import datetime
|
|
14
14
|
from functools import cache
|
|
15
|
-
from typing import TYPE_CHECKING
|
|
16
15
|
|
|
17
16
|
import httpx
|
|
18
17
|
from typing_extensions import Literal, TypeAliasType
|
|
@@ -21,12 +20,9 @@ from .._parts_manager import ModelResponsePartsManager
|
|
|
21
20
|
from ..exceptions import UserError
|
|
22
21
|
from ..messages import ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent
|
|
23
22
|
from ..settings import ModelSettings
|
|
23
|
+
from ..tools import ToolDefinition
|
|
24
24
|
from ..usage import Usage
|
|
25
25
|
|
|
26
|
-
if TYPE_CHECKING:
|
|
27
|
-
from ..tools import ToolDefinition
|
|
28
|
-
|
|
29
|
-
|
|
30
26
|
KnownModelName = TypeAliasType(
|
|
31
27
|
'KnownModelName',
|
|
32
28
|
Literal[
|
|
@@ -278,7 +274,7 @@ class Model(ABC):
|
|
|
278
274
|
messages: list[ModelMessage],
|
|
279
275
|
model_settings: ModelSettings | None,
|
|
280
276
|
model_request_parameters: ModelRequestParameters,
|
|
281
|
-
) ->
|
|
277
|
+
) -> ModelResponse:
|
|
282
278
|
"""Make a request to the model."""
|
|
283
279
|
raise NotImplementedError()
|
|
284
280
|
|
|
@@ -365,7 +361,10 @@ class StreamedResponse(ABC):
|
|
|
365
361
|
def get(self) -> ModelResponse:
|
|
366
362
|
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
|
|
367
363
|
return ModelResponse(
|
|
368
|
-
parts=self._parts_manager.get_parts(),
|
|
364
|
+
parts=self._parts_manager.get_parts(),
|
|
365
|
+
model_name=self.model_name,
|
|
366
|
+
timestamp=self.timestamp,
|
|
367
|
+
usage=self.usage(),
|
|
369
368
|
)
|
|
370
369
|
|
|
371
370
|
def usage(self) -> Usage:
|
|
@@ -25,7 +25,7 @@ class WalkJsonSchema(ABC):
|
|
|
25
25
|
self.simplify_nullable_unions = simplify_nullable_unions
|
|
26
26
|
|
|
27
27
|
self.defs: dict[str, JsonSchema] = self.schema.get('$defs', {})
|
|
28
|
-
self.refs_stack =
|
|
28
|
+
self.refs_stack: list[str] = []
|
|
29
29
|
self.recursive_refs = set[str]()
|
|
30
30
|
|
|
31
31
|
@abstractmethod
|
|
@@ -62,13 +62,16 @@ class WalkJsonSchema(ABC):
|
|
|
62
62
|
return handled
|
|
63
63
|
|
|
64
64
|
def _handle(self, schema: JsonSchema) -> JsonSchema:
|
|
65
|
+
nested_refs = 0
|
|
65
66
|
if self.prefer_inlined_defs:
|
|
66
67
|
while ref := schema.get('$ref'):
|
|
67
68
|
key = re.sub(r'^#/\$defs/', '', ref)
|
|
68
69
|
if key in self.refs_stack:
|
|
69
70
|
self.recursive_refs.add(key)
|
|
70
71
|
break # recursive ref can't be unpacked
|
|
71
|
-
self.refs_stack
|
|
72
|
+
self.refs_stack.append(key)
|
|
73
|
+
nested_refs += 1
|
|
74
|
+
|
|
72
75
|
def_schema = self.defs.get(key)
|
|
73
76
|
if def_schema is None: # pragma: no cover
|
|
74
77
|
raise UserError(f'Could not find $ref definition for {key}')
|
|
@@ -87,6 +90,9 @@ class WalkJsonSchema(ABC):
|
|
|
87
90
|
# Apply the base transform
|
|
88
91
|
schema = self.transform(schema)
|
|
89
92
|
|
|
93
|
+
if nested_refs > 0:
|
|
94
|
+
self.refs_stack = self.refs_stack[:-nested_refs]
|
|
95
|
+
|
|
90
96
|
return schema
|
|
91
97
|
|
|
92
98
|
def _handle_object(self, schema: JsonSchema) -> JsonSchema:
|
|
@@ -145,12 +145,14 @@ class AnthropicModel(Model):
|
|
|
145
145
|
messages: list[ModelMessage],
|
|
146
146
|
model_settings: ModelSettings | None,
|
|
147
147
|
model_request_parameters: ModelRequestParameters,
|
|
148
|
-
) ->
|
|
148
|
+
) -> ModelResponse:
|
|
149
149
|
check_allow_model_requests()
|
|
150
150
|
response = await self._messages_create(
|
|
151
151
|
messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
|
|
152
152
|
)
|
|
153
|
-
|
|
153
|
+
model_response = self._process_response(response)
|
|
154
|
+
model_response.usage.requests = 1
|
|
155
|
+
return model_response
|
|
154
156
|
|
|
155
157
|
@asynccontextmanager
|
|
156
158
|
async def request_stream(
|
|
@@ -260,7 +262,7 @@ class AnthropicModel(Model):
|
|
|
260
262
|
)
|
|
261
263
|
)
|
|
262
264
|
|
|
263
|
-
return ModelResponse(items, model_name=response.model)
|
|
265
|
+
return ModelResponse(items, usage=_map_usage(response), model_name=response.model)
|
|
264
266
|
|
|
265
267
|
async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
|
|
266
268
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
@@ -391,36 +393,31 @@ class AnthropicModel(Model):
|
|
|
391
393
|
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
|
|
392
394
|
if isinstance(message, AnthropicMessage):
|
|
393
395
|
response_usage = message.usage
|
|
396
|
+
elif isinstance(message, RawMessageStartEvent):
|
|
397
|
+
response_usage = message.message.usage
|
|
398
|
+
elif isinstance(message, RawMessageDeltaEvent):
|
|
399
|
+
response_usage = message.usage
|
|
394
400
|
else:
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
# No usage information provided in:
|
|
401
|
-
# - RawMessageStopEvent
|
|
402
|
-
# - RawContentBlockStartEvent
|
|
403
|
-
# - RawContentBlockDeltaEvent
|
|
404
|
-
# - RawContentBlockStopEvent
|
|
405
|
-
response_usage = None
|
|
406
|
-
|
|
407
|
-
if response_usage is None:
|
|
401
|
+
# No usage information provided in:
|
|
402
|
+
# - RawMessageStopEvent
|
|
403
|
+
# - RawContentBlockStartEvent
|
|
404
|
+
# - RawContentBlockDeltaEvent
|
|
405
|
+
# - RawContentBlockStopEvent
|
|
408
406
|
return usage.Usage()
|
|
409
407
|
|
|
410
|
-
# Store all integer-typed usage values in the details
|
|
411
|
-
|
|
412
|
-
details: dict[str, int] = {
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
details[key] = value
|
|
408
|
+
# Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
|
|
409
|
+
# `response_tokens`
|
|
410
|
+
details: dict[str, int] = {
|
|
411
|
+
key: value for key, value in response_usage.model_dump().items() if isinstance(value, int)
|
|
412
|
+
}
|
|
416
413
|
|
|
417
|
-
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence
|
|
414
|
+
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence using `get`
|
|
418
415
|
# Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens
|
|
419
416
|
# This approach maintains request_tokens as the count of all input tokens, with cached counts as details
|
|
420
417
|
request_tokens = (
|
|
421
|
-
|
|
422
|
-
+ (
|
|
423
|
-
+ (
|
|
418
|
+
details.get('input_tokens', 0)
|
|
419
|
+
+ details.get('cache_creation_input_tokens', 0)
|
|
420
|
+
+ details.get('cache_read_input_tokens', 0)
|
|
424
421
|
)
|
|
425
422
|
|
|
426
423
|
return usage.Usage(
|
|
@@ -232,10 +232,12 @@ class BedrockConverseModel(Model):
|
|
|
232
232
|
messages: list[ModelMessage],
|
|
233
233
|
model_settings: ModelSettings | None,
|
|
234
234
|
model_request_parameters: ModelRequestParameters,
|
|
235
|
-
) ->
|
|
235
|
+
) -> ModelResponse:
|
|
236
236
|
settings = cast(BedrockModelSettings, model_settings or {})
|
|
237
237
|
response = await self._messages_create(messages, False, settings, model_request_parameters)
|
|
238
|
-
|
|
238
|
+
model_response = await self._process_response(response)
|
|
239
|
+
model_response.usage.requests = 1
|
|
240
|
+
return model_response
|
|
239
241
|
|
|
240
242
|
@asynccontextmanager
|
|
241
243
|
async def request_stream(
|
|
@@ -248,7 +250,7 @@ class BedrockConverseModel(Model):
|
|
|
248
250
|
response = await self._messages_create(messages, True, settings, model_request_parameters)
|
|
249
251
|
yield BedrockStreamedResponse(_model_name=self.model_name, _event_stream=response)
|
|
250
252
|
|
|
251
|
-
async def _process_response(self, response: ConverseResponseTypeDef) ->
|
|
253
|
+
async def _process_response(self, response: ConverseResponseTypeDef) -> ModelResponse:
|
|
252
254
|
items: list[ModelResponsePart] = []
|
|
253
255
|
if message := response['output'].get('message'):
|
|
254
256
|
for item in message['content']:
|
|
@@ -269,7 +271,7 @@ class BedrockConverseModel(Model):
|
|
|
269
271
|
response_tokens=response['usage']['outputTokens'],
|
|
270
272
|
total_tokens=response['usage']['totalTokens'],
|
|
271
273
|
)
|
|
272
|
-
return ModelResponse(items, model_name=self.model_name)
|
|
274
|
+
return ModelResponse(items, usage=u, model_name=self.model_name)
|
|
273
275
|
|
|
274
276
|
@overload
|
|
275
277
|
async def _messages_create(
|
|
@@ -367,13 +369,16 @@ class BedrockConverseModel(Model):
|
|
|
367
369
|
async def _map_messages(
|
|
368
370
|
self, messages: list[ModelMessage]
|
|
369
371
|
) -> tuple[list[SystemContentBlockTypeDef], list[MessageUnionTypeDef]]:
|
|
370
|
-
"""
|
|
372
|
+
"""Maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`.
|
|
373
|
+
|
|
374
|
+
Groups consecutive ToolReturnPart objects into a single user message as required by Bedrock Claude/Nova models.
|
|
375
|
+
"""
|
|
371
376
|
system_prompt: list[SystemContentBlockTypeDef] = []
|
|
372
377
|
bedrock_messages: list[MessageUnionTypeDef] = []
|
|
373
378
|
document_count: Iterator[int] = count(1)
|
|
374
|
-
for
|
|
375
|
-
if isinstance(
|
|
376
|
-
for part in
|
|
379
|
+
for message in messages:
|
|
380
|
+
if isinstance(message, ModelRequest):
|
|
381
|
+
for part in message.parts:
|
|
377
382
|
if isinstance(part, SystemPromptPart):
|
|
378
383
|
system_prompt.append({'text': part.content})
|
|
379
384
|
elif isinstance(part, UserPromptPart):
|
|
@@ -414,9 +419,9 @@ class BedrockConverseModel(Model):
|
|
|
414
419
|
],
|
|
415
420
|
}
|
|
416
421
|
)
|
|
417
|
-
elif isinstance(
|
|
422
|
+
elif isinstance(message, ModelResponse):
|
|
418
423
|
content: list[ContentBlockOutputTypeDef] = []
|
|
419
|
-
for item in
|
|
424
|
+
for item in message.parts:
|
|
420
425
|
if isinstance(item, TextPart):
|
|
421
426
|
content.append({'text': item.content})
|
|
422
427
|
else:
|
|
@@ -424,12 +429,31 @@ class BedrockConverseModel(Model):
|
|
|
424
429
|
content.append(self._map_tool_call(item))
|
|
425
430
|
bedrock_messages.append({'role': 'assistant', 'content': content})
|
|
426
431
|
else:
|
|
427
|
-
assert_never(
|
|
432
|
+
assert_never(message)
|
|
433
|
+
|
|
434
|
+
# Merge together sequential user messages.
|
|
435
|
+
processed_messages: list[MessageUnionTypeDef] = []
|
|
436
|
+
last_message: dict[str, Any] | None = None
|
|
437
|
+
for current_message in bedrock_messages:
|
|
438
|
+
if (
|
|
439
|
+
last_message is not None
|
|
440
|
+
and current_message['role'] == last_message['role']
|
|
441
|
+
and current_message['role'] == 'user'
|
|
442
|
+
):
|
|
443
|
+
# Add the new user content onto the existing user message.
|
|
444
|
+
last_content = list(last_message['content'])
|
|
445
|
+
last_content.extend(current_message['content'])
|
|
446
|
+
last_message['content'] = last_content
|
|
447
|
+
continue
|
|
448
|
+
|
|
449
|
+
# Add the entire message to the list of messages.
|
|
450
|
+
processed_messages.append(current_message)
|
|
451
|
+
last_message = cast(dict[str, Any], current_message)
|
|
428
452
|
|
|
429
453
|
if instructions := self._get_instructions(messages):
|
|
430
454
|
system_prompt.insert(0, {'text': instructions})
|
|
431
455
|
|
|
432
|
-
return system_prompt,
|
|
456
|
+
return system_prompt, processed_messages
|
|
433
457
|
|
|
434
458
|
@staticmethod
|
|
435
459
|
async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) -> list[MessageUnionTypeDef]:
|
|
@@ -133,10 +133,12 @@ class CohereModel(Model):
|
|
|
133
133
|
messages: list[ModelMessage],
|
|
134
134
|
model_settings: ModelSettings | None,
|
|
135
135
|
model_request_parameters: ModelRequestParameters,
|
|
136
|
-
) ->
|
|
136
|
+
) -> ModelResponse:
|
|
137
137
|
check_allow_model_requests()
|
|
138
138
|
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
|
|
139
|
-
|
|
139
|
+
model_response = self._process_response(response)
|
|
140
|
+
model_response.usage.requests = 1
|
|
141
|
+
return model_response
|
|
140
142
|
|
|
141
143
|
@property
|
|
142
144
|
def model_name(self) -> CohereModelName:
|
|
@@ -191,7 +193,7 @@ class CohereModel(Model):
|
|
|
191
193
|
tool_call_id=c.id or _generate_tool_call_id(),
|
|
192
194
|
)
|
|
193
195
|
)
|
|
194
|
-
return ModelResponse(parts=parts, model_name=self._model_name)
|
|
196
|
+
return ModelResponse(parts=parts, usage=_map_usage(response), model_name=self._model_name)
|
|
195
197
|
|
|
196
198
|
def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]:
|
|
197
199
|
"""Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
|
|
@@ -15,7 +15,6 @@ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, i
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
16
|
from ..messages import ModelMessage, ModelResponse
|
|
17
17
|
from ..settings import ModelSettings
|
|
18
|
-
from ..usage import Usage
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
@dataclass(init=False)
|
|
@@ -55,7 +54,7 @@ class FallbackModel(Model):
|
|
|
55
54
|
messages: list[ModelMessage],
|
|
56
55
|
model_settings: ModelSettings | None,
|
|
57
56
|
model_request_parameters: ModelRequestParameters,
|
|
58
|
-
) ->
|
|
57
|
+
) -> ModelResponse:
|
|
59
58
|
"""Try each model in sequence until one succeeds.
|
|
60
59
|
|
|
61
60
|
In case of failure, raise a FallbackExceptionGroup with all exceptions.
|
|
@@ -65,7 +64,7 @@ class FallbackModel(Model):
|
|
|
65
64
|
for model in self.models:
|
|
66
65
|
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
|
|
67
66
|
try:
|
|
68
|
-
response
|
|
67
|
+
response = await model.request(messages, model_settings, customized_model_request_parameters)
|
|
69
68
|
except Exception as exc:
|
|
70
69
|
if self._fallback_on(exc):
|
|
71
70
|
exceptions.append(exc)
|
|
@@ -73,7 +72,7 @@ class FallbackModel(Model):
|
|
|
73
72
|
raise exc
|
|
74
73
|
|
|
75
74
|
self._set_span_attributes(model)
|
|
76
|
-
return response
|
|
75
|
+
return response
|
|
77
76
|
|
|
78
77
|
raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
|
|
79
78
|
|
|
@@ -88,7 +88,7 @@ class FunctionModel(Model):
|
|
|
88
88
|
messages: list[ModelMessage],
|
|
89
89
|
model_settings: ModelSettings | None,
|
|
90
90
|
model_request_parameters: ModelRequestParameters,
|
|
91
|
-
) ->
|
|
91
|
+
) -> ModelResponse:
|
|
92
92
|
agent_info = AgentInfo(
|
|
93
93
|
model_request_parameters.function_tools,
|
|
94
94
|
model_request_parameters.allow_text_output,
|
|
@@ -105,8 +105,11 @@ class FunctionModel(Model):
|
|
|
105
105
|
assert isinstance(response_, ModelResponse), response_
|
|
106
106
|
response = response_
|
|
107
107
|
response.model_name = self._model_name
|
|
108
|
-
#
|
|
109
|
-
|
|
108
|
+
# Add usage data if not already present
|
|
109
|
+
if not response.usage.has_values():
|
|
110
|
+
response.usage = _estimate_usage(chain(messages, [response]))
|
|
111
|
+
response.usage.requests = 1
|
|
112
|
+
return response
|
|
110
113
|
|
|
111
114
|
@asynccontextmanager
|
|
112
115
|
async def request_stream(
|
|
@@ -273,7 +276,9 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
|
|
|
273
276
|
else:
|
|
274
277
|
assert_never(message)
|
|
275
278
|
return usage.Usage(
|
|
276
|
-
request_tokens=request_tokens,
|
|
279
|
+
request_tokens=request_tokens,
|
|
280
|
+
response_tokens=response_tokens,
|
|
281
|
+
total_tokens=request_tokens + response_tokens,
|
|
277
282
|
)
|
|
278
283
|
|
|
279
284
|
|
|
@@ -145,14 +145,14 @@ class GeminiModel(Model):
|
|
|
145
145
|
messages: list[ModelMessage],
|
|
146
146
|
model_settings: ModelSettings | None,
|
|
147
147
|
model_request_parameters: ModelRequestParameters,
|
|
148
|
-
) ->
|
|
148
|
+
) -> ModelResponse:
|
|
149
149
|
check_allow_model_requests()
|
|
150
150
|
async with self._make_request(
|
|
151
151
|
messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
|
|
152
152
|
) as http_response:
|
|
153
153
|
data = await http_response.aread()
|
|
154
154
|
response = _gemini_response_ta.validate_json(data)
|
|
155
|
-
return self._process_response(response)
|
|
155
|
+
return self._process_response(response)
|
|
156
156
|
|
|
157
157
|
@asynccontextmanager
|
|
158
158
|
async def request_stream(
|
|
@@ -269,7 +269,9 @@ class GeminiModel(Model):
|
|
|
269
269
|
else:
|
|
270
270
|
raise UnexpectedModelBehavior('Content field missing from Gemini response', str(response))
|
|
271
271
|
parts = response['candidates'][0]['content']['parts']
|
|
272
|
-
|
|
272
|
+
usage = _metadata_as_usage(response)
|
|
273
|
+
usage.requests = 1
|
|
274
|
+
return _process_response_from_parts(parts, response.get('model_version', self._model_name), usage)
|
|
273
275
|
|
|
274
276
|
async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
|
|
275
277
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -591,7 +593,7 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
|
|
|
591
593
|
|
|
592
594
|
|
|
593
595
|
def _process_response_from_parts(
|
|
594
|
-
parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName,
|
|
596
|
+
parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, usage: usage.Usage
|
|
595
597
|
) -> ModelResponse:
|
|
596
598
|
items: list[ModelResponsePart] = []
|
|
597
599
|
for part in parts:
|
|
@@ -603,7 +605,7 @@ def _process_response_from_parts(
|
|
|
603
605
|
raise UnexpectedModelBehavior(
|
|
604
606
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
605
607
|
)
|
|
606
|
-
return ModelResponse(parts=items,
|
|
608
|
+
return ModelResponse(parts=items, usage=usage, model_name=model_name)
|
|
607
609
|
|
|
608
610
|
|
|
609
611
|
class _GeminiFunctionCall(TypedDict):
|
|
@@ -831,6 +833,12 @@ class _GeminiJsonSchema(WalkJsonSchema):
|
|
|
831
833
|
schema.pop('exclusiveMaximum', None)
|
|
832
834
|
schema.pop('exclusiveMinimum', None)
|
|
833
835
|
|
|
836
|
+
# Gemini only supports string enums, so we need to convert any enum values to strings.
|
|
837
|
+
# Pydantic will take care of transforming the transformed string values to the correct type.
|
|
838
|
+
if enum := schema.get('enum'):
|
|
839
|
+
schema['type'] = 'string'
|
|
840
|
+
schema['enum'] = [str(val) for val in enum]
|
|
841
|
+
|
|
834
842
|
type_ = schema.get('type')
|
|
835
843
|
if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
|
|
836
844
|
# This gets hit when we have a discriminated union
|
|
@@ -130,12 +130,14 @@ class GroqModel(Model):
|
|
|
130
130
|
messages: list[ModelMessage],
|
|
131
131
|
model_settings: ModelSettings | None,
|
|
132
132
|
model_request_parameters: ModelRequestParameters,
|
|
133
|
-
) ->
|
|
133
|
+
) -> ModelResponse:
|
|
134
134
|
check_allow_model_requests()
|
|
135
135
|
response = await self._completions_create(
|
|
136
136
|
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
137
137
|
)
|
|
138
|
-
|
|
138
|
+
model_response = self._process_response(response)
|
|
139
|
+
model_response.usage.requests = 1
|
|
140
|
+
return model_response
|
|
139
141
|
|
|
140
142
|
@asynccontextmanager
|
|
141
143
|
async def request_stream(
|
|
@@ -237,7 +239,7 @@ class GroqModel(Model):
|
|
|
237
239
|
if choice.message.tool_calls is not None:
|
|
238
240
|
for c in choice.message.tool_calls:
|
|
239
241
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
240
|
-
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
|
|
242
|
+
return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
|
|
241
243
|
|
|
242
244
|
async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
243
245
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -23,7 +23,6 @@ from ..messages import (
|
|
|
23
23
|
ModelResponse,
|
|
24
24
|
)
|
|
25
25
|
from ..settings import ModelSettings
|
|
26
|
-
from ..usage import Usage
|
|
27
26
|
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
|
|
28
27
|
from .wrapper import WrapperModel
|
|
29
28
|
|
|
@@ -122,11 +121,11 @@ class InstrumentedModel(WrapperModel):
|
|
|
122
121
|
messages: list[ModelMessage],
|
|
123
122
|
model_settings: ModelSettings | None,
|
|
124
123
|
model_request_parameters: ModelRequestParameters,
|
|
125
|
-
) ->
|
|
124
|
+
) -> ModelResponse:
|
|
126
125
|
with self._instrument(messages, model_settings, model_request_parameters) as finish:
|
|
127
|
-
response
|
|
128
|
-
finish(response
|
|
129
|
-
return response
|
|
126
|
+
response = await super().request(messages, model_settings, model_request_parameters)
|
|
127
|
+
finish(response)
|
|
128
|
+
return response
|
|
130
129
|
|
|
131
130
|
@asynccontextmanager
|
|
132
131
|
async def request_stream(
|
|
@@ -144,7 +143,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
144
143
|
yield response_stream
|
|
145
144
|
finally:
|
|
146
145
|
if response_stream:
|
|
147
|
-
finish(response_stream.get()
|
|
146
|
+
finish(response_stream.get())
|
|
148
147
|
|
|
149
148
|
@contextmanager
|
|
150
149
|
def _instrument(
|
|
@@ -152,7 +151,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
152
151
|
messages: list[ModelMessage],
|
|
153
152
|
model_settings: ModelSettings | None,
|
|
154
153
|
model_request_parameters: ModelRequestParameters,
|
|
155
|
-
) -> Iterator[Callable[[ModelResponse
|
|
154
|
+
) -> Iterator[Callable[[ModelResponse], None]]:
|
|
156
155
|
operation = 'chat'
|
|
157
156
|
span_name = f'{operation} {self.model_name}'
|
|
158
157
|
# TODO Missing attributes:
|
|
@@ -177,7 +176,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
177
176
|
|
|
178
177
|
with self.settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
|
|
179
178
|
|
|
180
|
-
def finish(response: ModelResponse
|
|
179
|
+
def finish(response: ModelResponse):
|
|
181
180
|
if not span.is_recording():
|
|
182
181
|
return
|
|
183
182
|
|
|
@@ -193,7 +192,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
193
192
|
},
|
|
194
193
|
)
|
|
195
194
|
)
|
|
196
|
-
new_attributes: dict[str, AttributeValue] = usage.opentelemetry_attributes() #
|
|
195
|
+
new_attributes: dict[str, AttributeValue] = response.usage.opentelemetry_attributes() # pyright: ignore[reportAssignmentType]
|
|
197
196
|
attributes.update(getattr(span, 'attributes', {}))
|
|
198
197
|
request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
|
|
199
198
|
new_attributes['gen_ai.response.model'] = response.model_name or request_model
|
|
@@ -147,13 +147,15 @@ class MistralModel(Model):
|
|
|
147
147
|
messages: list[ModelMessage],
|
|
148
148
|
model_settings: ModelSettings | None,
|
|
149
149
|
model_request_parameters: ModelRequestParameters,
|
|
150
|
-
) ->
|
|
150
|
+
) -> ModelResponse:
|
|
151
151
|
"""Make a non-streaming request to the model from Pydantic AI call."""
|
|
152
152
|
check_allow_model_requests()
|
|
153
153
|
response = await self._completions_create(
|
|
154
154
|
messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
|
|
155
155
|
)
|
|
156
|
-
|
|
156
|
+
model_response = self._process_response(response)
|
|
157
|
+
model_response.usage.requests = 1
|
|
158
|
+
return model_response
|
|
157
159
|
|
|
158
160
|
@asynccontextmanager
|
|
159
161
|
async def request_stream(
|
|
@@ -323,7 +325,7 @@ class MistralModel(Model):
|
|
|
323
325
|
tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call)
|
|
324
326
|
parts.append(tool)
|
|
325
327
|
|
|
326
|
-
return ModelResponse(parts, model_name=response.model, timestamp=timestamp)
|
|
328
|
+
return ModelResponse(parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
|
|
327
329
|
|
|
328
330
|
async def _process_streamed_response(
|
|
329
331
|
self,
|
|
@@ -192,12 +192,14 @@ class OpenAIModel(Model):
|
|
|
192
192
|
messages: list[ModelMessage],
|
|
193
193
|
model_settings: ModelSettings | None,
|
|
194
194
|
model_request_parameters: ModelRequestParameters,
|
|
195
|
-
) ->
|
|
195
|
+
) -> ModelResponse:
|
|
196
196
|
check_allow_model_requests()
|
|
197
197
|
response = await self._completions_create(
|
|
198
198
|
messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
|
|
199
199
|
)
|
|
200
|
-
|
|
200
|
+
model_response = self._process_response(response)
|
|
201
|
+
model_response.usage.requests = 1
|
|
202
|
+
return model_response
|
|
201
203
|
|
|
202
204
|
@asynccontextmanager
|
|
203
205
|
async def request_stream(
|
|
@@ -304,7 +306,7 @@ class OpenAIModel(Model):
|
|
|
304
306
|
if choice.message.tool_calls is not None:
|
|
305
307
|
for c in choice.message.tool_calls:
|
|
306
308
|
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
|
|
307
|
-
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
|
|
309
|
+
return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
|
|
308
310
|
|
|
309
311
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
|
310
312
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -522,12 +524,12 @@ class OpenAIResponsesModel(Model):
|
|
|
522
524
|
messages: list[ModelRequest | ModelResponse],
|
|
523
525
|
model_settings: ModelSettings | None,
|
|
524
526
|
model_request_parameters: ModelRequestParameters,
|
|
525
|
-
) ->
|
|
527
|
+
) -> ModelResponse:
|
|
526
528
|
check_allow_model_requests()
|
|
527
529
|
response = await self._responses_create(
|
|
528
530
|
messages, False, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
|
|
529
531
|
)
|
|
530
|
-
return self._process_response(response)
|
|
532
|
+
return self._process_response(response)
|
|
531
533
|
|
|
532
534
|
@asynccontextmanager
|
|
533
535
|
async def request_stream(
|
|
@@ -554,7 +556,7 @@ class OpenAIResponsesModel(Model):
|
|
|
554
556
|
for item in response.output:
|
|
555
557
|
if item.type == 'function_call':
|
|
556
558
|
items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
|
|
557
|
-
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
|
|
559
|
+
return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
|
|
558
560
|
|
|
559
561
|
async def _process_streamed_response(
|
|
560
562
|
self, response: AsyncStream[responses.ResponseStreamEvent]
|
|
@@ -935,6 +937,7 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
|
|
|
935
937
|
if response_usage.prompt_tokens_details is not None:
|
|
936
938
|
details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True))
|
|
937
939
|
return usage.Usage(
|
|
940
|
+
requests=1,
|
|
938
941
|
request_tokens=response_usage.prompt_tokens,
|
|
939
942
|
response_tokens=response_usage.completion_tokens,
|
|
940
943
|
total_tokens=response_usage.total_tokens,
|
|
@@ -86,11 +86,12 @@ class TestModel(Model):
|
|
|
86
86
|
messages: list[ModelMessage],
|
|
87
87
|
model_settings: ModelSettings | None,
|
|
88
88
|
model_request_parameters: ModelRequestParameters,
|
|
89
|
-
) ->
|
|
89
|
+
) -> ModelResponse:
|
|
90
90
|
self.last_model_request_parameters = model_request_parameters
|
|
91
91
|
model_response = self._request(messages, model_settings, model_request_parameters)
|
|
92
|
-
usage = _estimate_usage([*messages, model_response])
|
|
93
|
-
|
|
92
|
+
model_response.usage = _estimate_usage([*messages, model_response])
|
|
93
|
+
model_response.usage.requests = 1
|
|
94
|
+
return model_response
|
|
94
95
|
|
|
95
96
|
@asynccontextmanager
|
|
96
97
|
async def request_stream(
|
|
@@ -7,7 +7,6 @@ from typing import Any
|
|
|
7
7
|
|
|
8
8
|
from ..messages import ModelMessage, ModelResponse
|
|
9
9
|
from ..settings import ModelSettings
|
|
10
|
-
from ..usage import Usage
|
|
11
10
|
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
|
|
12
11
|
|
|
13
12
|
|
|
@@ -24,7 +23,7 @@ class WrapperModel(Model):
|
|
|
24
23
|
def __init__(self, wrapped: Model | KnownModelName):
|
|
25
24
|
self.wrapped = infer_model(wrapped)
|
|
26
25
|
|
|
27
|
-
async def request(self, *args: Any, **kwargs: Any) ->
|
|
26
|
+
async def request(self, *args: Any, **kwargs: Any) -> ModelResponse:
|
|
28
27
|
return await self.wrapped.request(*args, **kwargs)
|
|
29
28
|
|
|
30
29
|
@asynccontextmanager
|
|
@@ -28,14 +28,12 @@ class Usage:
|
|
|
28
28
|
details: dict[str, int] | None = None
|
|
29
29
|
"""Any extra details returned by the model."""
|
|
30
30
|
|
|
31
|
-
def incr(self, incr_usage: Usage
|
|
31
|
+
def incr(self, incr_usage: Usage) -> None:
|
|
32
32
|
"""Increment the usage in place.
|
|
33
33
|
|
|
34
34
|
Args:
|
|
35
35
|
incr_usage: The usage to increment by.
|
|
36
|
-
requests: The number of requests to increment by in addition to `incr_usage.requests`.
|
|
37
36
|
"""
|
|
38
|
-
self.requests += requests
|
|
39
37
|
for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
|
|
40
38
|
self_value = getattr(self, f)
|
|
41
39
|
other_value = getattr(incr_usage, f)
|
|
@@ -66,6 +64,10 @@ class Usage:
|
|
|
66
64
|
result[f'gen_ai.usage.details.{key}'] = value
|
|
67
65
|
return {k: v for k, v in result.items() if v}
|
|
68
66
|
|
|
67
|
+
def has_values(self) -> bool:
|
|
68
|
+
"""Whether any values are set and non-zero."""
|
|
69
|
+
return bool(self.requests or self.request_tokens or self.response_tokens or self.details)
|
|
70
|
+
|
|
69
71
|
|
|
70
72
|
@dataclass
|
|
71
73
|
class UsageLimits:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|