jaf-py 2.5.9__py3-none-any.whl → 2.5.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. jaf/__init__.py +154 -57
  2. jaf/a2a/__init__.py +42 -21
  3. jaf/a2a/agent.py +79 -126
  4. jaf/a2a/agent_card.py +87 -78
  5. jaf/a2a/client.py +30 -66
  6. jaf/a2a/examples/client_example.py +12 -12
  7. jaf/a2a/examples/integration_example.py +38 -47
  8. jaf/a2a/examples/server_example.py +56 -53
  9. jaf/a2a/memory/__init__.py +0 -4
  10. jaf/a2a/memory/cleanup.py +28 -21
  11. jaf/a2a/memory/factory.py +155 -133
  12. jaf/a2a/memory/providers/composite.py +21 -26
  13. jaf/a2a/memory/providers/in_memory.py +89 -83
  14. jaf/a2a/memory/providers/postgres.py +117 -115
  15. jaf/a2a/memory/providers/redis.py +128 -121
  16. jaf/a2a/memory/serialization.py +77 -87
  17. jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
  18. jaf/a2a/memory/tests/test_cleanup.py +211 -94
  19. jaf/a2a/memory/tests/test_serialization.py +73 -68
  20. jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
  21. jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
  22. jaf/a2a/memory/types.py +91 -53
  23. jaf/a2a/protocol.py +95 -125
  24. jaf/a2a/server.py +90 -118
  25. jaf/a2a/standalone_client.py +30 -43
  26. jaf/a2a/tests/__init__.py +16 -33
  27. jaf/a2a/tests/run_tests.py +17 -53
  28. jaf/a2a/tests/test_agent.py +40 -140
  29. jaf/a2a/tests/test_client.py +54 -117
  30. jaf/a2a/tests/test_integration.py +28 -82
  31. jaf/a2a/tests/test_protocol.py +54 -139
  32. jaf/a2a/tests/test_types.py +50 -136
  33. jaf/a2a/types.py +58 -34
  34. jaf/cli.py +21 -41
  35. jaf/core/__init__.py +7 -1
  36. jaf/core/agent_tool.py +93 -72
  37. jaf/core/analytics.py +257 -207
  38. jaf/core/checkpoint.py +223 -0
  39. jaf/core/composition.py +249 -235
  40. jaf/core/engine.py +817 -519
  41. jaf/core/errors.py +55 -42
  42. jaf/core/guardrails.py +276 -202
  43. jaf/core/handoff.py +47 -31
  44. jaf/core/parallel_agents.py +69 -75
  45. jaf/core/performance.py +75 -73
  46. jaf/core/proxy.py +43 -44
  47. jaf/core/proxy_helpers.py +24 -27
  48. jaf/core/regeneration.py +220 -129
  49. jaf/core/state.py +68 -66
  50. jaf/core/streaming.py +115 -108
  51. jaf/core/tool_results.py +111 -101
  52. jaf/core/tools.py +114 -116
  53. jaf/core/tracing.py +269 -210
  54. jaf/core/types.py +371 -151
  55. jaf/core/workflows.py +209 -168
  56. jaf/exceptions.py +46 -38
  57. jaf/memory/__init__.py +1 -6
  58. jaf/memory/approval_storage.py +54 -77
  59. jaf/memory/factory.py +4 -4
  60. jaf/memory/providers/in_memory.py +216 -180
  61. jaf/memory/providers/postgres.py +216 -146
  62. jaf/memory/providers/redis.py +173 -116
  63. jaf/memory/types.py +70 -51
  64. jaf/memory/utils.py +36 -34
  65. jaf/plugins/__init__.py +12 -12
  66. jaf/plugins/base.py +105 -96
  67. jaf/policies/__init__.py +0 -1
  68. jaf/policies/handoff.py +37 -46
  69. jaf/policies/validation.py +76 -52
  70. jaf/providers/__init__.py +6 -3
  71. jaf/providers/mcp.py +97 -51
  72. jaf/providers/model.py +361 -280
  73. jaf/server/__init__.py +1 -1
  74. jaf/server/main.py +7 -11
  75. jaf/server/server.py +514 -359
  76. jaf/server/types.py +208 -52
  77. jaf/utils/__init__.py +17 -18
  78. jaf/utils/attachments.py +111 -116
  79. jaf/utils/document_processor.py +175 -174
  80. jaf/visualization/__init__.py +1 -1
  81. jaf/visualization/example.py +111 -110
  82. jaf/visualization/functional_core.py +46 -71
  83. jaf/visualization/graphviz.py +154 -189
  84. jaf/visualization/imperative_shell.py +7 -16
  85. jaf/visualization/types.py +8 -4
  86. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/METADATA +2 -2
  87. jaf_py-2.5.11.dist-info/RECORD +97 -0
  88. jaf_py-2.5.9.dist-info/RECORD +0 -96
  89. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/WHEEL +0 -0
  90. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/entry_points.txt +0 -0
  91. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/licenses/LICENSE +0 -0
  92. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/top_level.txt +0 -0
