ai-pipeline-core 0.1.7__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,397 +1,395 @@
1
- """Pipeline decorators that combine Prefect functionality with tracing support.
1
+ """
2
+ ai_pipeline_core.pipeline
3
+ =========================
4
+
5
+ Tiny wrappers around Prefect's public ``@task`` and ``@flow`` that add our
6
+ ``trace`` decorator and **require async functions**.
7
+
8
+ Why this exists
9
+ ---------------
10
+ Prefect tasks/flows are awaitable at runtime, but their public type stubs
11
+ don’t declare that clearly. We therefore:
12
+
13
+ 1) Return the **real Prefect objects** (so you keep every Prefect method).
14
+ 2) Type them as small Protocols that say “this is awaitable and has common
15
+ helpers like `.submit`/`.map`”.
2
16
 
3
- These decorators extend the base Prefect decorators with automatic tracing capabilities.
17
+ This keeps Pyright happy without altering runtime behavior and avoids
18
+ leaking advanced typing constructs (like ``ParamSpec``) that confuse tools
19
+ that introspect callables (e.g., Pydantic).
20
+
21
+ Quick start
22
+ -----------
23
+ from ai_pipeline_core.pipeline import pipeline_task, pipeline_flow
24
+ from ai_pipeline_core.documents import DocumentList
25
+ from ai_pipeline_core.flow.options import FlowOptions
26
+
27
+ @pipeline_task
28
+ async def add(x: int, y: int) -> int:
29
+ return x + y
30
+
31
+ @pipeline_flow
32
+ async def my_flow(project_name: str, docs: DocumentList, opts: FlowOptions) -> DocumentList:
33
+ await add(1, 2) # awaitable and typed
34
+ return docs
35
+
36
+ Rules
37
+ -----
38
+ • Your decorated function **must** be ``async def``.
39
+ • ``@pipeline_flow`` functions must accept at least:
40
+ (project_name: str, documents: DocumentList, flow_options: FlowOptions | subclass).
41
+ • Both wrappers return the same Prefect objects you’d get from Prefect directly.
4
42
  """
5
43
 
44
+ from __future__ import annotations
45
+
6
46
  import datetime
7
- import functools
8
47
  import inspect
9
- from typing import (
10
- TYPE_CHECKING,
11
- Any,
12
- Callable,
13
- Coroutine,
14
- Dict,
15
- Iterable,
16
- Optional,
17
- TypeVar,
18
- Union,
19
- cast,
20
- overload,
21
- )
48
+ from typing import Any, Callable, Coroutine, Iterable, Protocol, TypeVar, Union, cast, overload
22
49
 
23
50
  from prefect.assets import Asset
24
51
  from prefect.cache_policies import CachePolicy
25
52
  from prefect.context import TaskRunContext
26
- from prefect.flows import Flow, FlowStateHook
53
+ from prefect.flows import FlowStateHook
54
+ from prefect.flows import flow as _prefect_flow # public import
27
55
  from prefect.futures import PrefectFuture
28
56
  from prefect.results import ResultSerializer, ResultStorage
29
57
  from prefect.task_runners import TaskRunner
30
- from prefect.tasks import (
31
- RetryConditionCallable,
32
- StateHookCallable,
33
- Task,
34
- TaskRunNameValueOrCallable,
35
- )
58
+ from prefect.tasks import task as _prefect_task # public import
36
59
  from prefect.utilities.annotations import NotSet
37
- from typing_extensions import Concatenate, ParamSpec
60
+ from typing_extensions import TypeAlias
38
61
 
39
62
  from ai_pipeline_core.documents import DocumentList
40
63
  from ai_pipeline_core.flow.options import FlowOptions
41
- from ai_pipeline_core.prefect import flow, task
42
64
  from ai_pipeline_core.tracing import TraceLevel, trace
43
65
 
44
- if TYPE_CHECKING:
45
- pass
66
+ # --------------------------------------------------------------------------- #
67
+ # Public callback aliases (Prefect stubs omit these exact types)
68
+ # --------------------------------------------------------------------------- #
69
+ RetryConditionCallable: TypeAlias = Callable[[Any, Any, Any], bool]
70
+ StateHookCallable: TypeAlias = Callable[[Any, Any, Any], None]
71
+ TaskRunNameValueOrCallable: TypeAlias = Union[str, Callable[[], str]]
46
72
 
47
- P = ParamSpec("P")
48
- R = TypeVar("R")
73
+ # --------------------------------------------------------------------------- #
74
+ # Typing helpers
75
+ # --------------------------------------------------------------------------- #
76
+ R_co = TypeVar("R_co", covariant=True)
77
+ FO_contra = TypeVar("FO_contra", bound=FlowOptions, contravariant=True)
78
+ """Flow options are an *input* type, so contravariant fits the callable model."""
49
79
 
50
- # ============================================================================
51
- # PIPELINE TASK DECORATOR
52
- # ============================================================================
53
80
 
81
+ class _TaskLike(Protocol[R_co]):
82
+ """Minimal 'task-like' view: awaitable call + common helpers."""
83
+
84
+ def __call__(self, *args: Any, **kwargs: Any) -> Coroutine[Any, Any, R_co]: ...
85
+
86
+ submit: Callable[..., Any]
87
+ map: Callable[..., Any]
88
+ name: str | None
89
+
90
+ def __getattr__(self, name: str) -> Any: ... # allow unknown helpers without type errors
54
91
 
55
- @overload
56
- def pipeline_task(__fn: Callable[P, R], /) -> Task[P, R]: ...
57
92
 
93
+ class _DocumentsFlowCallable(Protocol[FO_contra]):
94
+ """User async flow signature (first three params fixed)."""
58
95
 
96
+ def __call__(
97
+ self,
98
+ project_name: str,
99
+ documents: DocumentList,
100
+ flow_options: FO_contra,
101
+ *args: Any,
102
+ **kwargs: Any,
103
+ ) -> Coroutine[Any, Any, DocumentList]: ...
104
+
105
+
106
+ class _FlowLike(Protocol[FO_contra]):
107
+ """Callable returned by Prefect ``@flow`` wrapper that we expose to users."""
108
+
109
+ def __call__(
110
+ self,
111
+ project_name: str,
112
+ documents: DocumentList,
113
+ flow_options: FO_contra,
114
+ *args: Any,
115
+ **kwargs: Any,
116
+ ) -> Coroutine[Any, Any, DocumentList]: ...
117
+
118
+ name: str | None
119
+
120
+ def __getattr__(self, name: str) -> Any: ... # allow unknown helpers without type errors
121
+
122
+
123
+ # --------------------------------------------------------------------------- #
124
+ # Small helper: safely get a callable's name without upsetting the type checker
125
+ # --------------------------------------------------------------------------- #
126
+ def _callable_name(obj: Any, fallback: str) -> str:
127
+ try:
128
+ n = getattr(obj, "__name__", None)
129
+ return n if isinstance(n, str) else fallback
130
+ except Exception:
131
+ return fallback
132
+
133
+
134
+ # --------------------------------------------------------------------------- #
135
+ # @pipeline_task — async-only, traced, returns Prefect's Task object
136
+ # --------------------------------------------------------------------------- #
137
+ @overload
138
+ def pipeline_task(__fn: Callable[..., Coroutine[Any, Any, R_co]], /) -> _TaskLike[R_co]: ...
59
139
  @overload
60
140
  def pipeline_task(
61
141
  *,
62
- # Tracing parameters
142
+ # tracing
63
143
  trace_level: TraceLevel = "always",
64
144
  trace_ignore_input: bool = False,
65
145
  trace_ignore_output: bool = False,
66
146
  trace_ignore_inputs: list[str] | None = None,
67
- trace_input_formatter: Optional[Callable[..., str]] = None,
68
- trace_output_formatter: Optional[Callable[..., str]] = None,
69
- # Prefect parameters
70
- name: Optional[str] = None,
71
- description: Optional[str] = None,
72
- tags: Optional[Iterable[str]] = None,
73
- version: Optional[str] = None,
74
- cache_policy: Union[CachePolicy, type[NotSet]] = NotSet,
75
- cache_key_fn: Optional[Callable[[TaskRunContext, Dict[str, Any]], Optional[str]]] = None,
76
- cache_expiration: Optional[datetime.timedelta] = None,
77
- task_run_name: Optional[TaskRunNameValueOrCallable] = None,
78
- retries: Optional[int] = None,
79
- retry_delay_seconds: Optional[
80
- Union[float, int, list[float], Callable[[int], list[float]]]
81
- ] = None,
82
- retry_jitter_factor: Optional[float] = None,
83
- persist_result: Optional[bool] = None,
84
- result_storage: Optional[Union[ResultStorage, str]] = None,
85
- result_serializer: Optional[Union[ResultSerializer, str]] = None,
86
- result_storage_key: Optional[str] = None,
147
+ trace_input_formatter: Callable[..., str] | None = None,
148
+ trace_output_formatter: Callable[..., str] | None = None,
149
+ # prefect passthrough
150
+ name: str | None = None,
151
+ description: str | None = None,
152
+ tags: Iterable[str] | None = None,
153
+ version: str | None = None,
154
+ cache_policy: CachePolicy | type[NotSet] = NotSet,
155
+ cache_key_fn: Callable[[TaskRunContext, dict[str, Any]], str | None] | None = None,
156
+ cache_expiration: datetime.timedelta | None = None,
157
+ task_run_name: TaskRunNameValueOrCallable | None = None,
158
+ retries: int | None = None,
159
+ retry_delay_seconds: int | float | list[float] | Callable[[int], list[float]] | None = None,
160
+ retry_jitter_factor: float | None = None,
161
+ persist_result: bool | None = None,
162
+ result_storage: ResultStorage | str | None = None,
163
+ result_serializer: ResultSerializer | str | None = None,
164
+ result_storage_key: str | None = None,
87
165
  cache_result_in_memory: bool = True,
88
- timeout_seconds: Union[int, float, None] = None,
89
- log_prints: Optional[bool] = False,
90
- refresh_cache: Optional[bool] = None,
91
- on_completion: Optional[list[StateHookCallable]] = None,
92
- on_failure: Optional[list[StateHookCallable]] = None,
93
- retry_condition_fn: Optional[RetryConditionCallable] = None,
94
- viz_return_value: Optional[bool] = None,
95
- asset_deps: Optional[list[Union[str, Asset]]] = None,
96
- ) -> Callable[[Callable[P, R]], Task[P, R]]: ...
166
+ timeout_seconds: int | float | None = None,
167
+ log_prints: bool | None = False,
168
+ refresh_cache: bool | None = None,
169
+ on_completion: list[StateHookCallable] | None = None,
170
+ on_failure: list[StateHookCallable] | None = None,
171
+ retry_condition_fn: RetryConditionCallable | None = None,
172
+ viz_return_value: bool | None = None,
173
+ asset_deps: list[str | Asset] | None = None,
174
+ ) -> Callable[[Callable[..., Coroutine[Any, Any, R_co]]], _TaskLike[R_co]]: ...
97
175
 
98
176
 
99
177
  def pipeline_task(
100
- __fn: Optional[Callable[P, R]] = None,
178
+ __fn: Callable[..., Coroutine[Any, Any, R_co]] | None = None,
101
179
  /,
102
180
  *,
103
- # Tracing parameters
181
+ # tracing
104
182
  trace_level: TraceLevel = "always",
105
183
  trace_ignore_input: bool = False,
106
184
  trace_ignore_output: bool = False,
107
185
  trace_ignore_inputs: list[str] | None = None,
108
- trace_input_formatter: Optional[Callable[..., str]] = None,
109
- trace_output_formatter: Optional[Callable[..., str]] = None,
110
- # Prefect parameters
111
- name: Optional[str] = None,
112
- description: Optional[str] = None,
113
- tags: Optional[Iterable[str]] = None,
114
- version: Optional[str] = None,
115
- cache_policy: Union[CachePolicy, type[NotSet]] = NotSet,
116
- cache_key_fn: Optional[Callable[[TaskRunContext, Dict[str, Any]], Optional[str]]] = None,
117
- cache_expiration: Optional[datetime.timedelta] = None,
118
- task_run_name: Optional[TaskRunNameValueOrCallable] = None,
119
- retries: Optional[int] = None,
120
- retry_delay_seconds: Optional[
121
- Union[float, int, list[float], Callable[[int], list[float]]]
122
- ] = None,
123
- retry_jitter_factor: Optional[float] = None,
124
- persist_result: Optional[bool] = None,
125
- result_storage: Optional[Union[ResultStorage, str]] = None,
126
- result_serializer: Optional[Union[ResultSerializer, str]] = None,
127
- result_storage_key: Optional[str] = None,
186
+ trace_input_formatter: Callable[..., str] | None = None,
187
+ trace_output_formatter: Callable[..., str] | None = None,
188
+ # prefect passthrough
189
+ name: str | None = None,
190
+ description: str | None = None,
191
+ tags: Iterable[str] | None = None,
192
+ version: str | None = None,
193
+ cache_policy: CachePolicy | type[NotSet] = NotSet,
194
+ cache_key_fn: Callable[[TaskRunContext, dict[str, Any]], str | None] | None = None,
195
+ cache_expiration: datetime.timedelta | None = None,
196
+ task_run_name: TaskRunNameValueOrCallable | None = None,
197
+ retries: int | None = None,
198
+ retry_delay_seconds: int | float | list[float] | Callable[[int], list[float]] | None = None,
199
+ retry_jitter_factor: float | None = None,
200
+ persist_result: bool | None = None,
201
+ result_storage: ResultStorage | str | None = None,
202
+ result_serializer: ResultSerializer | str | None = None,
203
+ result_storage_key: str | None = None,
128
204
  cache_result_in_memory: bool = True,
129
- timeout_seconds: Union[int, float, None] = None,
130
- log_prints: Optional[bool] = False,
131
- refresh_cache: Optional[bool] = None,
132
- on_completion: Optional[list[StateHookCallable]] = None,
133
- on_failure: Optional[list[StateHookCallable]] = None,
134
- retry_condition_fn: Optional[RetryConditionCallable] = None,
135
- viz_return_value: Optional[bool] = None,
136
- asset_deps: Optional[list[Union[str, Asset]]] = None,
137
- ) -> Union[Task[P, R], Callable[[Callable[P, R]], Task[P, R]]]:
138
- """
139
- Pipeline task decorator that combines Prefect task functionality with automatic tracing.
140
-
141
- This decorator applies tracing before the Prefect task decorator, allowing you to
142
- monitor task execution with LMNR while maintaining all Prefect functionality.
143
-
144
- Args:
145
- trace_level: Control tracing ("always", "debug", "off")
146
- trace_ignore_input: Whether to ignore input in traces
147
- trace_ignore_output: Whether to ignore output in traces
148
- trace_ignore_inputs: List of input parameter names to ignore
149
- trace_input_formatter: Custom formatter for inputs
150
- trace_output_formatter: Custom formatter for outputs
205
+ timeout_seconds: int | float | None = None,
206
+ log_prints: bool | None = False,
207
+ refresh_cache: bool | None = None,
208
+ on_completion: list[StateHookCallable] | None = None,
209
+ on_failure: list[StateHookCallable] | None = None,
210
+ retry_condition_fn: RetryConditionCallable | None = None,
211
+ viz_return_value: bool | None = None,
212
+ asset_deps: list[str | Asset] | None = None,
213
+ ) -> _TaskLike[R_co] | Callable[[Callable[..., Coroutine[Any, Any, R_co]]], _TaskLike[R_co]]:
214
+ """Decorate an **async** function as a traced Prefect task."""
215
+ task_decorator: Callable[..., Any] = _prefect_task # helps the type checker
216
+
217
+ def _apply(fn: Callable[..., Coroutine[Any, Any, R_co]]) -> _TaskLike[R_co]:
218
+ if not inspect.iscoroutinefunction(fn):
219
+ raise TypeError(
220
+ f"@pipeline_task target '{_callable_name(fn, 'task')}' must be 'async def'"
221
+ )
151
222
 
