ai-pipeline-core 0.2.8__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. ai_pipeline_core/__init__.py +14 -4
  2. ai_pipeline_core/deployment/__init__.py +46 -0
  3. ai_pipeline_core/deployment/base.py +681 -0
  4. ai_pipeline_core/deployment/contract.py +84 -0
  5. ai_pipeline_core/deployment/helpers.py +98 -0
  6. ai_pipeline_core/documents/flow_document.py +1 -1
  7. ai_pipeline_core/documents/task_document.py +1 -1
  8. ai_pipeline_core/documents/temporary_document.py +1 -1
  9. ai_pipeline_core/flow/config.py +13 -2
  10. ai_pipeline_core/flow/options.py +1 -1
  11. ai_pipeline_core/llm/client.py +22 -23
  12. ai_pipeline_core/llm/model_response.py +6 -3
  13. ai_pipeline_core/llm/model_types.py +0 -1
  14. ai_pipeline_core/pipeline.py +1 -1
  15. ai_pipeline_core/progress.py +127 -0
  16. ai_pipeline_core/prompt_builder/__init__.py +5 -0
  17. ai_pipeline_core/prompt_builder/documents_prompt.jinja2 +23 -0
  18. ai_pipeline_core/prompt_builder/global_cache.py +78 -0
  19. ai_pipeline_core/prompt_builder/new_core_documents_prompt.jinja2 +6 -0
  20. ai_pipeline_core/prompt_builder/prompt_builder.py +253 -0
  21. ai_pipeline_core/prompt_builder/system_prompt.jinja2 +41 -0
  22. ai_pipeline_core/tracing.py +1 -1
  23. ai_pipeline_core/utils/remote_deployment.py +37 -187
  24. {ai_pipeline_core-0.2.8.dist-info → ai_pipeline_core-0.3.0.dist-info}/METADATA +23 -20
  25. ai_pipeline_core-0.3.0.dist-info/RECORD +49 -0
  26. {ai_pipeline_core-0.2.8.dist-info → ai_pipeline_core-0.3.0.dist-info}/WHEEL +1 -1
  27. ai_pipeline_core/simple_runner/__init__.py +0 -14
  28. ai_pipeline_core/simple_runner/cli.py +0 -254
  29. ai_pipeline_core/simple_runner/simple_runner.py +0 -247
  30. ai_pipeline_core-0.2.8.dist-info/RECORD +0 -41
  31. {ai_pipeline_core-0.2.8.dist-info → ai_pipeline_core-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,681 @@
1
+ """Core classes for pipeline deployments.
2
+
3
+ @public
4
+
5
+ Provides the PipelineDeployment base class and related types for
6
+ creating unified, type-safe pipeline deployments with:
7
+ - Per-flow caching (skip if outputs exist)
8
+ - Per-flow uploads (immediate, not just at end)
9
+ - Prefect state hooks (on_running, on_completion, etc.)
10
+ - Smart storage provisioning (override provision_storage)
11
+ - Upload on failure (partial results saved)
12
+ """
13
+
14
+ import asyncio
15
+ import os
16
+ import re
17
+ import sys
18
+ from abc import abstractmethod
19
+ from contextlib import ExitStack
20
+ from dataclasses import dataclass
21
+ from datetime import datetime, timedelta, timezone
22
+ from hashlib import sha256
23
+ from pathlib import Path
24
+ from typing import Any, Callable, ClassVar, Generic, Protocol, TypeVar, cast, final
25
+ from uuid import UUID
26
+
27
+ import httpx
28
+ from lmnr import Laminar
29
+ from prefect import get_client
30
+ from pydantic import BaseModel, ConfigDict, Field
31
+ from pydantic_settings import CliPositionalArg, SettingsConfigDict
32
+
33
+ from ai_pipeline_core.documents import DocumentList
34
+ from ai_pipeline_core.flow.options import FlowOptions
35
+ from ai_pipeline_core.logging import get_pipeline_logger, setup_logging
36
+ from ai_pipeline_core.prefect import disable_run_logger, flow, prefect_test_harness
37
+ from ai_pipeline_core.settings import settings
38
+
39
+ from .contract import CompletedRun, DeploymentResultData, FailedRun, ProgressRun
40
+ from .helpers import (
41
+ StatusPayload,
42
+ class_name_to_deployment_name,
43
+ download_documents,
44
+ extract_generic_params,
45
+ send_webhook,
46
+ upload_documents,
47
+ )
48
+
49
+ logger = get_pipeline_logger(__name__)
50
+
51
+
52
+ class DeploymentContext(BaseModel):
53
+ """@public Infrastructure configuration for deployments.
54
+
55
+ Webhooks are optional - provide URLs to enable:
56
+ - progress_webhook_url: Per-flow progress (started/completed/cached)
57
+ - status_webhook_url: Prefect state transitions (RUNNING/FAILED/etc)
58
+ - completion_webhook_url: Final result when deployment ends
59
+ """
60
+
61
+ input_documents_urls: list[str] = Field(default_factory=list)
62
+ output_documents_urls: dict[str, str] = Field(default_factory=dict)
63
+
64
+ progress_webhook_url: str = ""
65
+ status_webhook_url: str = ""
66
+ completion_webhook_url: str = ""
67
+
68
+ model_config = ConfigDict(frozen=True, extra="forbid")
69
+
70
+
71
+ class DeploymentResult(BaseModel):
72
+ """@public Base class for deployment results."""
73
+
74
+ success: bool
75
+ error: str | None = None
76
+
77
+ model_config = ConfigDict(frozen=True, extra="forbid")
78
+
79
+
80
+ TOptions = TypeVar("TOptions", bound=FlowOptions)
81
+ TResult = TypeVar("TResult", bound=DeploymentResult)
82
+
83
+
84
+ class FlowCallable(Protocol):
85
+ """Protocol for @pipeline_flow decorated functions."""
86
+
87
+ config: Any
88
+ name: str
89
+ __name__: str
90
+
91
+ def __call__(
92
+ self, project_name: str, documents: DocumentList, flow_options: FlowOptions
93
+ ) -> Any: ...
94
+
95
+ def with_options(self, **kwargs: Any) -> "FlowCallable":
96
+ """Return a copy with overridden Prefect flow options."""
97
+ ...
98
+
99
+
100
+ @dataclass(slots=True)
101
+ class _StatusWebhookHook:
102
+ """Prefect hook that sends status webhooks on state transitions."""
103
+
104
+ webhook_url: str
105
+ flow_run_id: str
106
+ project_name: str
107
+ step: int
108
+ total_steps: int
109
+ flow_name: str
110
+
111
+ async def __call__(self, flow: Any, flow_run: Any, state: Any) -> None:
112
+ payload: StatusPayload = {
113
+ "type": "status",
114
+ "flow_run_id": str(flow_run.id),
115
+ "project_name": self.project_name,
116
+ "step": self.step,
117
+ "total_steps": self.total_steps,
118
+ "flow_name": self.flow_name,
119
+ "state": state.type.value if hasattr(state.type, "value") else str(state.type),
120
+ "state_name": state.name or "",
121
+ "timestamp": datetime.now(timezone.utc).isoformat(),
122
+ }
123
+ try:
124
+ async with httpx.AsyncClient(timeout=10) as client:
125
+ await client.post(self.webhook_url, json=payload)
126
+ except Exception as e:
127
+ logger.warning(f"Status webhook failed: {e}")
128
+
129
+
130
+ class PipelineDeployment(Generic[TOptions, TResult]):
131
+ """@public Base class for pipeline deployments.
132
+
133
+ Features enabled by default when URLs/storage provided:
134
+ - Per-flow caching: Skip flows if outputs exist in storage
135
+ - Per-flow uploads: Upload documents after each flow
136
+ - Prefect hooks: Attach state hooks if status_webhook_url provided
137
+ - Upload on failure: Save partial results if pipeline fails
138
+ """
139
+
140
+ flows: ClassVar[list[FlowCallable]]
141
+ name: ClassVar[str]
142
+ options_type: ClassVar[type[FlowOptions]]
143
+ result_type: ClassVar[type[DeploymentResult]]
144
+
145
+ def __init_subclass__(cls, **kwargs: Any) -> None:
146
+ super().__init_subclass__(**kwargs)
147
+
148
+ if not hasattr(cls, "flows"):
149
+ return
150
+
151
+ if cls.__name__.startswith("Test"):
152
+ raise TypeError(f"Deployment class name cannot start with 'Test': {cls.__name__}")
153
+
154
+ cls.name = class_name_to_deployment_name(cls.__name__)
155
+
156
+ options_type, result_type = extract_generic_params(cls)
157
+ if options_type is None or result_type is None:
158
+ raise TypeError(
159
+ f"{cls.__name__} must specify Generic parameters: "
160
+ f"class {cls.__name__}(PipelineDeployment[MyOptions, MyResult])"
161
+ )
162
+
163
+ cls.options_type = options_type
164
+ cls.result_type = result_type
165
+
166
+ if not cls.flows:
167
+ raise TypeError(f"{cls.__name__}.flows cannot be empty")
168
+
169
+ @staticmethod
170
+ @abstractmethod
171
+ def build_result(project_name: str, documents: DocumentList, options: TOptions) -> TResult:
172
+ """Extract typed result from accumulated pipeline documents."""
173
+ ...
174
+
175
+ async def provision_storage(
176
+ self,
177
+ project_name: str,
178
+ documents: DocumentList,
179
+ options: TOptions,
180
+ context: DeploymentContext,
181
+ ) -> str:
182
+ """Provision GCS storage bucket based on project name and content hash.
183
+
184
+ Default: Creates `{project}-{date}-{hash}` bucket on GCS.
185
+ Returns empty string if GCS is unavailable or creation fails.
186
+ Override for custom storage provisioning logic.
187
+ """
188
+ if not documents:
189
+ return ""
190
+
191
+ try:
192
+ from ai_pipeline_core.storage.storage import GcsStorage # noqa: PLC0415
193
+ except ImportError:
194
+ return ""
195
+
196
+ content_hash = sha256(b"".join(sorted(d.content for d in documents))).hexdigest()[:6]
197
+ base = re.sub(r"[^a-z0-9-]", "-", project_name.lower()).strip("-") or "project"
198
+ today = datetime.now(timezone.utc).strftime("%y-%m-%d")
199
+ yesterday = (datetime.now(timezone.utc) - timedelta(days=1)).strftime("%y-%m-%d")
200
+
201
+ today_bucket = f"{base[:30]}-{today}-{content_hash}"
202
+ yesterday_bucket = f"{base[:30]}-{yesterday}-{content_hash}"
203
+
204
+ # Try today's bucket, then yesterday's, then create new
205
+ for bucket_name in (today_bucket, yesterday_bucket):
206
+ try:
207
+ storage = GcsStorage(bucket_name)
208
+ if await storage.list(recursive=False):
209
+ logger.info(f"Using existing bucket: {bucket_name}")
210
+ return f"gs://{bucket_name}"
211
+ except Exception:
212
+ continue
213
+
214
+ try:
215
+ storage = GcsStorage(today_bucket)
216
+ await storage.create_bucket()
217
+ logger.info(f"Created new bucket: {today_bucket}")
218
+ return f"gs://{today_bucket}"
219
+ except Exception as e:
220
+ logger.warning(f"Failed to provision GCS storage: {e}")
221
+ return ""
222
+
223
+ async def _load_cached_output(
224
+ self, flow_fn: FlowCallable, storage_uri: str
225
+ ) -> DocumentList | None:
226
+ """Load cached outputs if they exist. Override for custom cache logic."""
227
+ try:
228
+ output_type = flow_fn.config.OUTPUT_DOCUMENT_TYPE
229
+ docs = await flow_fn.config.load_documents_by_type(storage_uri, [output_type])
230
+ return docs if docs else None
231
+ except Exception:
232
+ return None
233
+
234
+ def _build_status_hooks(
235
+ self,
236
+ context: DeploymentContext,
237
+ flow_run_id: str,
238
+ project_name: str,
239
+ step: int,
240
+ total_steps: int,
241
+ flow_name: str,
242
+ ) -> dict[str, list[Callable[..., Any]]]:
243
+ """Build Prefect hooks for status webhooks."""
244
+ hook = _StatusWebhookHook(
245
+ webhook_url=context.status_webhook_url,
246
+ flow_run_id=flow_run_id,
247
+ project_name=project_name,
248
+ step=step,
249
+ total_steps=total_steps,
250
+ flow_name=flow_name,
251
+ )
252
+ return {
253
+ "on_running": [hook],
254
+ "on_completion": [hook],
255
+ "on_failure": [hook],
256
+ "on_crashed": [hook],
257
+ "on_cancellation": [hook],
258
+ }
259
+
260
+ async def _send_progress(
261
+ self,
262
+ context: DeploymentContext,
263
+ flow_run_id: str,
264
+ project_name: str,
265
+ storage_uri: str,
266
+ step: int,
267
+ total_steps: int,
268
+ flow_name: str,
269
+ status: str,
270
+ step_progress: float = 0.0,
271
+ message: str = "",
272
+ ) -> None:
273
+ """Send progress webhook and update flow run labels."""
274
+ progress = round((step - 1 + step_progress) / total_steps, 4)
275
+
276
+ if context.progress_webhook_url:
277
+ payload = ProgressRun(
278
+ flow_run_id=UUID(flow_run_id) if flow_run_id else UUID(int=0),
279
+ project_name=project_name,
280
+ state="RUNNING",
281
+ timestamp=datetime.now(timezone.utc),
282
+ storage_uri=storage_uri,
283
+ step=step,
284
+ total_steps=total_steps,
285
+ flow_name=flow_name,
286
+ status=status,
287
+ progress=progress,
288
+ step_progress=round(step_progress, 4),
289
+ message=message,
290
+ )
291
+ try:
292
+ await send_webhook(context.progress_webhook_url, payload)
293
+ except Exception as e:
294
+ logger.warning(f"Progress webhook failed: {e}")
295
+
296
+ if flow_run_id:
297
+ try:
298
+ async with get_client() as client:
299
+ await client.update_flow_run_labels(
300
+ flow_run_id=UUID(flow_run_id),
301
+ labels={
302
+ "progress.step": step,
303
+ "progress.total_steps": total_steps,
304
+ "progress.flow_name": flow_name,
305
+ "progress.status": status,
306
+ "progress.progress": progress,
307
+ "progress.step_progress": round(step_progress, 4),
308
+ "progress.message": message,
309
+ },
310
+ )
311
+ except Exception as e:
312
+ logger.warning(f"Progress label update failed: {e}")
313
+
314
+ async def _send_completion(
315
+ self,
316
+ context: DeploymentContext,
317
+ flow_run_id: str,
318
+ project_name: str,
319
+ storage_uri: str,
320
+ result: TResult | None,
321
+ error: str | None,
322
+ ) -> None:
323
+ """Send completion webhook."""
324
+ if not context.completion_webhook_url:
325
+ return
326
+ try:
327
+ now = datetime.now(timezone.utc)
328
+ frid = UUID(flow_run_id) if flow_run_id else UUID(int=0)
329
+ payload: CompletedRun | FailedRun
330
+ if result is not None:
331
+ payload = CompletedRun(
332
+ flow_run_id=frid,
333
+ project_name=project_name,
334
+ timestamp=now,
335
+ storage_uri=storage_uri,
336
+ state="COMPLETED",
337
+ result=DeploymentResultData.model_validate(result.model_dump()),
338
+ )
339
+ else:
340
+ payload = FailedRun(
341
+ flow_run_id=frid,
342
+ project_name=project_name,
343
+ timestamp=now,
344
+ storage_uri=storage_uri,
345
+ state="FAILED",
346
+ error=error or "Unknown error",
347
+ )
348
+ await send_webhook(context.completion_webhook_url, payload)
349
+ except Exception as e:
350
+ logger.warning(f"Completion webhook failed: {e}")
351
+
352
+ @final
353
+ async def run(
354
+ self,
355
+ project_name: str,
356
+ documents: str | DocumentList,
357
+ options: TOptions,
358
+ context: DeploymentContext,
359
+ ) -> TResult:
360
+ """Execute flows with caching, uploads, and webhooks enabled by default."""
361
+ from prefect import runtime # noqa: PLC0415
362
+
363
+ total_steps = len(self.flows)
364
+ flow_run_id = str(runtime.flow_run.get_id()) if runtime.flow_run else "" # pyright: ignore[reportAttributeAccessIssue]
365
+
366
+ # Resolve storage URI and documents
367
+ if isinstance(documents, str):
368
+ storage_uri = documents
369
+ docs = await self.flows[0].config.load_documents(storage_uri)
370
+ else:
371
+ docs = documents
372
+ storage_uri = await self.provision_storage(project_name, docs, options, context)
373
+ if storage_uri and docs:
374
+ await self.flows[0].config.save_documents(
375
+ storage_uri, docs, validate_output_type=False
376
+ )
377
+
378
+ # Write identity labels for polling endpoint
379
+ if flow_run_id:
380
+ try:
381
+ async with get_client() as client:
382
+ await client.update_flow_run_labels(
383
+ flow_run_id=UUID(flow_run_id),
384
+ labels={
385
+ "pipeline.project_name": project_name,
386
+ "pipeline.storage_uri": storage_uri,
387
+ },
388
+ )
389
+ except Exception as e:
390
+ logger.warning(f"Identity label update failed: {e}")
391
+
392
+ # Download additional input documents
393
+ if context.input_documents_urls:
394
+ first_input_type = self.flows[0].config.INPUT_DOCUMENT_TYPES[0]
395
+ downloaded = await download_documents(context.input_documents_urls, first_input_type)
396
+ docs = DocumentList(list(docs) + list(downloaded))
397
+
398
+ accumulated_docs = docs
399
+ completion_sent = False
400
+
401
+ try:
402
+ for step, flow_fn in enumerate(self.flows, start=1):
403
+ flow_name = getattr(flow_fn, "name", flow_fn.__name__)
404
+ flow_run_id = str(runtime.flow_run.get_id()) if runtime.flow_run else "" # pyright: ignore[reportAttributeAccessIssue]
405
+
406
+ # Per-flow caching: check if outputs exist
407
+ if storage_uri:
408
+ cached = await self._load_cached_output(flow_fn, storage_uri)
409
+ if cached is not None:
410
+ logger.info(f"[{step}/{total_steps}] Cache hit: {flow_name}")
411
+ accumulated_docs = DocumentList(list(accumulated_docs) + list(cached))
412
+ await self._send_progress(
413
+ context,
414
+ flow_run_id,
415
+ project_name,
416
+ storage_uri,
417
+ step,
418
+ total_steps,
419
+ flow_name,
420
+ "cached",
421
+ step_progress=1.0,
422
+ message=f"Loaded from cache: {flow_name}",
423
+ )
424
+ continue
425
+
426
+ # Prefect state hooks
427
+ active_flow = flow_fn
428
+ if context.status_webhook_url:
429
+ hooks = self._build_status_hooks(
430
+ context, flow_run_id, project_name, step, total_steps, flow_name
431
+ )
432
+ active_flow = flow_fn.with_options(**hooks)
433
+
434
+ # Progress: started
435
+ await self._send_progress(
436
+ context,
437
+ flow_run_id,
438
+ project_name,
439
+ storage_uri,
440
+ step,
441
+ total_steps,
442
+ flow_name,
443
+ "started",
444
+ step_progress=0.0,
445
+ message=f"Starting: {flow_name}",
446
+ )
447
+
448
+ logger.info(f"[{step}/{total_steps}] Starting: {flow_name}")
449
+
450
+ # Load documents for this flow
451
+ if storage_uri:
452
+ current_docs = await flow_fn.config.load_documents(storage_uri)
453
+ else:
454
+ current_docs = accumulated_docs
455
+
456
+ try:
457
+ new_docs = await active_flow(project_name, current_docs, options)
458
+ except Exception as e:
459
+ # Upload partial results on failure
460
+ if context.output_documents_urls:
461
+ await upload_documents(accumulated_docs, context.output_documents_urls)
462
+ await self._send_completion(
463
+ context, flow_run_id, project_name, storage_uri, result=None, error=str(e)
464
+ )
465
+ completion_sent = True
466
+ raise
467
+
468
+ # Save to storage
469
+ if storage_uri:
470
+ await flow_fn.config.save_documents(storage_uri, new_docs)
471
+
472
+ accumulated_docs = DocumentList(list(accumulated_docs) + list(new_docs))
473
+
474
+ # Per-flow upload
475
+ if context.output_documents_urls:
476
+ await upload_documents(new_docs, context.output_documents_urls)
477
+
478
+ # Progress: completed
479
+ await self._send_progress(
480
+ context,
481
+ flow_run_id,
482
+ project_name,
483
+ storage_uri,
484
+ step,
485
+ total_steps,
486
+ flow_name,
487
+ "completed",
488
+ step_progress=1.0,
489
+ message=f"Completed: {flow_name}",
490
+ )
491
+
492
+ logger.info(f"[{step}/{total_steps}] Completed: {flow_name}")
493
+
494
+ result = self.build_result(project_name, accumulated_docs, options)
495
+ await self._send_completion(
496
+ context, flow_run_id, project_name, storage_uri, result=result, error=None
497
+ )
498
+ return result
499
+
500
+ except Exception as e:
501
+ if not completion_sent:
502
+ await self._send_completion(
503
+ context, flow_run_id, project_name, storage_uri, result=None, error=str(e)
504
+ )
505
+ raise
506
+
507
+ @final
508
+ def run_local(
509
+ self,
510
+ project_name: str,
511
+ documents: str | DocumentList,
512
+ options: TOptions,
513
+ context: DeploymentContext | None = None,
514
+ output_dir: Path | None = None,
515
+ ) -> TResult:
516
+ """Run locally with Prefect test harness."""
517
+ if context is None:
518
+ context = DeploymentContext()
519
+
520
+ # If output_dir provided and documents is DocumentList, use output_dir as storage
521
+ if output_dir and isinstance(documents, DocumentList):
522
+ output_dir.mkdir(parents=True, exist_ok=True)
523
+ documents = str(output_dir)
524
+
525
+ with prefect_test_harness():
526
+ with disable_run_logger():
527
+ result = asyncio.run(self.run(project_name, documents, options, context))
528
+
529
+ if output_dir:
530
+ output_dir.mkdir(parents=True, exist_ok=True)
531
+ (output_dir / "result.json").write_text(result.model_dump_json(indent=2))
532
+
533
+ return result
534
+
535
+ @final
536
+ def run_cli(
537
+ self,
538
+ initializer: Callable[[TOptions], tuple[str, DocumentList]] | None = None,
539
+ trace_name: str | None = None,
540
+ ) -> None:
541
+ """Execute pipeline from CLI arguments with --start/--end step control."""
542
+ if len(sys.argv) == 1:
543
+ sys.argv.append("--help")
544
+
545
+ setup_logging()
546
+ try:
547
+ Laminar.initialize()
548
+ logger.info("LMNR tracing initialized.")
549
+ except Exception as e:
550
+ logger.warning(f"Failed to initialize LMNR: {e}")
551
+
552
+ deployment = self
553
+
554
+ class _CliOptions(
555
+ deployment.options_type,
556
+ cli_parse_args=True,
557
+ cli_kebab_case=True,
558
+ cli_exit_on_error=True,
559
+ cli_prog_name=deployment.name,
560
+ cli_use_class_docs_for_groups=True,
561
+ ):
562
+ working_directory: CliPositionalArg[Path]
563
+ project_name: str | None = None
564
+ start: int = 1
565
+ end: int | None = None
566
+
567
+ model_config = SettingsConfigDict(frozen=True, extra="ignore")
568
+
569
+ opts = cast(TOptions, _CliOptions()) # type: ignore[reportCallIssue]
570
+
571
+ wd: Path = getattr(opts, "working_directory")
572
+ wd.mkdir(parents=True, exist_ok=True)
573
+
574
+ project_name = getattr(opts, "project_name") or wd.name
575
+ start_step = getattr(opts, "start", 1)
576
+ end_step = getattr(opts, "end", None)
577
+
578
+ # Initialize documents and save to working directory
579
+ if initializer and start_step == 1:
580
+ _, documents = initializer(opts)
581
+ if documents and self.flows:
582
+ first_config = getattr(self.flows[0], "config", None)
583
+ if first_config:
584
+ asyncio.run(
585
+ first_config.save_documents(str(wd), documents, validate_output_type=False)
586
+ )
587
+
588
+ context = DeploymentContext()
589
+
590
+ with ExitStack() as stack:
591
+ if trace_name:
592
+ stack.enter_context(
593
+ Laminar.start_as_current_span(
594
+ name=f"{trace_name}-{project_name}",
595
+ input=[opts.model_dump_json()],
596
+ )
597
+ )
598
+
599
+ under_pytest = "PYTEST_CURRENT_TEST" in os.environ or "pytest" in sys.modules
600
+ if not settings.prefect_api_key and not under_pytest:
601
+ stack.enter_context(prefect_test_harness())
602
+ stack.enter_context(disable_run_logger())
603
+
604
+ result = asyncio.run(
605
+ self._run_with_steps(
606
+ project_name=project_name,
607
+ storage_uri=str(wd),
608
+ options=opts,
609
+ context=context,
610
+ start_step=start_step,
611
+ end_step=end_step,
612
+ )
613
+ )
614
+
615
+ result_file = wd / "result.json"
616
+ result_file.write_text(result.model_dump_json(indent=2))
617
+ logger.info(f"Result saved to {result_file}")
618
+
619
+ async def _run_with_steps(
620
+ self,
621
+ project_name: str,
622
+ storage_uri: str,
623
+ options: TOptions,
624
+ context: DeploymentContext,
625
+ start_step: int = 1,
626
+ end_step: int | None = None,
627
+ ) -> TResult:
628
+ """Run pipeline with start/end step control for CLI resume support."""
629
+ if end_step is None:
630
+ end_step = len(self.flows)
631
+
632
+ total_steps = len(self.flows)
633
+ accumulated_docs = DocumentList([])
634
+
635
+ for i in range(start_step - 1, end_step):
636
+ step = i + 1
637
+ flow_fn = self.flows[i]
638
+ flow_name = getattr(flow_fn, "name", flow_fn.__name__)
639
+ logger.info(f"--- [Step {step}/{total_steps}] {flow_name} ---")
640
+
641
+ # Check cache
642
+ cached = await self._load_cached_output(flow_fn, storage_uri)
643
+ if cached is not None:
644
+ logger.info(f"[{step}/{total_steps}] Cache hit: {flow_name}")
645
+ accumulated_docs = DocumentList(list(accumulated_docs) + list(cached))
646
+ continue
647
+
648
+ current_docs = await flow_fn.config.load_documents(storage_uri)
649
+ new_docs = await flow_fn(project_name, current_docs, options)
650
+ await flow_fn.config.save_documents(storage_uri, new_docs)
651
+ accumulated_docs = DocumentList(list(accumulated_docs) + list(new_docs))
652
+
653
+ return self.build_result(project_name, accumulated_docs, options)
654
+
655
+ @final
656
+ def as_prefect_flow(self) -> Callable[..., Any]:
657
+ """Generate Prefect flow for production deployment."""
658
+ deployment = self
659
+
660
+ @flow( # pyright: ignore[reportUntypedFunctionDecorator]
661
+ name=self.name,
662
+ flow_run_name=f"{self.name}-{{project_name}}",
663
+ persist_result=True,
664
+ result_serializer="json",
665
+ )
666
+ async def _deployment_flow(
667
+ project_name: str,
668
+ documents: str | DocumentList,
669
+ options: FlowOptions,
670
+ context: DeploymentContext,
671
+ ) -> DeploymentResult:
672
+ return await deployment.run(project_name, documents, cast(Any, options), context)
673
+
674
+ return _deployment_flow
675
+
676
+
677
+ __all__ = [
678
+ "DeploymentContext",
679
+ "DeploymentResult",
680
+ "PipelineDeployment",
681
+ ]