ai-pipeline-core 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_pipeline_core/__init__.py +64 -158
- ai_pipeline_core/deployment/__init__.py +6 -18
- ai_pipeline_core/deployment/base.py +392 -212
- ai_pipeline_core/deployment/contract.py +6 -10
- ai_pipeline_core/{utils → deployment}/deploy.py +50 -69
- ai_pipeline_core/deployment/helpers.py +16 -17
- ai_pipeline_core/{progress.py → deployment/progress.py} +23 -24
- ai_pipeline_core/{utils/remote_deployment.py → deployment/remote.py} +11 -14
- ai_pipeline_core/docs_generator/__init__.py +54 -0
- ai_pipeline_core/docs_generator/__main__.py +5 -0
- ai_pipeline_core/docs_generator/cli.py +196 -0
- ai_pipeline_core/docs_generator/extractor.py +324 -0
- ai_pipeline_core/docs_generator/guide_builder.py +644 -0
- ai_pipeline_core/docs_generator/trimmer.py +35 -0
- ai_pipeline_core/docs_generator/validator.py +114 -0
- ai_pipeline_core/document_store/__init__.py +13 -0
- ai_pipeline_core/document_store/_summary.py +9 -0
- ai_pipeline_core/document_store/_summary_worker.py +170 -0
- ai_pipeline_core/document_store/clickhouse.py +492 -0
- ai_pipeline_core/document_store/factory.py +38 -0
- ai_pipeline_core/document_store/local.py +312 -0
- ai_pipeline_core/document_store/memory.py +85 -0
- ai_pipeline_core/document_store/protocol.py +68 -0
- ai_pipeline_core/documents/__init__.py +12 -14
- ai_pipeline_core/documents/_context_vars.py +85 -0
- ai_pipeline_core/documents/_hashing.py +52 -0
- ai_pipeline_core/documents/attachment.py +85 -0
- ai_pipeline_core/documents/context.py +128 -0
- ai_pipeline_core/documents/document.py +318 -1434
- ai_pipeline_core/documents/mime_type.py +11 -84
- ai_pipeline_core/documents/utils.py +4 -12
- ai_pipeline_core/exceptions.py +10 -62
- ai_pipeline_core/images/__init__.py +32 -85
- ai_pipeline_core/images/_processing.py +5 -11
- ai_pipeline_core/llm/__init__.py +6 -4
- ai_pipeline_core/llm/ai_messages.py +102 -90
- ai_pipeline_core/llm/client.py +229 -183
- ai_pipeline_core/llm/model_options.py +12 -84
- ai_pipeline_core/llm/model_response.py +53 -99
- ai_pipeline_core/llm/model_types.py +8 -23
- ai_pipeline_core/logging/__init__.py +2 -7
- ai_pipeline_core/logging/logging.yml +1 -1
- ai_pipeline_core/logging/logging_config.py +27 -37
- ai_pipeline_core/logging/logging_mixin.py +15 -41
- ai_pipeline_core/observability/__init__.py +32 -0
- ai_pipeline_core/observability/_debug/__init__.py +30 -0
- ai_pipeline_core/observability/_debug/_auto_summary.py +94 -0
- ai_pipeline_core/{debug/config.py → observability/_debug/_config.py} +11 -7
- ai_pipeline_core/{debug/content.py → observability/_debug/_content.py} +133 -75
- ai_pipeline_core/{debug/processor.py → observability/_debug/_processor.py} +16 -17
- ai_pipeline_core/{debug/summary.py → observability/_debug/_summary.py} +113 -37
- ai_pipeline_core/observability/_debug/_types.py +75 -0
- ai_pipeline_core/{debug/writer.py → observability/_debug/_writer.py} +126 -196
- ai_pipeline_core/observability/_document_tracking.py +146 -0
- ai_pipeline_core/observability/_initialization.py +194 -0
- ai_pipeline_core/observability/_logging_bridge.py +57 -0
- ai_pipeline_core/observability/_summary.py +81 -0
- ai_pipeline_core/observability/_tracking/__init__.py +6 -0
- ai_pipeline_core/observability/_tracking/_client.py +178 -0
- ai_pipeline_core/observability/_tracking/_internal.py +28 -0
- ai_pipeline_core/observability/_tracking/_models.py +138 -0
- ai_pipeline_core/observability/_tracking/_processor.py +158 -0
- ai_pipeline_core/observability/_tracking/_service.py +311 -0
- ai_pipeline_core/observability/_tracking/_writer.py +229 -0
- ai_pipeline_core/{tracing.py → observability/tracing.py} +139 -335
- ai_pipeline_core/pipeline/__init__.py +10 -0
- ai_pipeline_core/pipeline/decorators.py +915 -0
- ai_pipeline_core/pipeline/options.py +16 -0
- ai_pipeline_core/prompt_manager.py +16 -102
- ai_pipeline_core/settings.py +26 -31
- ai_pipeline_core/testing.py +9 -0
- ai_pipeline_core-0.4.0.dist-info/METADATA +807 -0
- ai_pipeline_core-0.4.0.dist-info/RECORD +76 -0
- ai_pipeline_core/debug/__init__.py +0 -26
- ai_pipeline_core/documents/document_list.py +0 -420
- ai_pipeline_core/documents/flow_document.py +0 -112
- ai_pipeline_core/documents/task_document.py +0 -117
- ai_pipeline_core/documents/temporary_document.py +0 -74
- ai_pipeline_core/flow/__init__.py +0 -9
- ai_pipeline_core/flow/config.py +0 -494
- ai_pipeline_core/flow/options.py +0 -75
- ai_pipeline_core/pipeline.py +0 -718
- ai_pipeline_core/prefect.py +0 -63
- ai_pipeline_core/prompt_builder/__init__.py +0 -5
- ai_pipeline_core/prompt_builder/documents_prompt.jinja2 +0 -23
- ai_pipeline_core/prompt_builder/global_cache.py +0 -78
- ai_pipeline_core/prompt_builder/new_core_documents_prompt.jinja2 +0 -6
- ai_pipeline_core/prompt_builder/prompt_builder.py +0 -253
- ai_pipeline_core/prompt_builder/system_prompt.jinja2 +0 -41
- ai_pipeline_core/storage/__init__.py +0 -8
- ai_pipeline_core/storage/storage.py +0 -628
- ai_pipeline_core/utils/__init__.py +0 -8
- ai_pipeline_core-0.3.4.dist-info/METADATA +0 -569
- ai_pipeline_core-0.3.4.dist-info/RECORD +0 -57
- {ai_pipeline_core-0.3.4.dist-info → ai_pipeline_core-0.4.0.dist-info}/WHEEL +0 -0
- {ai_pipeline_core-0.3.4.dist-info → ai_pipeline_core-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,915 @@
|
|
|
1
|
+
"""Pipeline decorators with Prefect integration, tracing, and document lifecycle.
|
|
2
|
+
|
|
3
|
+
Wrappers around Prefect's @task and @flow that add Laminar tracing,
|
|
4
|
+
enforce async-only execution, and auto-save documents to the DocumentStore.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import datetime
|
|
8
|
+
import inspect
|
|
9
|
+
import types
|
|
10
|
+
from collections.abc import Callable, Coroutine, Iterable
|
|
11
|
+
from functools import wraps
|
|
12
|
+
from typing import (
|
|
13
|
+
Any,
|
|
14
|
+
Protocol,
|
|
15
|
+
TypeVar,
|
|
16
|
+
Union,
|
|
17
|
+
cast,
|
|
18
|
+
get_args,
|
|
19
|
+
get_origin,
|
|
20
|
+
get_type_hints,
|
|
21
|
+
overload,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from lmnr import Laminar
|
|
25
|
+
from prefect.assets import Asset
|
|
26
|
+
from prefect.cache_policies import CachePolicy
|
|
27
|
+
from prefect.context import TaskRunContext
|
|
28
|
+
from prefect.flows import FlowStateHook
|
|
29
|
+
from prefect.flows import flow as _prefect_flow
|
|
30
|
+
from prefect.futures import PrefectFuture
|
|
31
|
+
from prefect.results import ResultSerializer, ResultStorage
|
|
32
|
+
from prefect.task_runners import TaskRunner
|
|
33
|
+
from prefect.tasks import task as _prefect_task
|
|
34
|
+
from prefect.utilities.annotations import NotSet
|
|
35
|
+
from pydantic import BaseModel
|
|
36
|
+
|
|
37
|
+
from ai_pipeline_core.document_store import get_document_store
|
|
38
|
+
from ai_pipeline_core.documents import Document
|
|
39
|
+
from ai_pipeline_core.documents.context import (
|
|
40
|
+
RunContext,
|
|
41
|
+
TaskDocumentContext,
|
|
42
|
+
get_run_context,
|
|
43
|
+
reset_run_context,
|
|
44
|
+
reset_task_context,
|
|
45
|
+
set_run_context,
|
|
46
|
+
set_task_context,
|
|
47
|
+
)
|
|
48
|
+
from ai_pipeline_core.documents.utils import is_document_sha256
|
|
49
|
+
from ai_pipeline_core.logging import get_pipeline_logger
|
|
50
|
+
from ai_pipeline_core.observability._document_tracking import get_current_span_id, track_flow_io, track_task_io
|
|
51
|
+
from ai_pipeline_core.observability._initialization import get_tracking_service
|
|
52
|
+
from ai_pipeline_core.observability._tracking._models import DocumentEventType
|
|
53
|
+
from ai_pipeline_core.observability.tracing import TraceLevel, set_trace_cost, trace
|
|
54
|
+
from ai_pipeline_core.pipeline.options import FlowOptions
|
|
55
|
+
|
|
56
|
+
logger = get_pipeline_logger(__name__)
|
|
57
|
+
|
|
58
|
+
# --------------------------------------------------------------------------- #
|
|
59
|
+
# Public callback aliases (Prefect stubs omit these exact types)
|
|
60
|
+
# --------------------------------------------------------------------------- #
|
|
61
|
+
type RetryConditionCallable = Callable[[Any, Any, Any], bool]
|
|
62
|
+
type StateHookCallable = Callable[[Any, Any, Any], None]
|
|
63
|
+
type TaskRunNameValueOrCallable = str | Callable[[], str]
|
|
64
|
+
|
|
65
|
+
# --------------------------------------------------------------------------- #
|
|
66
|
+
# Typing helpers
|
|
67
|
+
# --------------------------------------------------------------------------- #
|
|
68
|
+
R_co = TypeVar("R_co", covariant=True)
|
|
69
|
+
FO_contra = TypeVar("FO_contra", bound=FlowOptions, contravariant=True)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class _TaskLike(Protocol[R_co]):
|
|
73
|
+
"""Protocol for type-safe Prefect task representation."""
|
|
74
|
+
|
|
75
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Coroutine[Any, Any, R_co]: ...
|
|
76
|
+
|
|
77
|
+
submit: Callable[..., Any]
|
|
78
|
+
map: Callable[..., Any]
|
|
79
|
+
name: str | None
|
|
80
|
+
estimated_minutes: int
|
|
81
|
+
|
|
82
|
+
def __getattr__(self, name: str) -> Any: ...
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class _FlowLike(Protocol[FO_contra]):
|
|
86
|
+
"""Protocol for decorated flow objects returned by @pipeline_flow."""
|
|
87
|
+
|
|
88
|
+
def __call__(
|
|
89
|
+
self,
|
|
90
|
+
project_name: str,
|
|
91
|
+
documents: list[Document],
|
|
92
|
+
flow_options: FO_contra,
|
|
93
|
+
) -> Coroutine[Any, Any, list[Document]]: ...
|
|
94
|
+
|
|
95
|
+
name: str | None
|
|
96
|
+
input_document_types: list[type[Document]]
|
|
97
|
+
output_document_types: list[type[Document]]
|
|
98
|
+
estimated_minutes: int
|
|
99
|
+
|
|
100
|
+
def __getattr__(self, name: str) -> Any: ...
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# --------------------------------------------------------------------------- #
|
|
104
|
+
# Annotation parsing helpers
|
|
105
|
+
# --------------------------------------------------------------------------- #
|
|
106
|
+
def _flatten_union(tp: Any) -> list[Any]:
|
|
107
|
+
"""Flatten Union / X | Y annotations into a list of constituent types."""
|
|
108
|
+
origin = get_origin(tp)
|
|
109
|
+
if origin is Union or isinstance(tp, types.UnionType):
|
|
110
|
+
result: list[Any] = []
|
|
111
|
+
for arg in get_args(tp):
|
|
112
|
+
result.extend(_flatten_union(arg))
|
|
113
|
+
return result
|
|
114
|
+
return [tp]
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _find_non_document_leaves(tp: Any) -> list[Any]:
|
|
118
|
+
"""Walk a return type annotation and collect leaf types that are not Document subclasses or NoneType.
|
|
119
|
+
|
|
120
|
+
Returns empty list when all leaf types are valid (Document subclasses or None).
|
|
121
|
+
Used by @pipeline_task to validate return annotations at decoration time.
|
|
122
|
+
"""
|
|
123
|
+
if tp is type(None) or (isinstance(tp, type) and issubclass(tp, Document)):
|
|
124
|
+
return []
|
|
125
|
+
|
|
126
|
+
origin = get_origin(tp)
|
|
127
|
+
|
|
128
|
+
# Union / X | Y: all branches must be valid
|
|
129
|
+
if origin is Union or isinstance(tp, types.UnionType):
|
|
130
|
+
return [leaf for arg in get_args(tp) for leaf in _find_non_document_leaves(arg)]
|
|
131
|
+
|
|
132
|
+
# list[X]: recurse into element type
|
|
133
|
+
if origin is list:
|
|
134
|
+
args = get_args(tp)
|
|
135
|
+
return _find_non_document_leaves(args[0]) if args else [tp]
|
|
136
|
+
|
|
137
|
+
# tuple[X, Y] or tuple[X, ...]
|
|
138
|
+
if origin is tuple:
|
|
139
|
+
args = get_args(tp)
|
|
140
|
+
if not args:
|
|
141
|
+
return [tp]
|
|
142
|
+
elements = (args[0],) if (len(args) == 2 and args[1] is Ellipsis) else args
|
|
143
|
+
return [leaf for arg in elements for leaf in _find_non_document_leaves(arg)]
|
|
144
|
+
|
|
145
|
+
# Everything else is invalid (int, str, Any, object, dict, etc.)
|
|
146
|
+
return [tp]
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _parse_document_types_from_annotation(annotation: Any) -> list[type[Document]]:
|
|
150
|
+
"""Extract Document subclasses from a list[...] type annotation.
|
|
151
|
+
|
|
152
|
+
Handles list[DocA], list[DocA | DocB], list[Union[DocA, DocB]].
|
|
153
|
+
Returns empty list if annotation is not a list of Document subclasses.
|
|
154
|
+
"""
|
|
155
|
+
origin = get_origin(annotation)
|
|
156
|
+
if origin is not list:
|
|
157
|
+
return []
|
|
158
|
+
|
|
159
|
+
args = get_args(annotation)
|
|
160
|
+
if not args:
|
|
161
|
+
return []
|
|
162
|
+
|
|
163
|
+
inner = args[0]
|
|
164
|
+
flat = _flatten_union(inner)
|
|
165
|
+
|
|
166
|
+
return [t for t in flat if isinstance(t, type) and issubclass(t, Document)]
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _resolve_type_hints(fn: Callable[..., Any]) -> dict[str, Any]:
|
|
170
|
+
"""Safely resolve type hints, falling back to empty dict on failure."""
|
|
171
|
+
try:
|
|
172
|
+
return get_type_hints(fn, include_extras=True)
|
|
173
|
+
except Exception:
|
|
174
|
+
logger.warning(
|
|
175
|
+
"Failed to resolve type hints for '%s'. Ensure all annotations are valid and importable.",
|
|
176
|
+
_callable_name(fn, "unknown"),
|
|
177
|
+
)
|
|
178
|
+
return {}
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
# --------------------------------------------------------------------------- #
|
|
182
|
+
# Document extraction helper
|
|
183
|
+
# --------------------------------------------------------------------------- #
|
|
184
|
+
def _extract_documents(result: Any) -> list[Document]:
|
|
185
|
+
"""Recursively extract unique Document instances from a result value.
|
|
186
|
+
|
|
187
|
+
Walks tuples, lists, dicts, and Pydantic BaseModel fields.
|
|
188
|
+
Deduplicates by object identity (same instance appearing multiple times
|
|
189
|
+
is collected only once). Checks Document before BaseModel since Document
|
|
190
|
+
IS a BaseModel subclass.
|
|
191
|
+
"""
|
|
192
|
+
docs: list[Document] = []
|
|
193
|
+
seen: set[int] = set()
|
|
194
|
+
|
|
195
|
+
def _walk(value: Any) -> None:
|
|
196
|
+
obj_id = id(value)
|
|
197
|
+
if obj_id in seen:
|
|
198
|
+
return
|
|
199
|
+
seen.add(obj_id)
|
|
200
|
+
|
|
201
|
+
if isinstance(value, Document):
|
|
202
|
+
docs.append(value)
|
|
203
|
+
return
|
|
204
|
+
if isinstance(value, (list, tuple)):
|
|
205
|
+
for item in cast(Iterable[Any], value):
|
|
206
|
+
_walk(item)
|
|
207
|
+
return
|
|
208
|
+
if isinstance(value, dict):
|
|
209
|
+
for v in cast(Iterable[Any], value.values()):
|
|
210
|
+
_walk(v)
|
|
211
|
+
return
|
|
212
|
+
if isinstance(value, BaseModel):
|
|
213
|
+
for field_name in type(value).model_fields:
|
|
214
|
+
_walk(getattr(value, field_name))
|
|
215
|
+
return
|
|
216
|
+
|
|
217
|
+
_walk(result)
|
|
218
|
+
return docs
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# --------------------------------------------------------------------------- #
|
|
222
|
+
# Small helpers
|
|
223
|
+
# --------------------------------------------------------------------------- #
|
|
224
|
+
def _callable_name(obj: Any, fallback: str) -> str:
|
|
225
|
+
"""Safely extract callable's name for error messages."""
|
|
226
|
+
try:
|
|
227
|
+
n = getattr(obj, "__name__", None)
|
|
228
|
+
return n if isinstance(n, str) else fallback
|
|
229
|
+
except Exception:
|
|
230
|
+
return fallback
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _is_already_traced(func: Callable[..., Any]) -> bool:
|
|
234
|
+
"""Check if a function has already been wrapped by the trace decorator."""
|
|
235
|
+
if hasattr(func, "__is_traced__") and func.__is_traced__: # type: ignore[attr-defined]
|
|
236
|
+
return True
|
|
237
|
+
|
|
238
|
+
current = func
|
|
239
|
+
depth = 0
|
|
240
|
+
while hasattr(current, "__wrapped__") and depth < 10:
|
|
241
|
+
wrapped = current.__wrapped__ # type: ignore[attr-defined]
|
|
242
|
+
if hasattr(wrapped, "__is_traced__") and wrapped.__is_traced__:
|
|
243
|
+
return True
|
|
244
|
+
current = wrapped
|
|
245
|
+
depth += 1
|
|
246
|
+
return False
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
# --------------------------------------------------------------------------- #
|
|
250
|
+
# Tracking helpers
|
|
251
|
+
# --------------------------------------------------------------------------- #
|
|
252
|
+
def _resolve_label(user_summary: str | bool, fn: Callable[..., Any], kwargs: dict[str, Any]) -> str:
|
|
253
|
+
"""Resolve user_summary to a label string."""
|
|
254
|
+
if isinstance(user_summary, str):
|
|
255
|
+
try:
|
|
256
|
+
return user_summary.format(**kwargs)
|
|
257
|
+
except (KeyError, IndexError):
|
|
258
|
+
return user_summary
|
|
259
|
+
return _callable_name(fn, "task").replace("_", " ").title()
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def _build_output_hint(result: object) -> str:
|
|
263
|
+
"""Build a privacy-safe metadata string describing a task's output."""
|
|
264
|
+
if result is None:
|
|
265
|
+
return "None"
|
|
266
|
+
if isinstance(result, list) and result and isinstance(result[0], Document):
|
|
267
|
+
doc_list = cast(list[Document], result)
|
|
268
|
+
class_counts: dict[str, int] = {}
|
|
269
|
+
total_size = 0
|
|
270
|
+
for doc in doc_list:
|
|
271
|
+
cls_name = type(doc).__name__
|
|
272
|
+
class_counts[cls_name] = class_counts.get(cls_name, 0) + 1
|
|
273
|
+
total_size += len(doc.content)
|
|
274
|
+
parts = [f"{name} x{count}" for name, count in class_counts.items()]
|
|
275
|
+
return f"{len(doc_list)} documents ({', '.join(parts)}), total {total_size / 1024:.0f}KB"
|
|
276
|
+
if isinstance(result, Document):
|
|
277
|
+
size = len(result.content)
|
|
278
|
+
return f"{type(result).__name__}(name={result.name!r}, {size / 1024:.0f}KB, {result.mime_type})"
|
|
279
|
+
if isinstance(result, (list, tuple)):
|
|
280
|
+
seq = cast(list[Any] | tuple[Any, ...], result)
|
|
281
|
+
return f"{type(seq).__name__} with {len(seq)} items"
|
|
282
|
+
if isinstance(result, str):
|
|
283
|
+
return f"str ({len(result)} chars)"
|
|
284
|
+
return type(result).__name__
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
# --------------------------------------------------------------------------- #
|
|
288
|
+
# Store event emission helper
|
|
289
|
+
# --------------------------------------------------------------------------- #
|
|
290
|
+
def _emit_store_events(documents: list[Document], event_type: DocumentEventType) -> None:
|
|
291
|
+
"""Emit store lifecycle events for documents. No-op if tracking is not available."""
|
|
292
|
+
try:
|
|
293
|
+
service = get_tracking_service()
|
|
294
|
+
if service is None:
|
|
295
|
+
return
|
|
296
|
+
span_id = get_current_span_id()
|
|
297
|
+
for doc in documents:
|
|
298
|
+
service.track_document_event(
|
|
299
|
+
document_sha256=doc.sha256,
|
|
300
|
+
span_id=span_id,
|
|
301
|
+
event_type=event_type,
|
|
302
|
+
)
|
|
303
|
+
except Exception:
|
|
304
|
+
pass
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
# --------------------------------------------------------------------------- #
|
|
308
|
+
# Document persistence helper (used by @pipeline_task and @pipeline_flow)
|
|
309
|
+
# --------------------------------------------------------------------------- #
|
|
310
|
+
async def _persist_documents(
|
|
311
|
+
documents: list[Document],
|
|
312
|
+
label: str,
|
|
313
|
+
ctx: "TaskDocumentContext",
|
|
314
|
+
*,
|
|
315
|
+
check_created: bool = False,
|
|
316
|
+
) -> None:
|
|
317
|
+
"""Validate provenance, deduplicate, and save documents to the store.
|
|
318
|
+
|
|
319
|
+
Silently skips if no store or no run context is configured.
|
|
320
|
+
Logs warnings on persistence failure (graceful degradation).
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
check_created: When True, warn if a returned document was not created in this
|
|
324
|
+
context. Enabled for @pipeline_task, disabled for @pipeline_flow (flows
|
|
325
|
+
delegate creation to nested tasks).
|
|
326
|
+
"""
|
|
327
|
+
run_ctx = get_run_context()
|
|
328
|
+
store = get_document_store()
|
|
329
|
+
if run_ctx is None or store is None:
|
|
330
|
+
return
|
|
331
|
+
|
|
332
|
+
if not documents:
|
|
333
|
+
return
|
|
334
|
+
|
|
335
|
+
deduped: list[Document] = []
|
|
336
|
+
try:
|
|
337
|
+
# Collect all SHA256 references (sources + origins) for existence check
|
|
338
|
+
ref_sha256s: set[str] = set()
|
|
339
|
+
for doc in documents:
|
|
340
|
+
for src in doc.sources:
|
|
341
|
+
if is_document_sha256(src):
|
|
342
|
+
ref_sha256s.add(src)
|
|
343
|
+
for origin in doc.origins:
|
|
344
|
+
ref_sha256s.add(origin)
|
|
345
|
+
|
|
346
|
+
existing: set[str] = set()
|
|
347
|
+
if ref_sha256s:
|
|
348
|
+
existing = await store.check_existing(sorted(ref_sha256s))
|
|
349
|
+
|
|
350
|
+
provenance_warnings = ctx.validate_provenance(documents, existing, check_created=check_created)
|
|
351
|
+
for warning in provenance_warnings:
|
|
352
|
+
logger.warning("[%s] %s", label, warning)
|
|
353
|
+
|
|
354
|
+
# Deduplicate and save
|
|
355
|
+
deduped = TaskDocumentContext.deduplicate(documents)
|
|
356
|
+
await store.save_batch(deduped, run_ctx.run_scope)
|
|
357
|
+
|
|
358
|
+
_emit_store_events(deduped, DocumentEventType.STORE_SAVED)
|
|
359
|
+
|
|
360
|
+
# Finalize: warn about created-but-not-returned documents
|
|
361
|
+
finalize_warnings = ctx.finalize(documents)
|
|
362
|
+
for warning in finalize_warnings:
|
|
363
|
+
logger.warning("[%s] %s", label, warning)
|
|
364
|
+
except Exception:
|
|
365
|
+
_emit_store_events(deduped or documents, DocumentEventType.STORE_SAVE_FAILED)
|
|
366
|
+
logger.warning("Failed to persist documents from '%s'", label, exc_info=True)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
# --------------------------------------------------------------------------- #
|
|
370
|
+
# @pipeline_task — async-only, traced, auto-persists documents
|
|
371
|
+
# --------------------------------------------------------------------------- #
|
|
372
|
+
@overload
|
|
373
|
+
def pipeline_task(__fn: Callable[..., Coroutine[Any, Any, R_co]], /) -> _TaskLike[R_co]: ... # noqa: UP047
|
|
374
|
+
@overload
|
|
375
|
+
def pipeline_task(
|
|
376
|
+
*,
|
|
377
|
+
# tracing
|
|
378
|
+
trace_level: TraceLevel = "always",
|
|
379
|
+
trace_ignore_input: bool = False,
|
|
380
|
+
trace_ignore_output: bool = False,
|
|
381
|
+
trace_ignore_inputs: list[str] | None = None,
|
|
382
|
+
trace_input_formatter: Callable[..., str] | None = None,
|
|
383
|
+
trace_output_formatter: Callable[..., str] | None = None,
|
|
384
|
+
trace_cost: float | None = None,
|
|
385
|
+
expected_cost: float | None = None,
|
|
386
|
+
trace_trim_documents: bool = True,
|
|
387
|
+
# tracking
|
|
388
|
+
user_summary: bool | str = False,
|
|
389
|
+
# document lifecycle
|
|
390
|
+
estimated_minutes: int = 1,
|
|
391
|
+
persist: bool = True,
|
|
392
|
+
# prefect passthrough
|
|
393
|
+
name: str | None = None,
|
|
394
|
+
description: str | None = None,
|
|
395
|
+
tags: Iterable[str] | None = None,
|
|
396
|
+
version: str | None = None,
|
|
397
|
+
cache_policy: CachePolicy | type[NotSet] = NotSet,
|
|
398
|
+
cache_key_fn: Callable[[TaskRunContext, dict[str, Any]], str | None] | None = None,
|
|
399
|
+
cache_expiration: datetime.timedelta | None = None,
|
|
400
|
+
task_run_name: TaskRunNameValueOrCallable | None = None,
|
|
401
|
+
retries: int | None = None,
|
|
402
|
+
retry_delay_seconds: int | float | list[float] | Callable[[int], list[float]] | None = None,
|
|
403
|
+
retry_jitter_factor: float | None = None,
|
|
404
|
+
persist_result: bool | None = None,
|
|
405
|
+
result_storage: ResultStorage | str | None = None,
|
|
406
|
+
result_serializer: ResultSerializer | str | None = None,
|
|
407
|
+
result_storage_key: str | None = None,
|
|
408
|
+
cache_result_in_memory: bool = True,
|
|
409
|
+
timeout_seconds: int | float | None = None,
|
|
410
|
+
log_prints: bool | None = False,
|
|
411
|
+
refresh_cache: bool | None = None,
|
|
412
|
+
on_completion: list[StateHookCallable] | None = None,
|
|
413
|
+
on_failure: list[StateHookCallable] | None = None,
|
|
414
|
+
retry_condition_fn: RetryConditionCallable | None = None,
|
|
415
|
+
viz_return_value: bool | None = None,
|
|
416
|
+
asset_deps: list[str | Asset] | None = None,
|
|
417
|
+
) -> Callable[[Callable[..., Coroutine[Any, Any, R_co]]], _TaskLike[R_co]]: ...
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def pipeline_task( # noqa: UP047
|
|
421
|
+
__fn: Callable[..., Coroutine[Any, Any, R_co]] | None = None,
|
|
422
|
+
/,
|
|
423
|
+
*,
|
|
424
|
+
# tracing
|
|
425
|
+
trace_level: TraceLevel = "always",
|
|
426
|
+
trace_ignore_input: bool = False,
|
|
427
|
+
trace_ignore_output: bool = False,
|
|
428
|
+
trace_ignore_inputs: list[str] | None = None,
|
|
429
|
+
trace_input_formatter: Callable[..., str] | None = None,
|
|
430
|
+
trace_output_formatter: Callable[..., str] | None = None,
|
|
431
|
+
trace_cost: float | None = None,
|
|
432
|
+
expected_cost: float | None = None,
|
|
433
|
+
trace_trim_documents: bool = True,
|
|
434
|
+
# tracking
|
|
435
|
+
user_summary: bool | str = False,
|
|
436
|
+
# document lifecycle
|
|
437
|
+
estimated_minutes: int = 1,
|
|
438
|
+
persist: bool = True,
|
|
439
|
+
# prefect passthrough
|
|
440
|
+
name: str | None = None,
|
|
441
|
+
description: str | None = None,
|
|
442
|
+
tags: Iterable[str] | None = None,
|
|
443
|
+
version: str | None = None,
|
|
444
|
+
cache_policy: CachePolicy | type[NotSet] = NotSet,
|
|
445
|
+
cache_key_fn: Callable[[TaskRunContext, dict[str, Any]], str | None] | None = None,
|
|
446
|
+
cache_expiration: datetime.timedelta | None = None,
|
|
447
|
+
task_run_name: TaskRunNameValueOrCallable | None = None,
|
|
448
|
+
retries: int | None = None,
|
|
449
|
+
retry_delay_seconds: int | float | list[float] | Callable[[int], list[float]] | None = None,
|
|
450
|
+
retry_jitter_factor: float | None = None,
|
|
451
|
+
persist_result: bool | None = None,
|
|
452
|
+
result_storage: ResultStorage | str | None = None,
|
|
453
|
+
result_serializer: ResultSerializer | str | None = None,
|
|
454
|
+
result_storage_key: str | None = None,
|
|
455
|
+
cache_result_in_memory: bool = True,
|
|
456
|
+
timeout_seconds: int | float | None = None,
|
|
457
|
+
log_prints: bool | None = False,
|
|
458
|
+
refresh_cache: bool | None = None,
|
|
459
|
+
on_completion: list[StateHookCallable] | None = None,
|
|
460
|
+
on_failure: list[StateHookCallable] | None = None,
|
|
461
|
+
retry_condition_fn: RetryConditionCallable | None = None,
|
|
462
|
+
viz_return_value: bool | None = None,
|
|
463
|
+
asset_deps: list[str | Asset] | None = None,
|
|
464
|
+
) -> _TaskLike[R_co] | Callable[[Callable[..., Coroutine[Any, Any, R_co]]], _TaskLike[R_co]]:
|
|
465
|
+
"""Decorate an async function as a traced Prefect task with document auto-save.
|
|
466
|
+
|
|
467
|
+
After the wrapped function returns, if documents are found in the result
|
|
468
|
+
and a DocumentStore + RunContext are available, documents are validated
|
|
469
|
+
for provenance, deduplicated by SHA256, and saved to the store.
|
|
470
|
+
|
|
471
|
+
When persist=True (default), the return type annotation is validated at
|
|
472
|
+
decoration time. Allowed return types::
|
|
473
|
+
|
|
474
|
+
-> MyDocument # single Document
|
|
475
|
+
-> list[DocA] / list[DocA | DocB] # list of Documents
|
|
476
|
+
-> tuple[DocA, DocB] # tuple of Documents
|
|
477
|
+
-> tuple[list[DocA], list[DocB]] # tuple of lists
|
|
478
|
+
-> tuple[DocA, ...] # variable-length tuple
|
|
479
|
+
-> None # side-effect tasks
|
|
480
|
+
-> DocA | None # optional Document
|
|
481
|
+
|
|
482
|
+
Use persist=False for tasks returning non-document values (tracing and
|
|
483
|
+
retries still apply, but no return type validation or document auto-save).
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
__fn: Function to decorate (when used without parentheses).
|
|
487
|
+
trace_level: When to trace ("always", "debug", "off").
|
|
488
|
+
trace_ignore_input: Don't trace input arguments.
|
|
489
|
+
trace_ignore_output: Don't trace return value.
|
|
490
|
+
trace_ignore_inputs: List of parameter names to exclude from tracing.
|
|
491
|
+
trace_input_formatter: Custom formatter for input tracing.
|
|
492
|
+
trace_output_formatter: Custom formatter for output tracing.
|
|
493
|
+
trace_cost: Optional cost value to track in metadata.
|
|
494
|
+
expected_cost: Optional expected cost budget for this task.
|
|
495
|
+
trace_trim_documents: Trim document content in traces (default True).
|
|
496
|
+
user_summary: Enable LLM-generated span summaries.
|
|
497
|
+
estimated_minutes: Estimated duration for progress tracking (must be > 0).
|
|
498
|
+
persist: Auto-save returned documents to the store (default True).
|
|
499
|
+
name: Task name (defaults to function name).
|
|
500
|
+
description: Human-readable task description.
|
|
501
|
+
tags: Tags for organization and filtering.
|
|
502
|
+
version: Task version string.
|
|
503
|
+
cache_policy: Caching policy for task results.
|
|
504
|
+
cache_key_fn: Custom cache key generation.
|
|
505
|
+
cache_expiration: How long to cache results.
|
|
506
|
+
task_run_name: Dynamic or static run name.
|
|
507
|
+
retries: Number of retry attempts (default 0).
|
|
508
|
+
retry_delay_seconds: Delay between retries.
|
|
509
|
+
retry_jitter_factor: Random jitter for retry delays.
|
|
510
|
+
persist_result: Whether to persist results.
|
|
511
|
+
result_storage: Where to store results.
|
|
512
|
+
result_serializer: How to serialize results.
|
|
513
|
+
result_storage_key: Custom storage key.
|
|
514
|
+
cache_result_in_memory: Keep results in memory.
|
|
515
|
+
timeout_seconds: Task execution timeout.
|
|
516
|
+
log_prints: Capture print() statements.
|
|
517
|
+
refresh_cache: Force cache refresh.
|
|
518
|
+
on_completion: Hooks for successful completion.
|
|
519
|
+
on_failure: Hooks for task failure.
|
|
520
|
+
retry_condition_fn: Custom retry condition.
|
|
521
|
+
viz_return_value: Include return value in visualization.
|
|
522
|
+
asset_deps: Upstream asset dependencies.
|
|
523
|
+
"""
|
|
524
|
+
if estimated_minutes < 1:
|
|
525
|
+
raise ValueError(f"estimated_minutes must be >= 1, got {estimated_minutes}")
|
|
526
|
+
|
|
527
|
+
task_decorator: Callable[..., Any] = _prefect_task
|
|
528
|
+
|
|
529
|
+
def _apply(fn: Callable[..., Coroutine[Any, Any, R_co]]) -> _TaskLike[R_co]:
|
|
530
|
+
fname = _callable_name(fn, "task")
|
|
531
|
+
|
|
532
|
+
if not inspect.iscoroutinefunction(fn):
|
|
533
|
+
raise TypeError(f"@pipeline_task target '{fname}' must be 'async def'")
|
|
534
|
+
|
|
535
|
+
if _is_already_traced(fn):
|
|
536
|
+
raise TypeError(
|
|
537
|
+
f"@pipeline_task target '{fname}' is already decorated "
|
|
538
|
+
f"with @trace. Remove the @trace decorator - @pipeline_task includes "
|
|
539
|
+
f"tracing automatically."
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
# Reject stale DocumentList references in annotations
|
|
543
|
+
for ann_name, ann_value in getattr(fn, "__annotations__", {}).items():
|
|
544
|
+
if "DocumentList" in str(ann_value):
|
|
545
|
+
label = "return type" if ann_name == "return" else f"parameter '{ann_name}'"
|
|
546
|
+
raise TypeError(f"@pipeline_task '{fname}' {label} references 'DocumentList' which has been removed. Use 'list[Document]' instead.")
|
|
547
|
+
|
|
548
|
+
# Validate return type annotation when persist=True
|
|
549
|
+
if persist:
|
|
550
|
+
hints = _resolve_type_hints(fn)
|
|
551
|
+
if "return" not in hints:
|
|
552
|
+
raise TypeError(
|
|
553
|
+
f"@pipeline_task '{fname}': missing return type annotation. "
|
|
554
|
+
f"Persisted tasks must return Document types "
|
|
555
|
+
f"(Document, list[Document], tuple[Document, ...], or None). "
|
|
556
|
+
f"Add a return annotation or use persist=False."
|
|
557
|
+
)
|
|
558
|
+
bad_types = _find_non_document_leaves(hints["return"])
|
|
559
|
+
if bad_types:
|
|
560
|
+
bad_names = ", ".join(getattr(t, "__name__", str(t)) for t in bad_types)
|
|
561
|
+
raise TypeError(
|
|
562
|
+
f"@pipeline_task '{fname}': return type contains non-Document types: {bad_names}. "
|
|
563
|
+
f"Persisted tasks must return Document, list[Document], "
|
|
564
|
+
f"tuple[Document, ...], or None. "
|
|
565
|
+
f"Use persist=False for tasks returning non-document values."
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
@wraps(fn)
|
|
569
|
+
async def _wrapper(*args: Any, **kwargs: Any) -> R_co:
|
|
570
|
+
attrs: dict[str, Any] = {}
|
|
571
|
+
if description:
|
|
572
|
+
attrs["description"] = description
|
|
573
|
+
if expected_cost is not None:
|
|
574
|
+
attrs["expected_cost"] = expected_cost
|
|
575
|
+
if attrs:
|
|
576
|
+
try:
|
|
577
|
+
Laminar.set_span_attributes(attrs) # pyright: ignore[reportArgumentType]
|
|
578
|
+
except Exception:
|
|
579
|
+
pass
|
|
580
|
+
|
|
581
|
+
# Set up TaskDocumentContext BEFORE calling fn() so Document.__init__ can register
|
|
582
|
+
ctx: TaskDocumentContext | None = None
|
|
583
|
+
task_token = None
|
|
584
|
+
if persist and get_run_context() is not None and get_document_store() is not None:
|
|
585
|
+
ctx = TaskDocumentContext()
|
|
586
|
+
task_token = set_task_context(ctx)
|
|
587
|
+
|
|
588
|
+
try:
|
|
589
|
+
result = await fn(*args, **kwargs)
|
|
590
|
+
finally:
|
|
591
|
+
if task_token is not None:
|
|
592
|
+
reset_task_context(task_token)
|
|
593
|
+
|
|
594
|
+
if trace_cost is not None and trace_cost > 0:
|
|
595
|
+
set_trace_cost(trace_cost)
|
|
596
|
+
|
|
597
|
+
# Track task I/O and schedule summaries
|
|
598
|
+
try:
|
|
599
|
+
track_task_io(fname, args, kwargs, result)
|
|
600
|
+
except Exception:
|
|
601
|
+
pass
|
|
602
|
+
|
|
603
|
+
if user_summary:
|
|
604
|
+
try:
|
|
605
|
+
service = get_tracking_service()
|
|
606
|
+
if service is not None:
|
|
607
|
+
span_id = get_current_span_id()
|
|
608
|
+
if span_id:
|
|
609
|
+
label = _resolve_label(user_summary, fn, kwargs)
|
|
610
|
+
output_hint = _build_output_hint(result)
|
|
611
|
+
service.schedule_summary(span_id, label, output_hint)
|
|
612
|
+
except Exception:
|
|
613
|
+
pass
|
|
614
|
+
|
|
615
|
+
# Document auto-save
|
|
616
|
+
if persist and ctx is not None:
|
|
617
|
+
await _persist_documents(_extract_documents(result), fname, ctx, check_created=True)
|
|
618
|
+
|
|
619
|
+
return result
|
|
620
|
+
|
|
621
|
+
traced_fn = trace(
|
|
622
|
+
level=trace_level,
|
|
623
|
+
name=name or fname,
|
|
624
|
+
ignore_input=trace_ignore_input,
|
|
625
|
+
ignore_output=trace_ignore_output,
|
|
626
|
+
ignore_inputs=trace_ignore_inputs,
|
|
627
|
+
input_formatter=trace_input_formatter,
|
|
628
|
+
output_formatter=trace_output_formatter,
|
|
629
|
+
trim_documents=trace_trim_documents,
|
|
630
|
+
)(_wrapper)
|
|
631
|
+
|
|
632
|
+
task_obj = cast(
|
|
633
|
+
_TaskLike[R_co],
|
|
634
|
+
task_decorator(
|
|
635
|
+
name=name or fname,
|
|
636
|
+
description=description,
|
|
637
|
+
tags=tags,
|
|
638
|
+
version=version,
|
|
639
|
+
cache_policy=cache_policy,
|
|
640
|
+
cache_key_fn=cache_key_fn,
|
|
641
|
+
cache_expiration=cache_expiration,
|
|
642
|
+
task_run_name=task_run_name or name or fname,
|
|
643
|
+
retries=0 if retries is None else retries,
|
|
644
|
+
retry_delay_seconds=retry_delay_seconds,
|
|
645
|
+
retry_jitter_factor=retry_jitter_factor,
|
|
646
|
+
persist_result=persist_result,
|
|
647
|
+
result_storage=result_storage,
|
|
648
|
+
result_serializer=result_serializer,
|
|
649
|
+
result_storage_key=result_storage_key,
|
|
650
|
+
cache_result_in_memory=cache_result_in_memory,
|
|
651
|
+
timeout_seconds=timeout_seconds,
|
|
652
|
+
log_prints=log_prints,
|
|
653
|
+
refresh_cache=refresh_cache,
|
|
654
|
+
on_completion=on_completion,
|
|
655
|
+
on_failure=on_failure,
|
|
656
|
+
retry_condition_fn=retry_condition_fn,
|
|
657
|
+
viz_return_value=viz_return_value,
|
|
658
|
+
asset_deps=asset_deps,
|
|
659
|
+
)(traced_fn),
|
|
660
|
+
)
|
|
661
|
+
task_obj.estimated_minutes = estimated_minutes
|
|
662
|
+
return task_obj
|
|
663
|
+
|
|
664
|
+
return _apply(__fn) if __fn else _apply
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
# --------------------------------------------------------------------------- #
|
|
668
|
+
# @pipeline_flow — async-only, traced, annotation-driven document types
|
|
669
|
+
# --------------------------------------------------------------------------- #
|
|
670
|
+
def pipeline_flow(
|
|
671
|
+
*,
|
|
672
|
+
# tracing
|
|
673
|
+
trace_level: TraceLevel = "always",
|
|
674
|
+
trace_ignore_input: bool = False,
|
|
675
|
+
trace_ignore_output: bool = False,
|
|
676
|
+
trace_ignore_inputs: list[str] | None = None,
|
|
677
|
+
trace_input_formatter: Callable[..., str] | None = None,
|
|
678
|
+
trace_output_formatter: Callable[..., str] | None = None,
|
|
679
|
+
trace_cost: float | None = None,
|
|
680
|
+
expected_cost: float | None = None,
|
|
681
|
+
trace_trim_documents: bool = True,
|
|
682
|
+
# tracking
|
|
683
|
+
user_summary: bool | str = False,
|
|
684
|
+
# document type specification
|
|
685
|
+
estimated_minutes: int = 1,
|
|
686
|
+
# prefect passthrough
|
|
687
|
+
name: str | None = None,
|
|
688
|
+
version: str | None = None,
|
|
689
|
+
flow_run_name: Callable[[], str] | str | None = None,
|
|
690
|
+
retries: int | None = None,
|
|
691
|
+
retry_delay_seconds: int | float | None = None,
|
|
692
|
+
task_runner: TaskRunner[PrefectFuture[Any]] | None = None,
|
|
693
|
+
description: str | None = None,
|
|
694
|
+
timeout_seconds: int | float | None = None,
|
|
695
|
+
validate_parameters: bool = True,
|
|
696
|
+
persist_result: bool | None = None,
|
|
697
|
+
result_storage: ResultStorage | str | None = None,
|
|
698
|
+
result_serializer: ResultSerializer | str | None = None,
|
|
699
|
+
cache_result_in_memory: bool = True,
|
|
700
|
+
log_prints: bool | None = None,
|
|
701
|
+
on_completion: list[FlowStateHook[Any, Any]] | None = None,
|
|
702
|
+
on_failure: list[FlowStateHook[Any, Any]] | None = None,
|
|
703
|
+
on_cancellation: list[FlowStateHook[Any, Any]] | None = None,
|
|
704
|
+
on_crashed: list[FlowStateHook[Any, Any]] | None = None,
|
|
705
|
+
on_running: list[FlowStateHook[Any, Any]] | None = None,
|
|
706
|
+
) -> Callable[[Callable[..., Coroutine[Any, Any, list[Document]]]], _FlowLike[Any]]:
|
|
707
|
+
"""Decorate an async function as a traced Prefect flow with annotation-driven document types.
|
|
708
|
+
|
|
709
|
+
Extracts input/output document types from the function's type annotations
|
|
710
|
+
at decoration time and attaches them as ``input_document_types`` and
|
|
711
|
+
``output_document_types`` attributes on the returned flow object.
|
|
712
|
+
|
|
713
|
+
Required function signature::
|
|
714
|
+
|
|
715
|
+
@pipeline_flow(estimated_minutes=30)
|
|
716
|
+
async def my_flow(
|
|
717
|
+
project_name: str,
|
|
718
|
+
documents: list[DocA | DocB],
|
|
719
|
+
flow_options: FlowOptions,
|
|
720
|
+
) -> list[OutputDoc]:
|
|
721
|
+
...
|
|
722
|
+
|
|
723
|
+
Args:
|
|
724
|
+
user_summary: Enable LLM-generated span summaries.
|
|
725
|
+
estimated_minutes: Estimated duration for progress tracking (must be >= 1).
|
|
726
|
+
|
|
727
|
+
Returns:
|
|
728
|
+
Decorator that produces a _FlowLike object with ``input_document_types``,
|
|
729
|
+
``output_document_types``, and ``estimated_minutes`` attributes.
|
|
730
|
+
|
|
731
|
+
Raises:
|
|
732
|
+
TypeError: If the function is not async, has wrong parameter count/types,
|
|
733
|
+
missing return annotation, or output types overlap input types.
|
|
734
|
+
ValueError: If estimated_minutes < 1.
|
|
735
|
+
"""
|
|
736
|
+
if estimated_minutes < 1:
|
|
737
|
+
raise ValueError(f"estimated_minutes must be >= 1, got {estimated_minutes}")
|
|
738
|
+
|
|
739
|
+
flow_decorator: Callable[..., Any] = _prefect_flow
|
|
740
|
+
|
|
741
|
+
def _apply(fn: Callable[..., Coroutine[Any, Any, list[Document]]]) -> _FlowLike[Any]:
|
|
742
|
+
fname = _callable_name(fn, "flow")
|
|
743
|
+
|
|
744
|
+
if not inspect.iscoroutinefunction(fn):
|
|
745
|
+
raise TypeError(f"@pipeline_flow '{fname}' must be declared with 'async def'")
|
|
746
|
+
|
|
747
|
+
if _is_already_traced(fn):
|
|
748
|
+
raise TypeError(
|
|
749
|
+
f"@pipeline_flow target '{fname}' is already decorated "
|
|
750
|
+
f"with @trace. Remove the @trace decorator - @pipeline_flow includes "
|
|
751
|
+
f"tracing automatically."
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
sig = inspect.signature(fn)
|
|
755
|
+
params = list(sig.parameters.values())
|
|
756
|
+
if len(params) != 3:
|
|
757
|
+
raise TypeError(
|
|
758
|
+
f"@pipeline_flow '{fname}' must have exactly 3 parameters "
|
|
759
|
+
f"(project_name: str, documents: list[...], flow_options: FlowOptions), got {len(params)}"
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
# Resolve document types from annotations
|
|
763
|
+
hints = _resolve_type_hints(fn)
|
|
764
|
+
|
|
765
|
+
# Validate first parameter is str
|
|
766
|
+
if params[0].name in hints and hints[params[0].name] is not str:
|
|
767
|
+
raise TypeError(f"@pipeline_flow '{fname}': first parameter '{params[0].name}' must be annotated as 'str'")
|
|
768
|
+
|
|
769
|
+
# Validate third parameter is FlowOptions or subclass
|
|
770
|
+
if params[2].name in hints:
|
|
771
|
+
p2_type = hints[params[2].name]
|
|
772
|
+
if not (isinstance(p2_type, type) and issubclass(p2_type, FlowOptions)):
|
|
773
|
+
raise TypeError(f"@pipeline_flow '{fname}': third parameter '{params[2].name}' must be FlowOptions or subclass, got {p2_type}")
|
|
774
|
+
|
|
775
|
+
# Extract input types from documents parameter annotation
|
|
776
|
+
resolved_input_types: list[type[Document]]
|
|
777
|
+
if params[1].name in hints:
|
|
778
|
+
resolved_input_types = _parse_document_types_from_annotation(hints[params[1].name])
|
|
779
|
+
else:
|
|
780
|
+
resolved_input_types = []
|
|
781
|
+
|
|
782
|
+
# Extract output types from return annotation
|
|
783
|
+
resolved_output_types: list[type[Document]]
|
|
784
|
+
if "return" in hints:
|
|
785
|
+
resolved_output_types = _parse_document_types_from_annotation(hints["return"])
|
|
786
|
+
else:
|
|
787
|
+
resolved_output_types = []
|
|
788
|
+
|
|
789
|
+
# Validate return annotation contains Document subclasses
|
|
790
|
+
if "return" in hints and not resolved_output_types:
|
|
791
|
+
raise TypeError(
|
|
792
|
+
f"@pipeline_flow '{fname}': return annotation does not contain "
|
|
793
|
+
f"Document subclasses. Flows must return list[SomeDocument]. "
|
|
794
|
+
f"Got: {hints['return']}."
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
# Output types must not overlap input types (skip for base Document used in generic flows)
|
|
798
|
+
if resolved_output_types and resolved_input_types:
|
|
799
|
+
overlap = set(resolved_output_types) & set(resolved_input_types) - {Document}
|
|
800
|
+
if overlap:
|
|
801
|
+
names = ", ".join(t.__name__ for t in overlap)
|
|
802
|
+
raise TypeError(f"@pipeline_flow '{fname}': output types [{names}] cannot also be input types")
|
|
803
|
+
|
|
804
|
+
@wraps(fn)
|
|
805
|
+
async def _wrapper(
|
|
806
|
+
project_name: str,
|
|
807
|
+
documents: list[Document],
|
|
808
|
+
flow_options: Any,
|
|
809
|
+
) -> list[Document]:
|
|
810
|
+
attrs: dict[str, Any] = {}
|
|
811
|
+
if description:
|
|
812
|
+
attrs["description"] = description
|
|
813
|
+
if expected_cost is not None:
|
|
814
|
+
attrs["expected_cost"] = expected_cost
|
|
815
|
+
if attrs:
|
|
816
|
+
try:
|
|
817
|
+
Laminar.set_span_attributes(attrs) # pyright: ignore[reportArgumentType]
|
|
818
|
+
except Exception:
|
|
819
|
+
pass
|
|
820
|
+
|
|
821
|
+
# Set RunContext for nested tasks (only if not already set by deployment)
|
|
822
|
+
existing_ctx = get_run_context()
|
|
823
|
+
run_token = None
|
|
824
|
+
if existing_ctx is None:
|
|
825
|
+
run_scope = f"{project_name}/{name or fname}"
|
|
826
|
+
run_token = set_run_context(RunContext(run_scope=run_scope))
|
|
827
|
+
|
|
828
|
+
# Set up TaskDocumentContext for flow-level document lifecycle
|
|
829
|
+
ctx: TaskDocumentContext | None = None
|
|
830
|
+
task_token = None
|
|
831
|
+
if get_run_context() is not None and get_document_store() is not None:
|
|
832
|
+
ctx = TaskDocumentContext()
|
|
833
|
+
task_token = set_task_context(ctx)
|
|
834
|
+
|
|
835
|
+
try:
|
|
836
|
+
result = await fn(project_name, documents, flow_options)
|
|
837
|
+
finally:
|
|
838
|
+
if task_token is not None:
|
|
839
|
+
reset_task_context(task_token)
|
|
840
|
+
if run_token is not None:
|
|
841
|
+
reset_run_context(run_token)
|
|
842
|
+
|
|
843
|
+
if trace_cost is not None and trace_cost > 0:
|
|
844
|
+
set_trace_cost(trace_cost)
|
|
845
|
+
if not isinstance(result, list): # pyright: ignore[reportUnnecessaryIsInstance] # runtime guard
|
|
846
|
+
raise TypeError(f"Flow '{fname}' must return list[Document], got {type(result).__name__}")
|
|
847
|
+
|
|
848
|
+
# Track flow I/O
|
|
849
|
+
try:
|
|
850
|
+
track_flow_io(fname, documents, result)
|
|
851
|
+
except Exception:
|
|
852
|
+
pass
|
|
853
|
+
|
|
854
|
+
if user_summary:
|
|
855
|
+
try:
|
|
856
|
+
service = get_tracking_service()
|
|
857
|
+
if service is not None:
|
|
858
|
+
span_id = get_current_span_id()
|
|
859
|
+
if span_id:
|
|
860
|
+
label = _resolve_label(user_summary, fn, {"project_name": project_name, "flow_options": flow_options})
|
|
861
|
+
output_hint = _build_output_hint(result)
|
|
862
|
+
service.schedule_summary(span_id, label, output_hint)
|
|
863
|
+
except Exception:
|
|
864
|
+
pass
|
|
865
|
+
|
|
866
|
+
# Document auto-save
|
|
867
|
+
if ctx is not None:
|
|
868
|
+
await _persist_documents(result, fname, ctx)
|
|
869
|
+
|
|
870
|
+
return result
|
|
871
|
+
|
|
872
|
+
traced = trace(
|
|
873
|
+
level=trace_level,
|
|
874
|
+
name=name or fname,
|
|
875
|
+
ignore_input=trace_ignore_input,
|
|
876
|
+
ignore_output=trace_ignore_output,
|
|
877
|
+
ignore_inputs=trace_ignore_inputs,
|
|
878
|
+
input_formatter=trace_input_formatter,
|
|
879
|
+
output_formatter=trace_output_formatter,
|
|
880
|
+
trim_documents=trace_trim_documents,
|
|
881
|
+
)(_wrapper)
|
|
882
|
+
|
|
883
|
+
flow_obj = cast(
|
|
884
|
+
_FlowLike[Any],
|
|
885
|
+
flow_decorator(
|
|
886
|
+
name=name or fname,
|
|
887
|
+
version=version,
|
|
888
|
+
flow_run_name=flow_run_name or name or fname,
|
|
889
|
+
retries=0 if retries is None else retries,
|
|
890
|
+
retry_delay_seconds=retry_delay_seconds,
|
|
891
|
+
task_runner=task_runner,
|
|
892
|
+
description=description,
|
|
893
|
+
timeout_seconds=timeout_seconds,
|
|
894
|
+
validate_parameters=validate_parameters,
|
|
895
|
+
persist_result=persist_result,
|
|
896
|
+
result_storage=result_storage,
|
|
897
|
+
result_serializer=result_serializer,
|
|
898
|
+
cache_result_in_memory=cache_result_in_memory,
|
|
899
|
+
log_prints=log_prints,
|
|
900
|
+
on_completion=on_completion,
|
|
901
|
+
on_failure=on_failure,
|
|
902
|
+
on_cancellation=on_cancellation,
|
|
903
|
+
on_crashed=on_crashed,
|
|
904
|
+
on_running=on_running,
|
|
905
|
+
)(traced),
|
|
906
|
+
)
|
|
907
|
+
flow_obj.input_document_types = resolved_input_types
|
|
908
|
+
flow_obj.output_document_types = resolved_output_types
|
|
909
|
+
flow_obj.estimated_minutes = estimated_minutes
|
|
910
|
+
return flow_obj
|
|
911
|
+
|
|
912
|
+
return _apply
|
|
913
|
+
|
|
914
|
+
|
|
915
|
+
__all__ = ["pipeline_flow", "pipeline_task"]
|