jaf/core/streaming.py CHANGED
@@ -14,26 +14,38 @@ from typing import AsyncIterator, Dict, List, Optional, Any, Union, Callable
14
14
  from enum import Enum
15
15
 
16
16
  from .types import (
17
- RunState, RunConfig, Message, TraceEvent, RunId, TraceId,
18
- ContentRole, ToolCall, JAFError, CompletedOutcome, ErrorOutcome, ModelBehaviorError
17
+ RunState,
18
+ RunConfig,
19
+ Message,
20
+ TraceEvent,
21
+ RunId,
22
+ TraceId,
23
+ ContentRole,
24
+ ToolCall,
25
+ JAFError,
26
+ CompletedOutcome,
27
+ ErrorOutcome,
28
+ ModelBehaviorError,
19
29
  )
20
30
 
21
31
 
22
32
  class StreamingEventType(str, Enum):
23
33
  """Types of streaming events."""
24
- START = 'start'
25
- CHUNK = 'chunk'
26
- TOOL_CALL = 'tool_call'
27
- TOOL_RESULT = 'tool_result'
28
- AGENT_SWITCH = 'agent_switch'
29
- ERROR = 'error'
30
- COMPLETE = 'complete'
31
- METADATA = 'metadata'
34
+
35
+ START = "start"
36
+ CHUNK = "chunk"
37
+ TOOL_CALL = "tool_call"
38
+ TOOL_RESULT = "tool_result"
39
+ AGENT_SWITCH = "agent_switch"
40
+ ERROR = "error"
41
+ COMPLETE = "complete"
42
+ METADATA = "metadata"
32
43
 
33
44
 
34
45
  @dataclass(frozen=True)
35
46
  class StreamingChunk:
36
47
  """A chunk of streaming content."""
48
+
37
49
  content: str
38
50
  delta: str # The new content added in this chunk
39
51
  is_complete: bool = False
@@ -43,15 +55,17 @@ class StreamingChunk:
43
55
  @dataclass(frozen=True)
44
56
  class StreamingToolCall:
45
57
  """Streaming tool call information."""
58
+
46
59
  tool_name: str
47
60
  arguments: Dict[str, Any]
48
61
  call_id: str
49
- status: str = 'started' # 'started', 'executing', 'completed', 'failed'
62
+ status: str = "started" # 'started', 'executing', 'completed', 'failed'
50
63
 
51
64
 
52
65
  @dataclass(frozen=True)
53
66
  class StreamingToolResult:
54
67
  """Streaming tool result information."""
68
+
55
69
  tool_name: str
56
70
  call_id: str
57
71
  result: Any
@@ -62,6 +76,7 @@ class StreamingToolResult:
62
76
  @dataclass(frozen=True)
63
77
  class StreamingMetadata:
64
78
  """Metadata about the streaming session."""
79
+
65
80
  agent_name: str
66
81
  model_name: str
67
82
  turn_count: int
@@ -73,25 +88,28 @@ class StreamingMetadata:
73
88
  @dataclass(frozen=True)
74
89
  class StreamingEvent:
75
90
  """A streaming event containing progressive updates."""
91
+
76
92
  type: StreamingEventType
77
93
  data: Union[StreamingChunk, StreamingToolCall, StreamingToolResult, StreamingMetadata, JAFError]
78
94
  timestamp: float = field(default_factory=time.time)
79
95
  run_id: Optional[RunId] = None
80
96
  trace_id: Optional[TraceId] = None
81
-
97
+
82
98
  def to_dict(self) -> Dict[str, Any]:
83
99
  """Convert streaming event to dictionary for serialization."""
