flock-core 0.5.20__py3-none-any.whl → 0.5.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flock-core might be problematic. Click here for more details.

@@ -13,12 +13,13 @@ from __future__ import annotations
13
13
  import asyncio
14
14
  from collections import OrderedDict, defaultdict
15
15
  from contextlib import nullcontext
16
- from datetime import UTC
17
- from typing import Any
16
+ from datetime import UTC, datetime
17
+ from typing import Any, Awaitable, Callable, Sequence
18
18
 
19
19
  from pydantic import BaseModel
20
20
 
21
21
  from flock.dashboard.events import StreamingOutputEvent
22
+ from flock.engines.streaming.sinks import RichSink, StreamSink, WebSocketSink
22
23
  from flock.logging.logging import get_logger
23
24
 
24
25
 
@@ -56,6 +57,379 @@ class DSPyStreamingExecutor:
56
57
  self.stream_vertical_overflow = stream_vertical_overflow
57
58
  self.theme = theme
58
59
  self.no_output = no_output
60
+ self._model_stream_cls: Any | None = None
61
+
62
+ def _make_listeners(self, dspy_mod, signature) -> list[Any]:
63
+ """Create DSPy stream listeners for string output fields."""
64
+ streaming_mod = getattr(dspy_mod, "streaming", None)
65
+ if not streaming_mod or not hasattr(streaming_mod, "StreamListener"):
66
+ return []
67
+
68
+ listeners: list[Any] = []
69
+ try:
70
+ for name, field in getattr(signature, "output_fields", {}).items():
71
+ if getattr(field, "annotation", None) is str:
72
+ listeners.append(
73
+ streaming_mod.StreamListener(signature_field_name=name)
74
+ )
75
+ except Exception:
76
+ return []
77
+ return listeners
78
+
79
+ def _payload_kwargs(self, *, payload: Any, description: str) -> dict[str, Any]:
80
+ """Normalize payload variations into kwargs for streamify."""
81
+ if isinstance(payload, dict) and "description" in payload:
82
+ return payload
83
+
84
+ if isinstance(payload, dict) and "input" in payload:
85
+ return {
86
+ "description": description,
87
+ "input": payload["input"],
88
+ "context": payload.get("context", []),
89
+ }
90
+
91
+ # Legacy fallback: treat payload as the primary input.
92
+ return {"description": description, "input": payload, "context": []}
93
+
94
+ def _artifact_type_label(self, agent: Any, output_group: Any) -> str:
95
+ """Derive user-facing artifact label for streaming events."""
96
+ outputs_to_display = (
97
+ getattr(output_group, "outputs", None)
98
+ if output_group and hasattr(output_group, "outputs")
99
+ else getattr(agent, "outputs", [])
100
+ if hasattr(agent, "outputs")
101
+ else []
102
+ )
103
+
104
+ if not outputs_to_display:
105
+ return "output"
106
+
107
+ # Preserve ordering while avoiding duplicates.
108
+ seen: set[str] = set()
109
+ segments: list[str] = []
110
+ for output in outputs_to_display:
111
+ type_name = getattr(getattr(output, "spec", None), "type_name", None)
112
+ if type_name and type_name not in seen:
113
+ seen.add(type_name)
114
+ segments.append(type_name)
115
+
116
+ return ", ".join(segments) if segments else "output"
117
+
118
+ def _streaming_classes_for(self, dspy_mod: Any) -> tuple[type | None, type | None]:
119
+ streaming_mod = getattr(dspy_mod, "streaming", None)
120
+ if not streaming_mod:
121
+ return None, None
122
+ status_cls = getattr(streaming_mod, "StatusMessage", None)
123
+ stream_cls = getattr(streaming_mod, "StreamResponse", None)
124
+ return status_cls, stream_cls
125
+
126
+ def _resolve_model_stream_cls(self) -> Any | None:
127
+ if self._model_stream_cls is None:
128
+ try:
129
+ from litellm import ModelResponseStream # type: ignore
130
+ except Exception: # pragma: no cover - litellm optional at runtime
131
+ self._model_stream_cls = False
132
+ else:
133
+ self._model_stream_cls = ModelResponseStream
134
+ return self._model_stream_cls or None
135
+
136
+ @staticmethod
137
+ def _normalize_status_message(
138
+ value: Any,
139
+ ) -> tuple[str, str | None, str | None, Any | None]:
140
+ message = getattr(value, "message", "")
141
+ return "status", str(message), None, None
142
+
143
+ @staticmethod
144
+ def _normalize_stream_response(
145
+ value: Any,
146
+ ) -> tuple[str, str | None, str | None, Any | None]:
147
+ chunk = getattr(value, "chunk", None)
148
+ signature_field = getattr(value, "signature_field_name", None)
149
+ return "token", ("" if chunk is None else str(chunk)), signature_field, None
150
+
151
+ @staticmethod
152
+ def _normalize_model_stream(
153
+ value: Any,
154
+ ) -> tuple[str, str | None, str | None, Any | None]:
155
+ token_text = ""
156
+ try:
157
+ token_text = value.choices[0].delta.content or ""
158
+ except Exception: # pragma: no cover - defensive parity with legacy path
159
+ token_text = ""
160
+ signature_field = getattr(value, "signature_field_name", None)
161
+ return "token", str(token_text), signature_field, None
162
+
163
+ def _initialize_display_data(
164
+ self,
165
+ *,
166
+ signature_order: Sequence[str],
167
+ agent: Any,
168
+ ctx: Any,
169
+ pre_generated_artifact_id: Any,
170
+ output_group: Any,
171
+ status_field: str,
172
+ ) -> tuple[OrderedDict[str, Any], str]:
173
+ """Build the initial Rich display structure for CLI streaming."""
174
+ display_data: OrderedDict[str, Any] = OrderedDict()
175
+ display_data["id"] = str(pre_generated_artifact_id)
176
+
177
+ artifact_type_name = self._artifact_type_label(agent, output_group)
178
+ display_data["type"] = artifact_type_name
179
+
180
+ payload_section: OrderedDict[str, Any] = OrderedDict()
181
+ for field_name in signature_order:
182
+ if field_name != "description":
183
+ payload_section[field_name] = ""
184
+ display_data["payload"] = payload_section
185
+
186
+ display_data["produced_by"] = getattr(agent, "name", "")
187
+ correlation_id = None
188
+ if ctx and getattr(ctx, "correlation_id", None):
189
+ correlation_id = str(ctx.correlation_id)
190
+ display_data["correlation_id"] = correlation_id
191
+ display_data["partition_key"] = None
192
+ display_data["tags"] = "set()"
193
+ display_data["visibility"] = OrderedDict([("kind", "Public")])
194
+ display_data["created_at"] = "streaming..."
195
+ display_data["version"] = 1
196
+ display_data["status"] = status_field
197
+
198
+ return display_data, artifact_type_name
199
+
200
+ def _prepare_rich_env(
201
+ self,
202
+ *,
203
+ console,
204
+ display_data: OrderedDict[str, Any],
205
+ agent: Any,
206
+ overflow_mode: str,
207
+ ) -> tuple[Any, dict[str, Any], dict[str, Any], str, Any]:
208
+ """Create formatter metadata and Live context for Rich output."""
209
+ from rich.live import Live
210
+
211
+ from flock.engines.dspy_engine import _ensure_live_crop_above
212
+
213
+ _ensure_live_crop_above()
214
+ formatter, theme_dict, styles, agent_label = self.prepare_stream_formatter(
215
+ agent
216
+ )
217
+ initial_panel = formatter.format_result(
218
+ display_data, agent_label, theme_dict, styles
219
+ )
220
+ live_cm = Live(
221
+ initial_panel,
222
+ console=console,
223
+ refresh_per_second=4,
224
+ transient=False,
225
+ vertical_overflow=overflow_mode,
226
+ )
227
+ return formatter, theme_dict, styles, agent_label, live_cm
228
+
229
+ def _build_rich_sink(
230
+ self,
231
+ *,
232
+ live: Any,
233
+ formatter: Any | None,
234
+ display_data: OrderedDict[str, Any],
235
+ agent_label: str | None,
236
+ theme_dict: dict[str, Any] | None,
237
+ styles: dict[str, Any] | None,
238
+ status_field: str,
239
+ signature_order: Sequence[str],
240
+ stream_buffers: defaultdict[str, list[str]],
241
+ timestamp_factory: Callable[[], str],
242
+ ) -> RichSink | None:
243
+ if formatter is None or live is None:
244
+ return None
245
+
246
+ def refresh_panel() -> None:
247
+ live.update(
248
+ formatter.format_result(display_data, agent_label, theme_dict, styles)
249
+ )
250
+
251
+ return RichSink(
252
+ display_data=display_data,
253
+ stream_buffers=stream_buffers,
254
+ status_field=status_field,
255
+ signature_order=signature_order,
256
+ formatter=formatter,
257
+ theme_dict=theme_dict,
258
+ styles=styles,
259
+ agent_label=agent_label,
260
+ refresh_panel=refresh_panel,
261
+ timestamp_factory=timestamp_factory,
262
+ )
263
+
264
+ def _build_websocket_sink(
265
+ self,
266
+ *,
267
+ ws_broadcast: Callable[[StreamingOutputEvent], Awaitable[None]] | None,
268
+ ctx: Any,
269
+ agent: Any,
270
+ pre_generated_artifact_id: Any,
271
+ artifact_type_name: str,
272
+ ) -> WebSocketSink | None:
273
+ if not ws_broadcast:
274
+ return None
275
+
276
+ def event_factory(
277
+ output_type: str, content: str, sequence: int, is_final: bool
278
+ ) -> StreamingOutputEvent:
279
+ return self._build_event(
280
+ ctx=ctx,
281
+ agent=agent,
282
+ artifact_id=pre_generated_artifact_id,
283
+ artifact_type=artifact_type_name,
284
+ output_type=output_type,
285
+ content=content,
286
+ sequence=sequence,
287
+ is_final=is_final,
288
+ )
289
+
290
+ return WebSocketSink(ws_broadcast=ws_broadcast, event_factory=event_factory)
291
+
292
+ def _collect_sinks(
293
+ self,
294
+ *,
295
+ rich_sink: RichSink | None,
296
+ ws_sink: WebSocketSink | None,
297
+ ) -> list[StreamSink]:
298
+ sinks: list[StreamSink] = []
299
+ if rich_sink:
300
+ sinks.append(rich_sink)
301
+ if ws_sink:
302
+ sinks.append(ws_sink)
303
+ return sinks
304
+
305
+ async def _dispatch_to_sinks(
306
+ self, sinks: Sequence[StreamSink], method: str, *args: Any
307
+ ) -> None:
308
+ for sink in sinks:
309
+ await getattr(sink, method)(*args)
310
+
311
+ async def _consume_stream(
312
+ self,
313
+ stream_generator: Any,
314
+ sinks: Sequence[StreamSink],
315
+ dspy_mod: Any,
316
+ ) -> tuple[Any | None, int]:
317
+ tokens_emitted = 0
318
+ final_result: Any | None = None
319
+
320
+ async for value in stream_generator:
321
+ kind, text, signature_field, prediction = self._normalize_value(
322
+ value, dspy_mod
323
+ )
324
+
325
+ if kind == "status" and text:
326
+ await self._dispatch_to_sinks(sinks, "on_status", text)
327
+ continue
328
+
329
+ if kind == "token" and text:
330
+ tokens_emitted += 1
331
+ await self._dispatch_to_sinks(sinks, "on_token", text, signature_field)
332
+ continue
333
+
334
+ if kind == "prediction":
335
+ final_result = prediction
336
+ await self._dispatch_to_sinks(
337
+ sinks, "on_final", prediction, tokens_emitted
338
+ )
339
+ await self._close_stream_generator(stream_generator)
340
+ return final_result, tokens_emitted
341
+
342
+ return final_result, tokens_emitted
343
+
344
+ async def _flush_sinks(self, sinks: Sequence[StreamSink]) -> None:
345
+ for sink in sinks:
346
+ await sink.flush()
347
+
348
+ def _finalize_stream_display(
349
+ self,
350
+ *,
351
+ rich_sink: RichSink | None,
352
+ formatter: Any | None,
353
+ display_data: OrderedDict[str, Any],
354
+ theme_dict: dict[str, Any] | None,
355
+ styles: dict[str, Any] | None,
356
+ agent_label: str | None,
357
+ ) -> tuple[Any, OrderedDict, dict | None, dict | None, str | None]:
358
+ if rich_sink:
359
+ return rich_sink.final_display_data
360
+ return formatter, display_data, theme_dict, styles, agent_label
361
+
362
+ @staticmethod
363
+ async def _close_stream_generator(stream_generator: Any) -> None:
364
+ aclose = getattr(stream_generator, "aclose", None)
365
+ if callable(aclose):
366
+ try:
367
+ await aclose()
368
+ except GeneratorExit:
369
+ pass
370
+ except BaseExceptionGroup as exc: # pragma: no cover - defensive logging
371
+ remaining = [
372
+ err
373
+ for err in getattr(exc, "exceptions", [])
374
+ if not isinstance(err, GeneratorExit)
375
+ ]
376
+ if remaining:
377
+ logger.debug("Error closing stream generator", exc_info=True)
378
+ except Exception:
379
+ logger.debug("Error closing stream generator", exc_info=True)
380
+
381
+ def _build_event(
382
+ self,
383
+ *,
384
+ ctx: Any,
385
+ agent: Any,
386
+ artifact_id: Any,
387
+ artifact_type: str,
388
+ output_type: str,
389
+ content: str,
390
+ sequence: int,
391
+ is_final: bool,
392
+ ) -> StreamingOutputEvent:
393
+ """Construct a StreamingOutputEvent with consistent metadata."""
394
+ correlation_id = ""
395
+ run_id = ""
396
+ if ctx:
397
+ correlation_id = str(getattr(ctx, "correlation_id", "") or "")
398
+ run_id = str(getattr(ctx, "task_id", "") or "")
399
+
400
+ return StreamingOutputEvent(
401
+ correlation_id=correlation_id,
402
+ agent_name=getattr(agent, "name", ""),
403
+ run_id=run_id,
404
+ output_type=output_type,
405
+ content=content,
406
+ sequence=sequence,
407
+ is_final=is_final,
408
+ artifact_id=str(artifact_id) if artifact_id is not None else "",
409
+ artifact_type=artifact_type,
410
+ )
411
+
412
+ def _normalize_value(
413
+ self, value: Any, dspy_mod: Any
414
+ ) -> tuple[str, str | None, str | None, Any | None]:
415
+ """Normalize raw DSPy streaming values into (kind, text, field, final)."""
416
+ status_cls, stream_cls = self._streaming_classes_for(dspy_mod)
417
+ model_stream_cls = self._resolve_model_stream_cls()
418
+ prediction_cls = getattr(dspy_mod, "Prediction", None)
419
+
420
+ if status_cls and isinstance(value, status_cls):
421
+ return self._normalize_status_message(value)
422
+
423
+ if stream_cls and isinstance(value, stream_cls):
424
+ return self._normalize_stream_response(value)
425
+
426
+ if model_stream_cls and isinstance(value, model_stream_cls):
427
+ return self._normalize_model_stream(value)
428
+
429
+ if prediction_cls and isinstance(value, prediction_cls):
430
+ return "prediction", None, None, value
431
+
432
+ return "unknown", None, None, None
59
433
 