152
- Plus all standard Prefect task parameters...
153
- """
223
+ traced_fn = trace(
224
+ level=trace_level,
225
+ name=name or _callable_name(fn, "task"),
226
+ ignore_input=trace_ignore_input,
227
+ ignore_output=trace_ignore_output,
228
+ ignore_inputs=trace_ignore_inputs,
229
+ input_formatter=trace_input_formatter,
230
+ output_formatter=trace_output_formatter,
231
+ )(fn)
154
232
 
155
- def decorator(fn: Callable[P, R]) -> Task[P, R]:
156
- # Apply tracing first if enabled
157
- if trace_level != "off":
158
- traced_fn = trace(
159
- level=trace_level,
160
- name=name or fn.__name__,
161
- ignore_input=trace_ignore_input,
162
- ignore_output=trace_ignore_output,
163
- ignore_inputs=trace_ignore_inputs,
164
- input_formatter=trace_input_formatter,
165
- output_formatter=trace_output_formatter,
166
- )(fn)
167
- else:
168
- traced_fn = fn
169
-
170
- # Then apply Prefect task decorator
171
- return task( # pyright: ignore[reportCallIssue,reportUnknownVariableType]
172
- traced_fn, # pyright: ignore[reportArgumentType]
173
- name=name,
174
- description=description,
175
- tags=tags,
176
- version=version,
177
- cache_policy=cache_policy,
178
- cache_key_fn=cache_key_fn,
179
- cache_expiration=cache_expiration,
180
- task_run_name=task_run_name,
181
- retries=retries or 0,
182
- retry_delay_seconds=retry_delay_seconds,
183
- retry_jitter_factor=retry_jitter_factor,
184
- persist_result=persist_result,
185
- result_storage=result_storage,
186
- result_serializer=result_serializer,
187
- result_storage_key=result_storage_key,
188
- cache_result_in_memory=cache_result_in_memory,
189
- timeout_seconds=timeout_seconds,
190
- log_prints=log_prints,
191
- refresh_cache=refresh_cache,
192
- on_completion=on_completion,
193
- on_failure=on_failure,
194
- retry_condition_fn=retry_condition_fn,
195
- viz_return_value=viz_return_value,
196
- asset_deps=asset_deps,
233
+ return cast(
234
+ _TaskLike[R_co],
235
+ task_decorator(
236
+ name=name,
237
+ description=description,
238
+ tags=tags,
239
+ version=version,
240
+ cache_policy=cache_policy,
241
+ cache_key_fn=cache_key_fn,
242
+ cache_expiration=cache_expiration,
243
+ task_run_name=task_run_name,
244
+ retries=0 if retries is None else retries,
245
+ retry_delay_seconds=retry_delay_seconds,
246
+ retry_jitter_factor=retry_jitter_factor,
247
+ persist_result=persist_result,
248
+ result_storage=result_storage,
249
+ result_serializer=result_serializer,
250
+ result_storage_key=result_storage_key,
251
+ cache_result_in_memory=cache_result_in_memory,
252
+ timeout_seconds=timeout_seconds,
253
+ log_prints=log_prints,
254
+ refresh_cache=refresh_cache,
255
+ on_completion=on_completion,
256
+ on_failure=on_failure,
257
+ retry_condition_fn=retry_condition_fn,
258
+ viz_return_value=viz_return_value,
259
+ asset_deps=asset_deps,
260
+ )(traced_fn),
197
261
  )
198
262
 
199
- if __fn:
200
- return decorator(__fn)
201
- return decorator
202
-
203
-
204
- # ============================================================================
205
- # PIPELINE FLOW DECORATOR WITH DOCUMENT PROCESSING
206
- # ============================================================================
207
-
208
- # Type aliases for document flow signatures
209
- DocumentsFlowSig = Callable[
210
- Concatenate[str, DocumentList, FlowOptions, P],
211
- Union[DocumentList, Coroutine[Any, Any, DocumentList]],
212
- ]
213
-
214
- DocumentsFlowResult = Flow[Concatenate[str, DocumentList, FlowOptions, P], DocumentList]
263
+ return _apply(__fn) if __fn else _apply
215
264
 
216
265
 
266
+ # --------------------------------------------------------------------------- #
267
+ # @pipeline_flow — async-only, traced, returns Prefect’s flow wrapper
268
+ # --------------------------------------------------------------------------- #
217
269
  @overload
218
- def pipeline_flow(
219
- __fn: DocumentsFlowSig[P],
220
- /,
221
- ) -> DocumentsFlowResult[P]: ...
222
-
223
-
270
+ def pipeline_flow(__fn: _DocumentsFlowCallable[FO_contra], /) -> _FlowLike[FO_contra]: ...
224
271
  @overload
225
272
  def pipeline_flow(
226
273
  *,
227
- # Tracing parameters
274
+ # tracing
228
275
  trace_level: TraceLevel = "always",
229
276
  trace_ignore_input: bool = False,
230
277
  trace_ignore_output: bool = False,
231
278
  trace_ignore_inputs: list[str] | None = None,
232
- trace_input_formatter: Optional[Callable[..., str]] = None,
233
- trace_output_formatter: Optional[Callable[..., str]] = None,
234
- # Prefect parameters
235
- name: Optional[str] = None,
236
- version: Optional[str] = None,
237
- flow_run_name: Optional[Union[Callable[[], str], str]] = None,
238
- retries: Optional[int] = None,
239
- retry_delay_seconds: Optional[Union[int, float]] = None,
240
- task_runner: Optional[TaskRunner[PrefectFuture[Any]]] = None,
241
- description: Optional[str] = None,
242
- timeout_seconds: Union[int, float, None] = None,
279
+ trace_input_formatter: Callable[..., str] | None = None,
280
+ trace_output_formatter: Callable[..., str] | None = None,
281
+ # prefect passthrough
282
+ name: str | None = None,
283
+ version: str | None = None,
284
+ flow_run_name: Union[Callable[[], str], str] | None = None,
285
+ retries: int | None = None,
286
+ retry_delay_seconds: int | float | None = None,
287
+ task_runner: TaskRunner[PrefectFuture[Any]] | None = None,
288
+ description: str | None = None,
289
+ timeout_seconds: int | float | None = None,
243
290
  validate_parameters: bool = True,
244
- persist_result: Optional[bool] = None,
245
- result_storage: Optional[Union[ResultStorage, str]] = None,
246
- result_serializer: Optional[Union[ResultSerializer, str]] = None,
291
+ persist_result: bool | None = None,
292
+ result_storage: ResultStorage | str | None = None,
293
+ result_serializer: ResultSerializer | str | None = None,
247
294
  cache_result_in_memory: bool = True,
248
- log_prints: Optional[bool] = None,
249
- on_completion: Optional[list["FlowStateHook[..., Any]"]] = None,
250
- on_failure: Optional[list["FlowStateHook[..., Any]"]] = None,
251
- on_cancellation: Optional[list["FlowStateHook[..., Any]"]] = None,
252
- on_crashed: Optional[list["FlowStateHook[..., Any]"]] = None,
253
- on_running: Optional[list["FlowStateHook[..., Any]"]] = None,
254
- ) -> Callable[[DocumentsFlowSig[P]], DocumentsFlowResult[P]]: ...
295
+ log_prints: bool | None = None,
296
+ on_completion: list[FlowStateHook[Any, Any]] | None = None,
297
+ on_failure: list[FlowStateHook[Any, Any]] | None = None,
298
+ on_cancellation: list[FlowStateHook[Any, Any]] | None = None,
299
+ on_crashed: list[FlowStateHook[Any, Any]] | None = None,
300
+ on_running: list[FlowStateHook[Any, Any]] | None = None,
301
+ ) -> Callable[[_DocumentsFlowCallable[FO_contra]], _FlowLike[FO_contra]]: ...
255
302
 
256
303
 
257
304
  def pipeline_flow(
258
- __fn: Optional[DocumentsFlowSig[P]] = None,
305
+ __fn: _DocumentsFlowCallable[FO_contra] | None = None,
259
306
  /,
260
307
  *,
261
- # Tracing parameters
308
+ # tracing
262
309
  trace_level: TraceLevel = "always",
263
310
  trace_ignore_input: bool = False,
264
311
  trace_ignore_output: bool = False,
265
312
  trace_ignore_inputs: list[str] | None = None,
266
- trace_input_formatter: Optional[Callable[..., str]] = None,
267
- trace_output_formatter: Optional[Callable[..., str]] = None,
268
- # Prefect parameters
269
- name: Optional[str] = None,
270
- version: Optional[str] = None,
271
- flow_run_name: Optional[Union[Callable[[], str], str]] = None,
272
- retries: Optional[int] = None,
273
- retry_delay_seconds: Optional[Union[int, float]] = None,
274
- task_runner: Optional[TaskRunner[PrefectFuture[Any]]] = None,
275
- description: Optional[str] = None,
276
- timeout_seconds: Union[int, float, None] = None,
313
+ trace_input_formatter: Callable[..., str] | None = None,
314
+ trace_output_formatter: Callable[..., str] | None = None,
315
+ # prefect passthrough
316
+ name: str | None = None,
317
+ version: str | None = None,
318
+ flow_run_name: Union[Callable[[], str], str] | None = None,
319
+ retries: int | None = None,
320
+ retry_delay_seconds: int | float | None = None,
321
+ task_runner: TaskRunner[PrefectFuture[Any]] | None = None,
322
+ description: str | None = None,
323
+ timeout_seconds: int | float | None = None,
277
324
  validate_parameters: bool = True,
278
- persist_result: Optional[bool] = None,
279
- result_storage: Optional[Union[ResultStorage, str]] = None,
280
- result_serializer: Optional[Union[ResultSerializer, str]] = None,
325
+ persist_result: bool | None = None,
326
+ result_storage: ResultStorage | str | None = None,
327
+ result_serializer: ResultSerializer | str | None = None,
281
328
  cache_result_in_memory: bool = True,
282
- log_prints: Optional[bool] = None,
283
- on_completion: Optional[list["FlowStateHook[..., Any]"]] = None,
284
- on_failure: Optional[list["FlowStateHook[..., Any]"]] = None,
285
- on_cancellation: Optional[list["FlowStateHook[..., Any]"]] = None,
286
- on_crashed: Optional[list["FlowStateHook[..., Any]"]] = None,
287
- on_running: Optional[list["FlowStateHook[..., Any]"]] = None,
288
- ) -> Union[DocumentsFlowResult[P], Callable[[DocumentsFlowSig[P]], DocumentsFlowResult[P]]]:
289
- """
290
- Pipeline flow for document processing with standardized signature.
291
-
292
- This decorator enforces a specific signature for document processing flows:
293
- - First parameter: project_name (str)
294
- - Second parameter: documents (DocumentList)
295
- - Third parameter: flow_options (FlowOptions or subclass)
296
- - Additional parameters allowed
297
- - Must return DocumentList
298
-
299
- It includes automatic tracing and all Prefect flow functionality.
300
-
301
- Args:
302
- trace_level: Control tracing ("always", "debug", "off")
303
- trace_ignore_input: Whether to ignore input in traces
304
- trace_ignore_output: Whether to ignore output in traces
305
- trace_ignore_inputs: List of input parameter names to ignore
306
- trace_input_formatter: Custom formatter for inputs
307
- trace_output_formatter: Custom formatter for outputs
308
-
309
- Plus all standard Prefect flow parameters...
329
+ log_prints: bool | None = None,
330
+ on_completion: list[FlowStateHook[Any, Any]] | None = None,
331
+ on_failure: list[FlowStateHook[Any, Any]] | None = None,
332
+ on_cancellation: list[FlowStateHook[Any, Any]] | None = None,
333
+ on_crashed: list[FlowStateHook[Any, Any]] | None = None,
334
+ on_running: list[FlowStateHook[Any, Any]] | None = None,
335
+ ) -> _FlowLike[FO_contra] | Callable[[_DocumentsFlowCallable[FO_contra]], _FlowLike[FO_contra]]:
336
+ """Decorate an **async** flow.
337
+
338
+ Required signature:
339
+ async def flow_fn(
340
+ project_name: str,
341
+ documents: DocumentList,
342
+ flow_options: FlowOptions, # or any subclass
343
+ *args,
344
+ **kwargs
345
+ ) -> DocumentList
346
+
347
+ Returns the same callable object Prefect’s ``@flow`` would return.
310
348
  """