84
100
  return {
85
- 'type': self.type.value,
86
- 'data': self._serialize_data(),
87
- 'timestamp': self.timestamp,
88
- 'run_id': str(self.run_id) if self.run_id else None,
89
- 'trace_id': str(self.trace_id) if self.trace_id else None
101
+ "type": self.type.value,
102
+ "data": self._serialize_data(),
103
+ "timestamp": self.timestamp,
104
+ "run_id": str(self.run_id) if self.run_id else None,
105
+ "trace_id": str(self.trace_id) if self.trace_id else None,
90
106
  }
91
-
107
+
92
108
  def _serialize_data(self) -> Dict[str, Any]:
93
109
  """Serialize the data field based on its type."""
94
- if isinstance(self.data, (StreamingChunk, StreamingToolCall, StreamingToolResult, StreamingMetadata)):
110
+ if isinstance(
111
+ self.data, (StreamingChunk, StreamingToolCall, StreamingToolResult, StreamingMetadata)
112
+ ):
95
113
  # Convert dataclass to dict
96
114
  result = {}
97
115
  for field_name, field_value in self.data.__dict__.items():
@@ -100,12 +118,12 @@ class StreamingEvent:
100
118
  return result
101
119
  elif isinstance(self.data, JAFError):
102
120
  return {
103
- 'error_type': self.data._tag,
104
- 'detail': getattr(self.data, 'detail', str(self.data))
121
+ "error_type": self.data._tag,
122
+ "detail": getattr(self.data, "detail", str(self.data)),
105
123
  }
106
124
  else:
107
- return {'value': self.data}
108
-
125
+ return {"value": self.data}
126
+
109
127
  def to_json(self) -> str:
110
128
  """Convert streaming event to JSON string."""
111
129
  return json.dumps(self.to_dict())
@@ -115,7 +133,7 @@ class StreamingBuffer:
115
133
  """
116
134
  Buffer for accumulating streaming content and managing state.
117
135
  """
118
-
136
+
119
137
  def __init__(self):
120
138
  self.content: str = ""
121
139
  self.chunks: List[StreamingChunk] = []
@@ -124,31 +142,31 @@ class StreamingBuffer:
124
142
  self.metadata: Optional[StreamingMetadata] = None
125
143
  self.is_complete: bool = False
126
144
  self.error: Optional[JAFError] = None
127
-
145
+
128
146
  def add_chunk(self, chunk: StreamingChunk) -> None:
129
147
  """Add a content chunk to the buffer."""
130
148
  self.chunks.append(chunk)
131
149
  self.content += chunk.delta
132
150
  if chunk.is_complete:
133
151
  self.is_complete = True
134
-
152
+
135
153
  def add_tool_call(self, tool_call: StreamingToolCall) -> None:
136
154
  """Add a tool call to the buffer."""
137
155
  self.tool_calls.append(tool_call)
138
-
156
+
139
157
  def add_tool_result(self, tool_result: StreamingToolResult) -> None:
140
158
  """Add a tool result to the buffer."""
141
159
  self.tool_results.append(tool_result)
142
-
160
+
143
161
  def set_metadata(self, metadata: StreamingMetadata) -> None:
144
162
  """Set session metadata."""
145
163
  self.metadata = metadata
146
-
164
+
147
165
  def set_error(self, error: JAFError) -> None:
148
166
  """Set error state."""
149
167
  self.error = error
150
168
  self.is_complete = True
151
-
169
+
152
170
  def get_final_message(self) -> Message:
153
171
  """Get the final accumulated message."""
154
172
  tool_calls = None
@@ -156,38 +174,31 @@ class StreamingBuffer:
156
174
  tool_calls = [
157
175
  ToolCall(
158
176
  id=tc.call_id,
159
- type='function',
160
- function={'name': tc.tool_name, 'arguments': json.dumps(tc.arguments)}
177
+ type="function",
178
+ function={"name": tc.tool_name, "arguments": json.dumps(tc.arguments)},
161
179
  )
162
180
  for tc in self.tool_calls
163
181
  ]
164
-
165
- return Message(
166
- role=ContentRole.ASSISTANT,
167
- content=self.content,
168
- tool_calls=tool_calls
169
- )
182
+
183
+ return Message(role=ContentRole.ASSISTANT, content=self.content, tool_calls=tool_calls)
170
184
 
171
185
 