60
434
  async def execute_standard(
61
435
  self, dspy_mod, program, *, description: str, payload: dict[str, Any]
@@ -136,32 +510,8 @@ class DSPyStreamingExecutor:
136
510
  )
137
511
  return result, None
138
512
 
139
- # Get artifact type name for WebSocket events
140
- artifact_type_name = "output"
141
- # Use output_group.outputs (current group) if available, otherwise fallback to agent.outputs (all groups)
142
- outputs_to_display = (
143
- output_group.outputs
144
- if output_group and hasattr(output_group, "outputs")
145
- else agent.outputs
146
- if hasattr(agent, "outputs")
147
- else []
148
- )
149
-
150
- if outputs_to_display:
151
- artifact_type_name = outputs_to_display[0].spec.type_name
152
-
153
- # Prepare stream listeners
154
- listeners = []
155
- try:
156
- streaming_mod = getattr(dspy_mod, "streaming", None)
157
- if streaming_mod and hasattr(streaming_mod, "StreamListener"):
158
- for name, field in signature.output_fields.items():
159
- if field.annotation is str:
160
- listeners.append(
161
- streaming_mod.StreamListener(signature_field_name=name)
162
- )
163
- except Exception:
164
- listeners = []
513
+ artifact_type_name = self._artifact_type_label(agent, output_group)
514
+ listeners = self._make_listeners(dspy_mod, signature)
165
515
 