349
+ flow_decorator: Callable[..., Any] = _prefect_flow
311
350
 
312
- def decorator(func: DocumentsFlowSig[P]) -> DocumentsFlowResult[P]:
313
- sig = inspect.signature(func)
314
- params = list(sig.parameters.values())
351
+ def _apply(fn: _DocumentsFlowCallable[FO_contra]) -> _FlowLike[FO_contra]:
352
+ fname = _callable_name(fn, "flow")
315
353
 
316
- if len(params) < 3:
354
+ if not inspect.iscoroutinefunction(fn):
355
+ raise TypeError(f"@pipeline_flow '{fname}' must be declared with 'async def'")
356
+ if len(inspect.signature(fn).parameters) < 3:
317
357
  raise TypeError(
318
- f"@pipeline_flow '{func.__name__}' must accept at least 3 arguments: "
319
- "(project_name, documents, flow_options)"
358
+ f"@pipeline_flow '{fname}' must accept "
359
+ "'project_name, documents, flow_options' as its first three parameters"
320
360
  )
321
361
 
322
- # Validate parameter types (optional but recommended)
323
- # We check names as a convention, not strict type checking at decoration time
324
- expected_names = ["project_name", "documents", "flow_options"]
325
- for i, expected in enumerate(expected_names):
326
- if i < len(params) and params[i].name != expected:
327
- print(
328
- f"Warning: Parameter {i + 1} of '{func.__name__}' is named '{params[i].name}' "
329
- f"but convention suggests '{expected}'"
362
+ async def _wrapper(
363
+ project_name: str,
364
+ documents: DocumentList,
365
+ flow_options: FO_contra,
366
+ *args: Any,
367
+ **kwargs: Any,
368
+ ) -> DocumentList:
369
+ result = await fn(project_name, documents, flow_options, *args, **kwargs)
370
+ if not isinstance(result, DocumentList): # pyright: ignore[reportUnnecessaryIsInstance]
371
+ raise TypeError(
372
+ f"Flow '{fname}' must return DocumentList, got {type(result).__name__}"
330
373
  )
374
+ return result
375
+
376
+ traced = trace(
377
+ level=trace_level,
378
+ name=name or fname,
379
+ ignore_input=trace_ignore_input,
380
+ ignore_output=trace_ignore_output,
381
+ ignore_inputs=trace_ignore_inputs,
382
+ input_formatter=trace_input_formatter,
383
+ output_formatter=trace_output_formatter,
384
+ )(_wrapper)
331
385
 
332
- # Create wrapper that ensures return type
333
- if inspect.iscoroutinefunction(func):
334
-
335
- @functools.wraps(func)
336
- async def wrapper( # pyright: ignore[reportRedeclaration]
337
- project_name: str,
338
- documents: DocumentList,
339
- flow_options: FlowOptions,
340
- *args, # pyright: ignore[reportMissingParameterType]
341
- **kwargs, # pyright: ignore[reportMissingParameterType]
342
- ) -> DocumentList:
343
- result = await func(project_name, documents, flow_options, *args, **kwargs)
344
- # Runtime type checking
345
- DL = DocumentList # Avoid recomputation
346
- if not isinstance(result, DL):
347
- raise TypeError(
348
- f"Flow '{func.__name__}' must return a DocumentList, "
349
- f"but returned {type(result).__name__}"
350
- )
351
- return result
352
- else:
353
-
354
- @functools.wraps(func)
355
- def wrapper( # pyright: ignore[reportRedeclaration]
356
- project_name: str,
357
- documents: DocumentList,
358
- flow_options: FlowOptions,
359
- *args, # pyright: ignore[reportMissingParameterType]
360
- **kwargs, # pyright: ignore[reportMissingParameterType]
361
- ) -> DocumentList:
362
- result = func(project_name, documents, flow_options, *args, **kwargs)
363
- # Runtime type checking
364
- DL = DocumentList # Avoid recomputation
365
- if not isinstance(result, DL):
366
- raise TypeError(
367
- f"Flow '{func.__name__}' must return a DocumentList, "
368
- f"but returned {type(result).__name__}"
369
- )
370
- return result
371
-
372
- # Apply tracing first if enabled
373
- if trace_level != "off":
374
- traced_wrapper = trace(
375
- level=trace_level,
376
- name=name or func.__name__,
377
- ignore_input=trace_ignore_input,
378
- ignore_output=trace_ignore_output,
379
- ignore_inputs=trace_ignore_inputs,
380
- input_formatter=trace_input_formatter,
381
- output_formatter=trace_output_formatter,
382
- )(wrapper)
383
- else:
384
- traced_wrapper = wrapper
385
-
386
- # Then apply Prefect flow decorator
387
386
  return cast(
388
- DocumentsFlowResult[P],
389
- flow( # pyright: ignore[reportCallIssue,reportUnknownVariableType]
390
- traced_wrapper, # pyright: ignore[reportArgumentType]
387
+ _FlowLike[FO_contra],
388
+ flow_decorator(
391
389
  name=name,
392
390
  version=version,
393
391
  flow_run_name=flow_run_name,
394
- retries=retries,
392
+ retries=0 if retries is None else retries,
395
393
  retry_delay_seconds=retry_delay_seconds,
396
394
  task_runner=task_runner,
397
395
  description=description,
@@ -407,12 +405,10 @@ def pipeline_flow(
407
405
  on_cancellation=on_cancellation,
408
406
  on_crashed=on_crashed,
409
407
  on_running=on_running,
410
- ),
408
+ )(traced),
411
409
  )
412
410
 
413
- if __fn:
414
- return decorator(__fn)
415
- return decorator
411
+ return _apply(__fn) if __fn else _apply
416
412
 
417
413
 
418
414
  __all__ = ["pipeline_task", "pipeline_flow"]