172
186
  async def run_streaming(
173
- initial_state: RunState,
174
- config: RunConfig,
175
- chunk_size: int = 50,
176
- include_metadata: bool = True
187
+ initial_state: RunState, config: RunConfig, chunk_size: int = 50, include_metadata: bool = True
177
188
  ) -> AsyncIterator[StreamingEvent]:
178
189
  """
179
190
  Run an agent with streaming output.
180
-
191
+
181
192
  This function provides real-time streaming of agent responses, tool calls,
182
193
  and execution metadata. It yields StreamingEvent objects that can be
183
194
  consumed by clients for progressive UI updates.
184
-
195
+
185
196
  Args:
186
197
  initial_state: Initial run state
187
198
  config: Run configuration
188
199
  chunk_size: Size of content chunks for streaming (characters)
189
200
  include_metadata: Whether to include performance metadata
190
-
201
+
191
202
  Yields:
192
203
  StreamingEvent: Progressive updates during execution
193
204
  """
@@ -203,10 +214,10 @@ async def run_streaming(
203
214
  model_name="unknown",
204
215
  turn_count=initial_state.turn_count,
205
216
  total_tokens=0,
206
- execution_time_ms=0
217
+ execution_time_ms=0,
207
218
  ),
208
219
  run_id=initial_state.run_id,
209
- trace_id=initial_state.trace_id
220
+ trace_id=initial_state.trace_id,
210
221
  )
211
222
 
212
223
  tool_call_ids: Dict[str, str] = {} # Map call_id -> tool_name for in-flight tool calls