166
516
  # Create streaming task
167
517
  streaming_task = dspy_mod.streamify(
@@ -170,157 +520,52 @@ class DSPyStreamingExecutor:
170
520
  stream_listeners=listeners if listeners else None,
171
521
  )
172
522
 
173
- # Execute with appropriate payload format
174
- if isinstance(payload, dict) and "description" in payload:
175
- # Semantic fields: pass all fields as kwargs
176
- stream_generator = streaming_task(**payload)
177
- elif isinstance(payload, dict) and "input" in payload:
178
- # Legacy format: {"input": ..., "context": ...}
179
- stream_generator = streaming_task(
180
- description=description,
181
- input=payload["input"],
182
- context=payload.get("context", []),
183
- )
184
- else:
185
- # Old format: direct payload
186
- stream_generator = streaming_task(
187
- description=description, input=payload, context=[]
523
+ stream_kwargs = self._payload_kwargs(payload=payload, description=description)
524
+ stream_generator = streaming_task(**stream_kwargs)
525
+
526
+ def event_factory(
527
+ output_type: str, content: str, sequence: int, is_final: bool
528
+ ) -> StreamingOutputEvent:
529
+ return self._build_event(
530
+ ctx=ctx,
531
+ agent=agent,
532
+ artifact_id=pre_generated_artifact_id,
533
+ artifact_type=artifact_type_name,
534
+ output_type=output_type,
535
+ content=content,
536
+ sequence=sequence,
537
+ is_final=is_final,
188
538
  )
189
539
 
190
- # Process stream (WebSocket only, no Rich display)
191
- final_result = None
192
- stream_sequence = 0
540
+ sink: StreamSink = WebSocketSink(
541
+ ws_broadcast=ws_broadcast,
542
+ event_factory=event_factory,
543
+ )
193
544
 
194
- # Track background WebSocket broadcast tasks to prevent garbage collection
195
- # Using fire-and-forget pattern to avoid blocking DSPy's streaming loop
196
- ws_broadcast_tasks: set[asyncio.Task] = set()
545
+ final_result = None
546
+ tokens_emitted = 0
197
547
 
198
548
  async for value in stream_generator:
199
- try:
200
- from dspy.streaming import StatusMessage, StreamResponse
201
- from litellm import ModelResponseStream
202
- except Exception:
203
- StatusMessage = object # type: ignore
204
- StreamResponse = object # type: ignore
205
- ModelResponseStream = object # type: ignore
206
-
207
- if isinstance(value, StatusMessage):
208
- token = getattr(value, "message", "")
209
- if token:
210
- try:
211
- event = StreamingOutputEvent(
212
- correlation_id=str(ctx.correlation_id)
213
- if ctx and ctx.correlation_id
214
- else "",
215
- agent_name=agent.name,
216
- run_id=ctx.task_id if ctx else "",
217
- output_type="log",
218
- content=str(token + "\n"),
219
- sequence=stream_sequence,
220
- is_final=False,
221
- artifact_id=str(pre_generated_artifact_id),
222
- artifact_type=artifact_type_name,
223
- )
224
- # Fire-and-forget to avoid blocking DSPy's streaming loop
225
- task = asyncio.create_task(ws_broadcast(event))
226
- ws_broadcast_tasks.add(task)
227
- task.add_done_callback(ws_broadcast_tasks.discard)
228
- stream_sequence += 1
229
- except Exception as e:
230
- logger.warning(f"Failed to emit streaming event: {e}")
231
-
232
- elif isinstance(value, StreamResponse):
233
- token = getattr(value, "chunk", None)
234
- if token:
235
- try:
236
- event = StreamingOutputEvent(
237
- correlation_id=str(ctx.correlation_id)
238
- if ctx and ctx.correlation_id
239
- else "",
240
- agent_name=agent.name,
241
- run_id=ctx.task_id if ctx else "",
242
- output_type="llm_token",
243
- content=str(token),
244
- sequence=stream_sequence,
245
- is_final=False,
246
- artifact_id=str(pre_generated_artifact_id),
247
- artifact_type=artifact_type_name,
248
- )
249
- # Fire-and-forget to avoid blocking DSPy's streaming loop
250
- task = asyncio.create_task(ws_broadcast(event))
251
- ws_broadcast_tasks.add(task)
252
- task.add_done_callback(ws_broadcast_tasks.discard)
253
- stream_sequence += 1
254
- except Exception as e:
255
- logger.warning(f"Failed to emit streaming event: {e}")
256
-
257
- elif isinstance(value, ModelResponseStream):
258
- chunk = value
259
- token = chunk.choices[0].delta.content or ""
260
- if token:
261
- try:
262
- event = StreamingOutputEvent(
263
- correlation_id=str(ctx.correlation_id)
264
- if ctx and ctx.correlation_id
265
- else "",
266
- agent_name=agent.name,
267
- run_id=ctx.task_id if ctx else "",
268
- output_type="llm_token",
269
- content=str(token),
270
- sequence=stream_sequence,
271
- is_final=False,
272
- artifact_id=str(pre_generated_artifact_id),
273
- artifact_type=artifact_type_name,
274
- )
275
- # Fire-and-forget to avoid blocking DSPy's streaming loop
276
- task = asyncio.create_task(ws_broadcast(event))
277
- ws_broadcast_tasks.add(task)
278
- task.add_done_callback(ws_broadcast_tasks.discard)
279
- stream_sequence += 1
280
- except Exception as e:
281
- logger.warning(f"Failed to emit streaming event: {e}")
282
-
283
- elif isinstance(value, dspy_mod.Prediction):
284
- final_result = value
285
- # Send final events
286
- try:
287
- event = StreamingOutputEvent(
288
- correlation_id=str(ctx.correlation_id)
289
- if ctx and ctx.correlation_id
290
- else "",
291
- agent_name=agent.name,
292
- run_id=ctx.task_id if ctx else "",
293
- output_type="log",
294
- content=f"\nAmount of output tokens: {stream_sequence}",
295
- sequence=stream_sequence,
296
- is_final=True,
297
- artifact_id=str(pre_generated_artifact_id),
298
- artifact_type=artifact_type_name,
299
- )
300
- # Fire-and-forget to avoid blocking DSPy's streaming loop
301
- task = asyncio.create_task(ws_broadcast(event))
302
- ws_broadcast_tasks.add(task)
303
- task.add_done_callback(ws_broadcast_tasks.discard)
304
-
305
- event = StreamingOutputEvent(
306
- correlation_id=str(ctx.correlation_id)
307
- if ctx and ctx.correlation_id
308
- else "",
309
- agent_name=agent.name,
310
- run_id=ctx.task_id if ctx else "",
311
- output_type="log",
312
- content="--- End of output ---",
313
- sequence=stream_sequence + 1,
314
- is_final=True,
315
- artifact_id=str(pre_generated_artifact_id),
316
- artifact_type=artifact_type_name,
317
- )
318
- # Fire-and-forget to avoid blocking DSPy's streaming loop
319
- task = asyncio.create_task(ws_broadcast(event))
320
- ws_broadcast_tasks.add(task)
321
- task.add_done_callback(ws_broadcast_tasks.discard)
322
- except Exception as e:
323
- logger.warning(f"Failed to emit final streaming event: {e}")
549
+ kind, text, signature_field, prediction = self._normalize_value(
550
+ value, dspy_mod
551
+ )
552
+
553
+ if kind == "status" and text:
554
+ await sink.on_status(text)
555
+ continue
556
+
557
+ if kind == "token" and text:
558
+ tokens_emitted += 1
559
+ await sink.on_token(text, signature_field)
560
+ continue
561
+
562
+ if kind == "prediction":
563
+ final_result = prediction
564
+ await sink.on_final(prediction, tokens_emitted)
565
+ await self._close_stream_generator(stream_generator)
566
+ break
567
+
568
+ await sink.flush()
324
569
 
325
570
  if final_result is None:
326
571
  raise RuntimeError(
@@ -328,7 +573,7 @@ class DSPyStreamingExecutor:
328
573
  )
329
574
 
330
575
  logger.info(
331
- f"Agent {agent.name}: WebSocket streaming completed ({stream_sequence} tokens)"
576
+ f"Agent {agent.name}: WebSocket streaming completed ({tokens_emitted} tokens)"
332
577
  )
333
578
  return final_result, None
334
579
 
@@ -345,406 +590,115 @@ class DSPyStreamingExecutor:
345
590
  pre_generated_artifact_id: Any = None,
346
591
  output_group=None,
347
592
  ) -> Any:
