jaf-py 2.5.10__py3-none-any.whl → 2.5.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/__init__.py +154 -57
- jaf/a2a/__init__.py +42 -21
- jaf/a2a/agent.py +79 -126
- jaf/a2a/agent_card.py +87 -78
- jaf/a2a/client.py +30 -66
- jaf/a2a/examples/client_example.py +12 -12
- jaf/a2a/examples/integration_example.py +38 -47
- jaf/a2a/examples/server_example.py +56 -53
- jaf/a2a/memory/__init__.py +0 -4
- jaf/a2a/memory/cleanup.py +28 -21
- jaf/a2a/memory/factory.py +155 -133
- jaf/a2a/memory/providers/composite.py +21 -26
- jaf/a2a/memory/providers/in_memory.py +89 -83
- jaf/a2a/memory/providers/postgres.py +117 -115
- jaf/a2a/memory/providers/redis.py +128 -121
- jaf/a2a/memory/serialization.py +77 -87
- jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
- jaf/a2a/memory/tests/test_cleanup.py +211 -94
- jaf/a2a/memory/tests/test_serialization.py +73 -68
- jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
- jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
- jaf/a2a/memory/types.py +91 -53
- jaf/a2a/protocol.py +95 -125
- jaf/a2a/server.py +90 -118
- jaf/a2a/standalone_client.py +30 -43
- jaf/a2a/tests/__init__.py +16 -33
- jaf/a2a/tests/run_tests.py +17 -53
- jaf/a2a/tests/test_agent.py +40 -140
- jaf/a2a/tests/test_client.py +54 -117
- jaf/a2a/tests/test_integration.py +28 -82
- jaf/a2a/tests/test_protocol.py +54 -139
- jaf/a2a/tests/test_types.py +50 -136
- jaf/a2a/types.py +58 -34
- jaf/cli.py +21 -41
- jaf/core/__init__.py +7 -1
- jaf/core/agent_tool.py +93 -72
- jaf/core/analytics.py +257 -207
- jaf/core/checkpoint.py +223 -0
- jaf/core/composition.py +249 -235
- jaf/core/engine.py +817 -519
- jaf/core/errors.py +55 -42
- jaf/core/guardrails.py +276 -202
- jaf/core/handoff.py +47 -31
- jaf/core/parallel_agents.py +69 -75
- jaf/core/performance.py +75 -73
- jaf/core/proxy.py +43 -44
- jaf/core/proxy_helpers.py +24 -27
- jaf/core/regeneration.py +220 -129
- jaf/core/state.py +68 -66
- jaf/core/streaming.py +115 -108
- jaf/core/tool_results.py +111 -101
- jaf/core/tools.py +114 -116
- jaf/core/tracing.py +269 -210
- jaf/core/types.py +371 -151
- jaf/core/workflows.py +209 -168
- jaf/exceptions.py +46 -38
- jaf/memory/__init__.py +1 -6
- jaf/memory/approval_storage.py +54 -77
- jaf/memory/factory.py +4 -4
- jaf/memory/providers/in_memory.py +216 -180
- jaf/memory/providers/postgres.py +216 -146
- jaf/memory/providers/redis.py +173 -116
- jaf/memory/types.py +70 -51
- jaf/memory/utils.py +36 -34
- jaf/plugins/__init__.py +12 -12
- jaf/plugins/base.py +105 -96
- jaf/policies/__init__.py +0 -1
- jaf/policies/handoff.py +37 -46
- jaf/policies/validation.py +76 -52
- jaf/providers/__init__.py +6 -3
- jaf/providers/mcp.py +97 -51
- jaf/providers/model.py +360 -279
- jaf/server/__init__.py +1 -1
- jaf/server/main.py +7 -11
- jaf/server/server.py +514 -359
- jaf/server/types.py +208 -52
- jaf/utils/__init__.py +17 -18
- jaf/utils/attachments.py +111 -116
- jaf/utils/document_processor.py +175 -174
- jaf/visualization/__init__.py +1 -1
- jaf/visualization/example.py +111 -110
- jaf/visualization/functional_core.py +46 -71
- jaf/visualization/graphviz.py +154 -189
- jaf/visualization/imperative_shell.py +7 -16
- jaf/visualization/types.py +8 -4
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/METADATA +2 -2
- jaf_py-2.5.11.dist-info/RECORD +97 -0
- jaf_py-2.5.10.dist-info/RECORD +0 -96
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/WHEEL +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/top_level.txt +0 -0
jaf/core/streaming.py
CHANGED
|
@@ -14,26 +14,38 @@ from typing import AsyncIterator, Dict, List, Optional, Any, Union, Callable
|
|
|
14
14
|
from enum import Enum
|
|
15
15
|
|
|
16
16
|
from .types import (
|
|
17
|
-
RunState,
|
|
18
|
-
|
|
17
|
+
RunState,
|
|
18
|
+
RunConfig,
|
|
19
|
+
Message,
|
|
20
|
+
TraceEvent,
|
|
21
|
+
RunId,
|
|
22
|
+
TraceId,
|
|
23
|
+
ContentRole,
|
|
24
|
+
ToolCall,
|
|
25
|
+
JAFError,
|
|
26
|
+
CompletedOutcome,
|
|
27
|
+
ErrorOutcome,
|
|
28
|
+
ModelBehaviorError,
|
|
19
29
|
)
|
|
20
30
|
|
|
21
31
|
|
|
22
32
|
class StreamingEventType(str, Enum):
|
|
23
33
|
"""Types of streaming events."""
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
34
|
+
|
|
35
|
+
START = "start"
|
|
36
|
+
CHUNK = "chunk"
|
|
37
|
+
TOOL_CALL = "tool_call"
|
|
38
|
+
TOOL_RESULT = "tool_result"
|
|
39
|
+
AGENT_SWITCH = "agent_switch"
|
|
40
|
+
ERROR = "error"
|
|
41
|
+
COMPLETE = "complete"
|
|
42
|
+
METADATA = "metadata"
|
|
32
43
|
|
|
33
44
|
|
|
34
45
|
@dataclass(frozen=True)
|
|
35
46
|
class StreamingChunk:
|
|
36
47
|
"""A chunk of streaming content."""
|
|
48
|
+
|
|
37
49
|
content: str
|
|
38
50
|
delta: str # The new content added in this chunk
|
|
39
51
|
is_complete: bool = False
|
|
@@ -43,15 +55,17 @@ class StreamingChunk:
|
|
|
43
55
|
@dataclass(frozen=True)
|
|
44
56
|
class StreamingToolCall:
|
|
45
57
|
"""Streaming tool call information."""
|
|
58
|
+
|
|
46
59
|
tool_name: str
|
|
47
60
|
arguments: Dict[str, Any]
|
|
48
61
|
call_id: str
|
|
49
|
-
status: str =
|
|
62
|
+
status: str = "started" # 'started', 'executing', 'completed', 'failed'
|
|
50
63
|
|
|
51
64
|
|
|
52
65
|
@dataclass(frozen=True)
|
|
53
66
|
class StreamingToolResult:
|
|
54
67
|
"""Streaming tool result information."""
|
|
68
|
+
|
|
55
69
|
tool_name: str
|
|
56
70
|
call_id: str
|
|
57
71
|
result: Any
|
|
@@ -62,6 +76,7 @@ class StreamingToolResult:
|
|
|
62
76
|
@dataclass(frozen=True)
|
|
63
77
|
class StreamingMetadata:
|
|
64
78
|
"""Metadata about the streaming session."""
|
|
79
|
+
|
|
65
80
|
agent_name: str
|
|
66
81
|
model_name: str
|
|
67
82
|
turn_count: int
|
|
@@ -73,25 +88,28 @@ class StreamingMetadata:
|
|
|
73
88
|
@dataclass(frozen=True)
|
|
74
89
|
class StreamingEvent:
|
|
75
90
|
"""A streaming event containing progressive updates."""
|
|
91
|
+
|
|
76
92
|
type: StreamingEventType
|
|
77
93
|
data: Union[StreamingChunk, StreamingToolCall, StreamingToolResult, StreamingMetadata, JAFError]
|
|
78
94
|
timestamp: float = field(default_factory=time.time)
|
|
79
95
|
run_id: Optional[RunId] = None
|
|
80
96
|
trace_id: Optional[TraceId] = None
|
|
81
|
-
|
|
97
|
+
|
|
82
98
|
def to_dict(self) -> Dict[str, Any]:
|
|
83
99
|
"""Convert streaming event to dictionary for serialization."""
|
|
84
100
|
return {
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
101
|
+
"type": self.type.value,
|
|
102
|
+
"data": self._serialize_data(),
|
|
103
|
+
"timestamp": self.timestamp,
|
|
104
|
+
"run_id": str(self.run_id) if self.run_id else None,
|
|
105
|
+
"trace_id": str(self.trace_id) if self.trace_id else None,
|
|
90
106
|
}
|
|
91
|
-
|
|
107
|
+
|
|
92
108
|
def _serialize_data(self) -> Dict[str, Any]:
|
|
93
109
|
"""Serialize the data field based on its type."""
|
|
94
|
-
if isinstance(
|
|
110
|
+
if isinstance(
|
|
111
|
+
self.data, (StreamingChunk, StreamingToolCall, StreamingToolResult, StreamingMetadata)
|
|
112
|
+
):
|
|
95
113
|
# Convert dataclass to dict
|
|
96
114
|
result = {}
|
|
97
115
|
for field_name, field_value in self.data.__dict__.items():
|
|
@@ -100,12 +118,12 @@ class StreamingEvent:
|
|
|
100
118
|
return result
|
|
101
119
|
elif isinstance(self.data, JAFError):
|
|
102
120
|
return {
|
|
103
|
-
|
|
104
|
-
|
|
121
|
+
"error_type": self.data._tag,
|
|
122
|
+
"detail": getattr(self.data, "detail", str(self.data)),
|
|
105
123
|
}
|
|
106
124
|
else:
|
|
107
|
-
return {
|
|
108
|
-
|
|
125
|
+
return {"value": self.data}
|
|
126
|
+
|
|
109
127
|
def to_json(self) -> str:
|
|
110
128
|
"""Convert streaming event to JSON string."""
|
|
111
129
|
return json.dumps(self.to_dict())
|
|
@@ -115,7 +133,7 @@ class StreamingBuffer:
|
|
|
115
133
|
"""
|
|
116
134
|
Buffer for accumulating streaming content and managing state.
|
|
117
135
|
"""
|
|
118
|
-
|
|
136
|
+
|
|
119
137
|
def __init__(self):
|
|
120
138
|
self.content: str = ""
|
|
121
139
|
self.chunks: List[StreamingChunk] = []
|
|
@@ -124,31 +142,31 @@ class StreamingBuffer:
|
|
|
124
142
|
self.metadata: Optional[StreamingMetadata] = None
|
|
125
143
|
self.is_complete: bool = False
|
|
126
144
|
self.error: Optional[JAFError] = None
|
|
127
|
-
|
|
145
|
+
|
|
128
146
|
def add_chunk(self, chunk: StreamingChunk) -> None:
|
|
129
147
|
"""Add a content chunk to the buffer."""
|
|
130
148
|
self.chunks.append(chunk)
|
|
131
149
|
self.content += chunk.delta
|
|
132
150
|
if chunk.is_complete:
|
|
133
151
|
self.is_complete = True
|
|
134
|
-
|
|
152
|
+
|
|
135
153
|
def add_tool_call(self, tool_call: StreamingToolCall) -> None:
|
|
136
154
|
"""Add a tool call to the buffer."""
|
|
137
155
|
self.tool_calls.append(tool_call)
|
|
138
|
-
|
|
156
|
+
|
|
139
157
|
def add_tool_result(self, tool_result: StreamingToolResult) -> None:
|
|
140
158
|
"""Add a tool result to the buffer."""
|
|
141
159
|
self.tool_results.append(tool_result)
|
|
142
|
-
|
|
160
|
+
|
|
143
161
|
def set_metadata(self, metadata: StreamingMetadata) -> None:
|
|
144
162
|
"""Set session metadata."""
|
|
145
163
|
self.metadata = metadata
|
|
146
|
-
|
|
164
|
+
|
|
147
165
|
def set_error(self, error: JAFError) -> None:
|
|
148
166
|
"""Set error state."""
|
|
149
167
|
self.error = error
|
|
150
168
|
self.is_complete = True
|
|
151
|
-
|
|
169
|
+
|
|
152
170
|
def get_final_message(self) -> Message:
|
|
153
171
|
"""Get the final accumulated message."""
|
|
154
172
|
tool_calls = None
|
|
@@ -156,38 +174,31 @@ class StreamingBuffer:
|
|
|
156
174
|
tool_calls = [
|
|
157
175
|
ToolCall(
|
|
158
176
|
id=tc.call_id,
|
|
159
|
-
type=
|
|
160
|
-
function={
|
|
177
|
+
type="function",
|
|
178
|
+
function={"name": tc.tool_name, "arguments": json.dumps(tc.arguments)},
|
|
161
179
|
)
|
|
162
180
|
for tc in self.tool_calls
|
|
163
181
|
]
|
|
164
|
-
|
|
165
|
-
return Message(
|
|
166
|
-
role=ContentRole.ASSISTANT,
|
|
167
|
-
content=self.content,
|
|
168
|
-
tool_calls=tool_calls
|
|
169
|
-
)
|
|
182
|
+
|
|
183
|
+
return Message(role=ContentRole.ASSISTANT, content=self.content, tool_calls=tool_calls)
|
|
170
184
|
|
|
171
185
|
|
|
172
186
|
async def run_streaming(
|
|
173
|
-
initial_state: RunState,
|
|
174
|
-
config: RunConfig,
|
|
175
|
-
chunk_size: int = 50,
|
|
176
|
-
include_metadata: bool = True
|
|
187
|
+
initial_state: RunState, config: RunConfig, chunk_size: int = 50, include_metadata: bool = True
|
|
177
188
|
) -> AsyncIterator[StreamingEvent]:
|
|
178
189
|
"""
|
|
179
190
|
Run an agent with streaming output.
|
|
180
|
-
|
|
191
|
+
|
|
181
192
|
This function provides real-time streaming of agent responses, tool calls,
|
|
182
193
|
and execution metadata. It yields StreamingEvent objects that can be
|
|
183
194
|
consumed by clients for progressive UI updates.
|
|
184
|
-
|
|
195
|
+
|
|
185
196
|
Args:
|
|
186
197
|
initial_state: Initial run state
|
|
187
198
|
config: Run configuration
|
|
188
199
|
chunk_size: Size of content chunks for streaming (characters)
|
|
189
200
|
include_metadata: Whether to include performance metadata
|
|
190
|
-
|
|
201
|
+
|
|
191
202
|
Yields:
|
|
192
203
|
StreamingEvent: Progressive updates during execution
|
|
193
204
|
"""
|
|
@@ -203,10 +214,10 @@ async def run_streaming(
|
|
|
203
214
|
model_name="unknown",
|
|
204
215
|
turn_count=initial_state.turn_count,
|
|
205
216
|
total_tokens=0,
|
|
206
|
-
execution_time_ms=0
|
|
217
|
+
execution_time_ms=0,
|
|
207
218
|
),
|
|
208
219
|
run_id=initial_state.run_id,
|
|
209
|
-
trace_id=initial_state.trace_id
|
|
220
|
+
trace_id=initial_state.trace_id,
|
|
210
221
|
)
|
|
211
222
|
|
|
212
223
|
tool_call_ids: Dict[str, str] = {} # Map call_id -> tool_name for in-flight tool calls
|
|
@@ -225,37 +236,36 @@ async def run_streaming(
|
|
|
225
236
|
return getattr(payload, key)
|
|
226
237
|
return None
|
|
227
238
|
|
|
228
|
-
if event.type ==
|
|
229
|
-
tool_name = _get_event_value([
|
|
230
|
-
args = _get_event_value([
|
|
231
|
-
call_id = _get_event_value([
|
|
239
|
+
if event.type == "tool_call_start":
|
|
240
|
+
tool_name = _get_event_value(["tool_name", "toolName"]) or "unknown"
|
|
241
|
+
args = _get_event_value(["args", "arguments"])
|
|
242
|
+
call_id = _get_event_value(["call_id", "tool_call_id", "toolCallId"])
|
|
232
243
|
|
|
233
244
|
if not call_id:
|
|
234
245
|
call_id = f"call_{uuid.uuid4().hex[:8]}"
|
|
235
246
|
if isinstance(payload, dict):
|
|
236
|
-
payload[
|
|
247
|
+
payload["call_id"] = call_id
|
|
237
248
|
|
|
238
249
|
tool_call_ids[call_id] = tool_name
|
|
239
250
|
|
|
240
251
|
tool_call = StreamingToolCall(
|
|
241
|
-
tool_name=tool_name,
|
|
242
|
-
arguments=args,
|
|
243
|
-
call_id=call_id,
|
|
244
|
-
status='started'
|
|
252
|
+
tool_name=tool_name, arguments=args, call_id=call_id, status="started"
|
|
245
253
|
)
|
|
246
254
|
streaming_event = StreamingEvent(
|
|
247
255
|
type=StreamingEventType.TOOL_CALL,
|
|
248
256
|
data=tool_call,
|
|
249
257
|
run_id=initial_state.run_id,
|
|
250
|
-
trace_id=initial_state.trace_id
|
|
258
|
+
trace_id=initial_state.trace_id,
|
|
251
259
|
)
|
|
252
|
-
elif event.type ==
|
|
253
|
-
tool_name = _get_event_value([
|
|
254
|
-
call_id = _get_event_value([
|
|
260
|
+
elif event.type == "tool_call_end":
|
|
261
|
+
tool_name = _get_event_value(["tool_name", "toolName"]) or "unknown"
|
|
262
|
+
call_id = _get_event_value(["call_id", "tool_call_id", "toolCallId"])
|
|
255
263
|
|
|
256
264
|
if not call_id:
|
|
257
265
|
# Fallback to locate a pending tool call with the same tool name
|
|
258
|
-
matching_call_id = next(
|
|
266
|
+
matching_call_id = next(
|
|
267
|
+
(cid for cid, name in tool_call_ids.items() if name == tool_name), None
|
|
268
|
+
)
|
|
259
269
|
if matching_call_id:
|
|
260
270
|
call_id = matching_call_id
|
|
261
271
|
else:
|
|
@@ -268,16 +278,16 @@ async def run_streaming(
|
|
|
268
278
|
tool_result = StreamingToolResult(
|
|
269
279
|
tool_name=tool_name,
|
|
270
280
|
call_id=call_id,
|
|
271
|
-
result=_get_event_value([
|
|
272
|
-
status=_get_event_value([
|
|
281
|
+
result=_get_event_value(["result"]),
|
|
282
|
+
status=_get_event_value(["status"]) or "completed",
|
|
273
283
|
)
|
|
274
284
|
streaming_event = StreamingEvent(
|
|
275
285
|
type=StreamingEventType.TOOL_RESULT,
|
|
276
286
|
data=tool_result,
|
|
277
287
|
run_id=initial_state.run_id,
|
|
278
|
-
trace_id=initial_state.trace_id
|
|
288
|
+
trace_id=initial_state.trace_id,
|
|
279
289
|
)
|
|
280
|
-
|
|
290
|
+
|
|
281
291
|
if streaming_event:
|
|
282
292
|
try:
|
|
283
293
|
event_queue.put_nowait(streaming_event)
|
|
@@ -296,7 +306,7 @@ async def run_streaming(
|
|
|
296
306
|
on_event=event_handler,
|
|
297
307
|
memory=config.memory,
|
|
298
308
|
conversation_id=config.conversation_id,
|
|
299
|
-
prefer_streaming=config.prefer_streaming
|
|
309
|
+
prefer_streaming=config.prefer_streaming,
|
|
300
310
|
)
|
|
301
311
|
|
|
302
312
|
from .engine import run
|
|
@@ -330,30 +340,30 @@ async def run_streaming(
|
|
|
330
340
|
type=StreamingEventType.ERROR,
|
|
331
341
|
data=error,
|
|
332
342
|
run_id=initial_state.run_id,
|
|
333
|
-
trace_id=initial_state.trace_id
|
|
343
|
+
trace_id=initial_state.trace_id,
|
|
334
344
|
)
|
|
335
345
|
return
|
|
336
346
|
|
|
337
|
-
if result.outcome.status ==
|
|
347
|
+
if result.outcome.status == "completed":
|
|
338
348
|
final_content = str(result.outcome.output) if result.outcome.output else ""
|
|
339
|
-
|
|
349
|
+
|
|
340
350
|
# Stream content in chunks
|
|
341
351
|
for i in range(0, len(final_content), chunk_size):
|
|
342
|
-
chunk_content = final_content[i:i + chunk_size]
|
|
352
|
+
chunk_content = final_content[i : i + chunk_size]
|
|
343
353
|
is_final_chunk = i + chunk_size >= len(final_content)
|
|
344
|
-
|
|
354
|
+
|
|
345
355
|
chunk = StreamingChunk(
|
|
346
|
-
content=final_content[:i + len(chunk_content)],
|
|
356
|
+
content=final_content[: i + len(chunk_content)],
|
|
347
357
|
delta=chunk_content,
|
|
348
358
|
is_complete=is_final_chunk,
|
|
349
|
-
token_count=len(final_content.split()) if is_final_chunk else None
|
|
359
|
+
token_count=len(final_content.split()) if is_final_chunk else None,
|
|
350
360
|
)
|
|
351
|
-
|
|
361
|
+
|
|
352
362
|
yield StreamingEvent(
|
|
353
363
|
type=StreamingEventType.CHUNK,
|
|
354
364
|
data=chunk,
|
|
355
365
|
run_id=initial_state.run_id,
|
|
356
|
-
trace_id=initial_state.trace_id
|
|
366
|
+
trace_id=initial_state.trace_id,
|
|
357
367
|
)
|
|
358
368
|
# Remove artificial delay for better performance
|
|
359
369
|
|
|
@@ -364,27 +374,27 @@ async def run_streaming(
|
|
|
364
374
|
model_name=config.model_override or "default",
|
|
365
375
|
turn_count=result.final_state.turn_count,
|
|
366
376
|
total_tokens=len(final_content.split()),
|
|
367
|
-
execution_time_ms=execution_time
|
|
377
|
+
execution_time_ms=execution_time,
|
|
368
378
|
)
|
|
369
379
|
yield StreamingEvent(
|
|
370
380
|
type=StreamingEventType.METADATA,
|
|
371
381
|
data=metadata,
|
|
372
382
|
run_id=initial_state.run_id,
|
|
373
|
-
trace_id=initial_state.trace_id
|
|
383
|
+
trace_id=initial_state.trace_id,
|
|
374
384
|
)
|
|
375
|
-
|
|
385
|
+
|
|
376
386
|
yield StreamingEvent(
|
|
377
387
|
type=StreamingEventType.COMPLETE,
|
|
378
388
|
data=StreamingChunk(content=final_content, delta="", is_complete=True),
|
|
379
389
|
run_id=initial_state.run_id,
|
|
380
|
-
trace_id=initial_state.trace_id
|
|
390
|
+
trace_id=initial_state.trace_id,
|
|
381
391
|
)
|
|
382
392
|
else:
|
|
383
393
|
yield StreamingEvent(
|
|
384
394
|
type=StreamingEventType.ERROR,
|
|
385
395
|
data=result.outcome.error,
|
|
386
396
|
run_id=initial_state.run_id,
|
|
387
|
-
trace_id=initial_state.trace_id
|
|
397
|
+
trace_id=initial_state.trace_id,
|
|
388
398
|
)
|
|
389
399
|
|
|
390
400
|
|
|
@@ -392,33 +402,31 @@ class StreamingCollector:
|
|
|
392
402
|
"""
|
|
393
403
|
Collects streaming events for analysis and replay.
|
|
394
404
|
"""
|
|
395
|
-
|
|
405
|
+
|
|
396
406
|
def __init__(self):
|
|
397
407
|
self.events: List[StreamingEvent] = []
|
|
398
408
|
self.buffers: Dict[str, StreamingBuffer] = {}
|
|
399
|
-
|
|
409
|
+
|
|
400
410
|
async def collect_stream(
|
|
401
|
-
self,
|
|
402
|
-
stream: AsyncIterator[StreamingEvent],
|
|
403
|
-
run_id: Optional[str] = None
|
|
411
|
+
self, stream: AsyncIterator[StreamingEvent], run_id: Optional[str] = None
|
|
404
412
|
) -> StreamingBuffer:
|
|
405
413
|
"""
|
|
406
414
|
Collect all events from a stream and return the final buffer.
|
|
407
|
-
|
|
415
|
+
|
|
408
416
|
Args:
|
|
409
417
|
stream: Async iterator of streaming events
|
|
410
418
|
run_id: Optional run ID for tracking
|
|
411
|
-
|
|
419
|
+
|
|
412
420
|
Returns:
|
|
413
421
|
StreamingBuffer: Final accumulated buffer
|
|
414
422
|
"""
|
|
415
423
|
buffer_key = run_id or "default"
|
|
416
424
|
buffer = StreamingBuffer()
|
|
417
425
|
self.buffers[buffer_key] = buffer
|
|
418
|
-
|
|
426
|
+
|
|
419
427
|
async for event in stream:
|
|
420
428
|
self.events.append(event)
|
|
421
|
-
|
|
429
|
+
|
|
422
430
|
if event.type == StreamingEventType.CHUNK:
|
|
423
431
|
buffer.add_chunk(event.data)
|
|
424
432
|
elif event.type == StreamingEventType.TOOL_CALL:
|
|
@@ -431,46 +439,45 @@ class StreamingCollector:
|
|
|
431
439
|
buffer.set_error(event.data)
|
|
432
440
|
elif event.type == StreamingEventType.COMPLETE:
|
|
433
441
|
buffer.is_complete = True
|
|
434
|
-
|
|
442
|
+
|
|
435
443
|
return buffer
|
|
436
|
-
|
|
444
|
+
|
|
437
445
|
def get_events_for_run(self, run_id: str) -> List[StreamingEvent]:
|
|
438
446
|
"""Get all events for a specific run."""
|
|
439
|
-
return [
|
|
440
|
-
|
|
441
|
-
if event.run_id and str(event.run_id) == run_id
|
|
442
|
-
]
|
|
443
|
-
|
|
447
|
+
return [event for event in self.events if event.run_id and str(event.run_id) == run_id]
|
|
448
|
+
|
|
444
449
|
def replay_stream(self, run_id: str, delay_ms: int = 50) -> AsyncIterator[StreamingEvent]:
|
|
445
450
|
"""
|
|
446
451
|
Replay a collected stream with optional delay.
|
|
447
|
-
|
|
452
|
+
|
|
448
453
|
Args:
|
|
449
454
|
run_id: Run ID to replay
|
|
450
455
|
delay_ms: Delay between events in milliseconds
|
|
451
|
-
|
|
456
|
+
|
|
452
457
|
Yields:
|
|
453
458
|
StreamingEvent: Replayed events
|
|
454
459
|
"""
|
|
460
|
+
|
|
455
461
|
async def _replay():
|
|
456
462
|
events = self.get_events_for_run(run_id)
|
|
457
463
|
for event in events:
|
|
458
464
|
yield event
|
|
459
465
|
if delay_ms > 0:
|
|
460
466
|
await asyncio.sleep(delay_ms / 1000)
|
|
461
|
-
|
|
467
|
+
|
|
462
468
|
return _replay()
|
|
463
469
|
|
|
464
470
|
|
|
465
471
|
# Utility functions for streaming integration
|
|
466
472
|
|
|
473
|
+
|
|
467
474
|
def create_sse_response(event: StreamingEvent) -> str:
|
|
468
475
|
"""
|
|
469
476
|
Create a Server-Sent Events (SSE) formatted response.
|
|
470
|
-
|
|
477
|
+
|
|
471
478
|
Args:
|
|
472
479
|
event: Streaming event to format
|
|
473
|
-
|
|
480
|
+
|
|
474
481
|
Returns:
|
|
475
482
|
str: SSE-formatted string
|
|
476
483
|
"""
|
|
@@ -478,12 +485,11 @@ def create_sse_response(event: StreamingEvent) -> str:
|
|
|
478
485
|
|
|
479
486
|
|
|
480
487
|
async def stream_to_websocket(
|
|
481
|
-
stream: AsyncIterator[StreamingEvent],
|
|
482
|
-
websocket_send: Callable[[str], None]
|
|
488
|
+
stream: AsyncIterator[StreamingEvent], websocket_send: Callable[[str], None]
|
|
483
489
|
) -> None:
|
|
484
490
|
"""
|
|
485
491
|
Stream events to a WebSocket connection.
|
|
486
|
-
|
|
492
|
+
|
|
487
493
|
Args:
|
|
488
494
|
stream: Stream of events
|
|
489
495
|
websocket_send: WebSocket send function
|
|
@@ -493,21 +499,22 @@ async def stream_to_websocket(
|
|
|
493
499
|
|
|
494
500
|
|
|
495
501
|
def create_streaming_middleware(
|
|
496
|
-
on_event: Optional[Callable[[StreamingEvent], None]] = None
|
|
502
|
+
on_event: Optional[Callable[[StreamingEvent], None]] = None,
|
|
497
503
|
) -> Callable[[AsyncIterator[StreamingEvent]], AsyncIterator[StreamingEvent]]:
|
|
498
504
|
"""
|
|
499
505
|
Create middleware for processing streaming events.
|
|
500
|
-
|
|
506
|
+
|
|
501
507
|
Args:
|
|
502
508
|
on_event: Optional callback for each event
|
|
503
|
-
|
|
509
|
+
|
|
504
510
|
Returns:
|
|
505
511
|
Middleware function
|
|
506
512
|
"""
|
|
513
|
+
|
|
507
514
|
async def middleware(stream: AsyncIterator[StreamingEvent]) -> AsyncIterator[StreamingEvent]:
|
|
508
515
|
async for event in stream:
|
|
509
516
|
if on_event:
|
|
510
517
|
on_event(event)
|
|
511
518
|
yield event
|
|
512
|
-
|
|
519
|
+
|
|
513
520
|
return middleware
|