ai-pipeline-core 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. ai_pipeline_core/__init__.py +70 -144
  2. ai_pipeline_core/deployment/__init__.py +6 -18
  3. ai_pipeline_core/deployment/base.py +392 -212
  4. ai_pipeline_core/deployment/contract.py +6 -10
  5. ai_pipeline_core/{utils → deployment}/deploy.py +50 -69
  6. ai_pipeline_core/deployment/helpers.py +16 -17
  7. ai_pipeline_core/{progress.py → deployment/progress.py} +23 -24
  8. ai_pipeline_core/{utils/remote_deployment.py → deployment/remote.py} +11 -14
  9. ai_pipeline_core/docs_generator/__init__.py +54 -0
  10. ai_pipeline_core/docs_generator/__main__.py +5 -0
  11. ai_pipeline_core/docs_generator/cli.py +196 -0
  12. ai_pipeline_core/docs_generator/extractor.py +324 -0
  13. ai_pipeline_core/docs_generator/guide_builder.py +644 -0
  14. ai_pipeline_core/docs_generator/trimmer.py +35 -0
  15. ai_pipeline_core/docs_generator/validator.py +114 -0
  16. ai_pipeline_core/document_store/__init__.py +13 -0
  17. ai_pipeline_core/document_store/_summary.py +9 -0
  18. ai_pipeline_core/document_store/_summary_worker.py +170 -0
  19. ai_pipeline_core/document_store/clickhouse.py +492 -0
  20. ai_pipeline_core/document_store/factory.py +38 -0
  21. ai_pipeline_core/document_store/local.py +312 -0
  22. ai_pipeline_core/document_store/memory.py +85 -0
  23. ai_pipeline_core/document_store/protocol.py +68 -0
  24. ai_pipeline_core/documents/__init__.py +12 -14
  25. ai_pipeline_core/documents/_context_vars.py +85 -0
  26. ai_pipeline_core/documents/_hashing.py +52 -0
  27. ai_pipeline_core/documents/attachment.py +85 -0
  28. ai_pipeline_core/documents/context.py +128 -0
  29. ai_pipeline_core/documents/document.py +318 -1434
  30. ai_pipeline_core/documents/mime_type.py +37 -82
  31. ai_pipeline_core/documents/utils.py +4 -12
  32. ai_pipeline_core/exceptions.py +10 -62
  33. ai_pipeline_core/images/__init__.py +32 -85
  34. ai_pipeline_core/images/_processing.py +5 -11
  35. ai_pipeline_core/llm/__init__.py +6 -4
  36. ai_pipeline_core/llm/ai_messages.py +106 -81
  37. ai_pipeline_core/llm/client.py +267 -158
  38. ai_pipeline_core/llm/model_options.py +12 -84
  39. ai_pipeline_core/llm/model_response.py +53 -99
  40. ai_pipeline_core/llm/model_types.py +8 -23
  41. ai_pipeline_core/logging/__init__.py +2 -7
  42. ai_pipeline_core/logging/logging.yml +1 -1
  43. ai_pipeline_core/logging/logging_config.py +27 -37
  44. ai_pipeline_core/logging/logging_mixin.py +15 -41
  45. ai_pipeline_core/observability/__init__.py +32 -0
  46. ai_pipeline_core/observability/_debug/__init__.py +30 -0
  47. ai_pipeline_core/observability/_debug/_auto_summary.py +94 -0
  48. ai_pipeline_core/{debug/config.py → observability/_debug/_config.py} +11 -7
  49. ai_pipeline_core/{debug/content.py → observability/_debug/_content.py} +134 -75
  50. ai_pipeline_core/{debug/processor.py → observability/_debug/_processor.py} +16 -17
  51. ai_pipeline_core/{debug/summary.py → observability/_debug/_summary.py} +113 -37
  52. ai_pipeline_core/observability/_debug/_types.py +75 -0
  53. ai_pipeline_core/{debug/writer.py → observability/_debug/_writer.py} +126 -196
  54. ai_pipeline_core/observability/_document_tracking.py +146 -0
  55. ai_pipeline_core/observability/_initialization.py +194 -0
  56. ai_pipeline_core/observability/_logging_bridge.py +57 -0
  57. ai_pipeline_core/observability/_summary.py +81 -0
  58. ai_pipeline_core/observability/_tracking/__init__.py +6 -0
  59. ai_pipeline_core/observability/_tracking/_client.py +178 -0
  60. ai_pipeline_core/observability/_tracking/_internal.py +28 -0
  61. ai_pipeline_core/observability/_tracking/_models.py +138 -0
  62. ai_pipeline_core/observability/_tracking/_processor.py +158 -0
  63. ai_pipeline_core/observability/_tracking/_service.py +311 -0
  64. ai_pipeline_core/observability/_tracking/_writer.py +229 -0
  65. ai_pipeline_core/{tracing.py → observability/tracing.py} +139 -335
  66. ai_pipeline_core/pipeline/__init__.py +10 -0
  67. ai_pipeline_core/pipeline/decorators.py +915 -0
  68. ai_pipeline_core/pipeline/options.py +16 -0
  69. ai_pipeline_core/prompt_manager.py +16 -102
  70. ai_pipeline_core/settings.py +26 -31
  71. ai_pipeline_core/testing.py +9 -0
  72. ai_pipeline_core-0.4.0.dist-info/METADATA +807 -0
  73. ai_pipeline_core-0.4.0.dist-info/RECORD +76 -0
  74. ai_pipeline_core/debug/__init__.py +0 -26
  75. ai_pipeline_core/documents/document_list.py +0 -420
  76. ai_pipeline_core/documents/flow_document.py +0 -112
  77. ai_pipeline_core/documents/task_document.py +0 -117
  78. ai_pipeline_core/documents/temporary_document.py +0 -74
  79. ai_pipeline_core/flow/__init__.py +0 -9
  80. ai_pipeline_core/flow/config.py +0 -494
  81. ai_pipeline_core/flow/options.py +0 -75
  82. ai_pipeline_core/pipeline.py +0 -718
  83. ai_pipeline_core/prefect.py +0 -63
  84. ai_pipeline_core/prompt_builder/__init__.py +0 -5
  85. ai_pipeline_core/prompt_builder/documents_prompt.jinja2 +0 -23
  86. ai_pipeline_core/prompt_builder/global_cache.py +0 -78
  87. ai_pipeline_core/prompt_builder/new_core_documents_prompt.jinja2 +0 -6
  88. ai_pipeline_core/prompt_builder/prompt_builder.py +0 -253
  89. ai_pipeline_core/prompt_builder/system_prompt.jinja2 +0 -41
  90. ai_pipeline_core/storage/__init__.py +0 -8
  91. ai_pipeline_core/storage/storage.py +0 -628
  92. ai_pipeline_core/utils/__init__.py +0 -8
  93. ai_pipeline_core-0.3.3.dist-info/METADATA +0 -569
  94. ai_pipeline_core-0.3.3.dist-info/RECORD +0 -57
  95. {ai_pipeline_core-0.3.3.dist-info → ai_pipeline_core-0.4.0.dist-info}/WHEEL +0 -0
  96. {ai_pipeline_core-0.3.3.dist-info → ai_pipeline_core-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,915 @@
1
+ """Pipeline decorators with Prefect integration, tracing, and document lifecycle.
2
+
3
+ Wrappers around Prefect's @task and @flow that add Laminar tracing,
4
+ enforce async-only execution, and auto-save documents to the DocumentStore.
5
+ """
6
+
7
+ import datetime
8
+ import inspect
9
+ import types
10
+ from collections.abc import Callable, Coroutine, Iterable
11
+ from functools import wraps
12
+ from typing import (
13
+ Any,
14
+ Protocol,
15
+ TypeVar,
16
+ Union,
17
+ cast,
18
+ get_args,
19
+ get_origin,
20
+ get_type_hints,
21
+ overload,
22
+ )
23
+
24
+ from lmnr import Laminar
25
+ from prefect.assets import Asset
26
+ from prefect.cache_policies import CachePolicy
27
+ from prefect.context import TaskRunContext
28
+ from prefect.flows import FlowStateHook
29
+ from prefect.flows import flow as _prefect_flow
30
+ from prefect.futures import PrefectFuture
31
+ from prefect.results import ResultSerializer, ResultStorage
32
+ from prefect.task_runners import TaskRunner
33
+ from prefect.tasks import task as _prefect_task
34
+ from prefect.utilities.annotations import NotSet
35
+ from pydantic import BaseModel
36
+
37
+ from ai_pipeline_core.document_store import get_document_store
38
+ from ai_pipeline_core.documents import Document
39
+ from ai_pipeline_core.documents.context import (
40
+ RunContext,
41
+ TaskDocumentContext,
42
+ get_run_context,
43
+ reset_run_context,
44
+ reset_task_context,
45
+ set_run_context,
46
+ set_task_context,
47
+ )
48
+ from ai_pipeline_core.documents.utils import is_document_sha256
49
+ from ai_pipeline_core.logging import get_pipeline_logger
50
+ from ai_pipeline_core.observability._document_tracking import get_current_span_id, track_flow_io, track_task_io
51
+ from ai_pipeline_core.observability._initialization import get_tracking_service
52
+ from ai_pipeline_core.observability._tracking._models import DocumentEventType
53
+ from ai_pipeline_core.observability.tracing import TraceLevel, set_trace_cost, trace
54
+ from ai_pipeline_core.pipeline.options import FlowOptions
55
+
56
+ logger = get_pipeline_logger(__name__)
57
+
58
+ # --------------------------------------------------------------------------- #
59
+ # Public callback aliases (Prefect stubs omit these exact types)
60
+ # --------------------------------------------------------------------------- #
61
+ type RetryConditionCallable = Callable[[Any, Any, Any], bool]
62
+ type StateHookCallable = Callable[[Any, Any, Any], None]
63
+ type TaskRunNameValueOrCallable = str | Callable[[], str]
64
+
65
+ # --------------------------------------------------------------------------- #
66
+ # Typing helpers
67
+ # --------------------------------------------------------------------------- #
68
+ R_co = TypeVar("R_co", covariant=True)
69
+ FO_contra = TypeVar("FO_contra", bound=FlowOptions, contravariant=True)
70
+
71
+
72
+ class _TaskLike(Protocol[R_co]):
73
+ """Protocol for type-safe Prefect task representation."""
74
+
75
+ def __call__(self, *args: Any, **kwargs: Any) -> Coroutine[Any, Any, R_co]: ...
76
+
77
+ submit: Callable[..., Any]
78
+ map: Callable[..., Any]
79
+ name: str | None
80
+ estimated_minutes: int
81
+
82
+ def __getattr__(self, name: str) -> Any: ...
83
+
84
+
85
+ class _FlowLike(Protocol[FO_contra]):
86
+ """Protocol for decorated flow objects returned by @pipeline_flow."""
87
+
88
+ def __call__(
89
+ self,
90
+ project_name: str,
91
+ documents: list[Document],
92
+ flow_options: FO_contra,
93
+ ) -> Coroutine[Any, Any, list[Document]]: ...
94
+
95
+ name: str | None
96
+ input_document_types: list[type[Document]]
97
+ output_document_types: list[type[Document]]
98
+ estimated_minutes: int
99
+
100
+ def __getattr__(self, name: str) -> Any: ...
101
+
102
+
103
+ # --------------------------------------------------------------------------- #
104
+ # Annotation parsing helpers
105
+ # --------------------------------------------------------------------------- #
106
+ def _flatten_union(tp: Any) -> list[Any]:
107
+ """Flatten Union / X | Y annotations into a list of constituent types."""
108
+ origin = get_origin(tp)
109
+ if origin is Union or isinstance(tp, types.UnionType):
110
+ result: list[Any] = []
111
+ for arg in get_args(tp):
112
+ result.extend(_flatten_union(arg))
113
+ return result
114
+ return [tp]
115
+
116
+
117
+ def _find_non_document_leaves(tp: Any) -> list[Any]:
118
+ """Walk a return type annotation and collect leaf types that are not Document subclasses or NoneType.
119
+
120
+ Returns empty list when all leaf types are valid (Document subclasses or None).
121
+ Used by @pipeline_task to validate return annotations at decoration time.
122
+ """
123
+ if tp is type(None) or (isinstance(tp, type) and issubclass(tp, Document)):
124
+ return []
125
+
126
+ origin = get_origin(tp)
127
+
128
+ # Union / X | Y: all branches must be valid
129
+ if origin is Union or isinstance(tp, types.UnionType):
130
+ return [leaf for arg in get_args(tp) for leaf in _find_non_document_leaves(arg)]
131
+
132
+ # list[X]: recurse into element type
133
+ if origin is list:
134
+ args = get_args(tp)
135
+ return _find_non_document_leaves(args[0]) if args else [tp]
136
+
137
+ # tuple[X, Y] or tuple[X, ...]
138
+ if origin is tuple:
139
+ args = get_args(tp)
140
+ if not args:
141
+ return [tp]
142
+ elements = (args[0],) if (len(args) == 2 and args[1] is Ellipsis) else args
143
+ return [leaf for arg in elements for leaf in _find_non_document_leaves(arg)]
144
+
145
+ # Everything else is invalid (int, str, Any, object, dict, etc.)
146
+ return [tp]
147
+
148
+
149
+ def _parse_document_types_from_annotation(annotation: Any) -> list[type[Document]]:
150
+ """Extract Document subclasses from a list[...] type annotation.
151
+
152
+ Handles list[DocA], list[DocA | DocB], list[Union[DocA, DocB]].
153
+ Returns empty list if annotation is not a list of Document subclasses.
154
+ """
155
+ origin = get_origin(annotation)
156
+ if origin is not list:
157
+ return []
158
+
159
+ args = get_args(annotation)
160
+ if not args:
161
+ return []
162
+
163
+ inner = args[0]
164
+ flat = _flatten_union(inner)
165
+
166
+ return [t for t in flat if isinstance(t, type) and issubclass(t, Document)]
167
+
168
+
169
+ def _resolve_type_hints(fn: Callable[..., Any]) -> dict[str, Any]:
170
+ """Safely resolve type hints, falling back to empty dict on failure."""
171
+ try:
172
+ return get_type_hints(fn, include_extras=True)
173
+ except Exception:
174
+ logger.warning(
175
+ "Failed to resolve type hints for '%s'. Ensure all annotations are valid and importable.",
176
+ _callable_name(fn, "unknown"),
177
+ )
178
+ return {}
179
+
180
+
181
+ # --------------------------------------------------------------------------- #
182
+ # Document extraction helper
183
+ # --------------------------------------------------------------------------- #
184
+ def _extract_documents(result: Any) -> list[Document]:
185
+ """Recursively extract unique Document instances from a result value.
186
+
187
+ Walks tuples, lists, dicts, and Pydantic BaseModel fields.
188
+ Deduplicates by object identity (same instance appearing multiple times
189
+ is collected only once). Checks Document before BaseModel since Document
190
+ IS a BaseModel subclass.
191
+ """
192
+ docs: list[Document] = []
193
+ seen: set[int] = set()
194
+
195
+ def _walk(value: Any) -> None:
196
+ obj_id = id(value)
197
+ if obj_id in seen:
198
+ return
199
+ seen.add(obj_id)
200
+
201
+ if isinstance(value, Document):
202
+ docs.append(value)
203
+ return
204
+ if isinstance(value, (list, tuple)):
205
+ for item in cast(Iterable[Any], value):
206
+ _walk(item)
207
+ return
208
+ if isinstance(value, dict):
209
+ for v in cast(Iterable[Any], value.values()):
210
+ _walk(v)
211
+ return
212
+ if isinstance(value, BaseModel):
213
+ for field_name in type(value).model_fields:
214
+ _walk(getattr(value, field_name))
215
+ return
216
+
217
+ _walk(result)
218
+ return docs
219
+
220
+
221
+ # --------------------------------------------------------------------------- #
222
+ # Small helpers
223
+ # --------------------------------------------------------------------------- #
224
+ def _callable_name(obj: Any, fallback: str) -> str:
225
+ """Safely extract callable's name for error messages."""
226
+ try:
227
+ n = getattr(obj, "__name__", None)
228
+ return n if isinstance(n, str) else fallback
229
+ except Exception:
230
+ return fallback
231
+
232
+
233
+ def _is_already_traced(func: Callable[..., Any]) -> bool:
234
+ """Check if a function has already been wrapped by the trace decorator."""
235
+ if hasattr(func, "__is_traced__") and func.__is_traced__: # type: ignore[attr-defined]
236
+ return True
237
+
238
+ current = func
239
+ depth = 0
240
+ while hasattr(current, "__wrapped__") and depth < 10:
241
+ wrapped = current.__wrapped__ # type: ignore[attr-defined]
242
+ if hasattr(wrapped, "__is_traced__") and wrapped.__is_traced__:
243
+ return True
244
+ current = wrapped
245
+ depth += 1
246
+ return False
247
+
248
+
249
+ # --------------------------------------------------------------------------- #
250
+ # Tracking helpers
251
+ # --------------------------------------------------------------------------- #
252
+ def _resolve_label(user_summary: str | bool, fn: Callable[..., Any], kwargs: dict[str, Any]) -> str:
253
+ """Resolve user_summary to a label string."""
254
+ if isinstance(user_summary, str):
255
+ try:
256
+ return user_summary.format(**kwargs)
257
+ except (KeyError, IndexError):
258
+ return user_summary
259
+ return _callable_name(fn, "task").replace("_", " ").title()
260
+
261
+
262
+ def _build_output_hint(result: object) -> str:
263
+ """Build a privacy-safe metadata string describing a task's output."""
264
+ if result is None:
265
+ return "None"
266
+ if isinstance(result, list) and result and isinstance(result[0], Document):
267
+ doc_list = cast(list[Document], result)
268
+ class_counts: dict[str, int] = {}
269
+ total_size = 0
270
+ for doc in doc_list:
271
+ cls_name = type(doc).__name__
272
+ class_counts[cls_name] = class_counts.get(cls_name, 0) + 1
273
+ total_size += len(doc.content)
274
+ parts = [f"{name} x{count}" for name, count in class_counts.items()]
275
+ return f"{len(doc_list)} documents ({', '.join(parts)}), total {total_size / 1024:.0f}KB"
276
+ if isinstance(result, Document):
277
+ size = len(result.content)
278
+ return f"{type(result).__name__}(name={result.name!r}, {size / 1024:.0f}KB, {result.mime_type})"
279
+ if isinstance(result, (list, tuple)):
280
+ seq = cast(list[Any] | tuple[Any, ...], result)
281
+ return f"{type(seq).__name__} with {len(seq)} items"
282
+ if isinstance(result, str):
283
+ return f"str ({len(result)} chars)"
284
+ return type(result).__name__
285
+
286
+
287
+ # --------------------------------------------------------------------------- #
288
+ # Store event emission helper
289
+ # --------------------------------------------------------------------------- #
290
+ def _emit_store_events(documents: list[Document], event_type: DocumentEventType) -> None:
291
+ """Emit store lifecycle events for documents. No-op if tracking is not available."""
292
+ try:
293
+ service = get_tracking_service()
294
+ if service is None:
295
+ return
296
+ span_id = get_current_span_id()
297
+ for doc in documents:
298
+ service.track_document_event(
299
+ document_sha256=doc.sha256,
300
+ span_id=span_id,
301
+ event_type=event_type,
302
+ )
303
+ except Exception:
304
+ pass
305
+
306
+
307
+ # --------------------------------------------------------------------------- #
308
+ # Document persistence helper (used by @pipeline_task and @pipeline_flow)
309
+ # --------------------------------------------------------------------------- #
310
+ async def _persist_documents(
311
+ documents: list[Document],
312
+ label: str,
313
+ ctx: "TaskDocumentContext",
314
+ *,
315
+ check_created: bool = False,
316
+ ) -> None:
317
+ """Validate provenance, deduplicate, and save documents to the store.
318
+
319
+ Silently skips if no store or no run context is configured.
320
+ Logs warnings on persistence failure (graceful degradation).
321
+
322
+ Args:
323
+ check_created: When True, warn if a returned document was not created in this
324
+ context. Enabled for @pipeline_task, disabled for @pipeline_flow (flows
325
+ delegate creation to nested tasks).
326
+ """
327
+ run_ctx = get_run_context()
328
+ store = get_document_store()
329
+ if run_ctx is None or store is None:
330
+ return
331
+
332
+ if not documents:
333
+ return
334
+
335
+ deduped: list[Document] = []
336
+ try:
337
+ # Collect all SHA256 references (sources + origins) for existence check
338
+ ref_sha256s: set[str] = set()
339
+ for doc in documents:
340
+ for src in doc.sources:
341
+ if is_document_sha256(src):
342
+ ref_sha256s.add(src)
343
+ for origin in doc.origins:
344
+ ref_sha256s.add(origin)
345
+
346
+ existing: set[str] = set()
347
+ if ref_sha256s:
348
+ existing = await store.check_existing(sorted(ref_sha256s))
349
+
350
+ provenance_warnings = ctx.validate_provenance(documents, existing, check_created=check_created)
351
+ for warning in provenance_warnings:
352
+ logger.warning("[%s] %s", label, warning)
353
+
354
+ # Deduplicate and save
355
+ deduped = TaskDocumentContext.deduplicate(documents)
356
+ await store.save_batch(deduped, run_ctx.run_scope)
357
+
358
+ _emit_store_events(deduped, DocumentEventType.STORE_SAVED)
359
+
360
+ # Finalize: warn about created-but-not-returned documents
361
+ finalize_warnings = ctx.finalize(documents)
362
+ for warning in finalize_warnings:
363
+ logger.warning("[%s] %s", label, warning)
364
+ except Exception:
365
+ _emit_store_events(deduped or documents, DocumentEventType.STORE_SAVE_FAILED)
366
+ logger.warning("Failed to persist documents from '%s'", label, exc_info=True)
367
+
368
+
369
+ # --------------------------------------------------------------------------- #
370
+ # @pipeline_task — async-only, traced, auto-persists documents
371
+ # --------------------------------------------------------------------------- #
372
+ @overload
373
+ def pipeline_task(__fn: Callable[..., Coroutine[Any, Any, R_co]], /) -> _TaskLike[R_co]: ... # noqa: UP047
374
+ @overload
375
+ def pipeline_task(
376
+ *,
377
+ # tracing
378
+ trace_level: TraceLevel = "always",
379
+ trace_ignore_input: bool = False,
380
+ trace_ignore_output: bool = False,
381
+ trace_ignore_inputs: list[str] | None = None,
382
+ trace_input_formatter: Callable[..., str] | None = None,
383
+ trace_output_formatter: Callable[..., str] | None = None,
384
+ trace_cost: float | None = None,
385
+ expected_cost: float | None = None,
386
+ trace_trim_documents: bool = True,
387
+ # tracking
388
+ user_summary: bool | str = False,
389
+ # document lifecycle
390
+ estimated_minutes: int = 1,
391
+ persist: bool = True,
392
+ # prefect passthrough
393
+ name: str | None = None,
394
+ description: str | None = None,
395
+ tags: Iterable[str] | None = None,
396
+ version: str | None = None,
397
+ cache_policy: CachePolicy | type[NotSet] = NotSet,
398
+ cache_key_fn: Callable[[TaskRunContext, dict[str, Any]], str | None] | None = None,
399
+ cache_expiration: datetime.timedelta | None = None,
400
+ task_run_name: TaskRunNameValueOrCallable | None = None,
401
+ retries: int | None = None,
402
+ retry_delay_seconds: int | float | list[float] | Callable[[int], list[float]] | None = None,
403
+ retry_jitter_factor: float | None = None,
404
+ persist_result: bool | None = None,
405
+ result_storage: ResultStorage | str | None = None,
406
+ result_serializer: ResultSerializer | str | None = None,
407
+ result_storage_key: str | None = None,
408
+ cache_result_in_memory: bool = True,
409
+ timeout_seconds: int | float | None = None,
410
+ log_prints: bool | None = False,
411
+ refresh_cache: bool | None = None,
412
+ on_completion: list[StateHookCallable] | None = None,
413
+ on_failure: list[StateHookCallable] | None = None,
414
+ retry_condition_fn: RetryConditionCallable | None = None,
415
+ viz_return_value: bool | None = None,
416
+ asset_deps: list[str | Asset] | None = None,
417
+ ) -> Callable[[Callable[..., Coroutine[Any, Any, R_co]]], _TaskLike[R_co]]: ...
418
+
419
+
420
+ def pipeline_task( # noqa: UP047
421
+ __fn: Callable[..., Coroutine[Any, Any, R_co]] | None = None,
422
+ /,
423
+ *,
424
+ # tracing
425
+ trace_level: TraceLevel = "always",
426
+ trace_ignore_input: bool = False,
427
+ trace_ignore_output: bool = False,
428
+ trace_ignore_inputs: list[str] | None = None,
429
+ trace_input_formatter: Callable[..., str] | None = None,
430
+ trace_output_formatter: Callable[..., str] | None = None,
431
+ trace_cost: float | None = None,
432
+ expected_cost: float | None = None,
433
+ trace_trim_documents: bool = True,
434
+ # tracking
435
+ user_summary: bool | str = False,
436
+ # document lifecycle
437
+ estimated_minutes: int = 1,
438
+ persist: bool = True,
439
+ # prefect passthrough
440
+ name: str | None = None,
441
+ description: str | None = None,
442
+ tags: Iterable[str] | None = None,
443
+ version: str | None = None,
444
+ cache_policy: CachePolicy | type[NotSet] = NotSet,
445
+ cache_key_fn: Callable[[TaskRunContext, dict[str, Any]], str | None] | None = None,
446
+ cache_expiration: datetime.timedelta | None = None,
447
+ task_run_name: TaskRunNameValueOrCallable | None = None,
448
+ retries: int | None = None,
449
+ retry_delay_seconds: int | float | list[float] | Callable[[int], list[float]] | None = None,
450
+ retry_jitter_factor: float | None = None,
451
+ persist_result: bool | None = None,
452
+ result_storage: ResultStorage | str | None = None,
453
+ result_serializer: ResultSerializer | str | None = None,
454
+ result_storage_key: str | None = None,
455
+ cache_result_in_memory: bool = True,
456
+ timeout_seconds: int | float | None = None,
457
+ log_prints: bool | None = False,
458
+ refresh_cache: bool | None = None,
459
+ on_completion: list[StateHookCallable] | None = None,
460
+ on_failure: list[StateHookCallable] | None = None,
461
+ retry_condition_fn: RetryConditionCallable | None = None,
462
+ viz_return_value: bool | None = None,
463
+ asset_deps: list[str | Asset] | None = None,
464
+ ) -> _TaskLike[R_co] | Callable[[Callable[..., Coroutine[Any, Any, R_co]]], _TaskLike[R_co]]:
465
+ """Decorate an async function as a traced Prefect task with document auto-save.
466
+
467
+ After the wrapped function returns, if documents are found in the result
468
+ and a DocumentStore + RunContext are available, documents are validated
469
+ for provenance, deduplicated by SHA256, and saved to the store.
470
+
471
+ When persist=True (default), the return type annotation is validated at
472
+ decoration time. Allowed return types::
473
+
474
+ -> MyDocument # single Document
475
+ -> list[DocA] / list[DocA | DocB] # list of Documents
476
+ -> tuple[DocA, DocB] # tuple of Documents
477
+ -> tuple[list[DocA], list[DocB]] # tuple of lists
478
+ -> tuple[DocA, ...] # variable-length tuple
479
+ -> None # side-effect tasks
480
+ -> DocA | None # optional Document
481
+
482
+ Use persist=False for tasks returning non-document values (tracing and
483
+ retries still apply, but no return type validation or document auto-save).
484
+
485
+ Args:
486
+ __fn: Function to decorate (when used without parentheses).
487
+ trace_level: When to trace ("always", "debug", "off").
488
+ trace_ignore_input: Don't trace input arguments.
489
+ trace_ignore_output: Don't trace return value.
490
+ trace_ignore_inputs: List of parameter names to exclude from tracing.
491
+ trace_input_formatter: Custom formatter for input tracing.
492
+ trace_output_formatter: Custom formatter for output tracing.
493
+ trace_cost: Optional cost value to track in metadata.
494
+ expected_cost: Optional expected cost budget for this task.
495
+ trace_trim_documents: Trim document content in traces (default True).
496
+ user_summary: Enable LLM-generated span summaries.
497
+ estimated_minutes: Estimated duration for progress tracking (must be > 0).
498
+ persist: Auto-save returned documents to the store (default True).
499
+ name: Task name (defaults to function name).
500
+ description: Human-readable task description.
501
+ tags: Tags for organization and filtering.
502
+ version: Task version string.
503
+ cache_policy: Caching policy for task results.
504
+ cache_key_fn: Custom cache key generation.
505
+ cache_expiration: How long to cache results.
506
+ task_run_name: Dynamic or static run name.
507
+ retries: Number of retry attempts (default 0).
508
+ retry_delay_seconds: Delay between retries.
509
+ retry_jitter_factor: Random jitter for retry delays.
510
+ persist_result: Whether to persist results.
511
+ result_storage: Where to store results.
512
+ result_serializer: How to serialize results.
513
+ result_storage_key: Custom storage key.
514
+ cache_result_in_memory: Keep results in memory.
515
+ timeout_seconds: Task execution timeout.
516
+ log_prints: Capture print() statements.
517
+ refresh_cache: Force cache refresh.
518
+ on_completion: Hooks for successful completion.
519
+ on_failure: Hooks for task failure.
520
+ retry_condition_fn: Custom retry condition.
521
+ viz_return_value: Include return value in visualization.
522
+ asset_deps: Upstream asset dependencies.
523
+ """
524
+ if estimated_minutes < 1:
525
+ raise ValueError(f"estimated_minutes must be >= 1, got {estimated_minutes}")
526
+
527
+ task_decorator: Callable[..., Any] = _prefect_task
528
+
529
+ def _apply(fn: Callable[..., Coroutine[Any, Any, R_co]]) -> _TaskLike[R_co]:
530
+ fname = _callable_name(fn, "task")
531
+
532
+ if not inspect.iscoroutinefunction(fn):
533
+ raise TypeError(f"@pipeline_task target '{fname}' must be 'async def'")
534
+
535
+ if _is_already_traced(fn):
536
+ raise TypeError(
537
+ f"@pipeline_task target '{fname}' is already decorated "
538
+ f"with @trace. Remove the @trace decorator - @pipeline_task includes "
539
+ f"tracing automatically."
540
+ )
541
+
542
+ # Reject stale DocumentList references in annotations
543
+ for ann_name, ann_value in getattr(fn, "__annotations__", {}).items():
544
+ if "DocumentList" in str(ann_value):
545
+ label = "return type" if ann_name == "return" else f"parameter '{ann_name}'"
546
+ raise TypeError(f"@pipeline_task '{fname}' {label} references 'DocumentList' which has been removed. Use 'list[Document]' instead.")
547
+
548
+ # Validate return type annotation when persist=True
549
+ if persist:
550
+ hints = _resolve_type_hints(fn)
551
+ if "return" not in hints:
552
+ raise TypeError(
553
+ f"@pipeline_task '{fname}': missing return type annotation. "
554
+ f"Persisted tasks must return Document types "
555
+ f"(Document, list[Document], tuple[Document, ...], or None). "
556
+ f"Add a return annotation or use persist=False."
557
+ )
558
+ bad_types = _find_non_document_leaves(hints["return"])
559
+ if bad_types:
560
+ bad_names = ", ".join(getattr(t, "__name__", str(t)) for t in bad_types)
561
+ raise TypeError(
562
+ f"@pipeline_task '{fname}': return type contains non-Document types: {bad_names}. "
563
+ f"Persisted tasks must return Document, list[Document], "
564
+ f"tuple[Document, ...], or None. "
565
+ f"Use persist=False for tasks returning non-document values."
566
+ )
567
+
568
+ @wraps(fn)
569
+ async def _wrapper(*args: Any, **kwargs: Any) -> R_co:
570
+ attrs: dict[str, Any] = {}
571
+ if description:
572
+ attrs["description"] = description
573
+ if expected_cost is not None:
574
+ attrs["expected_cost"] = expected_cost
575
+ if attrs:
576
+ try:
577
+ Laminar.set_span_attributes(attrs) # pyright: ignore[reportArgumentType]
578
+ except Exception:
579
+ pass
580
+
581
+ # Set up TaskDocumentContext BEFORE calling fn() so Document.__init__ can register
582
+ ctx: TaskDocumentContext | None = None
583
+ task_token = None
584
+ if persist and get_run_context() is not None and get_document_store() is not None:
585
+ ctx = TaskDocumentContext()
586
+ task_token = set_task_context(ctx)
587
+
588
+ try:
589
+ result = await fn(*args, **kwargs)
590
+ finally:
591
+ if task_token is not None:
592
+ reset_task_context(task_token)
593
+
594
+ if trace_cost is not None and trace_cost > 0:
595
+ set_trace_cost(trace_cost)
596
+
597
+ # Track task I/O and schedule summaries
598
+ try:
599
+ track_task_io(fname, args, kwargs, result)
600
+ except Exception:
601
+ pass
602
+
603
+ if user_summary:
604
+ try:
605
+ service = get_tracking_service()
606
+ if service is not None:
607
+ span_id = get_current_span_id()
608
+ if span_id:
609
+ label = _resolve_label(user_summary, fn, kwargs)
610
+ output_hint = _build_output_hint(result)
611
+ service.schedule_summary(span_id, label, output_hint)
612
+ except Exception:
613
+ pass
614
+
615
+ # Document auto-save
616
+ if persist and ctx is not None:
617
+ await _persist_documents(_extract_documents(result), fname, ctx, check_created=True)
618
+
619
+ return result
620
+
621
+ traced_fn = trace(
622
+ level=trace_level,
623
+ name=name or fname,
624
+ ignore_input=trace_ignore_input,
625
+ ignore_output=trace_ignore_output,
626
+ ignore_inputs=trace_ignore_inputs,
627
+ input_formatter=trace_input_formatter,
628
+ output_formatter=trace_output_formatter,
629
+ trim_documents=trace_trim_documents,
630
+ )(_wrapper)
631
+
632
+ task_obj = cast(
633
+ _TaskLike[R_co],
634
+ task_decorator(
635
+ name=name or fname,
636
+ description=description,
637
+ tags=tags,
638
+ version=version,
639
+ cache_policy=cache_policy,
640
+ cache_key_fn=cache_key_fn,
641
+ cache_expiration=cache_expiration,
642
+ task_run_name=task_run_name or name or fname,
643
+ retries=0 if retries is None else retries,
644
+ retry_delay_seconds=retry_delay_seconds,
645
+ retry_jitter_factor=retry_jitter_factor,
646
+ persist_result=persist_result,
647
+ result_storage=result_storage,
648
+ result_serializer=result_serializer,
649
+ result_storage_key=result_storage_key,
650
+ cache_result_in_memory=cache_result_in_memory,
651
+ timeout_seconds=timeout_seconds,
652
+ log_prints=log_prints,
653
+ refresh_cache=refresh_cache,
654
+ on_completion=on_completion,
655
+ on_failure=on_failure,
656
+ retry_condition_fn=retry_condition_fn,
657
+ viz_return_value=viz_return_value,
658
+ asset_deps=asset_deps,
659
+ )(traced_fn),
660
+ )
661
+ task_obj.estimated_minutes = estimated_minutes
662
+ return task_obj
663
+
664
+ return _apply(__fn) if __fn else _apply
665
+
666
+
667
+ # --------------------------------------------------------------------------- #
668
+ # @pipeline_flow — async-only, traced, annotation-driven document types
669
+ # --------------------------------------------------------------------------- #
670
+ def pipeline_flow(
671
+ *,
672
+ # tracing
673
+ trace_level: TraceLevel = "always",
674
+ trace_ignore_input: bool = False,
675
+ trace_ignore_output: bool = False,
676
+ trace_ignore_inputs: list[str] | None = None,
677
+ trace_input_formatter: Callable[..., str] | None = None,
678
+ trace_output_formatter: Callable[..., str] | None = None,
679
+ trace_cost: float | None = None,
680
+ expected_cost: float | None = None,
681
+ trace_trim_documents: bool = True,
682
+ # tracking
683
+ user_summary: bool | str = False,
684
+ # document type specification
685
+ estimated_minutes: int = 1,
686
+ # prefect passthrough
687
+ name: str | None = None,
688
+ version: str | None = None,
689
+ flow_run_name: Callable[[], str] | str | None = None,
690
+ retries: int | None = None,
691
+ retry_delay_seconds: int | float | None = None,
692
+ task_runner: TaskRunner[PrefectFuture[Any]] | None = None,
693
+ description: str | None = None,
694
+ timeout_seconds: int | float | None = None,
695
+ validate_parameters: bool = True,
696
+ persist_result: bool | None = None,
697
+ result_storage: ResultStorage | str | None = None,
698
+ result_serializer: ResultSerializer | str | None = None,
699
+ cache_result_in_memory: bool = True,
700
+ log_prints: bool | None = None,
701
+ on_completion: list[FlowStateHook[Any, Any]] | None = None,
702
+ on_failure: list[FlowStateHook[Any, Any]] | None = None,
703
+ on_cancellation: list[FlowStateHook[Any, Any]] | None = None,
704
+ on_crashed: list[FlowStateHook[Any, Any]] | None = None,
705
+ on_running: list[FlowStateHook[Any, Any]] | None = None,
706
+ ) -> Callable[[Callable[..., Coroutine[Any, Any, list[Document]]]], _FlowLike[Any]]:
707
+ """Decorate an async function as a traced Prefect flow with annotation-driven document types.
708
+
709
+ Extracts input/output document types from the function's type annotations
710
+ at decoration time and attaches them as ``input_document_types`` and
711
+ ``output_document_types`` attributes on the returned flow object.
712
+
713
+ Required function signature::
714
+
715
+ @pipeline_flow(estimated_minutes=30)
716
+ async def my_flow(
717
+ project_name: str,
718
+ documents: list[DocA | DocB],
719
+ flow_options: FlowOptions,
720
+ ) -> list[OutputDoc]:
721
+ ...
722
+
723
+ Args:
724
+ user_summary: Enable LLM-generated span summaries.
725
+ estimated_minutes: Estimated duration for progress tracking (must be >= 1).
726
+
727
+ Returns:
728
+ Decorator that produces a _FlowLike object with ``input_document_types``,
729
+ ``output_document_types``, and ``estimated_minutes`` attributes.
730
+
731
+ Raises:
732
+ TypeError: If the function is not async, has wrong parameter count/types,
733
+ missing return annotation, or output types overlap input types.
734
+ ValueError: If estimated_minutes < 1.
735
+ """
736
+ if estimated_minutes < 1:
737
+ raise ValueError(f"estimated_minutes must be >= 1, got {estimated_minutes}")
738
+
739
+ flow_decorator: Callable[..., Any] = _prefect_flow
740
+
741
+ def _apply(fn: Callable[..., Coroutine[Any, Any, list[Document]]]) -> _FlowLike[Any]:
742
+ fname = _callable_name(fn, "flow")
743
+
744
+ if not inspect.iscoroutinefunction(fn):
745
+ raise TypeError(f"@pipeline_flow '{fname}' must be declared with 'async def'")
746
+
747
+ if _is_already_traced(fn):
748
+ raise TypeError(
749
+ f"@pipeline_flow target '{fname}' is already decorated "
750
+ f"with @trace. Remove the @trace decorator - @pipeline_flow includes "
751
+ f"tracing automatically."
752
+ )
753
+
754
+ sig = inspect.signature(fn)
755
+ params = list(sig.parameters.values())
756
+ if len(params) != 3:
757
+ raise TypeError(
758
+ f"@pipeline_flow '{fname}' must have exactly 3 parameters "
759
+ f"(project_name: str, documents: list[...], flow_options: FlowOptions), got {len(params)}"
760
+ )
761
+
762
+ # Resolve document types from annotations
763
+ hints = _resolve_type_hints(fn)
764
+
765
+ # Validate first parameter is str
766
+ if params[0].name in hints and hints[params[0].name] is not str:
767
+ raise TypeError(f"@pipeline_flow '{fname}': first parameter '{params[0].name}' must be annotated as 'str'")
768
+
769
+ # Validate third parameter is FlowOptions or subclass
770
+ if params[2].name in hints:
771
+ p2_type = hints[params[2].name]
772
+ if not (isinstance(p2_type, type) and issubclass(p2_type, FlowOptions)):
773
+ raise TypeError(f"@pipeline_flow '{fname}': third parameter '{params[2].name}' must be FlowOptions or subclass, got {p2_type}")
774
+
775
+ # Extract input types from documents parameter annotation
776
+ resolved_input_types: list[type[Document]]
777
+ if params[1].name in hints:
778
+ resolved_input_types = _parse_document_types_from_annotation(hints[params[1].name])
779
+ else:
780
+ resolved_input_types = []
781
+
782
+ # Extract output types from return annotation
783
+ resolved_output_types: list[type[Document]]
784
+ if "return" in hints:
785
+ resolved_output_types = _parse_document_types_from_annotation(hints["return"])
786
+ else:
787
+ resolved_output_types = []
788
+
789
+ # Validate return annotation contains Document subclasses
790
+ if "return" in hints and not resolved_output_types:
791
+ raise TypeError(
792
+ f"@pipeline_flow '{fname}': return annotation does not contain "
793
+ f"Document subclasses. Flows must return list[SomeDocument]. "
794
+ f"Got: {hints['return']}."
795
+ )
796
+
797
+ # Output types must not overlap input types (skip for base Document used in generic flows)
798
+ if resolved_output_types and resolved_input_types:
799
+ overlap = set(resolved_output_types) & set(resolved_input_types) - {Document}
800
+ if overlap:
801
+ names = ", ".join(t.__name__ for t in overlap)
802
+ raise TypeError(f"@pipeline_flow '{fname}': output types [{names}] cannot also be input types")
803
+
804
+ @wraps(fn)
805
+ async def _wrapper(
806
+ project_name: str,
807
+ documents: list[Document],
808
+ flow_options: Any,
809
+ ) -> list[Document]:
810
+ attrs: dict[str, Any] = {}
811
+ if description:
812
+ attrs["description"] = description
813
+ if expected_cost is not None:
814
+ attrs["expected_cost"] = expected_cost
815
+ if attrs:
816
+ try:
817
+ Laminar.set_span_attributes(attrs) # pyright: ignore[reportArgumentType]
818
+ except Exception:
819
+ pass
820
+
821
+ # Set RunContext for nested tasks (only if not already set by deployment)
822
+ existing_ctx = get_run_context()
823
+ run_token = None
824
+ if existing_ctx is None:
825
+ run_scope = f"{project_name}/{name or fname}"
826
+ run_token = set_run_context(RunContext(run_scope=run_scope))
827
+
828
+ # Set up TaskDocumentContext for flow-level document lifecycle
829
+ ctx: TaskDocumentContext | None = None
830
+ task_token = None
831
+ if get_run_context() is not None and get_document_store() is not None:
832
+ ctx = TaskDocumentContext()
833
+ task_token = set_task_context(ctx)
834
+
835
+ try:
836
+ result = await fn(project_name, documents, flow_options)
837
+ finally:
838
+ if task_token is not None:
839
+ reset_task_context(task_token)
840
+ if run_token is not None:
841
+ reset_run_context(run_token)
842
+
843
+ if trace_cost is not None and trace_cost > 0:
844
+ set_trace_cost(trace_cost)
845
+ if not isinstance(result, list): # pyright: ignore[reportUnnecessaryIsInstance] # runtime guard
846
+ raise TypeError(f"Flow '{fname}' must return list[Document], got {type(result).__name__}")
847
+
848
+ # Track flow I/O
849
+ try:
850
+ track_flow_io(fname, documents, result)
851
+ except Exception:
852
+ pass
853
+
854
+ if user_summary:
855
+ try:
856
+ service = get_tracking_service()
857
+ if service is not None:
858
+ span_id = get_current_span_id()
859
+ if span_id:
860
+ label = _resolve_label(user_summary, fn, {"project_name": project_name, "flow_options": flow_options})
861
+ output_hint = _build_output_hint(result)
862
+ service.schedule_summary(span_id, label, output_hint)
863
+ except Exception:
864
+ pass
865
+
866
+ # Document auto-save
867
+ if ctx is not None:
868
+ await _persist_documents(result, fname, ctx)
869
+
870
+ return result
871
+
872
+ traced = trace(
873
+ level=trace_level,
874
+ name=name or fname,
875
+ ignore_input=trace_ignore_input,
876
+ ignore_output=trace_ignore_output,
877
+ ignore_inputs=trace_ignore_inputs,
878
+ input_formatter=trace_input_formatter,
879
+ output_formatter=trace_output_formatter,
880
+ trim_documents=trace_trim_documents,
881
+ )(_wrapper)
882
+
883
+ flow_obj = cast(
884
+ _FlowLike[Any],
885
+ flow_decorator(
886
+ name=name or fname,
887
+ version=version,
888
+ flow_run_name=flow_run_name or name or fname,
889
+ retries=0 if retries is None else retries,
890
+ retry_delay_seconds=retry_delay_seconds,
891
+ task_runner=task_runner,
892
+ description=description,
893
+ timeout_seconds=timeout_seconds,
894
+ validate_parameters=validate_parameters,
895
+ persist_result=persist_result,
896
+ result_storage=result_storage,
897
+ result_serializer=result_serializer,
898
+ cache_result_in_memory=cache_result_in_memory,
899
+ log_prints=log_prints,
900
+ on_completion=on_completion,
901
+ on_failure=on_failure,
902
+ on_cancellation=on_cancellation,
903
+ on_crashed=on_crashed,
904
+ on_running=on_running,
905
+ )(traced),
906
+ )
907
+ flow_obj.input_document_types = resolved_input_types
908
+ flow_obj.output_document_types = resolved_output_types
909
+ flow_obj.estimated_minutes = estimated_minutes
910
+ return flow_obj
911
+
912
+ return _apply
913
+
914
+
915
+ __all__ = ["pipeline_flow", "pipeline_task"]