swarmkit 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,487 @@
1
+ """Pipeline - Fluent API for Swarm Operations.
2
+
3
+ Thin wrapper over Swarm providing method chaining, timing, and events.
4
+
5
+ Example:
6
+ ```python
7
+ pipeline = (
8
+ Pipeline(swarm)
9
+ .map(MapConfig(prompt="Analyze..."))
10
+ .filter(FilterConfig(
11
+ prompt="Rate quality",
12
+ schema=QualitySchema,
13
+ condition=lambda d: d.score > 7,
14
+ ))
15
+ .reduce(ReduceConfig(prompt="Synthesize..."))
16
+ )
17
+
18
+ # Run with items
19
+ result = await pipeline.run(documents)
20
+
21
+ # Reusable - run with different data
22
+ await pipeline.run(batch1)
23
+ await pipeline.run(batch2)
24
+ ```
25
+ """
26
+
27
+ import secrets
28
+ import time
29
+ from dataclasses import dataclass, field, replace
30
+ from typing import Any, Callable, Generic, List, Optional, TypeVar, Union, overload
31
+
32
+ from ..swarm import Swarm
33
+ from ..swarm.types import FileMap, ItemInput, BestOfConfig, VerifyConfig
34
+ from ..swarm.results import SwarmResult, SwarmResultList, ReduceResult
35
+ from ..retry import RetryConfig
36
+ from .types import (
37
+ Step,
38
+ StepType,
39
+ MapConfig,
40
+ FilterConfig,
41
+ ReduceConfig,
42
+ StepResult,
43
+ PipelineResult,
44
+ PipelineEvents,
45
+ PipelineEventMap,
46
+ StepStartEvent,
47
+ StepCompleteEvent,
48
+ StepErrorEvent,
49
+ ItemRetryEvent,
50
+ WorkerCompleteEvent,
51
+ VerifierCompleteEvent,
52
+ CandidateCompleteEvent,
53
+ JudgeCompleteEvent,
54
+ )
55
+
56
+
57
+ T = TypeVar('T')
58
+
59
+
60
+ # =============================================================================
61
+ # PIPELINE
62
+ # =============================================================================
63
+
64
+ @dataclass
65
+ class Pipeline(Generic[T]):
66
+ """Pipeline for chaining Swarm operations.
67
+
68
+ Swarm is bound at construction (infrastructure).
69
+ Items are passed at execution (data).
70
+ Pipeline is immutable - each method returns a new instance.
71
+ """
72
+ _swarm: Swarm
73
+ _steps: List[Step] = field(default_factory=list)
74
+ _events: PipelineEvents = field(default_factory=PipelineEvents)
75
+
76
+ # ===========================================================================
77
+ # STEP METHODS
78
+ # ===========================================================================
79
+
80
+ def map(self, config: MapConfig[T]) -> 'Pipeline[T]':
81
+ """Add a map step to transform items in parallel."""
82
+ new_steps = self._steps + [Step(type="map", config=config)]
83
+ return Pipeline(
84
+ _swarm=self._swarm,
85
+ _steps=new_steps,
86
+ _events=self._events,
87
+ )
88
+
89
+ def filter(self, config: FilterConfig[T]) -> 'Pipeline[T]':
90
+ """Add a filter step to evaluate and filter items."""
91
+ new_steps = self._steps + [Step(type="filter", config=config)]
92
+ return Pipeline(
93
+ _swarm=self._swarm,
94
+ _steps=new_steps,
95
+ _events=self._events,
96
+ )
97
+
98
+ def reduce(self, config: ReduceConfig[T]) -> 'TerminalPipeline[T]':
99
+ """Add a reduce step (terminal - no steps can follow)."""
100
+ new_steps = self._steps + [Step(type="reduce", config=config)]
101
+ return TerminalPipeline(
102
+ _swarm=self._swarm,
103
+ _steps=new_steps,
104
+ _events=self._events,
105
+ )
106
+
107
+ # ===========================================================================
108
+ # EVENTS
109
+ # ===========================================================================
110
+
111
+ @overload
112
+ def on(self, handlers: PipelineEvents) -> 'Pipeline[T]': ...
113
+
114
+ @overload
115
+ def on(self, event: str, handler: Callable) -> 'Pipeline[T]': ...
116
+
117
+ def on(self, event_or_handlers: Union[PipelineEvents, str], handler: Optional[Callable] = None) -> 'Pipeline[T]':
118
+ """Register event handlers for step lifecycle.
119
+
120
+ Supports two styles:
121
+ - Object: .on(PipelineEvents(on_step_complete=fn))
122
+ - Chainable: .on("step_complete", fn)
123
+ """
124
+ if isinstance(event_or_handlers, str):
125
+ # Chainable style: .on("step_complete", fn)
126
+ key = PipelineEventMap.get(event_or_handlers)
127
+ if key is None:
128
+ raise ValueError(f"Unknown event: {event_or_handlers}")
129
+ new_events = replace(self._events, **{key: handler})
130
+ else:
131
+ # Object style: .on(PipelineEvents(...))
132
+ new_events = replace(
133
+ self._events,
134
+ on_step_start=event_or_handlers.on_step_start or self._events.on_step_start,
135
+ on_step_complete=event_or_handlers.on_step_complete or self._events.on_step_complete,
136
+ on_step_error=event_or_handlers.on_step_error or self._events.on_step_error,
137
+ on_item_retry=event_or_handlers.on_item_retry or self._events.on_item_retry,
138
+ on_worker_complete=event_or_handlers.on_worker_complete or self._events.on_worker_complete,
139
+ on_verifier_complete=event_or_handlers.on_verifier_complete or self._events.on_verifier_complete,
140
+ on_candidate_complete=event_or_handlers.on_candidate_complete or self._events.on_candidate_complete,
141
+ on_judge_complete=event_or_handlers.on_judge_complete or self._events.on_judge_complete,
142
+ )
143
+ return Pipeline(
144
+ _swarm=self._swarm,
145
+ _steps=self._steps,
146
+ _events=new_events,
147
+ )
148
+
149
+ # ===========================================================================
150
+ # EXECUTION
151
+ # ===========================================================================
152
+
153
+ async def run(self, items: List[ItemInput]) -> PipelineResult[T]:
154
+ """Execute the pipeline with the given items."""
155
+ run_id = secrets.token_hex(8)
156
+ step_results: List[StepResult] = []
157
+ current_items: List[ItemInput] = list(items)
158
+ start_time = time.time()
159
+
160
+ for i, step in enumerate(self._steps):
161
+ step_name = getattr(step.config, 'name', None)
162
+ step_start = time.time()
163
+
164
+ if self._events.on_step_start:
165
+ self._events.on_step_start(StepStartEvent(
166
+ type=step.type,
167
+ index=i,
168
+ name=step_name,
169
+ item_count=len(current_items),
170
+ ))
171
+
172
+ try:
173
+ result = await self._execute_step(step, current_items, i, step_name)
174
+ duration_ms = int((time.time() - step_start) * 1000)
175
+
176
+ step_results.append(StepResult(
177
+ type=step.type,
178
+ index=i,
179
+ duration_ms=duration_ms,
180
+ results=result["output"],
181
+ ))
182
+
183
+ if self._events.on_step_complete:
184
+ self._events.on_step_complete(StepCompleteEvent(
185
+ type=step.type,
186
+ index=i,
187
+ name=step_name,
188
+ duration_ms=duration_ms,
189
+ success_count=result["success_count"],
190
+ error_count=result["error_count"],
191
+ filtered_count=result["filtered_count"],
192
+ ))
193
+
194
+ # Reduce is terminal
195
+ if step.type == "reduce":
196
+ return PipelineResult(
197
+ run_id=run_id,
198
+ steps=step_results,
199
+ output=result["output"],
200
+ total_duration_ms=int((time.time() - start_time) * 1000),
201
+ )
202
+
203
+ current_items = result["next_items"]
204
+
205
+ except Exception as e:
206
+ if self._events.on_step_error:
207
+ self._events.on_step_error(StepErrorEvent(
208
+ type=step.type,
209
+ index=i,
210
+ name=step_name,
211
+ error=e,
212
+ ))
213
+ raise
214
+
215
+ last_result = step_results[-1] if step_results else None
216
+ return PipelineResult(
217
+ run_id=run_id,
218
+ steps=step_results,
219
+ output=last_result.results if last_result else [],
220
+ total_duration_ms=int((time.time() - start_time) * 1000),
221
+ )
222
+
223
+ # ===========================================================================
224
+ # PRIVATE
225
+ # ===========================================================================
226
+
227
+ async def _execute_step(
228
+ self,
229
+ step: Step,
230
+ items: List[ItemInput],
231
+ step_index: int,
232
+ step_name: Optional[str],
233
+ ) -> dict:
234
+ """Execute a single step and return results."""
235
+ if step.type == "map":
236
+ config = step.config
237
+ results = await self._swarm.map(
238
+ items=items,
239
+ prompt=config.prompt,
240
+ system_prompt=config.system_prompt,
241
+ schema=config.schema,
242
+ schema_options=config.schema_options,
243
+ agent=config.agent,
244
+ mcp_servers=config.mcp_servers,
245
+ best_of=self._wrap_best_of(config.best_of, step_index, step_name),
246
+ verify=self._wrap_verify(config.verify, step_index, step_name),
247
+ retry=self._wrap_retry(config.retry, step_index, step_name),
248
+ timeout_ms=config.timeout_ms,
249
+ )
250
+ return {
251
+ "output": list(results),
252
+ "next_items": results.success,
253
+ "success_count": len(results.success),
254
+ "error_count": len(results.error),
255
+ "filtered_count": 0,
256
+ }
257
+
258
+ if step.type == "filter":
259
+ config = step.config
260
+ results = await self._swarm.filter(
261
+ items=items,
262
+ prompt=config.prompt,
263
+ schema=config.schema,
264
+ condition=config.condition,
265
+ schema_options=config.schema_options,
266
+ system_prompt=config.system_prompt,
267
+ agent=config.agent,
268
+ mcp_servers=config.mcp_servers,
269
+ verify=self._wrap_verify(config.verify, step_index, step_name),
270
+ retry=self._wrap_retry(config.retry, step_index, step_name),
271
+ timeout_ms=config.timeout_ms,
272
+ )
273
+ emit = getattr(config, 'emit', 'success')
274
+ if emit == "success":
275
+ next_items = results.success
276
+ elif emit == "filtered":
277
+ next_items = results.filtered
278
+ else: # "all"
279
+ next_items = results.success + results.filtered
280
+ return {
281
+ "output": list(results),
282
+ "next_items": next_items,
283
+ "success_count": len(results.success),
284
+ "error_count": len(results.error),
285
+ "filtered_count": len(results.filtered),
286
+ }
287
+
288
+ # reduce
289
+ config = step.config
290
+ result = await self._swarm.reduce(
291
+ items=items,
292
+ prompt=config.prompt,
293
+ system_prompt=config.system_prompt,
294
+ schema=config.schema,
295
+ schema_options=config.schema_options,
296
+ agent=config.agent,
297
+ mcp_servers=config.mcp_servers,
298
+ verify=self._wrap_verify(config.verify, step_index, step_name),
299
+ retry=self._wrap_retry(config.retry, step_index, step_name),
300
+ timeout_ms=config.timeout_ms,
301
+ )
302
+ return {
303
+ "output": result,
304
+ "next_items": [],
305
+ "success_count": 1 if result.status == "success" else 0,
306
+ "error_count": 1 if result.status == "error" else 0,
307
+ "filtered_count": 0,
308
+ }
309
+
310
+ def _wrap_retry(
311
+ self,
312
+ config: Optional[RetryConfig],
313
+ step_index: int,
314
+ step_name: Optional[str],
315
+ ) -> Optional[RetryConfig]:
316
+ """Wrap retry config to inject pipeline-level callback."""
317
+ if config is None:
318
+ return None
319
+
320
+ original_callback = config.on_item_retry
321
+
322
+ def wrapped_callback(item_index: int, attempt: int, error: str):
323
+ if original_callback:
324
+ original_callback(item_index, attempt, error)
325
+ if self._events.on_item_retry:
326
+ self._events.on_item_retry(ItemRetryEvent(
327
+ step_index=step_index,
328
+ step_name=step_name,
329
+ item_index=item_index,
330
+ attempt=attempt,
331
+ error=error,
332
+ ))
333
+
334
+ return RetryConfig(
335
+ max_attempts=config.max_attempts,
336
+ backoff_ms=config.backoff_ms,
337
+ backoff_multiplier=config.backoff_multiplier,
338
+ retry_on=config.retry_on,
339
+ on_item_retry=wrapped_callback,
340
+ )
341
+
342
+ def _wrap_verify(
343
+ self,
344
+ config: Optional[VerifyConfig],
345
+ step_index: int,
346
+ step_name: Optional[str],
347
+ ) -> Optional[VerifyConfig]:
348
+ """Wrap verify config to inject pipeline-level callbacks."""
349
+ if config is None:
350
+ return None
351
+
352
+ original_worker = config.on_worker_complete
353
+ original_verifier = config.on_verifier_complete
354
+
355
+ def wrapped_worker(item_index: int, attempt: int, status: str):
356
+ if original_worker:
357
+ original_worker(item_index, attempt, status)
358
+ if self._events.on_worker_complete:
359
+ self._events.on_worker_complete(WorkerCompleteEvent(
360
+ step_index=step_index,
361
+ step_name=step_name,
362
+ item_index=item_index,
363
+ attempt=attempt,
364
+ status=status,
365
+ ))
366
+
367
+ def wrapped_verifier(item_index: int, attempt: int, passed: bool, feedback: Optional[str]):
368
+ if original_verifier:
369
+ original_verifier(item_index, attempt, passed, feedback)
370
+ if self._events.on_verifier_complete:
371
+ self._events.on_verifier_complete(VerifierCompleteEvent(
372
+ step_index=step_index,
373
+ step_name=step_name,
374
+ item_index=item_index,
375
+ attempt=attempt,
376
+ passed=passed,
377
+ feedback=feedback,
378
+ ))
379
+
380
+ return VerifyConfig(
381
+ criteria=config.criteria,
382
+ max_attempts=config.max_attempts,
383
+ verifier_agent=config.verifier_agent,
384
+ verifier_mcp_servers=config.verifier_mcp_servers,
385
+ on_worker_complete=wrapped_worker,
386
+ on_verifier_complete=wrapped_verifier,
387
+ )
388
+
389
+ def _wrap_best_of(
390
+ self,
391
+ config: Optional[BestOfConfig],
392
+ step_index: int,
393
+ step_name: Optional[str],
394
+ ) -> Optional[BestOfConfig]:
395
+ """Wrap bestOf config to inject pipeline-level callbacks."""
396
+ if config is None:
397
+ return None
398
+
399
+ original_candidate = config.on_candidate_complete
400
+ original_judge = config.on_judge_complete
401
+
402
+ def wrapped_candidate(item_index: int, candidate_index: int, status: str):
403
+ if original_candidate:
404
+ original_candidate(item_index, candidate_index, status)
405
+ if self._events.on_candidate_complete:
406
+ self._events.on_candidate_complete(CandidateCompleteEvent(
407
+ step_index=step_index,
408
+ step_name=step_name,
409
+ item_index=item_index,
410
+ candidate_index=candidate_index,
411
+ status=status,
412
+ ))
413
+
414
+ def wrapped_judge(item_index: int, winner_index: int, reasoning: str):
415
+ if original_judge:
416
+ original_judge(item_index, winner_index, reasoning)
417
+ if self._events.on_judge_complete:
418
+ self._events.on_judge_complete(JudgeCompleteEvent(
419
+ step_index=step_index,
420
+ step_name=step_name,
421
+ item_index=item_index,
422
+ winner_index=winner_index,
423
+ reasoning=reasoning,
424
+ ))
425
+
426
+ return BestOfConfig(
427
+ judge_criteria=config.judge_criteria,
428
+ n=config.n,
429
+ task_agents=config.task_agents,
430
+ judge_agent=config.judge_agent,
431
+ mcp_servers=config.mcp_servers,
432
+ judge_mcp_servers=config.judge_mcp_servers,
433
+ on_candidate_complete=wrapped_candidate,
434
+ on_judge_complete=wrapped_judge,
435
+ )
436
+
437
+
438
+ # =============================================================================
439
+ # TERMINAL PIPELINE
440
+ # =============================================================================
441
+
442
+ @dataclass
443
+ class TerminalPipeline(Pipeline[T]):
444
+ """Pipeline after reduce - no more steps can be added."""
445
+
446
+ def map(self, config: MapConfig) -> 'Pipeline':
447
+ """Cannot add steps after reduce."""
448
+ raise RuntimeError("Cannot add steps after reduce")
449
+
450
+ def filter(self, config: FilterConfig) -> 'Pipeline':
451
+ """Cannot add steps after reduce."""
452
+ raise RuntimeError("Cannot add steps after reduce")
453
+
454
+ def reduce(self, config: ReduceConfig) -> 'TerminalPipeline':
455
+ """Cannot add steps after reduce."""
456
+ raise RuntimeError("Cannot add steps after reduce")
457
+
458
+ @overload
459
+ def on(self, handlers: PipelineEvents) -> 'TerminalPipeline[T]': ...
460
+
461
+ @overload
462
+ def on(self, event: str, handler: Callable) -> 'TerminalPipeline[T]': ...
463
+
464
+ def on(self, event_or_handlers: Union[PipelineEvents, str], handler: Optional[Callable] = None) -> 'TerminalPipeline[T]':
465
+ """Register event handlers for step lifecycle."""
466
+ if isinstance(event_or_handlers, str):
467
+ key = PipelineEventMap.get(event_or_handlers)
468
+ if key is None:
469
+ raise ValueError(f"Unknown event: {event_or_handlers}")
470
+ new_events = replace(self._events, **{key: handler})
471
+ else:
472
+ new_events = replace(
473
+ self._events,
474
+ on_step_start=event_or_handlers.on_step_start or self._events.on_step_start,
475
+ on_step_complete=event_or_handlers.on_step_complete or self._events.on_step_complete,
476
+ on_step_error=event_or_handlers.on_step_error or self._events.on_step_error,
477
+ on_item_retry=event_or_handlers.on_item_retry or self._events.on_item_retry,
478
+ on_worker_complete=event_or_handlers.on_worker_complete or self._events.on_worker_complete,
479
+ on_verifier_complete=event_or_handlers.on_verifier_complete or self._events.on_verifier_complete,
480
+ on_candidate_complete=event_or_handlers.on_candidate_complete or self._events.on_candidate_complete,
481
+ on_judge_complete=event_or_handlers.on_judge_complete or self._events.on_judge_complete,
482
+ )
483
+ return TerminalPipeline(
484
+ _swarm=self._swarm,
485
+ _steps=self._steps,
486
+ _events=new_events,
487
+ )