@@ -225,37 +236,36 @@ async def run_streaming(
225
236
  return getattr(payload, key)
226
237
  return None
227
238
 
228
- if event.type == 'tool_call_start':
229
- tool_name = _get_event_value(['tool_name', 'toolName']) or 'unknown'
230
- args = _get_event_value(['args', 'arguments'])
231
- call_id = _get_event_value(['call_id', 'tool_call_id', 'toolCallId'])
239
+ if event.type == "tool_call_start":
240
+ tool_name = _get_event_value(["tool_name", "toolName"]) or "unknown"
241
+ args = _get_event_value(["args", "arguments"])
242
+ call_id = _get_event_value(["call_id", "tool_call_id", "toolCallId"])
232
243
 
233
244
  if not call_id:
234
245
  call_id = f"call_{uuid.uuid4().hex[:8]}"
235
246
  if isinstance(payload, dict):
236
- payload['call_id'] = call_id
247
+ payload["call_id"] = call_id
237
248
 
238
249
  tool_call_ids[call_id] = tool_name
239
250
 
240
251
  tool_call = StreamingToolCall(
241
- tool_name=tool_name,
242
- arguments=args,
243
- call_id=call_id,
244
- status='started'
252
+ tool_name=tool_name, arguments=args, call_id=call_id, status="started"
245
253
  )
246
254
  streaming_event = StreamingEvent(
247
255
  type=StreamingEventType.TOOL_CALL,
248
256
  data=tool_call,
249
257
  run_id=initial_state.run_id,
250
- trace_id=initial_state.trace_id
258
+ trace_id=initial_state.trace_id,
251
259
  )
252
- elif event.type == 'tool_call_end':
253
- tool_name = _get_event_value(['tool_name', 'toolName']) or 'unknown'
254
- call_id = _get_event_value(['call_id', 'tool_call_id', 'toolCallId'])
260
+ elif event.type == "tool_call_end":
261
+ tool_name = _get_event_value(["tool_name", "toolName"]) or "unknown"
262
+ call_id = _get_event_value(["call_id", "tool_call_id", "toolCallId"])
255
263
 
256
264
  if not call_id:
257
265
  # Fallback to locate a pending tool call with the same tool name
258
- matching_call_id = next((cid for cid, name in tool_call_ids.items() if name == tool_name), None)
266
+ matching_call_id = next(
267
+ (cid for cid, name in tool_call_ids.items() if name == tool_name), None
268
+ )
259
269
  if matching_call_id:
260
270
  call_id = matching_call_id
261
271
  else:
@@ -268,16 +278,16 @@ async def run_streaming(
268
278
  tool_result = StreamingToolResult(
269
279
  tool_name=tool_name,
270
280
  call_id=call_id,
271
- result=_get_event_value(['result']),
272
- status=_get_event_value(['status']) or 'completed'
281
+ result=_get_event_value(["result"]),
282
+ status=_get_event_value(["status"]) or "completed",
273
283
  )
274
284
  streaming_event = StreamingEvent(
275
285
  type=StreamingEventType.TOOL_RESULT,
276
286
  data=tool_result,
277
287
  run_id=initial_state.run_id,
278
- trace_id=initial_state.trace_id
288
+ trace_id=initial_state.trace_id,
279
289
  )
280
-
290
+
281
291
  if streaming_event:
282
292
  try:
283
293
  event_queue.put_nowait(streaming_event)
@@ -296,7 +306,7 @@ async def run_streaming(
296
306
  on_event=event_handler,
297
307
  memory=config.memory,
298
308
  conversation_id=config.conversation_id,
299
- prefer_streaming=config.prefer_streaming
309
+ prefer_streaming=config.prefer_streaming,
300
310
  )
301
311
 
302
312
  from .engine import run
@@ -330,30 +340,30 @@ async def run_streaming(
330
340
  type=StreamingEventType.ERROR,
331
341
  data=error,
332
342
  run_id=initial_state.run_id,
333
- trace_id=initial_state.trace_id
343
+ trace_id=initial_state.trace_id,
334
344
  )
335
345
  return
336
346
 
337
- if result.outcome.status == 'completed':
347
+ if result.outcome.status == "completed":
338
348
  final_content = str(result.outcome.output) if result.outcome.output else ""
339
-
349
+
340
350
  # Stream content in chunks
341
351
  for i in range(0, len(final_content), chunk_size):
342
- chunk_content = final_content[i:i + chunk_size]
352
+ chunk_content = final_content[i : i + chunk_size]
343
353
  is_final_chunk = i + chunk_size >= len(final_content)
344
-
354
+
345
355
  chunk = StreamingChunk(
346
- content=final_content[:i + len(chunk_content)],
356
+ content=final_content[: i + len(chunk_content)],
347
357
  delta=chunk_content,
348
358
  is_complete=is_final_chunk,
349
- token_count=len(final_content.split()) if is_final_chunk else None
359
+ token_count=len(final_content.split()) if is_final_chunk else None,
350
360
  )
351
-
361
+
352
362
  yield StreamingEvent(
353
363
  type=StreamingEventType.CHUNK,
354
364
  data=chunk,
355
365
  run_id=initial_state.run_id,
356
- trace_id=initial_state.trace_id
366
+ trace_id=initial_state.trace_id,
357
367
  )
358
368
  # Remove artificial delay for better performance
359
369
 
@@ -364,27 +374,27 @@ async def run_streaming(
364
374
  model_name=config.model_override or "default",
365
375
  turn_count=result.final_state.turn_count,
366
376
  total_tokens=len(final_content.split()),
367
- execution_time_ms=execution_time
377
+ execution_time_ms=execution_time,
368
378
  )
369
379
  yield StreamingEvent(
370
380
  type=StreamingEventType.METADATA,
371
381
  data=metadata,
372
382
  run_id=initial_state.run_id,
373
- trace_id=initial_state.trace_id
383
+ trace_id=initial_state.trace_id,
374
384
  )
375
-
385
+
376
386
  yield StreamingEvent(
377
387
  type=StreamingEventType.COMPLETE,
378
388
  data=StreamingChunk(content=final_content, delta="", is_complete=True),
379
389
  run_id=initial_state.run_id,
380
- trace_id=initial_state.trace_id
390
+ trace_id=initial_state.trace_id,
381
391
  )
382
392
  else:
383
393
  yield StreamingEvent(
384
394
  type=StreamingEventType.ERROR,
385
395
  data=result.outcome.error,
386
396
  run_id=initial_state.run_id,
387
- trace_id=initial_state.trace_id
397
+ trace_id=initial_state.trace_id,
388
398
  )
389
399
 
390
400
 
@@ -392,33 +402,31 @@ class StreamingCollector:
392
402
  """
393
403
  Collects streaming events for analysis and replay.
394
404
  """
395
-
405
+
396
406
  def __init__(self):
397
407
  self.events: List[StreamingEvent] = []
398
408
  self.buffers: Dict[str, StreamingBuffer] = {}
399
-
409
+
400
410
  async def collect_stream(
401
- self,
402
- stream: AsyncIterator[StreamingEvent],
403
- run_id: Optional[str] = None
411
+ self, stream: AsyncIterator[StreamingEvent], run_id: Optional[str] = None
404
412
  ) -> StreamingBuffer:
405
413
  """
406
414
  Collect all events from a stream and return the final buffer.
407
-
415
+
408
416
  Args:
409
417
  stream: Async iterator of streaming events
410
418
  run_id: Optional run ID for tracking
411
-
419
+
412
420
  Returns:
413
421
  StreamingBuffer: Final accumulated buffer
414
422
  """
415
423
  buffer_key = run_id or "default"
416
424
  buffer = StreamingBuffer()
417
425
  self.buffers[buffer_key] = buffer
418
-
426
+
419
427
  async for event in stream:
420
428
  self.events.append(event)
421
-
429
+
422
430
  if event.type == StreamingEventType.CHUNK:
423
431
  buffer.add_chunk(event.data)
424
432
  elif event.type == StreamingEventType.TOOL_CALL:
@@ -431,46 +439,45 @@ class StreamingCollector:
431
439
  buffer.set_error(event.data)
432
440
  elif event.type == StreamingEventType.COMPLETE:
433
441
  buffer.is_complete = True
434
-
442
+
435
443
  return buffer
436
-
444
+
437
445
  def get_events_for_run(self, run_id: str) -> List[StreamingEvent]:
438
446
  """Get all events for a specific run."""
439
- return [
440
- event for event in self.events
441
- if event.run_id and str(event.run_id) == run_id
442
- ]
443
-
447
+ return [event for event in self.events if event.run_id and str(event.run_id) == run_id]
448
+
444
449
  def replay_stream(self, run_id: str, delay_ms: int = 50) -> AsyncIterator[StreamingEvent]:
445
450
  """
446
451
  Replay a collected stream with optional delay.
447
-
452
+
448
453
  Args:
449
454
  run_id: Run ID to replay
450
455
  delay_ms: Delay between events in milliseconds
451
-
456
+
452
457
  Yields:
453
458
  StreamingEvent: Replayed events
454
459
  """
460
+
455
461
  async def _replay():
456
462
  events = self.get_events_for_run(run_id)
457
463
  for event in events:
458
464
  yield event
459
465
  if delay_ms > 0:
460
466
  await asyncio.sleep(delay_ms / 1000)
461
-
467
+
462
468
  return _replay()
463
469
 
464
470
 
465
471
  # Utility functions for streaming integration
466
472
 
473
+
467
474
  def create_sse_response(event: StreamingEvent) -> str:
468
475
  """
469
476
  Create a Server-Sent Events (SSE) formatted response.
470
-
477
+
471
478
  Args:
472
479
  event: Streaming event to format
473
-
480
+
474
481
  Returns:
475
482
  str: SSE-formatted string
476
483
  """
@@ -478,12 +485,11 @@ def create_sse_response(event: StreamingEvent) -> str:
478
485
 
479
486
 
480
487
  async def stream_to_websocket(
481
- stream: AsyncIterator[StreamingEvent],
482
- websocket_send: Callable[[str], None]
488
+ stream: AsyncIterator[StreamingEvent], websocket_send: Callable[[str], None]
483
489
  ) -> None:
484
490
  """
485
491
  Stream events to a WebSocket connection.
486
-
492
+
487
493
  Args:
488
494
  stream: Stream of events
489
495
  websocket_send: WebSocket send function
@@ -493,21 +499,22 @@ async def stream_to_websocket(
493
499
 
494
500
 
495
501
  def create_streaming_middleware(
496
- on_event: Optional[Callable[[StreamingEvent], None]] = None
502
+ on_event: Optional[Callable[[StreamingEvent], None]] = None,
497
503
  ) -> Callable[[AsyncIterator[StreamingEvent]], AsyncIterator[StreamingEvent]]:
498
504
  """
499
505
  Create middleware for processing streaming events.
500
-
506
+
501
507
  Args:
502
508
  on_event: Optional callback for each event
503
-
509
+
504
510
  Returns:
505
511
  Middleware function
506
512
  """
513
+
507
514
  async def middleware(stream: AsyncIterator[StreamingEvent]) -> AsyncIterator[StreamingEvent]:
508
515
  async for event in stream:
509
516
  if on_event:
510
517
  on_event(event)
511
518
  yield event
512
-
519
+
513
520
  return middleware