348
- """Execute DSPy program in streaming mode with Rich table updates.
349
-
350
- Args:
351
- dspy_mod: DSPy module
352
- program: DSPy program (Predict or ReAct)
353
- signature: DSPy Signature
354
- description: System description
355
- payload: Execution payload with semantic field names
356
- agent: Agent instance
357
- ctx: Execution context
358
- pre_generated_artifact_id: Pre-generated artifact ID for streaming
359
- output_group: OutputGroup defining expected outputs
593
+ """Execute DSPy program in streaming mode with Rich table updates."""
360
594
 
361
- Returns:
362
- Tuple of (DSPy Prediction result, stream display data for final rendering)
363
- """
364
595
  from rich.console import Console
365
- from rich.live import Live
366
596
 
367
597
  console = Console()
368
598
 
369
599
  # Get WebSocket broadcast function (security: wrapper prevents object traversal)
370
- # Phase 6+7 Security Fix: Use broadcast wrapper from Agent class variable (prevents GOD MODE restoration)
371
600
  from flock.core import Agent
372
601
 
373
602
  ws_broadcast = Agent._websocket_broadcast_global
374
603
 
375
- # Prepare stream listeners for output field
376
- listeners = []
377
- try:
378
- streaming_mod = getattr(dspy_mod, "streaming", None)
379
- if streaming_mod and hasattr(streaming_mod, "StreamListener"):
380
- for name, field in signature.output_fields.items():
381
- if field.annotation is str:
382
- listeners.append(
383
- streaming_mod.StreamListener(signature_field_name=name)
384
- )
385
- except Exception:
386
- listeners = []
387
-
604
+ listeners = self._make_listeners(dspy_mod, signature)
388
605
  streaming_task = dspy_mod.streamify(
389
606
  program,
390
607
  is_async_program=True,
391
608
  stream_listeners=listeners if listeners else None,
392
609
  )
