ai-pipeline-core 0.2.8__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_pipeline_core/__init__.py +14 -4
- ai_pipeline_core/deployment/__init__.py +46 -0
- ai_pipeline_core/deployment/base.py +681 -0
- ai_pipeline_core/deployment/contract.py +84 -0
- ai_pipeline_core/deployment/helpers.py +98 -0
- ai_pipeline_core/documents/flow_document.py +1 -1
- ai_pipeline_core/documents/task_document.py +1 -1
- ai_pipeline_core/documents/temporary_document.py +1 -1
- ai_pipeline_core/flow/config.py +13 -2
- ai_pipeline_core/flow/options.py +1 -1
- ai_pipeline_core/llm/client.py +22 -23
- ai_pipeline_core/llm/model_response.py +6 -3
- ai_pipeline_core/llm/model_types.py +0 -1
- ai_pipeline_core/pipeline.py +1 -1
- ai_pipeline_core/progress.py +127 -0
- ai_pipeline_core/prompt_builder/__init__.py +5 -0
- ai_pipeline_core/prompt_builder/documents_prompt.jinja2 +23 -0
- ai_pipeline_core/prompt_builder/global_cache.py +78 -0
- ai_pipeline_core/prompt_builder/new_core_documents_prompt.jinja2 +6 -0
- ai_pipeline_core/prompt_builder/prompt_builder.py +253 -0
- ai_pipeline_core/prompt_builder/system_prompt.jinja2 +41 -0
- ai_pipeline_core/tracing.py +1 -1
- ai_pipeline_core/utils/remote_deployment.py +37 -187
- {ai_pipeline_core-0.2.8.dist-info → ai_pipeline_core-0.3.0.dist-info}/METADATA +23 -20
- ai_pipeline_core-0.3.0.dist-info/RECORD +49 -0
- {ai_pipeline_core-0.2.8.dist-info → ai_pipeline_core-0.3.0.dist-info}/WHEEL +1 -1
- ai_pipeline_core/simple_runner/__init__.py +0 -14
- ai_pipeline_core/simple_runner/cli.py +0 -254
- ai_pipeline_core/simple_runner/simple_runner.py +0 -247
- ai_pipeline_core-0.2.8.dist-info/RECORD +0 -41
- {ai_pipeline_core-0.2.8.dist-info → ai_pipeline_core-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,681 @@
|
|
|
1
|
+
"""Core classes for pipeline deployments.
|
|
2
|
+
|
|
3
|
+
@public
|
|
4
|
+
|
|
5
|
+
Provides the PipelineDeployment base class and related types for
|
|
6
|
+
creating unified, type-safe pipeline deployments with:
|
|
7
|
+
- Per-flow caching (skip if outputs exist)
|
|
8
|
+
- Per-flow uploads (immediate, not just at end)
|
|
9
|
+
- Prefect state hooks (on_running, on_completion, etc.)
|
|
10
|
+
- Smart storage provisioning (override provision_storage)
|
|
11
|
+
- Upload on failure (partial results saved)
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import asyncio
|
|
15
|
+
import os
|
|
16
|
+
import re
|
|
17
|
+
import sys
|
|
18
|
+
from abc import abstractmethod
|
|
19
|
+
from contextlib import ExitStack
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
from datetime import datetime, timedelta, timezone
|
|
22
|
+
from hashlib import sha256
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
from typing import Any, Callable, ClassVar, Generic, Protocol, TypeVar, cast, final
|
|
25
|
+
from uuid import UUID
|
|
26
|
+
|
|
27
|
+
import httpx
|
|
28
|
+
from lmnr import Laminar
|
|
29
|
+
from prefect import get_client
|
|
30
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
31
|
+
from pydantic_settings import CliPositionalArg, SettingsConfigDict
|
|
32
|
+
|
|
33
|
+
from ai_pipeline_core.documents import DocumentList
|
|
34
|
+
from ai_pipeline_core.flow.options import FlowOptions
|
|
35
|
+
from ai_pipeline_core.logging import get_pipeline_logger, setup_logging
|
|
36
|
+
from ai_pipeline_core.prefect import disable_run_logger, flow, prefect_test_harness
|
|
37
|
+
from ai_pipeline_core.settings import settings
|
|
38
|
+
|
|
39
|
+
from .contract import CompletedRun, DeploymentResultData, FailedRun, ProgressRun
|
|
40
|
+
from .helpers import (
|
|
41
|
+
StatusPayload,
|
|
42
|
+
class_name_to_deployment_name,
|
|
43
|
+
download_documents,
|
|
44
|
+
extract_generic_params,
|
|
45
|
+
send_webhook,
|
|
46
|
+
upload_documents,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
logger = get_pipeline_logger(__name__)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class DeploymentContext(BaseModel):
|
|
53
|
+
"""@public Infrastructure configuration for deployments.
|
|
54
|
+
|
|
55
|
+
Webhooks are optional - provide URLs to enable:
|
|
56
|
+
- progress_webhook_url: Per-flow progress (started/completed/cached)
|
|
57
|
+
- status_webhook_url: Prefect state transitions (RUNNING/FAILED/etc)
|
|
58
|
+
- completion_webhook_url: Final result when deployment ends
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
input_documents_urls: list[str] = Field(default_factory=list)
|
|
62
|
+
output_documents_urls: dict[str, str] = Field(default_factory=dict)
|
|
63
|
+
|
|
64
|
+
progress_webhook_url: str = ""
|
|
65
|
+
status_webhook_url: str = ""
|
|
66
|
+
completion_webhook_url: str = ""
|
|
67
|
+
|
|
68
|
+
model_config = ConfigDict(frozen=True, extra="forbid")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class DeploymentResult(BaseModel):
|
|
72
|
+
"""@public Base class for deployment results."""
|
|
73
|
+
|
|
74
|
+
success: bool
|
|
75
|
+
error: str | None = None
|
|
76
|
+
|
|
77
|
+
model_config = ConfigDict(frozen=True, extra="forbid")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
TOptions = TypeVar("TOptions", bound=FlowOptions)
|
|
81
|
+
TResult = TypeVar("TResult", bound=DeploymentResult)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class FlowCallable(Protocol):
|
|
85
|
+
"""Protocol for @pipeline_flow decorated functions."""
|
|
86
|
+
|
|
87
|
+
config: Any
|
|
88
|
+
name: str
|
|
89
|
+
__name__: str
|
|
90
|
+
|
|
91
|
+
def __call__(
|
|
92
|
+
self, project_name: str, documents: DocumentList, flow_options: FlowOptions
|
|
93
|
+
) -> Any: ...
|
|
94
|
+
|
|
95
|
+
def with_options(self, **kwargs: Any) -> "FlowCallable":
|
|
96
|
+
"""Return a copy with overridden Prefect flow options."""
|
|
97
|
+
...
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@dataclass(slots=True)
|
|
101
|
+
class _StatusWebhookHook:
|
|
102
|
+
"""Prefect hook that sends status webhooks on state transitions."""
|
|
103
|
+
|
|
104
|
+
webhook_url: str
|
|
105
|
+
flow_run_id: str
|
|
106
|
+
project_name: str
|
|
107
|
+
step: int
|
|
108
|
+
total_steps: int
|
|
109
|
+
flow_name: str
|
|
110
|
+
|
|
111
|
+
async def __call__(self, flow: Any, flow_run: Any, state: Any) -> None:
|
|
112
|
+
payload: StatusPayload = {
|
|
113
|
+
"type": "status",
|
|
114
|
+
"flow_run_id": str(flow_run.id),
|
|
115
|
+
"project_name": self.project_name,
|
|
116
|
+
"step": self.step,
|
|
117
|
+
"total_steps": self.total_steps,
|
|
118
|
+
"flow_name": self.flow_name,
|
|
119
|
+
"state": state.type.value if hasattr(state.type, "value") else str(state.type),
|
|
120
|
+
"state_name": state.name or "",
|
|
121
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
122
|
+
}
|
|
123
|
+
try:
|
|
124
|
+
async with httpx.AsyncClient(timeout=10) as client:
|
|
125
|
+
await client.post(self.webhook_url, json=payload)
|
|
126
|
+
except Exception as e:
|
|
127
|
+
logger.warning(f"Status webhook failed: {e}")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class PipelineDeployment(Generic[TOptions, TResult]):
|
|
131
|
+
"""@public Base class for pipeline deployments.
|
|
132
|
+
|
|
133
|
+
Features enabled by default when URLs/storage provided:
|
|
134
|
+
- Per-flow caching: Skip flows if outputs exist in storage
|
|
135
|
+
- Per-flow uploads: Upload documents after each flow
|
|
136
|
+
- Prefect hooks: Attach state hooks if status_webhook_url provided
|
|
137
|
+
- Upload on failure: Save partial results if pipeline fails
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
flows: ClassVar[list[FlowCallable]]
|
|
141
|
+
name: ClassVar[str]
|
|
142
|
+
options_type: ClassVar[type[FlowOptions]]
|
|
143
|
+
result_type: ClassVar[type[DeploymentResult]]
|
|
144
|
+
|
|
145
|
+
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
146
|
+
super().__init_subclass__(**kwargs)
|
|
147
|
+
|
|
148
|
+
if not hasattr(cls, "flows"):
|
|
149
|
+
return
|
|
150
|
+
|
|
151
|
+
if cls.__name__.startswith("Test"):
|
|
152
|
+
raise TypeError(f"Deployment class name cannot start with 'Test': {cls.__name__}")
|
|
153
|
+
|
|
154
|
+
cls.name = class_name_to_deployment_name(cls.__name__)
|
|
155
|
+
|
|
156
|
+
options_type, result_type = extract_generic_params(cls)
|
|
157
|
+
if options_type is None or result_type is None:
|
|
158
|
+
raise TypeError(
|
|
159
|
+
f"{cls.__name__} must specify Generic parameters: "
|
|
160
|
+
f"class {cls.__name__}(PipelineDeployment[MyOptions, MyResult])"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
cls.options_type = options_type
|
|
164
|
+
cls.result_type = result_type
|
|
165
|
+
|
|
166
|
+
if not cls.flows:
|
|
167
|
+
raise TypeError(f"{cls.__name__}.flows cannot be empty")
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
@abstractmethod
|
|
171
|
+
def build_result(project_name: str, documents: DocumentList, options: TOptions) -> TResult:
|
|
172
|
+
"""Extract typed result from accumulated pipeline documents."""
|
|
173
|
+
...
|
|
174
|
+
|
|
175
|
+
async def provision_storage(
|
|
176
|
+
self,
|
|
177
|
+
project_name: str,
|
|
178
|
+
documents: DocumentList,
|
|
179
|
+
options: TOptions,
|
|
180
|
+
context: DeploymentContext,
|
|
181
|
+
) -> str:
|
|
182
|
+
"""Provision GCS storage bucket based on project name and content hash.
|
|
183
|
+
|
|
184
|
+
Default: Creates `{project}-{date}-{hash}` bucket on GCS.
|
|
185
|
+
Returns empty string if GCS is unavailable or creation fails.
|
|
186
|
+
Override for custom storage provisioning logic.
|
|
187
|
+
"""
|
|
188
|
+
if not documents:
|
|
189
|
+
return ""
|
|
190
|
+
|
|
191
|
+
try:
|
|
192
|
+
from ai_pipeline_core.storage.storage import GcsStorage # noqa: PLC0415
|
|
193
|
+
except ImportError:
|
|
194
|
+
return ""
|
|
195
|
+
|
|
196
|
+
content_hash = sha256(b"".join(sorted(d.content for d in documents))).hexdigest()[:6]
|
|
197
|
+
base = re.sub(r"[^a-z0-9-]", "-", project_name.lower()).strip("-") or "project"
|
|
198
|
+
today = datetime.now(timezone.utc).strftime("%y-%m-%d")
|
|
199
|
+
yesterday = (datetime.now(timezone.utc) - timedelta(days=1)).strftime("%y-%m-%d")
|
|
200
|
+
|
|
201
|
+
today_bucket = f"{base[:30]}-{today}-{content_hash}"
|
|
202
|
+
yesterday_bucket = f"{base[:30]}-{yesterday}-{content_hash}"
|
|
203
|
+
|
|
204
|
+
# Try today's bucket, then yesterday's, then create new
|
|
205
|
+
for bucket_name in (today_bucket, yesterday_bucket):
|
|
206
|
+
try:
|
|
207
|
+
storage = GcsStorage(bucket_name)
|
|
208
|
+
if await storage.list(recursive=False):
|
|
209
|
+
logger.info(f"Using existing bucket: {bucket_name}")
|
|
210
|
+
return f"gs://{bucket_name}"
|
|
211
|
+
except Exception:
|
|
212
|
+
continue
|
|
213
|
+
|
|
214
|
+
try:
|
|
215
|
+
storage = GcsStorage(today_bucket)
|
|
216
|
+
await storage.create_bucket()
|
|
217
|
+
logger.info(f"Created new bucket: {today_bucket}")
|
|
218
|
+
return f"gs://{today_bucket}"
|
|
219
|
+
except Exception as e:
|
|
220
|
+
logger.warning(f"Failed to provision GCS storage: {e}")
|
|
221
|
+
return ""
|
|
222
|
+
|
|
223
|
+
async def _load_cached_output(
|
|
224
|
+
self, flow_fn: FlowCallable, storage_uri: str
|
|
225
|
+
) -> DocumentList | None:
|
|
226
|
+
"""Load cached outputs if they exist. Override for custom cache logic."""
|
|
227
|
+
try:
|
|
228
|
+
output_type = flow_fn.config.OUTPUT_DOCUMENT_TYPE
|
|
229
|
+
docs = await flow_fn.config.load_documents_by_type(storage_uri, [output_type])
|
|
230
|
+
return docs if docs else None
|
|
231
|
+
except Exception:
|
|
232
|
+
return None
|
|
233
|
+
|
|
234
|
+
def _build_status_hooks(
|
|
235
|
+
self,
|
|
236
|
+
context: DeploymentContext,
|
|
237
|
+
flow_run_id: str,
|
|
238
|
+
project_name: str,
|
|
239
|
+
step: int,
|
|
240
|
+
total_steps: int,
|
|
241
|
+
flow_name: str,
|
|
242
|
+
) -> dict[str, list[Callable[..., Any]]]:
|
|
243
|
+
"""Build Prefect hooks for status webhooks."""
|
|
244
|
+
hook = _StatusWebhookHook(
|
|
245
|
+
webhook_url=context.status_webhook_url,
|
|
246
|
+
flow_run_id=flow_run_id,
|
|
247
|
+
project_name=project_name,
|
|
248
|
+
step=step,
|
|
249
|
+
total_steps=total_steps,
|
|
250
|
+
flow_name=flow_name,
|
|
251
|
+
)
|
|
252
|
+
return {
|
|
253
|
+
"on_running": [hook],
|
|
254
|
+
"on_completion": [hook],
|
|
255
|
+
"on_failure": [hook],
|
|
256
|
+
"on_crashed": [hook],
|
|
257
|
+
"on_cancellation": [hook],
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
async def _send_progress(
|
|
261
|
+
self,
|
|
262
|
+
context: DeploymentContext,
|
|
263
|
+
flow_run_id: str,
|
|
264
|
+
project_name: str,
|
|
265
|
+
storage_uri: str,
|
|
266
|
+
step: int,
|
|
267
|
+
total_steps: int,
|
|
268
|
+
flow_name: str,
|
|
269
|
+
status: str,
|
|
270
|
+
step_progress: float = 0.0,
|
|
271
|
+
message: str = "",
|
|
272
|
+
) -> None:
|
|
273
|
+
"""Send progress webhook and update flow run labels."""
|
|
274
|
+
progress = round((step - 1 + step_progress) / total_steps, 4)
|
|
275
|
+
|
|
276
|
+
if context.progress_webhook_url:
|
|
277
|
+
payload = ProgressRun(
|
|
278
|
+
flow_run_id=UUID(flow_run_id) if flow_run_id else UUID(int=0),
|
|
279
|
+
project_name=project_name,
|
|
280
|
+
state="RUNNING",
|
|
281
|
+
timestamp=datetime.now(timezone.utc),
|
|
282
|
+
storage_uri=storage_uri,
|
|
283
|
+
step=step,
|
|
284
|
+
total_steps=total_steps,
|
|
285
|
+
flow_name=flow_name,
|
|
286
|
+
status=status,
|
|
287
|
+
progress=progress,
|
|
288
|
+
step_progress=round(step_progress, 4),
|
|
289
|
+
message=message,
|
|
290
|
+
)
|
|
291
|
+
try:
|
|
292
|
+
await send_webhook(context.progress_webhook_url, payload)
|
|
293
|
+
except Exception as e:
|
|
294
|
+
logger.warning(f"Progress webhook failed: {e}")
|
|
295
|
+
|
|
296
|
+
if flow_run_id:
|
|
297
|
+
try:
|
|
298
|
+
async with get_client() as client:
|
|
299
|
+
await client.update_flow_run_labels(
|
|
300
|
+
flow_run_id=UUID(flow_run_id),
|
|
301
|
+
labels={
|
|
302
|
+
"progress.step": step,
|
|
303
|
+
"progress.total_steps": total_steps,
|
|
304
|
+
"progress.flow_name": flow_name,
|
|
305
|
+
"progress.status": status,
|
|
306
|
+
"progress.progress": progress,
|
|
307
|
+
"progress.step_progress": round(step_progress, 4),
|
|
308
|
+
"progress.message": message,
|
|
309
|
+
},
|
|
310
|
+
)
|
|
311
|
+
except Exception as e:
|
|
312
|
+
logger.warning(f"Progress label update failed: {e}")
|
|
313
|
+
|
|
314
|
+
async def _send_completion(
|
|
315
|
+
self,
|
|
316
|
+
context: DeploymentContext,
|
|
317
|
+
flow_run_id: str,
|
|
318
|
+
project_name: str,
|
|
319
|
+
storage_uri: str,
|
|
320
|
+
result: TResult | None,
|
|
321
|
+
error: str | None,
|
|
322
|
+
) -> None:
|
|
323
|
+
"""Send completion webhook."""
|
|
324
|
+
if not context.completion_webhook_url:
|
|
325
|
+
return
|
|
326
|
+
try:
|
|
327
|
+
now = datetime.now(timezone.utc)
|
|
328
|
+
frid = UUID(flow_run_id) if flow_run_id else UUID(int=0)
|
|
329
|
+
payload: CompletedRun | FailedRun
|
|
330
|
+
if result is not None:
|
|
331
|
+
payload = CompletedRun(
|
|
332
|
+
flow_run_id=frid,
|
|
333
|
+
project_name=project_name,
|
|
334
|
+
timestamp=now,
|
|
335
|
+
storage_uri=storage_uri,
|
|
336
|
+
state="COMPLETED",
|
|
337
|
+
result=DeploymentResultData.model_validate(result.model_dump()),
|
|
338
|
+
)
|
|
339
|
+
else:
|
|
340
|
+
payload = FailedRun(
|
|
341
|
+
flow_run_id=frid,
|
|
342
|
+
project_name=project_name,
|
|
343
|
+
timestamp=now,
|
|
344
|
+
storage_uri=storage_uri,
|
|
345
|
+
state="FAILED",
|
|
346
|
+
error=error or "Unknown error",
|
|
347
|
+
)
|
|
348
|
+
await send_webhook(context.completion_webhook_url, payload)
|
|
349
|
+
except Exception as e:
|
|
350
|
+
logger.warning(f"Completion webhook failed: {e}")
|
|
351
|
+
|
|
352
|
+
@final
|
|
353
|
+
async def run(
|
|
354
|
+
self,
|
|
355
|
+
project_name: str,
|
|
356
|
+
documents: str | DocumentList,
|
|
357
|
+
options: TOptions,
|
|
358
|
+
context: DeploymentContext,
|
|
359
|
+
) -> TResult:
|
|
360
|
+
"""Execute flows with caching, uploads, and webhooks enabled by default."""
|
|
361
|
+
from prefect import runtime # noqa: PLC0415
|
|
362
|
+
|
|
363
|
+
total_steps = len(self.flows)
|
|
364
|
+
flow_run_id = str(runtime.flow_run.get_id()) if runtime.flow_run else "" # pyright: ignore[reportAttributeAccessIssue]
|
|
365
|
+
|
|
366
|
+
# Resolve storage URI and documents
|
|
367
|
+
if isinstance(documents, str):
|
|
368
|
+
storage_uri = documents
|
|
369
|
+
docs = await self.flows[0].config.load_documents(storage_uri)
|
|
370
|
+
else:
|
|
371
|
+
docs = documents
|
|
372
|
+
storage_uri = await self.provision_storage(project_name, docs, options, context)
|
|
373
|
+
if storage_uri and docs:
|
|
374
|
+
await self.flows[0].config.save_documents(
|
|
375
|
+
storage_uri, docs, validate_output_type=False
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# Write identity labels for polling endpoint
|
|
379
|
+
if flow_run_id:
|
|
380
|
+
try:
|
|
381
|
+
async with get_client() as client:
|
|
382
|
+
await client.update_flow_run_labels(
|
|
383
|
+
flow_run_id=UUID(flow_run_id),
|
|
384
|
+
labels={
|
|
385
|
+
"pipeline.project_name": project_name,
|
|
386
|
+
"pipeline.storage_uri": storage_uri,
|
|
387
|
+
},
|
|
388
|
+
)
|
|
389
|
+
except Exception as e:
|
|
390
|
+
logger.warning(f"Identity label update failed: {e}")
|
|
391
|
+
|
|
392
|
+
# Download additional input documents
|
|
393
|
+
if context.input_documents_urls:
|
|
394
|
+
first_input_type = self.flows[0].config.INPUT_DOCUMENT_TYPES[0]
|
|
395
|
+
downloaded = await download_documents(context.input_documents_urls, first_input_type)
|
|
396
|
+
docs = DocumentList(list(docs) + list(downloaded))
|
|
397
|
+
|
|
398
|
+
accumulated_docs = docs
|
|
399
|
+
completion_sent = False
|
|
400
|
+
|
|
401
|
+
try:
|
|
402
|
+
for step, flow_fn in enumerate(self.flows, start=1):
|
|
403
|
+
flow_name = getattr(flow_fn, "name", flow_fn.__name__)
|
|
404
|
+
flow_run_id = str(runtime.flow_run.get_id()) if runtime.flow_run else "" # pyright: ignore[reportAttributeAccessIssue]
|
|
405
|
+
|
|
406
|
+
# Per-flow caching: check if outputs exist
|
|
407
|
+
if storage_uri:
|
|
408
|
+
cached = await self._load_cached_output(flow_fn, storage_uri)
|
|
409
|
+
if cached is not None:
|
|
410
|
+
logger.info(f"[{step}/{total_steps}] Cache hit: {flow_name}")
|
|
411
|
+
accumulated_docs = DocumentList(list(accumulated_docs) + list(cached))
|
|
412
|
+
await self._send_progress(
|
|
413
|
+
context,
|
|
414
|
+
flow_run_id,
|
|
415
|
+
project_name,
|
|
416
|
+
storage_uri,
|
|
417
|
+
step,
|
|
418
|
+
total_steps,
|
|
419
|
+
flow_name,
|
|
420
|
+
"cached",
|
|
421
|
+
step_progress=1.0,
|
|
422
|
+
message=f"Loaded from cache: {flow_name}",
|
|
423
|
+
)
|
|
424
|
+
continue
|
|
425
|
+
|
|
426
|
+
# Prefect state hooks
|
|
427
|
+
active_flow = flow_fn
|
|
428
|
+
if context.status_webhook_url:
|
|
429
|
+
hooks = self._build_status_hooks(
|
|
430
|
+
context, flow_run_id, project_name, step, total_steps, flow_name
|
|
431
|
+
)
|
|
432
|
+
active_flow = flow_fn.with_options(**hooks)
|
|
433
|
+
|
|
434
|
+
# Progress: started
|
|
435
|
+
await self._send_progress(
|
|
436
|
+
context,
|
|
437
|
+
flow_run_id,
|
|
438
|
+
project_name,
|
|
439
|
+
storage_uri,
|
|
440
|
+
step,
|
|
441
|
+
total_steps,
|
|
442
|
+
flow_name,
|
|
443
|
+
"started",
|
|
444
|
+
step_progress=0.0,
|
|
445
|
+
message=f"Starting: {flow_name}",
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
logger.info(f"[{step}/{total_steps}] Starting: {flow_name}")
|
|
449
|
+
|
|
450
|
+
# Load documents for this flow
|
|
451
|
+
if storage_uri:
|
|
452
|
+
current_docs = await flow_fn.config.load_documents(storage_uri)
|
|
453
|
+
else:
|
|
454
|
+
current_docs = accumulated_docs
|
|
455
|
+
|
|
456
|
+
try:
|
|
457
|
+
new_docs = await active_flow(project_name, current_docs, options)
|
|
458
|
+
except Exception as e:
|
|
459
|
+
# Upload partial results on failure
|
|
460
|
+
if context.output_documents_urls:
|
|
461
|
+
await upload_documents(accumulated_docs, context.output_documents_urls)
|
|
462
|
+
await self._send_completion(
|
|
463
|
+
context, flow_run_id, project_name, storage_uri, result=None, error=str(e)
|
|
464
|
+
)
|
|
465
|
+
completion_sent = True
|
|
466
|
+
raise
|
|
467
|
+
|
|
468
|
+
# Save to storage
|
|
469
|
+
if storage_uri:
|
|
470
|
+
await flow_fn.config.save_documents(storage_uri, new_docs)
|
|
471
|
+
|
|
472
|
+
accumulated_docs = DocumentList(list(accumulated_docs) + list(new_docs))
|
|
473
|
+
|
|
474
|
+
# Per-flow upload
|
|
475
|
+
if context.output_documents_urls:
|
|
476
|
+
await upload_documents(new_docs, context.output_documents_urls)
|
|
477
|
+
|
|
478
|
+
# Progress: completed
|
|
479
|
+
await self._send_progress(
|
|
480
|
+
context,
|
|
481
|
+
flow_run_id,
|
|
482
|
+
project_name,
|
|
483
|
+
storage_uri,
|
|
484
|
+
step,
|
|
485
|
+
total_steps,
|
|
486
|
+
flow_name,
|
|
487
|
+
"completed",
|
|
488
|
+
step_progress=1.0,
|
|
489
|
+
message=f"Completed: {flow_name}",
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
logger.info(f"[{step}/{total_steps}] Completed: {flow_name}")
|
|
493
|
+
|
|
494
|
+
result = self.build_result(project_name, accumulated_docs, options)
|
|
495
|
+
await self._send_completion(
|
|
496
|
+
context, flow_run_id, project_name, storage_uri, result=result, error=None
|
|
497
|
+
)
|
|
498
|
+
return result
|
|
499
|
+
|
|
500
|
+
except Exception as e:
|
|
501
|
+
if not completion_sent:
|
|
502
|
+
await self._send_completion(
|
|
503
|
+
context, flow_run_id, project_name, storage_uri, result=None, error=str(e)
|
|
504
|
+
)
|
|
505
|
+
raise
|
|
506
|
+
|
|
507
|
+
@final
|
|
508
|
+
def run_local(
|
|
509
|
+
self,
|
|
510
|
+
project_name: str,
|
|
511
|
+
documents: str | DocumentList,
|
|
512
|
+
options: TOptions,
|
|
513
|
+
context: DeploymentContext | None = None,
|
|
514
|
+
output_dir: Path | None = None,
|
|
515
|
+
) -> TResult:
|
|
516
|
+
"""Run locally with Prefect test harness."""
|
|
517
|
+
if context is None:
|
|
518
|
+
context = DeploymentContext()
|
|
519
|
+
|
|
520
|
+
# If output_dir provided and documents is DocumentList, use output_dir as storage
|
|
521
|
+
if output_dir and isinstance(documents, DocumentList):
|
|
522
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
523
|
+
documents = str(output_dir)
|
|
524
|
+
|
|
525
|
+
with prefect_test_harness():
|
|
526
|
+
with disable_run_logger():
|
|
527
|
+
result = asyncio.run(self.run(project_name, documents, options, context))
|
|
528
|
+
|
|
529
|
+
if output_dir:
|
|
530
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
531
|
+
(output_dir / "result.json").write_text(result.model_dump_json(indent=2))
|
|
532
|
+
|
|
533
|
+
return result
|
|
534
|
+
|
|
535
|
+
@final
|
|
536
|
+
def run_cli(
|
|
537
|
+
self,
|
|
538
|
+
initializer: Callable[[TOptions], tuple[str, DocumentList]] | None = None,
|
|
539
|
+
trace_name: str | None = None,
|
|
540
|
+
) -> None:
|
|
541
|
+
"""Execute pipeline from CLI arguments with --start/--end step control."""
|
|
542
|
+
if len(sys.argv) == 1:
|
|
543
|
+
sys.argv.append("--help")
|
|
544
|
+
|
|
545
|
+
setup_logging()
|
|
546
|
+
try:
|
|
547
|
+
Laminar.initialize()
|
|
548
|
+
logger.info("LMNR tracing initialized.")
|
|
549
|
+
except Exception as e:
|
|
550
|
+
logger.warning(f"Failed to initialize LMNR: {e}")
|
|
551
|
+
|
|
552
|
+
deployment = self
|
|
553
|
+
|
|
554
|
+
class _CliOptions(
|
|
555
|
+
deployment.options_type,
|
|
556
|
+
cli_parse_args=True,
|
|
557
|
+
cli_kebab_case=True,
|
|
558
|
+
cli_exit_on_error=True,
|
|
559
|
+
cli_prog_name=deployment.name,
|
|
560
|
+
cli_use_class_docs_for_groups=True,
|
|
561
|
+
):
|
|
562
|
+
working_directory: CliPositionalArg[Path]
|
|
563
|
+
project_name: str | None = None
|
|
564
|
+
start: int = 1
|
|
565
|
+
end: int | None = None
|
|
566
|
+
|
|
567
|
+
model_config = SettingsConfigDict(frozen=True, extra="ignore")
|
|
568
|
+
|
|
569
|
+
opts = cast(TOptions, _CliOptions()) # type: ignore[reportCallIssue]
|
|
570
|
+
|
|
571
|
+
wd: Path = getattr(opts, "working_directory")
|
|
572
|
+
wd.mkdir(parents=True, exist_ok=True)
|
|
573
|
+
|
|
574
|
+
project_name = getattr(opts, "project_name") or wd.name
|
|
575
|
+
start_step = getattr(opts, "start", 1)
|
|
576
|
+
end_step = getattr(opts, "end", None)
|
|
577
|
+
|
|
578
|
+
# Initialize documents and save to working directory
|
|
579
|
+
if initializer and start_step == 1:
|
|
580
|
+
_, documents = initializer(opts)
|
|
581
|
+
if documents and self.flows:
|
|
582
|
+
first_config = getattr(self.flows[0], "config", None)
|
|
583
|
+
if first_config:
|
|
584
|
+
asyncio.run(
|
|
585
|
+
first_config.save_documents(str(wd), documents, validate_output_type=False)
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
context = DeploymentContext()
|
|
589
|
+
|
|
590
|
+
with ExitStack() as stack:
|
|
591
|
+
if trace_name:
|
|
592
|
+
stack.enter_context(
|
|
593
|
+
Laminar.start_as_current_span(
|
|
594
|
+
name=f"{trace_name}-{project_name}",
|
|
595
|
+
input=[opts.model_dump_json()],
|
|
596
|
+
)
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
under_pytest = "PYTEST_CURRENT_TEST" in os.environ or "pytest" in sys.modules
|
|
600
|
+
if not settings.prefect_api_key and not under_pytest:
|
|
601
|
+
stack.enter_context(prefect_test_harness())
|
|
602
|
+
stack.enter_context(disable_run_logger())
|
|
603
|
+
|
|
604
|
+
result = asyncio.run(
|
|
605
|
+
self._run_with_steps(
|
|
606
|
+
project_name=project_name,
|
|
607
|
+
storage_uri=str(wd),
|
|
608
|
+
options=opts,
|
|
609
|
+
context=context,
|
|
610
|
+
start_step=start_step,
|
|
611
|
+
end_step=end_step,
|
|
612
|
+
)
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
result_file = wd / "result.json"
|
|
616
|
+
result_file.write_text(result.model_dump_json(indent=2))
|
|
617
|
+
logger.info(f"Result saved to {result_file}")
|
|
618
|
+
|
|
619
|
+
async def _run_with_steps(
|
|
620
|
+
self,
|
|
621
|
+
project_name: str,
|
|
622
|
+
storage_uri: str,
|
|
623
|
+
options: TOptions,
|
|
624
|
+
context: DeploymentContext,
|
|
625
|
+
start_step: int = 1,
|
|
626
|
+
end_step: int | None = None,
|
|
627
|
+
) -> TResult:
|
|
628
|
+
"""Run pipeline with start/end step control for CLI resume support."""
|
|
629
|
+
if end_step is None:
|
|
630
|
+
end_step = len(self.flows)
|
|
631
|
+
|
|
632
|
+
total_steps = len(self.flows)
|
|
633
|
+
accumulated_docs = DocumentList([])
|
|
634
|
+
|
|
635
|
+
for i in range(start_step - 1, end_step):
|
|
636
|
+
step = i + 1
|
|
637
|
+
flow_fn = self.flows[i]
|
|
638
|
+
flow_name = getattr(flow_fn, "name", flow_fn.__name__)
|
|
639
|
+
logger.info(f"--- [Step {step}/{total_steps}] {flow_name} ---")
|
|
640
|
+
|
|
641
|
+
# Check cache
|
|
642
|
+
cached = await self._load_cached_output(flow_fn, storage_uri)
|
|
643
|
+
if cached is not None:
|
|
644
|
+
logger.info(f"[{step}/{total_steps}] Cache hit: {flow_name}")
|
|
645
|
+
accumulated_docs = DocumentList(list(accumulated_docs) + list(cached))
|
|
646
|
+
continue
|
|
647
|
+
|
|
648
|
+
current_docs = await flow_fn.config.load_documents(storage_uri)
|
|
649
|
+
new_docs = await flow_fn(project_name, current_docs, options)
|
|
650
|
+
await flow_fn.config.save_documents(storage_uri, new_docs)
|
|
651
|
+
accumulated_docs = DocumentList(list(accumulated_docs) + list(new_docs))
|
|
652
|
+
|
|
653
|
+
return self.build_result(project_name, accumulated_docs, options)
|
|
654
|
+
|
|
655
|
+
@final
|
|
656
|
+
def as_prefect_flow(self) -> Callable[..., Any]:
|
|
657
|
+
"""Generate Prefect flow for production deployment."""
|
|
658
|
+
deployment = self
|
|
659
|
+
|
|
660
|
+
@flow( # pyright: ignore[reportUntypedFunctionDecorator]
|
|
661
|
+
name=self.name,
|
|
662
|
+
flow_run_name=f"{self.name}-{{project_name}}",
|
|
663
|
+
persist_result=True,
|
|
664
|
+
result_serializer="json",
|
|
665
|
+
)
|
|
666
|
+
async def _deployment_flow(
|
|
667
|
+
project_name: str,
|
|
668
|
+
documents: str | DocumentList,
|
|
669
|
+
options: FlowOptions,
|
|
670
|
+
context: DeploymentContext,
|
|
671
|
+
) -> DeploymentResult:
|
|
672
|
+
return await deployment.run(project_name, documents, cast(Any, options), context)
|
|
673
|
+
|
|
674
|
+
return _deployment_flow
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
__all__ = [
|
|
678
|
+
"DeploymentContext",
|
|
679
|
+
"DeploymentResult",
|
|
680
|
+
"PipelineDeployment",
|
|
681
|
+
]
|