swarmkit 0.1.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bridge/__init__.py +5 -0
- bridge/dist/bridge.bundle.cjs +8 -0
- swarmkit/__init__.py +152 -0
- swarmkit/agent.py +480 -0
- swarmkit/bridge.py +475 -0
- swarmkit/config.py +92 -0
- swarmkit/pipeline/__init__.py +59 -0
- swarmkit/pipeline/pipeline.py +487 -0
- swarmkit/pipeline/types.py +272 -0
- swarmkit/prompts/__init__.py +126 -0
- swarmkit/prompts/agent_md/judge.md +30 -0
- swarmkit/prompts/agent_md/reduce.md +7 -0
- swarmkit/prompts/agent_md/verify.md +33 -0
- swarmkit/prompts/user/judge.md +1 -0
- swarmkit/prompts/user/retry_feedback.md +9 -0
- swarmkit/prompts/user/verify.md +1 -0
- swarmkit/results.py +45 -0
- swarmkit/retry.py +133 -0
- swarmkit/schema.py +107 -0
- swarmkit/swarm/__init__.py +75 -0
- swarmkit/swarm/results.py +140 -0
- swarmkit/swarm/swarm.py +1751 -0
- swarmkit/swarm/types.py +193 -0
- swarmkit/utils.py +82 -0
- swarmkit-0.1.34.dist-info/METADATA +80 -0
- swarmkit-0.1.34.dist-info/RECORD +29 -0
- swarmkit-0.1.34.dist-info/WHEEL +5 -0
- swarmkit-0.1.34.dist-info/licenses/LICENSE +24 -0
- swarmkit-0.1.34.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,487 @@
|
|
|
1
|
+
"""Pipeline - Fluent API for Swarm Operations.
|
|
2
|
+
|
|
3
|
+
Thin wrapper over Swarm providing method chaining, timing, and events.
|
|
4
|
+
|
|
5
|
+
Example:
|
|
6
|
+
```python
|
|
7
|
+
pipeline = (
|
|
8
|
+
Pipeline(swarm)
|
|
9
|
+
.map(MapConfig(prompt="Analyze..."))
|
|
10
|
+
.filter(FilterConfig(
|
|
11
|
+
prompt="Rate quality",
|
|
12
|
+
schema=QualitySchema,
|
|
13
|
+
condition=lambda d: d.score > 7,
|
|
14
|
+
))
|
|
15
|
+
.reduce(ReduceConfig(prompt="Synthesize..."))
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# Run with items
|
|
19
|
+
result = await pipeline.run(documents)
|
|
20
|
+
|
|
21
|
+
# Reusable - run with different data
|
|
22
|
+
await pipeline.run(batch1)
|
|
23
|
+
await pipeline.run(batch2)
|
|
24
|
+
```
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
import secrets
|
|
28
|
+
import time
|
|
29
|
+
from dataclasses import dataclass, field, replace
|
|
30
|
+
from typing import Any, Callable, Generic, List, Optional, TypeVar, Union, overload
|
|
31
|
+
|
|
32
|
+
from ..swarm import Swarm
|
|
33
|
+
from ..swarm.types import FileMap, ItemInput, BestOfConfig, VerifyConfig
|
|
34
|
+
from ..swarm.results import SwarmResult, SwarmResultList, ReduceResult
|
|
35
|
+
from ..retry import RetryConfig
|
|
36
|
+
from .types import (
|
|
37
|
+
Step,
|
|
38
|
+
StepType,
|
|
39
|
+
MapConfig,
|
|
40
|
+
FilterConfig,
|
|
41
|
+
ReduceConfig,
|
|
42
|
+
StepResult,
|
|
43
|
+
PipelineResult,
|
|
44
|
+
PipelineEvents,
|
|
45
|
+
PipelineEventMap,
|
|
46
|
+
StepStartEvent,
|
|
47
|
+
StepCompleteEvent,
|
|
48
|
+
StepErrorEvent,
|
|
49
|
+
ItemRetryEvent,
|
|
50
|
+
WorkerCompleteEvent,
|
|
51
|
+
VerifierCompleteEvent,
|
|
52
|
+
CandidateCompleteEvent,
|
|
53
|
+
JudgeCompleteEvent,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
T = TypeVar('T')
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# =============================================================================
|
|
61
|
+
# PIPELINE
|
|
62
|
+
# =============================================================================
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class Pipeline(Generic[T]):
|
|
66
|
+
"""Pipeline for chaining Swarm operations.
|
|
67
|
+
|
|
68
|
+
Swarm is bound at construction (infrastructure).
|
|
69
|
+
Items are passed at execution (data).
|
|
70
|
+
Pipeline is immutable - each method returns a new instance.
|
|
71
|
+
"""
|
|
72
|
+
_swarm: Swarm
|
|
73
|
+
_steps: List[Step] = field(default_factory=list)
|
|
74
|
+
_events: PipelineEvents = field(default_factory=PipelineEvents)
|
|
75
|
+
|
|
76
|
+
# ===========================================================================
|
|
77
|
+
# STEP METHODS
|
|
78
|
+
# ===========================================================================
|
|
79
|
+
|
|
80
|
+
def map(self, config: MapConfig[T]) -> 'Pipeline[T]':
|
|
81
|
+
"""Add a map step to transform items in parallel."""
|
|
82
|
+
new_steps = self._steps + [Step(type="map", config=config)]
|
|
83
|
+
return Pipeline(
|
|
84
|
+
_swarm=self._swarm,
|
|
85
|
+
_steps=new_steps,
|
|
86
|
+
_events=self._events,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def filter(self, config: FilterConfig[T]) -> 'Pipeline[T]':
|
|
90
|
+
"""Add a filter step to evaluate and filter items."""
|
|
91
|
+
new_steps = self._steps + [Step(type="filter", config=config)]
|
|
92
|
+
return Pipeline(
|
|
93
|
+
_swarm=self._swarm,
|
|
94
|
+
_steps=new_steps,
|
|
95
|
+
_events=self._events,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def reduce(self, config: ReduceConfig[T]) -> 'TerminalPipeline[T]':
|
|
99
|
+
"""Add a reduce step (terminal - no steps can follow)."""
|
|
100
|
+
new_steps = self._steps + [Step(type="reduce", config=config)]
|
|
101
|
+
return TerminalPipeline(
|
|
102
|
+
_swarm=self._swarm,
|
|
103
|
+
_steps=new_steps,
|
|
104
|
+
_events=self._events,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# ===========================================================================
|
|
108
|
+
# EVENTS
|
|
109
|
+
# ===========================================================================
|
|
110
|
+
|
|
111
|
+
@overload
|
|
112
|
+
def on(self, handlers: PipelineEvents) -> 'Pipeline[T]': ...
|
|
113
|
+
|
|
114
|
+
@overload
|
|
115
|
+
def on(self, event: str, handler: Callable) -> 'Pipeline[T]': ...
|
|
116
|
+
|
|
117
|
+
def on(self, event_or_handlers: Union[PipelineEvents, str], handler: Optional[Callable] = None) -> 'Pipeline[T]':
|
|
118
|
+
"""Register event handlers for step lifecycle.
|
|
119
|
+
|
|
120
|
+
Supports two styles:
|
|
121
|
+
- Object: .on(PipelineEvents(on_step_complete=fn))
|
|
122
|
+
- Chainable: .on("step_complete", fn)
|
|
123
|
+
"""
|
|
124
|
+
if isinstance(event_or_handlers, str):
|
|
125
|
+
# Chainable style: .on("step_complete", fn)
|
|
126
|
+
key = PipelineEventMap.get(event_or_handlers)
|
|
127
|
+
if key is None:
|
|
128
|
+
raise ValueError(f"Unknown event: {event_or_handlers}")
|
|
129
|
+
new_events = replace(self._events, **{key: handler})
|
|
130
|
+
else:
|
|
131
|
+
# Object style: .on(PipelineEvents(...))
|
|
132
|
+
new_events = replace(
|
|
133
|
+
self._events,
|
|
134
|
+
on_step_start=event_or_handlers.on_step_start or self._events.on_step_start,
|
|
135
|
+
on_step_complete=event_or_handlers.on_step_complete or self._events.on_step_complete,
|
|
136
|
+
on_step_error=event_or_handlers.on_step_error or self._events.on_step_error,
|
|
137
|
+
on_item_retry=event_or_handlers.on_item_retry or self._events.on_item_retry,
|
|
138
|
+
on_worker_complete=event_or_handlers.on_worker_complete or self._events.on_worker_complete,
|
|
139
|
+
on_verifier_complete=event_or_handlers.on_verifier_complete or self._events.on_verifier_complete,
|
|
140
|
+
on_candidate_complete=event_or_handlers.on_candidate_complete or self._events.on_candidate_complete,
|
|
141
|
+
on_judge_complete=event_or_handlers.on_judge_complete or self._events.on_judge_complete,
|
|
142
|
+
)
|
|
143
|
+
return Pipeline(
|
|
144
|
+
_swarm=self._swarm,
|
|
145
|
+
_steps=self._steps,
|
|
146
|
+
_events=new_events,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# ===========================================================================
|
|
150
|
+
# EXECUTION
|
|
151
|
+
# ===========================================================================
|
|
152
|
+
|
|
153
|
+
async def run(self, items: List[ItemInput]) -> PipelineResult[T]:
|
|
154
|
+
"""Execute the pipeline with the given items."""
|
|
155
|
+
run_id = secrets.token_hex(8)
|
|
156
|
+
step_results: List[StepResult] = []
|
|
157
|
+
current_items: List[ItemInput] = list(items)
|
|
158
|
+
start_time = time.time()
|
|
159
|
+
|
|
160
|
+
for i, step in enumerate(self._steps):
|
|
161
|
+
step_name = getattr(step.config, 'name', None)
|
|
162
|
+
step_start = time.time()
|
|
163
|
+
|
|
164
|
+
if self._events.on_step_start:
|
|
165
|
+
self._events.on_step_start(StepStartEvent(
|
|
166
|
+
type=step.type,
|
|
167
|
+
index=i,
|
|
168
|
+
name=step_name,
|
|
169
|
+
item_count=len(current_items),
|
|
170
|
+
))
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
result = await self._execute_step(step, current_items, i, step_name)
|
|
174
|
+
duration_ms = int((time.time() - step_start) * 1000)
|
|
175
|
+
|
|
176
|
+
step_results.append(StepResult(
|
|
177
|
+
type=step.type,
|
|
178
|
+
index=i,
|
|
179
|
+
duration_ms=duration_ms,
|
|
180
|
+
results=result["output"],
|
|
181
|
+
))
|
|
182
|
+
|
|
183
|
+
if self._events.on_step_complete:
|
|
184
|
+
self._events.on_step_complete(StepCompleteEvent(
|
|
185
|
+
type=step.type,
|
|
186
|
+
index=i,
|
|
187
|
+
name=step_name,
|
|
188
|
+
duration_ms=duration_ms,
|
|
189
|
+
success_count=result["success_count"],
|
|
190
|
+
error_count=result["error_count"],
|
|
191
|
+
filtered_count=result["filtered_count"],
|
|
192
|
+
))
|
|
193
|
+
|
|
194
|
+
# Reduce is terminal
|
|
195
|
+
if step.type == "reduce":
|
|
196
|
+
return PipelineResult(
|
|
197
|
+
run_id=run_id,
|
|
198
|
+
steps=step_results,
|
|
199
|
+
output=result["output"],
|
|
200
|
+
total_duration_ms=int((time.time() - start_time) * 1000),
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
current_items = result["next_items"]
|
|
204
|
+
|
|
205
|
+
except Exception as e:
|
|
206
|
+
if self._events.on_step_error:
|
|
207
|
+
self._events.on_step_error(StepErrorEvent(
|
|
208
|
+
type=step.type,
|
|
209
|
+
index=i,
|
|
210
|
+
name=step_name,
|
|
211
|
+
error=e,
|
|
212
|
+
))
|
|
213
|
+
raise
|
|
214
|
+
|
|
215
|
+
last_result = step_results[-1] if step_results else None
|
|
216
|
+
return PipelineResult(
|
|
217
|
+
run_id=run_id,
|
|
218
|
+
steps=step_results,
|
|
219
|
+
output=last_result.results if last_result else [],
|
|
220
|
+
total_duration_ms=int((time.time() - start_time) * 1000),
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# ===========================================================================
|
|
224
|
+
# PRIVATE
|
|
225
|
+
# ===========================================================================
|
|
226
|
+
|
|
227
|
+
async def _execute_step(
|
|
228
|
+
self,
|
|
229
|
+
step: Step,
|
|
230
|
+
items: List[ItemInput],
|
|
231
|
+
step_index: int,
|
|
232
|
+
step_name: Optional[str],
|
|
233
|
+
) -> dict:
|
|
234
|
+
"""Execute a single step and return results."""
|
|
235
|
+
if step.type == "map":
|
|
236
|
+
config = step.config
|
|
237
|
+
results = await self._swarm.map(
|
|
238
|
+
items=items,
|
|
239
|
+
prompt=config.prompt,
|
|
240
|
+
system_prompt=config.system_prompt,
|
|
241
|
+
schema=config.schema,
|
|
242
|
+
schema_options=config.schema_options,
|
|
243
|
+
agent=config.agent,
|
|
244
|
+
mcp_servers=config.mcp_servers,
|
|
245
|
+
best_of=self._wrap_best_of(config.best_of, step_index, step_name),
|
|
246
|
+
verify=self._wrap_verify(config.verify, step_index, step_name),
|
|
247
|
+
retry=self._wrap_retry(config.retry, step_index, step_name),
|
|
248
|
+
timeout_ms=config.timeout_ms,
|
|
249
|
+
)
|
|
250
|
+
return {
|
|
251
|
+
"output": list(results),
|
|
252
|
+
"next_items": results.success,
|
|
253
|
+
"success_count": len(results.success),
|
|
254
|
+
"error_count": len(results.error),
|
|
255
|
+
"filtered_count": 0,
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
if step.type == "filter":
|
|
259
|
+
config = step.config
|
|
260
|
+
results = await self._swarm.filter(
|
|
261
|
+
items=items,
|
|
262
|
+
prompt=config.prompt,
|
|
263
|
+
schema=config.schema,
|
|
264
|
+
condition=config.condition,
|
|
265
|
+
schema_options=config.schema_options,
|
|
266
|
+
system_prompt=config.system_prompt,
|
|
267
|
+
agent=config.agent,
|
|
268
|
+
mcp_servers=config.mcp_servers,
|
|
269
|
+
verify=self._wrap_verify(config.verify, step_index, step_name),
|
|
270
|
+
retry=self._wrap_retry(config.retry, step_index, step_name),
|
|
271
|
+
timeout_ms=config.timeout_ms,
|
|
272
|
+
)
|
|
273
|
+
emit = getattr(config, 'emit', 'success')
|
|
274
|
+
if emit == "success":
|
|
275
|
+
next_items = results.success
|
|
276
|
+
elif emit == "filtered":
|
|
277
|
+
next_items = results.filtered
|
|
278
|
+
else: # "all"
|
|
279
|
+
next_items = results.success + results.filtered
|
|
280
|
+
return {
|
|
281
|
+
"output": list(results),
|
|
282
|
+
"next_items": next_items,
|
|
283
|
+
"success_count": len(results.success),
|
|
284
|
+
"error_count": len(results.error),
|
|
285
|
+
"filtered_count": len(results.filtered),
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
# reduce
|
|
289
|
+
config = step.config
|
|
290
|
+
result = await self._swarm.reduce(
|
|
291
|
+
items=items,
|
|
292
|
+
prompt=config.prompt,
|
|
293
|
+
system_prompt=config.system_prompt,
|
|
294
|
+
schema=config.schema,
|
|
295
|
+
schema_options=config.schema_options,
|
|
296
|
+
agent=config.agent,
|
|
297
|
+
mcp_servers=config.mcp_servers,
|
|
298
|
+
verify=self._wrap_verify(config.verify, step_index, step_name),
|
|
299
|
+
retry=self._wrap_retry(config.retry, step_index, step_name),
|
|
300
|
+
timeout_ms=config.timeout_ms,
|
|
301
|
+
)
|
|
302
|
+
return {
|
|
303
|
+
"output": result,
|
|
304
|
+
"next_items": [],
|
|
305
|
+
"success_count": 1 if result.status == "success" else 0,
|
|
306
|
+
"error_count": 1 if result.status == "error" else 0,
|
|
307
|
+
"filtered_count": 0,
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
def _wrap_retry(
|
|
311
|
+
self,
|
|
312
|
+
config: Optional[RetryConfig],
|
|
313
|
+
step_index: int,
|
|
314
|
+
step_name: Optional[str],
|
|
315
|
+
) -> Optional[RetryConfig]:
|
|
316
|
+
"""Wrap retry config to inject pipeline-level callback."""
|
|
317
|
+
if config is None:
|
|
318
|
+
return None
|
|
319
|
+
|
|
320
|
+
original_callback = config.on_item_retry
|
|
321
|
+
|
|
322
|
+
def wrapped_callback(item_index: int, attempt: int, error: str):
|
|
323
|
+
if original_callback:
|
|
324
|
+
original_callback(item_index, attempt, error)
|
|
325
|
+
if self._events.on_item_retry:
|
|
326
|
+
self._events.on_item_retry(ItemRetryEvent(
|
|
327
|
+
step_index=step_index,
|
|
328
|
+
step_name=step_name,
|
|
329
|
+
item_index=item_index,
|
|
330
|
+
attempt=attempt,
|
|
331
|
+
error=error,
|
|
332
|
+
))
|
|
333
|
+
|
|
334
|
+
return RetryConfig(
|
|
335
|
+
max_attempts=config.max_attempts,
|
|
336
|
+
backoff_ms=config.backoff_ms,
|
|
337
|
+
backoff_multiplier=config.backoff_multiplier,
|
|
338
|
+
retry_on=config.retry_on,
|
|
339
|
+
on_item_retry=wrapped_callback,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
def _wrap_verify(
|
|
343
|
+
self,
|
|
344
|
+
config: Optional[VerifyConfig],
|
|
345
|
+
step_index: int,
|
|
346
|
+
step_name: Optional[str],
|
|
347
|
+
) -> Optional[VerifyConfig]:
|
|
348
|
+
"""Wrap verify config to inject pipeline-level callbacks."""
|
|
349
|
+
if config is None:
|
|
350
|
+
return None
|
|
351
|
+
|
|
352
|
+
original_worker = config.on_worker_complete
|
|
353
|
+
original_verifier = config.on_verifier_complete
|
|
354
|
+
|
|
355
|
+
def wrapped_worker(item_index: int, attempt: int, status: str):
|
|
356
|
+
if original_worker:
|
|
357
|
+
original_worker(item_index, attempt, status)
|
|
358
|
+
if self._events.on_worker_complete:
|
|
359
|
+
self._events.on_worker_complete(WorkerCompleteEvent(
|
|
360
|
+
step_index=step_index,
|
|
361
|
+
step_name=step_name,
|
|
362
|
+
item_index=item_index,
|
|
363
|
+
attempt=attempt,
|
|
364
|
+
status=status,
|
|
365
|
+
))
|
|
366
|
+
|
|
367
|
+
def wrapped_verifier(item_index: int, attempt: int, passed: bool, feedback: Optional[str]):
|
|
368
|
+
if original_verifier:
|
|
369
|
+
original_verifier(item_index, attempt, passed, feedback)
|
|
370
|
+
if self._events.on_verifier_complete:
|
|
371
|
+
self._events.on_verifier_complete(VerifierCompleteEvent(
|
|
372
|
+
step_index=step_index,
|
|
373
|
+
step_name=step_name,
|
|
374
|
+
item_index=item_index,
|
|
375
|
+
attempt=attempt,
|
|
376
|
+
passed=passed,
|
|
377
|
+
feedback=feedback,
|
|
378
|
+
))
|
|
379
|
+
|
|
380
|
+
return VerifyConfig(
|
|
381
|
+
criteria=config.criteria,
|
|
382
|
+
max_attempts=config.max_attempts,
|
|
383
|
+
verifier_agent=config.verifier_agent,
|
|
384
|
+
verifier_mcp_servers=config.verifier_mcp_servers,
|
|
385
|
+
on_worker_complete=wrapped_worker,
|
|
386
|
+
on_verifier_complete=wrapped_verifier,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
def _wrap_best_of(
|
|
390
|
+
self,
|
|
391
|
+
config: Optional[BestOfConfig],
|
|
392
|
+
step_index: int,
|
|
393
|
+
step_name: Optional[str],
|
|
394
|
+
) -> Optional[BestOfConfig]:
|
|
395
|
+
"""Wrap bestOf config to inject pipeline-level callbacks."""
|
|
396
|
+
if config is None:
|
|
397
|
+
return None
|
|
398
|
+
|
|
399
|
+
original_candidate = config.on_candidate_complete
|
|
400
|
+
original_judge = config.on_judge_complete
|
|
401
|
+
|
|
402
|
+
def wrapped_candidate(item_index: int, candidate_index: int, status: str):
|
|
403
|
+
if original_candidate:
|
|
404
|
+
original_candidate(item_index, candidate_index, status)
|
|
405
|
+
if self._events.on_candidate_complete:
|
|
406
|
+
self._events.on_candidate_complete(CandidateCompleteEvent(
|
|
407
|
+
step_index=step_index,
|
|
408
|
+
step_name=step_name,
|
|
409
|
+
item_index=item_index,
|
|
410
|
+
candidate_index=candidate_index,
|
|
411
|
+
status=status,
|
|
412
|
+
))
|
|
413
|
+
|
|
414
|
+
def wrapped_judge(item_index: int, winner_index: int, reasoning: str):
|
|
415
|
+
if original_judge:
|
|
416
|
+
original_judge(item_index, winner_index, reasoning)
|
|
417
|
+
if self._events.on_judge_complete:
|
|
418
|
+
self._events.on_judge_complete(JudgeCompleteEvent(
|
|
419
|
+
step_index=step_index,
|
|
420
|
+
step_name=step_name,
|
|
421
|
+
item_index=item_index,
|
|
422
|
+
winner_index=winner_index,
|
|
423
|
+
reasoning=reasoning,
|
|
424
|
+
))
|
|
425
|
+
|
|
426
|
+
return BestOfConfig(
|
|
427
|
+
judge_criteria=config.judge_criteria,
|
|
428
|
+
n=config.n,
|
|
429
|
+
task_agents=config.task_agents,
|
|
430
|
+
judge_agent=config.judge_agent,
|
|
431
|
+
mcp_servers=config.mcp_servers,
|
|
432
|
+
judge_mcp_servers=config.judge_mcp_servers,
|
|
433
|
+
on_candidate_complete=wrapped_candidate,
|
|
434
|
+
on_judge_complete=wrapped_judge,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
# =============================================================================
|
|
439
|
+
# TERMINAL PIPELINE
|
|
440
|
+
# =============================================================================
|
|
441
|
+
|
|
442
|
+
@dataclass
|
|
443
|
+
class TerminalPipeline(Pipeline[T]):
|
|
444
|
+
"""Pipeline after reduce - no more steps can be added."""
|
|
445
|
+
|
|
446
|
+
def map(self, config: MapConfig) -> 'Pipeline':
|
|
447
|
+
"""Cannot add steps after reduce."""
|
|
448
|
+
raise RuntimeError("Cannot add steps after reduce")
|
|
449
|
+
|
|
450
|
+
def filter(self, config: FilterConfig) -> 'Pipeline':
|
|
451
|
+
"""Cannot add steps after reduce."""
|
|
452
|
+
raise RuntimeError("Cannot add steps after reduce")
|
|
453
|
+
|
|
454
|
+
def reduce(self, config: ReduceConfig) -> 'TerminalPipeline':
|
|
455
|
+
"""Cannot add steps after reduce."""
|
|
456
|
+
raise RuntimeError("Cannot add steps after reduce")
|
|
457
|
+
|
|
458
|
+
@overload
|
|
459
|
+
def on(self, handlers: PipelineEvents) -> 'TerminalPipeline[T]': ...
|
|
460
|
+
|
|
461
|
+
@overload
|
|
462
|
+
def on(self, event: str, handler: Callable) -> 'TerminalPipeline[T]': ...
|
|
463
|
+
|
|
464
|
+
def on(self, event_or_handlers: Union[PipelineEvents, str], handler: Optional[Callable] = None) -> 'TerminalPipeline[T]':
|
|
465
|
+
"""Register event handlers for step lifecycle."""
|
|
466
|
+
if isinstance(event_or_handlers, str):
|
|
467
|
+
key = PipelineEventMap.get(event_or_handlers)
|
|
468
|
+
if key is None:
|
|
469
|
+
raise ValueError(f"Unknown event: {event_or_handlers}")
|
|
470
|
+
new_events = replace(self._events, **{key: handler})
|
|
471
|
+
else:
|
|
472
|
+
new_events = replace(
|
|
473
|
+
self._events,
|
|
474
|
+
on_step_start=event_or_handlers.on_step_start or self._events.on_step_start,
|
|
475
|
+
on_step_complete=event_or_handlers.on_step_complete or self._events.on_step_complete,
|
|
476
|
+
on_step_error=event_or_handlers.on_step_error or self._events.on_step_error,
|
|
477
|
+
on_item_retry=event_or_handlers.on_item_retry or self._events.on_item_retry,
|
|
478
|
+
on_worker_complete=event_or_handlers.on_worker_complete or self._events.on_worker_complete,
|
|
479
|
+
on_verifier_complete=event_or_handlers.on_verifier_complete or self._events.on_verifier_complete,
|
|
480
|
+
on_candidate_complete=event_or_handlers.on_candidate_complete or self._events.on_candidate_complete,
|
|
481
|
+
on_judge_complete=event_or_handlers.on_judge_complete or self._events.on_judge_complete,
|
|
482
|
+
)
|
|
483
|
+
return TerminalPipeline(
|
|
484
|
+
_swarm=self._swarm,
|
|
485
|
+
_steps=self._steps,
|
|
486
|
+
_events=new_events,
|
|
487
|
+
)
|