393
610
 
394
- # Execute with appropriate payload format
395
- if isinstance(payload, dict) and "description" in payload:
396
- # Semantic fields: pass all fields as kwargs
397
- stream_generator = streaming_task(**payload)
398
- elif isinstance(payload, dict) and "input" in payload:
399
- # Legacy format: {"input": ..., "context": ...}
400
- stream_generator = streaming_task(
401
- description=description,
402
- input=payload["input"],
403
- context=payload.get("context", []),
404
- )
405
- else:
406
- # Old format: direct payload
407
- stream_generator = streaming_task(
408
- description=description, input=payload, context=[]
409
- )
611
+ stream_kwargs = self._payload_kwargs(payload=payload, description=description)
612
+ stream_generator = streaming_task(**stream_kwargs)
410
613
 
411
- signature_order = []
412
614
  status_field = self.status_output_field
413
615
  try:
414
616
  signature_order = list(signature.output_fields.keys())
415
617
  except Exception:
416
618
  signature_order = []
417
619
 
418
- # Initialize display data in full artifact format (matching OutputUtilityComponent display)
419
- display_data: OrderedDict[str, Any] = OrderedDict()
420
-
421
- # Use the pre-generated artifact ID that was created before execution started
422
- display_data["id"] = str(pre_generated_artifact_id)
423
-
424
- # Get the artifact type name from agent configuration
425
- artifact_type_name = "output"
426
- # Use output_group.outputs (current group) if available, otherwise fallback to agent.outputs (all groups)
427
- outputs_to_display = (
428
- output_group.outputs
429
- if output_group and hasattr(output_group, "outputs")
430
- else agent.outputs
431
- if hasattr(agent, "outputs")
432
- else []
433
- )
434
-
435
- if outputs_to_display:
436
- artifact_type_name = outputs_to_display[0].spec.type_name
437
- for output in outputs_to_display:
438
- if output.spec.type_name not in artifact_type_name:
439
- artifact_type_name += ", " + output.spec.type_name
440
-
441
- display_data["type"] = artifact_type_name
442
- display_data["payload"] = OrderedDict()
443
-
444
- # Add output fields to payload section
445
- for field_name in signature_order:
446
- if field_name != "description": # Skip description field
447
- display_data["payload"][field_name] = ""
448
-
449
- display_data["produced_by"] = agent.name
450
- display_data["correlation_id"] = (
451
- str(ctx.correlation_id) if ctx and ctx.correlation_id else None
620
+ display_data, artifact_type_name = self._initialize_display_data(
621
+ signature_order=signature_order,
622
+ agent=agent,
623
+ ctx=ctx,
624
+ pre_generated_artifact_id=pre_generated_artifact_id,
625
+ output_group=output_group,
626
+ status_field=status_field,
452
627
  )
453
- display_data["partition_key"] = None
454
- display_data["tags"] = "set()"
455
- display_data["visibility"] = OrderedDict([("kind", "Public")])
456
- display_data["created_at"] = "streaming..."
457
- display_data["version"] = 1
458
- display_data["status"] = status_field
459
628
 
460
629
  stream_buffers: defaultdict[str, list[str]] = defaultdict(list)
461
- stream_buffers[status_field] = []
462
- stream_sequence = 0 # Monotonic sequence for ordering
463
-
464
- # Track background WebSocket broadcast tasks to prevent garbage collection
465
- ws_broadcast_tasks: set[asyncio.Task] = set()
466
-
467
- formatter = theme_dict = styles = agent_label = None
468
- live_cm = nullcontext()
469
630
  overflow_mode = self.stream_vertical_overflow
470
631
 
471
632
  if not self.no_output:
472
- # Import the patch function here to ensure it's applied
473
- from flock.engines.dspy_engine import _ensure_live_crop_above
474
-
475
- _ensure_live_crop_above()
476
633
  (
477
634
  formatter,
478
635
  theme_dict,
479
636
  styles,
480
637
  agent_label,
481
- ) = self.prepare_stream_formatter(agent)
482
- initial_panel = formatter.format_result(
483
- display_data, agent_label, theme_dict, styles
484
- )
485
- live_cm = Live(
486
- initial_panel,
638
+ live_cm,
639
+ ) = self._prepare_rich_env(
487
640
  console=console,
488
- refresh_per_second=4,
489
- transient=False,
490
- vertical_overflow=overflow_mode,
641
+ display_data=display_data,
642
+ agent=agent,
643
+ overflow_mode=overflow_mode,
491
644
  )
645
+ else:
646
+ formatter = theme_dict = styles = agent_label = None
647
+ live_cm = nullcontext()
648
+
649
+ timestamp_factory = lambda: datetime.now(UTC).isoformat()
492
650
 
493
651
  final_result: Any = None
652
+ tokens_emitted = 0
653
+ sinks: list[StreamSink] = []
654
+ rich_sink: RichSink | None = None
494
655
 
495
656
  with live_cm as live:
657
+ rich_sink = self._build_rich_sink(
658
+ live=live,
659
+ formatter=formatter,
660
+ display_data=display_data,
661
+ agent_label=agent_label,
662
+ theme_dict=theme_dict,
663
+ styles=styles,
664
+ status_field=status_field,
665
+ signature_order=signature_order,
666
+ stream_buffers=stream_buffers,
667
+ timestamp_factory=timestamp_factory,
668
+ )
496
669
 
497
- def _refresh_panel() -> None:
498
- if formatter is None or live is None:
499
- return
500
- live.update(
501
- formatter.format_result(
502
- display_data, agent_label, theme_dict, styles
503
- )
504
- )
670
+ ws_sink = self._build_websocket_sink(
671
+ ws_broadcast=ws_broadcast,
672
+ ctx=ctx,
673
+ agent=agent,
674
+ pre_generated_artifact_id=pre_generated_artifact_id,
675
+ artifact_type_name=artifact_type_name,
676
+ )
677
+
678
+ sinks = self._collect_sinks(rich_sink=rich_sink, ws_sink=ws_sink)
679
+ final_result, tokens_emitted = await self._consume_stream(
680
+ stream_generator, sinks, dspy_mod
681
+ )
505
682
 
506
- async for value in stream_generator:
507
- try:
508
- from dspy.streaming import StatusMessage, StreamResponse
509
- from litellm import ModelResponseStream
510
- except Exception:
511
- StatusMessage = object # type: ignore
512
- StreamResponse = object # type: ignore
513
- ModelResponseStream = object # type: ignore
514
-
515
- if isinstance(value, StatusMessage):
516
- token = getattr(value, "message", "")
517
- if token:
518
- stream_buffers[status_field].append(str(token) + "\n")
519
- display_data["status"] = "".join(stream_buffers[status_field])
520
-
521
- # Emit to WebSocket (non-blocking to prevent deadlock)
522
- if ws_broadcast and token:
523
- try:
524
- event = StreamingOutputEvent(
525
- correlation_id=str(ctx.correlation_id)
526
- if ctx and ctx.correlation_id
527
- else "",
528
- agent_name=agent.name,
529
- run_id=ctx.task_id if ctx else "",
530
- output_type="llm_token",
531
- content=str(token + "\n"),
532
- sequence=stream_sequence,
533
- is_final=False,
534
- artifact_id=str(
535
- pre_generated_artifact_id
536
- ), # Phase 6: Track artifact for message streaming
537
- artifact_type=artifact_type_name, # Phase 6: Artifact type name
538
- )
539
- # Use create_task to avoid blocking the streaming loop
540
- task = asyncio.create_task(ws_broadcast(event))
541
- ws_broadcast_tasks.add(task)
542
- task.add_done_callback(ws_broadcast_tasks.discard)
543
- stream_sequence += 1
544
- except Exception as e:
545
- logger.warning(f"Failed to emit streaming event: {e}")
546
- else:
547
- logger.debug(
548
- "No WebSocket manager present for streaming event."
549
- )
550
-
551
- if formatter is not None:
552
- _refresh_panel()
553
- continue
554
-
555
- if isinstance(value, StreamResponse):
556
- token = getattr(value, "chunk", None)
557
- signature_field = getattr(value, "signature_field_name", None)
558
- if signature_field and signature_field != "description":
559
- # Update payload section - accumulate in "output" buffer
560
- buffer_key = f"_stream_{signature_field}"
561
- if token:
562
- stream_buffers[buffer_key].append(str(token))
563
- # Show streaming text in payload
564
- display_data["payload"]["_streaming"] = "".join(
565
- stream_buffers[buffer_key]
566
- )
567
-
568
- # Emit to WebSocket (non-blocking to prevent deadlock)
569
- if ws_broadcast:
570
- logger.info(
571
- f"[STREAMING] Emitting StreamResponse token='{token}', sequence={stream_sequence}"
572
- )
573
- try:
574
- event = StreamingOutputEvent(
575
- correlation_id=str(ctx.correlation_id)
576
- if ctx and ctx.correlation_id
577
- else "",
578
- agent_name=agent.name,
579
- run_id=ctx.task_id if ctx else "",
580
- output_type="llm_token",
581
- content=str(token),
582
- sequence=stream_sequence,
583
- is_final=False,
584
- artifact_id=str(
585
- pre_generated_artifact_id
586
- ), # Phase 6: Track artifact for message streaming
587
- artifact_type=artifact_type_name, # Phase 6: Artifact type name
588
- )
589
- # Use create_task to avoid blocking the streaming loop
590
- task = asyncio.create_task(ws_broadcast(event))
591
- ws_broadcast_tasks.add(task)
592
- task.add_done_callback(ws_broadcast_tasks.discard)
593
- stream_sequence += 1
594
- except Exception as e:
595
- logger.warning(
596
- f"Failed to emit streaming event: {e}"
597
- )
598
-
599
- if formatter is not None:
600
- _refresh_panel()
601
- continue
602
-
603
- if isinstance(value, ModelResponseStream):
604
- chunk = value
605
- token = chunk.choices[0].delta.content or ""
606
- signature_field = getattr(value, "signature_field_name", None)
607
-
608
- if signature_field and signature_field != "description":
609
- # Update payload section - accumulate in buffer
610
- buffer_key = f"_stream_{signature_field}"
611
- if token:
612
- stream_buffers[buffer_key].append(str(token))
613
- # Show streaming text in payload
614
- display_data["payload"]["_streaming"] = "".join(
615
- stream_buffers[buffer_key]
616
- )
617
- elif token:
618
- stream_buffers[status_field].append(str(token))
619
- display_data["status"] = "".join(stream_buffers[status_field])
620
-
621
- # Emit to WebSocket (non-blocking to prevent deadlock)
622
- if ws_broadcast and token:
623
- try:
624
- event = StreamingOutputEvent(
625
- correlation_id=str(ctx.correlation_id)
626
- if ctx and ctx.correlation_id
627
- else "",
628
- agent_name=agent.name,
629
- run_id=ctx.task_id if ctx else "",
630
- output_type="llm_token",
631
- content=str(token),
632
- sequence=stream_sequence,
633
- is_final=False,
634
- artifact_id=str(
635
- pre_generated_artifact_id
636
- ), # Phase 6: Track artifact for message streaming
637
- artifact_type=display_data[
638
- "type"
639
- ], # Phase 6: Artifact type name from display_data
640
- )
641
- # Use create_task to avoid blocking the streaming loop
642
- task = asyncio.create_task(ws_broadcast(event))
643
- ws_broadcast_tasks.add(task)
644
- task.add_done_callback(ws_broadcast_tasks.discard)
645
- stream_sequence += 1
646
- except Exception as e:
647
- logger.warning(f"Failed to emit streaming event: {e}")
648
-
649
- if formatter is not None:
650
- _refresh_panel()
651
- continue
652
-
653
- if isinstance(value, dspy_mod.Prediction):
654
- final_result = value
655
-
656
- # Emit final streaming event (non-blocking to prevent deadlock)
657
- if ws_broadcast:
658
- try:
659
- event = StreamingOutputEvent(
660
- correlation_id=str(ctx.correlation_id)
661
- if ctx and ctx.correlation_id
662
- else "",
663
- agent_name=agent.name,
664
- run_id=ctx.task_id if ctx else "",
665
- output_type="log",
666
- content="\nAmount of output tokens: "
667
- + str(stream_sequence),
668
- sequence=stream_sequence,
669
- is_final=True, # Mark as final
670
- artifact_id=str(
671
- pre_generated_artifact_id
672
- ), # Phase 6: Track artifact for message streaming
673
- artifact_type=display_data[
674
- "type"
675
- ], # Phase 6: Artifact type name
676
- )
677
- # Use create_task to avoid blocking the streaming loop
678
- task = asyncio.create_task(ws_broadcast(event))
679
- ws_broadcast_tasks.add(task)
680
- task.add_done_callback(ws_broadcast_tasks.discard)
681
- event = StreamingOutputEvent(
682
- correlation_id=str(ctx.correlation_id)
683
- if ctx and ctx.correlation_id
684
- else "",
685
- agent_name=agent.name,
686
- run_id=ctx.task_id if ctx else "",
687
- output_type="log",
688
- content="--- End of output ---",
689
- sequence=stream_sequence,
690
- is_final=True, # Mark as final
691
- artifact_id=str(
692
- pre_generated_artifact_id
693
- ), # Phase 6: Track artifact for message streaming
694
- artifact_type=display_data[
695
- "type"
696
- ], # Phase 6: Artifact type name
697
- )
698
- # Use create_task to avoid blocking the streaming loop
699
- task = asyncio.create_task(ws_broadcast(event))
700
- ws_broadcast_tasks.add(task)
701
- task.add_done_callback(ws_broadcast_tasks.discard)
702
- except Exception as e:
703
- logger.warning(f"Failed to emit final streaming event: {e}")
704
-
705
- if formatter is not None:
706
- # Update payload section with final values
707
- payload_data = OrderedDict()
708
- for field_name in signature_order:
709
- if field_name != "description" and hasattr(
710
- final_result, field_name
711
- ):
712
- field_value = getattr(final_result, field_name)
713
-
714
- # Convert BaseModel instances to dicts for proper table rendering
715
- if isinstance(field_value, list):
716
- # Handle lists of BaseModel instances (fan-out/batch)
717
- payload_data[field_name] = [
718
- item.model_dump()
719
- if isinstance(item, BaseModel)
720
- else item
721
- for item in field_value
722
- ]
723
- elif isinstance(field_value, BaseModel):
724
- # Handle single BaseModel instance
725
- payload_data[field_name] = field_value.model_dump()
726
- else:
727
- # Handle primitive types
728
- payload_data[field_name] = field_value
729
-
730
- # Update all fields with actual values
731
- display_data["payload"].clear()
732
- display_data["payload"].update(payload_data)
733
-
734
- # Update timestamp
735
- from datetime import datetime
736
-
737
- display_data["created_at"] = datetime.now(UTC).isoformat()
738
-
739
- # Remove status field from display
740
- display_data.pop("status", None)
741
- _refresh_panel()
683
+ await self._flush_sinks(sinks)
742
684
 
743
685
  if final_result is None:
744
686
  raise RuntimeError("Streaming did not yield a final prediction.")
745
687
 
746
- # Return both the result and the display data for final ID update
747
- return final_result, (formatter, display_data, theme_dict, styles, agent_label)
688
+ stream_display = self._finalize_stream_display(
689
+ rich_sink=rich_sink,
690
+ formatter=formatter,
691
+ display_data=display_data,
692
+ theme_dict=theme_dict,
693
+ styles=styles,
694
+ agent_label=agent_label,
695
+ )
696
+
697
+ logger.info(
698
+ f"Agent {agent.name}: Rich streaming completed ({tokens_emitted} tokens)"
699
+ )
700
+
701
+ return final_result, stream_display
748
702
 
749
703
  def prepare_stream_formatter(
750
704
  self, agent: Any