pydantic-ai-slim 1.2.1__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/__init__.py +6 -0
- pydantic_ai/_agent_graph.py +67 -20
- pydantic_ai/_cli.py +2 -2
- pydantic_ai/_output.py +20 -12
- pydantic_ai/_run_context.py +6 -2
- pydantic_ai/_utils.py +26 -8
- pydantic_ai/ag_ui.py +50 -696
- pydantic_ai/agent/__init__.py +13 -25
- pydantic_ai/agent/abstract.py +146 -9
- pydantic_ai/builtin_tools.py +106 -4
- pydantic_ai/direct.py +16 -4
- pydantic_ai/durable_exec/dbos/_agent.py +3 -0
- pydantic_ai/durable_exec/prefect/_agent.py +3 -0
- pydantic_ai/durable_exec/temporal/__init__.py +11 -0
- pydantic_ai/durable_exec/temporal/_agent.py +3 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +23 -72
- pydantic_ai/durable_exec/temporal/_mcp_server.py +30 -30
- pydantic_ai/durable_exec/temporal/_run_context.py +7 -2
- pydantic_ai/durable_exec/temporal/_toolset.py +67 -3
- pydantic_ai/exceptions.py +6 -1
- pydantic_ai/mcp.py +1 -22
- pydantic_ai/messages.py +46 -8
- pydantic_ai/models/__init__.py +87 -38
- pydantic_ai/models/anthropic.py +132 -11
- pydantic_ai/models/bedrock.py +4 -4
- pydantic_ai/models/cohere.py +0 -7
- pydantic_ai/models/gemini.py +9 -2
- pydantic_ai/models/google.py +26 -23
- pydantic_ai/models/groq.py +13 -5
- pydantic_ai/models/huggingface.py +2 -2
- pydantic_ai/models/openai.py +251 -52
- pydantic_ai/models/outlines.py +563 -0
- pydantic_ai/models/test.py +6 -3
- pydantic_ai/profiles/openai.py +7 -0
- pydantic_ai/providers/__init__.py +25 -12
- pydantic_ai/providers/anthropic.py +2 -2
- pydantic_ai/providers/bedrock.py +60 -16
- pydantic_ai/providers/gateway.py +60 -72
- pydantic_ai/providers/google.py +91 -24
- pydantic_ai/providers/openrouter.py +3 -0
- pydantic_ai/providers/outlines.py +40 -0
- pydantic_ai/providers/ovhcloud.py +95 -0
- pydantic_ai/result.py +173 -8
- pydantic_ai/run.py +40 -24
- pydantic_ai/settings.py +8 -0
- pydantic_ai/tools.py +10 -6
- pydantic_ai/toolsets/fastmcp.py +215 -0
- pydantic_ai/ui/__init__.py +16 -0
- pydantic_ai/ui/_adapter.py +386 -0
- pydantic_ai/ui/_event_stream.py +591 -0
- pydantic_ai/ui/_messages_builder.py +28 -0
- pydantic_ai/ui/ag_ui/__init__.py +9 -0
- pydantic_ai/ui/ag_ui/_adapter.py +187 -0
- pydantic_ai/ui/ag_ui/_event_stream.py +236 -0
- pydantic_ai/ui/ag_ui/app.py +148 -0
- pydantic_ai/ui/vercel_ai/__init__.py +16 -0
- pydantic_ai/ui/vercel_ai/_adapter.py +199 -0
- pydantic_ai/ui/vercel_ai/_event_stream.py +187 -0
- pydantic_ai/ui/vercel_ai/_utils.py +16 -0
- pydantic_ai/ui/vercel_ai/request_types.py +275 -0
- pydantic_ai/ui/vercel_ai/response_types.py +230 -0
- pydantic_ai/usage.py +13 -2
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/METADATA +23 -5
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/RECORD +67 -49
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/__init__.py
CHANGED
|
@@ -12,6 +12,7 @@ from .agent import (
|
|
|
12
12
|
from .builtin_tools import (
|
|
13
13
|
CodeExecutionTool,
|
|
14
14
|
ImageGenerationTool,
|
|
15
|
+
MCPServerTool,
|
|
15
16
|
MemoryTool,
|
|
16
17
|
UrlContextTool,
|
|
17
18
|
WebSearchTool,
|
|
@@ -22,6 +23,7 @@ from .exceptions import (
|
|
|
22
23
|
ApprovalRequired,
|
|
23
24
|
CallDeferred,
|
|
24
25
|
FallbackExceptionGroup,
|
|
26
|
+
IncompleteToolCall,
|
|
25
27
|
ModelHTTPError,
|
|
26
28
|
ModelRetry,
|
|
27
29
|
UnexpectedModelBehavior,
|
|
@@ -63,6 +65,7 @@ from .messages import (
|
|
|
63
65
|
ModelResponseStreamEvent,
|
|
64
66
|
MultiModalContent,
|
|
65
67
|
PartDeltaEvent,
|
|
68
|
+
PartEndEvent,
|
|
66
69
|
PartStartEvent,
|
|
67
70
|
RetryPromptPart,
|
|
68
71
|
SystemPromptPart,
|
|
@@ -124,6 +127,7 @@ __all__ = (
|
|
|
124
127
|
'ModelRetry',
|
|
125
128
|
'ModelHTTPError',
|
|
126
129
|
'FallbackExceptionGroup',
|
|
130
|
+
'IncompleteToolCall',
|
|
127
131
|
'UnexpectedModelBehavior',
|
|
128
132
|
'UsageLimitExceeded',
|
|
129
133
|
'UserError',
|
|
@@ -161,6 +165,7 @@ __all__ = (
|
|
|
161
165
|
'ModelResponseStreamEvent',
|
|
162
166
|
'MultiModalContent',
|
|
163
167
|
'PartDeltaEvent',
|
|
168
|
+
'PartEndEvent',
|
|
164
169
|
'PartStartEvent',
|
|
165
170
|
'RetryPromptPart',
|
|
166
171
|
'SystemPromptPart',
|
|
@@ -211,6 +216,7 @@ __all__ = (
|
|
|
211
216
|
'CodeExecutionTool',
|
|
212
217
|
'ImageGenerationTool',
|
|
213
218
|
'MemoryTool',
|
|
219
|
+
'MCPServerTool',
|
|
214
220
|
# output
|
|
215
221
|
'ToolOutput',
|
|
216
222
|
'NativeOutput',
|
pydantic_ai/_agent_graph.py
CHANGED
|
@@ -20,7 +20,8 @@ from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION
|
|
|
20
20
|
from pydantic_ai._tool_manager import ToolManager
|
|
21
21
|
from pydantic_ai._utils import dataclasses_no_defaults_repr, get_union_args, is_async_callable, run_in_executor
|
|
22
22
|
from pydantic_ai.builtin_tools import AbstractBuiltinTool
|
|
23
|
-
from pydantic_graph import BaseNode,
|
|
23
|
+
from pydantic_graph import BaseNode, GraphRunContext
|
|
24
|
+
from pydantic_graph.beta import Graph, GraphBuilder
|
|
24
25
|
from pydantic_graph.nodes import End, NodeRunEndT
|
|
25
26
|
|
|
26
27
|
from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
|
|
@@ -92,9 +93,28 @@ class GraphAgentState:
|
|
|
92
93
|
retries: int = 0
|
|
93
94
|
run_step: int = 0
|
|
94
95
|
|
|
95
|
-
def increment_retries(
|
|
96
|
+
def increment_retries(
|
|
97
|
+
self,
|
|
98
|
+
max_result_retries: int,
|
|
99
|
+
error: BaseException | None = None,
|
|
100
|
+
model_settings: ModelSettings | None = None,
|
|
101
|
+
) -> None:
|
|
96
102
|
self.retries += 1
|
|
97
103
|
if self.retries > max_result_retries:
|
|
104
|
+
if (
|
|
105
|
+
self.message_history
|
|
106
|
+
and isinstance(model_response := self.message_history[-1], _messages.ModelResponse)
|
|
107
|
+
and model_response.finish_reason == 'length'
|
|
108
|
+
and model_response.parts
|
|
109
|
+
and isinstance(tool_call := model_response.parts[-1], _messages.ToolCallPart)
|
|
110
|
+
):
|
|
111
|
+
try:
|
|
112
|
+
tool_call.args_as_dict()
|
|
113
|
+
except Exception:
|
|
114
|
+
max_tokens = (model_settings or {}).get('max_tokens') if model_settings else None
|
|
115
|
+
raise exceptions.IncompleteToolCall(
|
|
116
|
+
f'Model token limit ({max_tokens if max_tokens is not None else "provider default"}) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.'
|
|
117
|
+
)
|
|
98
118
|
message = f'Exceeded maximum retries ({max_result_retries}) for output validation'
|
|
99
119
|
if error:
|
|
100
120
|
if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None:
|
|
@@ -247,6 +267,9 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
247
267
|
|
|
248
268
|
next_message.instructions = await ctx.deps.get_instructions(run_context)
|
|
249
269
|
|
|
270
|
+
if not messages and not next_message.parts and not next_message.instructions:
|
|
271
|
+
raise exceptions.UserError('No message history, user prompt, or instructions provided')
|
|
272
|
+
|
|
250
273
|
return ModelRequestNode[DepsT, NodeRunEndT](request=next_message)
|
|
251
274
|
|
|
252
275
|
async def _handle_deferred_tool_results( # noqa: C901
|
|
@@ -568,8 +591,12 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
568
591
|
# resubmit the most recent request that resulted in an empty response,
|
|
569
592
|
# as the empty response and request will not create any items in the API payload,
|
|
570
593
|
# in the hope the model will return a non-empty response this time.
|
|
571
|
-
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
572
|
-
|
|
594
|
+
ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings)
|
|
595
|
+
run_context = build_run_context(ctx)
|
|
596
|
+
instructions = await ctx.deps.get_instructions(run_context)
|
|
597
|
+
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
|
|
598
|
+
_messages.ModelRequest(parts=[], instructions=instructions)
|
|
599
|
+
)
|
|
573
600
|
return
|
|
574
601
|
|
|
575
602
|
text = ''
|
|
@@ -630,8 +657,14 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
630
657
|
)
|
|
631
658
|
raise ToolRetryError(m)
|
|
632
659
|
except ToolRetryError as e:
|
|
633
|
-
ctx.state.increment_retries(
|
|
634
|
-
|
|
660
|
+
ctx.state.increment_retries(
|
|
661
|
+
ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings
|
|
662
|
+
)
|
|
663
|
+
run_context = build_run_context(ctx)
|
|
664
|
+
instructions = await ctx.deps.get_instructions(run_context)
|
|
665
|
+
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
|
|
666
|
+
_messages.ModelRequest(parts=[e.tool_retry], instructions=instructions)
|
|
667
|
+
)
|
|
635
668
|
|
|
636
669
|
self._events_iterator = _run_stream()
|
|
637
670
|
|
|
@@ -788,10 +821,14 @@ async def process_tool_calls( # noqa: C901
|
|
|
788
821
|
try:
|
|
789
822
|
result_data = await tool_manager.handle_call(call)
|
|
790
823
|
except exceptions.UnexpectedModelBehavior as e:
|
|
791
|
-
ctx.state.increment_retries(
|
|
824
|
+
ctx.state.increment_retries(
|
|
825
|
+
ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings
|
|
826
|
+
)
|
|
792
827
|
raise e # pragma: lax no cover
|
|
793
828
|
except ToolRetryError as e:
|
|
794
|
-
ctx.state.increment_retries(
|
|
829
|
+
ctx.state.increment_retries(
|
|
830
|
+
ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings
|
|
831
|
+
)
|
|
795
832
|
yield _messages.FunctionToolCallEvent(call)
|
|
796
833
|
output_parts.append(e.tool_retry)
|
|
797
834
|
yield _messages.FunctionToolResultEvent(e.tool_retry)
|
|
@@ -820,7 +857,7 @@ async def process_tool_calls( # noqa: C901
|
|
|
820
857
|
|
|
821
858
|
# Then, we handle unknown tool calls
|
|
822
859
|
if tool_calls_by_kind['unknown']:
|
|
823
|
-
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
860
|
+
ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings)
|
|
824
861
|
calls_to_run.extend(tool_calls_by_kind['unknown'])
|
|
825
862
|
|
|
826
863
|
calls_to_run_results: dict[str, DeferredToolResult] = {}
|
|
@@ -1129,22 +1166,32 @@ def build_agent_graph(
|
|
|
1129
1166
|
name: str | None,
|
|
1130
1167
|
deps_type: type[DepsT],
|
|
1131
1168
|
output_type: OutputSpec[OutputT],
|
|
1132
|
-
) -> Graph[
|
|
1169
|
+
) -> Graph[
|
|
1170
|
+
GraphAgentState,
|
|
1171
|
+
GraphAgentDeps[DepsT, OutputT],
|
|
1172
|
+
UserPromptNode[DepsT, OutputT],
|
|
1173
|
+
result.FinalResult[OutputT],
|
|
1174
|
+
]:
|
|
1133
1175
|
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
|
|
1134
|
-
|
|
1135
|
-
UserPromptNode[DepsT],
|
|
1136
|
-
ModelRequestNode[DepsT],
|
|
1137
|
-
CallToolsNode[DepsT],
|
|
1138
|
-
SetFinalResult[DepsT],
|
|
1139
|
-
)
|
|
1140
|
-
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[OutputT]](
|
|
1141
|
-
nodes=nodes,
|
|
1176
|
+
g = GraphBuilder(
|
|
1142
1177
|
name=name or 'Agent',
|
|
1143
1178
|
state_type=GraphAgentState,
|
|
1144
|
-
|
|
1179
|
+
deps_type=GraphAgentDeps[DepsT, OutputT],
|
|
1180
|
+
input_type=UserPromptNode[DepsT, OutputT],
|
|
1181
|
+
output_type=result.FinalResult[OutputT],
|
|
1145
1182
|
auto_instrument=False,
|
|
1146
1183
|
)
|
|
1147
|
-
|
|
1184
|
+
|
|
1185
|
+
g.add(
|
|
1186
|
+
g.edge_from(g.start_node).to(UserPromptNode[DepsT, OutputT]),
|
|
1187
|
+
g.node(UserPromptNode[DepsT, OutputT]),
|
|
1188
|
+
g.node(ModelRequestNode[DepsT, OutputT]),
|
|
1189
|
+
g.node(CallToolsNode[DepsT, OutputT]),
|
|
1190
|
+
g.node(
|
|
1191
|
+
SetFinalResult[DepsT, OutputT],
|
|
1192
|
+
),
|
|
1193
|
+
)
|
|
1194
|
+
return g.build(validate_graph_structure=False)
|
|
1148
1195
|
|
|
1149
1196
|
|
|
1150
1197
|
async def _process_message_history(
|
pydantic_ai/_cli.py
CHANGED
|
@@ -103,7 +103,7 @@ def cli_exit(prog_name: str = 'pai'): # pragma: no cover
|
|
|
103
103
|
|
|
104
104
|
|
|
105
105
|
def cli( # noqa: C901
|
|
106
|
-
args_list: Sequence[str] | None = None, *, prog_name: str = 'pai', default_model: str = 'openai:gpt-
|
|
106
|
+
args_list: Sequence[str] | None = None, *, prog_name: str = 'pai', default_model: str = 'openai:gpt-5'
|
|
107
107
|
) -> int:
|
|
108
108
|
"""Run the CLI and return the exit code for the process."""
|
|
109
109
|
parser = argparse.ArgumentParser(
|
|
@@ -124,7 +124,7 @@ Special prompts:
|
|
|
124
124
|
'-m',
|
|
125
125
|
'--model',
|
|
126
126
|
nargs='?',
|
|
127
|
-
help=f'Model to use, in format "<provider>:<model>" e.g. "openai:gpt-
|
|
127
|
+
help=f'Model to use, in format "<provider>:<model>" e.g. "openai:gpt-5" or "anthropic:claude-sonnet-4-5". Defaults to "{default_model}".',
|
|
128
128
|
)
|
|
129
129
|
# we don't want to autocomplete or list models that don't include the provider,
|
|
130
130
|
# e.g. we want to show `openai:gpt-4o` but not `gpt-4o`
|
pydantic_ai/_output.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import json
|
|
5
|
+
import re
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
6
7
|
from collections.abc import Awaitable, Callable, Sequence
|
|
7
8
|
from dataclasses import dataclass, field
|
|
@@ -70,6 +71,7 @@ Usage `OutputValidatorFunc[AgentDepsT, T]`.
|
|
|
70
71
|
|
|
71
72
|
DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
|
|
72
73
|
DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
|
|
74
|
+
OUTPUT_TOOL_NAME_SANITIZER = re.compile(r'[^a-zA-Z0-9-_]')
|
|
73
75
|
|
|
74
76
|
|
|
75
77
|
async def execute_traced_output_function(
|
|
@@ -554,6 +556,20 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
|
554
556
|
def mode(self) -> OutputMode:
|
|
555
557
|
return 'prompted'
|
|
556
558
|
|
|
559
|
+
@classmethod
|
|
560
|
+
def build_instructions(cls, template: str, object_def: OutputObjectDefinition) -> str:
|
|
561
|
+
"""Build instructions from a template and an object definition."""
|
|
562
|
+
schema = object_def.json_schema.copy()
|
|
563
|
+
if object_def.name:
|
|
564
|
+
schema['title'] = object_def.name
|
|
565
|
+
if object_def.description:
|
|
566
|
+
schema['description'] = object_def.description
|
|
567
|
+
|
|
568
|
+
if '{schema}' not in template:
|
|
569
|
+
template = '\n\n'.join([template, '{schema}'])
|
|
570
|
+
|
|
571
|
+
return template.format(schema=json.dumps(schema))
|
|
572
|
+
|
|
557
573
|
def raise_if_unsupported(self, profile: ModelProfile) -> None:
|
|
558
574
|
"""Raise an error if the mode is not supported by this model."""
|
|
559
575
|
super().raise_if_unsupported(profile)
|
|
@@ -561,18 +577,8 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
|
561
577
|
def instructions(self, default_template: str) -> str:
|
|
562
578
|
"""Get instructions to tell model to output JSON matching the schema."""
|
|
563
579
|
template = self.template or default_template
|
|
564
|
-
|
|
565
|
-
if '{schema}' not in template:
|
|
566
|
-
template = '\n\n'.join([template, '{schema}'])
|
|
567
|
-
|
|
568
580
|
object_def = self.object_def
|
|
569
|
-
|
|
570
|
-
if object_def.name:
|
|
571
|
-
schema['title'] = object_def.name
|
|
572
|
-
if object_def.description:
|
|
573
|
-
schema['description'] = object_def.description
|
|
574
|
-
|
|
575
|
-
return template.format(schema=json.dumps(schema))
|
|
581
|
+
return self.build_instructions(template, object_def)
|
|
576
582
|
|
|
577
583
|
|
|
578
584
|
@dataclass(init=False)
|
|
@@ -997,7 +1003,9 @@ class OutputToolset(AbstractToolset[AgentDepsT]):
|
|
|
997
1003
|
if name is None:
|
|
998
1004
|
name = default_name
|
|
999
1005
|
if multiple:
|
|
1000
|
-
|
|
1006
|
+
# strip unsupported characters like "[" and "]" from generic class names
|
|
1007
|
+
safe_name = OUTPUT_TOOL_NAME_SANITIZER.sub('', object_def.name or '')
|
|
1008
|
+
name += f'_{safe_name}'
|
|
1001
1009
|
|
|
1002
1010
|
i = 1
|
|
1003
1011
|
original_name = name
|
pydantic_ai/_run_context.py
CHANGED
|
@@ -16,15 +16,19 @@ if TYPE_CHECKING:
|
|
|
16
16
|
from .models import Model
|
|
17
17
|
from .result import RunUsage
|
|
18
18
|
|
|
19
|
+
# TODO (v2): Change the default for all typevars like this from `None` to `object`
|
|
19
20
|
AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
|
|
20
21
|
"""Type variable for agent dependencies."""
|
|
21
22
|
|
|
23
|
+
RunContextAgentDepsT = TypeVar('RunContextAgentDepsT', default=None, covariant=True)
|
|
24
|
+
"""Type variable for the agent dependencies in `RunContext`."""
|
|
25
|
+
|
|
22
26
|
|
|
23
27
|
@dataclasses.dataclass(repr=False, kw_only=True)
|
|
24
|
-
class RunContext(Generic[
|
|
28
|
+
class RunContext(Generic[RunContextAgentDepsT]):
|
|
25
29
|
"""Information about the current call."""
|
|
26
30
|
|
|
27
|
-
deps:
|
|
31
|
+
deps: RunContextAgentDepsT
|
|
28
32
|
"""Dependencies for the agent."""
|
|
29
33
|
model: Model
|
|
30
34
|
"""The model used in this run."""
|
pydantic_ai/_utils.py
CHANGED
|
@@ -147,7 +147,7 @@ async def group_by_temporal(
|
|
|
147
147
|
aiterable: The async iterable to group.
|
|
148
148
|
soft_max_interval: Maximum interval over which to group items, this should avoid a trickle of items causing
|
|
149
149
|
a group to never be yielded. It's a soft max in the sense that once we're over this time, we yield items
|
|
150
|
-
as soon as `aiter
|
|
150
|
+
as soon as `anext(aiter)` returns. If `None`, no grouping/debouncing is performed
|
|
151
151
|
|
|
152
152
|
Returns:
|
|
153
153
|
A context manager usable as an async iterable of lists of items produced by the input async iterable.
|
|
@@ -171,7 +171,7 @@ async def group_by_temporal(
|
|
|
171
171
|
buffer: list[T] = []
|
|
172
172
|
group_start_time = time.monotonic()
|
|
173
173
|
|
|
174
|
-
aiterator = aiterable
|
|
174
|
+
aiterator = aiter(aiterable)
|
|
175
175
|
while True:
|
|
176
176
|
if group_start_time is None:
|
|
177
177
|
# group hasn't started, we just wait for the maximum interval
|
|
@@ -182,9 +182,9 @@ async def group_by_temporal(
|
|
|
182
182
|
|
|
183
183
|
# if there's no current task, we get the next one
|
|
184
184
|
if task is None:
|
|
185
|
-
# aiter
|
|
185
|
+
# anext(aiter) returns an Awaitable[T], not a Coroutine which asyncio.create_task expects
|
|
186
186
|
# so far, this doesn't seem to be a problem
|
|
187
|
-
task = asyncio.create_task(aiterator
|
|
187
|
+
task = asyncio.create_task(anext(aiterator)) # pyright: ignore[reportArgumentType]
|
|
188
188
|
|
|
189
189
|
# we use asyncio.wait to avoid cancelling the coroutine if it's not done
|
|
190
190
|
done, _ = await asyncio.wait((task,), timeout=wait_time)
|
|
@@ -234,6 +234,15 @@ def sync_anext(iterator: Iterator[T]) -> T:
|
|
|
234
234
|
raise StopAsyncIteration() from e
|
|
235
235
|
|
|
236
236
|
|
|
237
|
+
def sync_async_iterator(async_iter: AsyncIterator[T]) -> Iterator[T]:
|
|
238
|
+
loop = get_event_loop()
|
|
239
|
+
while True:
|
|
240
|
+
try:
|
|
241
|
+
yield loop.run_until_complete(anext(async_iter))
|
|
242
|
+
except StopAsyncIteration:
|
|
243
|
+
break
|
|
244
|
+
|
|
245
|
+
|
|
237
246
|
def now_utc() -> datetime:
|
|
238
247
|
return datetime.now(tz=timezone.utc)
|
|
239
248
|
|
|
@@ -284,10 +293,10 @@ class PeekableAsyncStream(Generic[T]):
|
|
|
284
293
|
|
|
285
294
|
# Otherwise, we need to fetch the next item from the underlying iterator.
|
|
286
295
|
if self._source_iter is None:
|
|
287
|
-
self._source_iter = self._source
|
|
296
|
+
self._source_iter = aiter(self._source)
|
|
288
297
|
|
|
289
298
|
try:
|
|
290
|
-
self._buffer = await self._source_iter
|
|
299
|
+
self._buffer = await anext(self._source_iter)
|
|
291
300
|
except StopAsyncIteration:
|
|
292
301
|
self._exhausted = True
|
|
293
302
|
return UNSET
|
|
@@ -318,10 +327,10 @@ class PeekableAsyncStream(Generic[T]):
|
|
|
318
327
|
|
|
319
328
|
# Otherwise, fetch the next item from the source.
|
|
320
329
|
if self._source_iter is None:
|
|
321
|
-
self._source_iter = self._source
|
|
330
|
+
self._source_iter = aiter(self._source)
|
|
322
331
|
|
|
323
332
|
try:
|
|
324
|
-
return await self._source_iter
|
|
333
|
+
return await anext(self._source_iter)
|
|
325
334
|
except StopAsyncIteration:
|
|
326
335
|
self._exhausted = True
|
|
327
336
|
raise
|
|
@@ -489,3 +498,12 @@ def get_union_args(tp: Any) -> tuple[Any, ...]:
|
|
|
489
498
|
return tuple(_unwrap_annotated(arg) for arg in get_args(tp))
|
|
490
499
|
else:
|
|
491
500
|
return ()
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def get_event_loop():
|
|
504
|
+
try:
|
|
505
|
+
event_loop = asyncio.get_event_loop()
|
|
506
|
+
except RuntimeError: # pragma: lax no cover
|
|
507
|
+
event_loop = asyncio.new_event_loop()
|
|
508
|
+
asyncio.set_event_loop(event_loop)
|
|
509
|
+
return event_loop
|