stepcraft 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stepcraft/__init__.py +58 -0
- stepcraft/async_runtime.py +50 -0
- stepcraft/branching.py +61 -0
- stepcraft/config.py +26 -0
- stepcraft/constants.py +16 -0
- stepcraft/context.py +66 -0
- stepcraft/decorators.py +160 -0
- stepcraft/exceptions.py +27 -0
- stepcraft/fan.py +153 -0
- stepcraft/graph.py +247 -0
- stepcraft/hooks.py +28 -0
- stepcraft/node.py +57 -0
- stepcraft/pipeline.py +217 -0
- stepcraft/pools.py +84 -0
- stepcraft/py.typed +0 -0
- stepcraft/result.py +24 -0
- stepcraft/runtime.py +52 -0
- stepcraft/spec.py +215 -0
- stepcraft/step.py +445 -0
- stepcraft/typecheck.py +47 -0
- stepcraft/utils.py +59 -0
- stepcraft-0.1.0.dist-info/METADATA +475 -0
- stepcraft-0.1.0.dist-info/RECORD +25 -0
- stepcraft-0.1.0.dist-info/WHEEL +4 -0
- stepcraft-0.1.0.dist-info/licenses/LICENSE +21 -0
stepcraft/__init__.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Composable function pipeline framework for Python."""
|
|
2
|
+
|
|
3
|
+
from .async_runtime import (
|
|
4
|
+
HAS_RSLOOP,
|
|
5
|
+
install_rsloop,
|
|
6
|
+
rsloop_policy,
|
|
7
|
+
run_async,
|
|
8
|
+
uninstall_rsloop,
|
|
9
|
+
)
|
|
10
|
+
from .branching import ConditionalStep, SwitchStep
|
|
11
|
+
from .constants import HAS_NUMPY, PIPE
|
|
12
|
+
from .context import get_context
|
|
13
|
+
from .decorators import circuit_breaker, node, piped, retry
|
|
14
|
+
from .exceptions import (
|
|
15
|
+
CircuitBreakerError,
|
|
16
|
+
GraphCycleError,
|
|
17
|
+
PipelineError,
|
|
18
|
+
RetryExhaustedError,
|
|
19
|
+
)
|
|
20
|
+
from .fan import FanInStep, FanOutStep, MapReduceStep
|
|
21
|
+
from .graph import Graph
|
|
22
|
+
from .hooks import StepHook
|
|
23
|
+
from .node import Node
|
|
24
|
+
from .pipeline import Pipeline, PipelineBuilder
|
|
25
|
+
from .pools import _POOLS, cleanup_pools, configure_pools
|
|
26
|
+
from .result import ExecutionResult
|
|
27
|
+
from .runtime import (
|
|
28
|
+
HAS_FREE_THREADING,
|
|
29
|
+
is_gil_enabled,
|
|
30
|
+
threads_provide_true_parallelism,
|
|
31
|
+
)
|
|
32
|
+
from .step import PipeStep
|
|
33
|
+
from .typecheck import apply_step_beartype, beartype_enabled
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
# Core
|
|
37
|
+
'PIPE', 'piped', 'retry', 'circuit_breaker',
|
|
38
|
+
'PipeStep', 'Pipeline', 'PipelineBuilder',
|
|
39
|
+
# OOP
|
|
40
|
+
'Node', 'node',
|
|
41
|
+
# Graph / DAG
|
|
42
|
+
'Graph', 'ConditionalStep', 'SwitchStep',
|
|
43
|
+
# Fan-out/in
|
|
44
|
+
'FanOutStep', 'FanInStep', 'MapReduceStep',
|
|
45
|
+
# Results
|
|
46
|
+
'ExecutionResult',
|
|
47
|
+
# Observability
|
|
48
|
+
'StepHook',
|
|
49
|
+
# Shared context
|
|
50
|
+
'get_context',
|
|
51
|
+
# Errors
|
|
52
|
+
'PipelineError', 'RetryExhaustedError', 'CircuitBreakerError', 'GraphCycleError',
|
|
53
|
+
# Utilities
|
|
54
|
+
'cleanup_pools', 'configure_pools', 'HAS_NUMPY', 'HAS_RSLOOP', 'HAS_FREE_THREADING',
|
|
55
|
+
'is_gil_enabled', 'threads_provide_true_parallelism',
|
|
56
|
+
'run_async', 'install_rsloop', 'uninstall_rsloop', 'rsloop_policy',
|
|
57
|
+
'apply_step_beartype', 'beartype_enabled',
|
|
58
|
+
]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from typing import Any, Coroutine, Generator, TypeVar
|
|
6
|
+
|
|
7
|
+
T = TypeVar("T")
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import rsloop
|
|
11
|
+
|
|
12
|
+
HAS_RSLOOP = True
|
|
13
|
+
except ImportError:
|
|
14
|
+
rsloop = None # type: ignore[assignment]
|
|
15
|
+
HAS_RSLOOP = False
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def run_async(coro: Coroutine[Any, Any, T]) -> T:
|
|
19
|
+
"""Run a coroutine with rsloop when installed, otherwise stdlib asyncio."""
|
|
20
|
+
if HAS_RSLOOP:
|
|
21
|
+
return rsloop.run(coro)
|
|
22
|
+
return asyncio.run(coro)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def install_rsloop() -> None:
|
|
26
|
+
"""Install rsloop as the default asyncio event loop policy."""
|
|
27
|
+
if not HAS_RSLOOP:
|
|
28
|
+
raise ImportError(
|
|
29
|
+
"rsloop is not installed. Install with: pip install stepcraft[rsloop]"
|
|
30
|
+
)
|
|
31
|
+
rsloop.install()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def uninstall_rsloop() -> None:
|
|
35
|
+
"""Restore the previous asyncio event loop policy."""
|
|
36
|
+
if not HAS_RSLOOP:
|
|
37
|
+
raise ImportError(
|
|
38
|
+
"rsloop is not installed. Install with: pip install stepcraft[rsloop]"
|
|
39
|
+
)
|
|
40
|
+
rsloop.uninstall()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@contextmanager
|
|
44
|
+
def rsloop_policy() -> Generator[None, None, None]:
|
|
45
|
+
"""Context manager that temporarily installs rsloop as the event loop policy."""
|
|
46
|
+
install_rsloop()
|
|
47
|
+
try:
|
|
48
|
+
yield
|
|
49
|
+
finally:
|
|
50
|
+
uninstall_rsloop()
|
stepcraft/branching.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Callable, Dict, Generic
|
|
5
|
+
|
|
6
|
+
from .constants import R, T
|
|
7
|
+
from .utils import _async_run_branch_value, _get_func_name, _run_branch_value
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class ConditionalStep(Generic[T, R]):
|
|
12
|
+
"""Route data through different branches based on a condition."""
|
|
13
|
+
|
|
14
|
+
condition: Callable[[T], bool]
|
|
15
|
+
if_true: Any
|
|
16
|
+
if_false: Any = None
|
|
17
|
+
|
|
18
|
+
def run(self, value: T) -> R:
|
|
19
|
+
if self.condition(value):
|
|
20
|
+
return _run_branch_value(self.if_true, value)
|
|
21
|
+
return _run_branch_value(self.if_false, value)
|
|
22
|
+
|
|
23
|
+
async def async_run(self, value: T) -> R:
|
|
24
|
+
if self.condition(value):
|
|
25
|
+
return await _async_run_branch_value(self.if_true, value)
|
|
26
|
+
return await _async_run_branch_value(self.if_false, value)
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def _func_name(self) -> str:
|
|
30
|
+
return f"ConditionalStep({_get_func_name(self.condition)})"
|
|
31
|
+
|
|
32
|
+
def __or__(self, other):
|
|
33
|
+
from .pipeline import Pipeline
|
|
34
|
+
return Pipeline([self]) | other
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class SwitchStep(Generic[T, R]):
|
|
39
|
+
"""Route data through one of many branches based on a key function."""
|
|
40
|
+
|
|
41
|
+
key: Callable[[T], str]
|
|
42
|
+
branches: Dict[str, Any]
|
|
43
|
+
default: Any = None
|
|
44
|
+
|
|
45
|
+
def _get_branch(self, value: T) -> Any:
|
|
46
|
+
k = self.key(value)
|
|
47
|
+
return self.branches.get(k, self.default)
|
|
48
|
+
|
|
49
|
+
def run(self, value: T) -> R:
|
|
50
|
+
return _run_branch_value(self._get_branch(value), value)
|
|
51
|
+
|
|
52
|
+
async def async_run(self, value: T) -> R:
|
|
53
|
+
return await _async_run_branch_value(self._get_branch(value), value)
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def _func_name(self) -> str:
|
|
57
|
+
return f"SwitchStep({_get_func_name(self.key)})"
|
|
58
|
+
|
|
59
|
+
def __or__(self, other):
|
|
60
|
+
from .pipeline import Pipeline
|
|
61
|
+
return Pipeline([self]) | other
|
stepcraft/config.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Tuple
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass(frozen=True)
|
|
9
|
+
class RetryConfig:
|
|
10
|
+
attempts: int = 3
|
|
11
|
+
delay: float = 1.0
|
|
12
|
+
backoff: float = 2.0
|
|
13
|
+
errors: Tuple[type, ...] = (Exception,)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class CircuitBreakerConfig:
|
|
18
|
+
threshold: int = 5
|
|
19
|
+
timeout: float = 60.0
|
|
20
|
+
half_open_max_calls: int = 1
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CircuitState(Enum):
|
|
24
|
+
CLOSED = 0
|
|
25
|
+
OPEN = 1
|
|
26
|
+
HALF_OPEN = 2
|
stepcraft/constants.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TypeVar
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
import numpy as np
|
|
8
|
+
HAS_NUMPY = True
|
|
9
|
+
except ImportError:
|
|
10
|
+
np = None # type: ignore[assignment]
|
|
11
|
+
HAS_NUMPY = False
|
|
12
|
+
|
|
13
|
+
PIPE: object = object()
|
|
14
|
+
T = TypeVar("T")
|
|
15
|
+
R = TypeVar("R")
|
|
16
|
+
logger = logging.getLogger(__name__)
|
stepcraft/context.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextvars
|
|
4
|
+
import functools
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from typing import Any, Callable, Dict, Iterator, Optional
|
|
7
|
+
|
|
8
|
+
# Shared, read-only-ish context for steps to access during a pipeline run.
|
|
9
|
+
_PIPELINE_CONTEXT: contextvars.ContextVar[Optional[Dict[str, Any]]] = (
|
|
10
|
+
contextvars.ContextVar("stepcraft_context", default=None)
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_context() -> Dict[str, Any]:
|
|
15
|
+
"""Return the active pipeline context, or an empty dict outside a run."""
|
|
16
|
+
ctx = _PIPELINE_CONTEXT.get()
|
|
17
|
+
return ctx if ctx is not None else {}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _set_context(ctx: Optional[Dict[str, Any]]) -> contextvars.Token:
|
|
21
|
+
return _PIPELINE_CONTEXT.set(ctx)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _reset_context(token: contextvars.Token) -> None:
|
|
25
|
+
_PIPELINE_CONTEXT.reset(token)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@contextmanager
|
|
29
|
+
def activate_context(ctx: Optional[Dict[str, Any]]) -> Iterator[None]:
|
|
30
|
+
"""Make *ctx* visible to steps via ``get_context()`` for the duration."""
|
|
31
|
+
if ctx is None:
|
|
32
|
+
yield
|
|
33
|
+
return
|
|
34
|
+
token = _set_context(ctx)
|
|
35
|
+
try:
|
|
36
|
+
yield
|
|
37
|
+
finally:
|
|
38
|
+
_reset_context(token)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _run_with_captured_context(
|
|
42
|
+
ctx: Optional[Dict[str, Any]],
|
|
43
|
+
fn: Callable[..., Any],
|
|
44
|
+
/,
|
|
45
|
+
*args: Any,
|
|
46
|
+
**kwargs: Any,
|
|
47
|
+
) -> Any:
|
|
48
|
+
"""Run *fn* after setting pipeline context (for process-pool workers)."""
|
|
49
|
+
if ctx is None:
|
|
50
|
+
return fn(*args, **kwargs)
|
|
51
|
+
token = _set_context(ctx)
|
|
52
|
+
try:
|
|
53
|
+
return fn(*args, **kwargs)
|
|
54
|
+
finally:
|
|
55
|
+
_reset_context(token)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def wrap_worker(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
59
|
+
"""Wrap *fn* so pool workers inherit the caller's pipeline context."""
|
|
60
|
+
ctx = _PIPELINE_CONTEXT.get()
|
|
61
|
+
return functools.partial(_run_with_captured_context, ctx, fn)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# Backward-compatible aliases for thread vs process dispatch sites.
|
|
65
|
+
wrap_thread_worker = wrap_worker
|
|
66
|
+
wrap_process_worker = wrap_worker
|
stepcraft/decorators.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Callable, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
from .config import CircuitBreakerConfig, RetryConfig
|
|
6
|
+
from .constants import HAS_NUMPY, logger, np
|
|
7
|
+
from .node import Node
|
|
8
|
+
from .step import PipeStep
|
|
9
|
+
from .typecheck import apply_step_beartype, resolve_output_schema
|
|
10
|
+
from .utils import _get_func_name
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def piped(
|
|
14
|
+
func: Optional[Callable] = None,
|
|
15
|
+
*,
|
|
16
|
+
batch_size: int = 1,
|
|
17
|
+
parallel: Optional[str] = None,
|
|
18
|
+
auto_map: bool = True,
|
|
19
|
+
map: Optional[bool] = None,
|
|
20
|
+
timeout: Optional[float] = None,
|
|
21
|
+
cancel_on_timeout: bool = False,
|
|
22
|
+
schema: Optional[type] = None,
|
|
23
|
+
jit: bool = False,
|
|
24
|
+
vectorize: bool = False,
|
|
25
|
+
) -> Union[PipeStep, Callable[[Callable], PipeStep]]:
|
|
26
|
+
"""Create a PipeStep from a function.
|
|
27
|
+
|
|
28
|
+
``map`` is a README-friendly alias for ``auto_map``; when given it wins.
|
|
29
|
+
"""
|
|
30
|
+
if map is not None:
|
|
31
|
+
auto_map = map
|
|
32
|
+
|
|
33
|
+
def decorator(f: Callable) -> PipeStep:
|
|
34
|
+
if parallel and batch_size > 1:
|
|
35
|
+
logger.warning(
|
|
36
|
+
"%s: parallel=%r takes precedence over batch_size=%d; "
|
|
37
|
+
"batching is ignored when both are set",
|
|
38
|
+
_get_func_name(f), parallel, batch_size,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
checked = apply_step_beartype(f)
|
|
42
|
+
optimized = checked
|
|
43
|
+
output_schema = resolve_output_schema(f, schema)
|
|
44
|
+
|
|
45
|
+
if jit:
|
|
46
|
+
try:
|
|
47
|
+
import numba
|
|
48
|
+
optimized = numba.njit(fastmath=True, cache=True, nogil=True)(optimized)
|
|
49
|
+
except ImportError:
|
|
50
|
+
logger.warning(
|
|
51
|
+
"jit=True on %s but numba is not installed; running without JIT. "
|
|
52
|
+
"Install with: pip install stepcraft[numba]",
|
|
53
|
+
_get_func_name(f),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if vectorize:
|
|
57
|
+
vectorized = False
|
|
58
|
+
try:
|
|
59
|
+
import numba
|
|
60
|
+
optimized = numba.vectorize(
|
|
61
|
+
['float64(float64)', 'float32(float32)', 'int64(int64)'],
|
|
62
|
+
nopython=True, cache=True,
|
|
63
|
+
)(optimized)
|
|
64
|
+
vectorized = True
|
|
65
|
+
except ImportError:
|
|
66
|
+
if HAS_NUMPY:
|
|
67
|
+
optimized = np.vectorize(optimized, cache=True)
|
|
68
|
+
vectorized = True
|
|
69
|
+
if not vectorized:
|
|
70
|
+
logger.warning(
|
|
71
|
+
"vectorize=True on %s but neither numba nor numpy is available; "
|
|
72
|
+
"running without vectorization. "
|
|
73
|
+
"Install with: pip install stepcraft[numba] or stepcraft[numpy]",
|
|
74
|
+
_get_func_name(f),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return PipeStep(
|
|
78
|
+
func=optimized,
|
|
79
|
+
batch_size=batch_size,
|
|
80
|
+
parallel=parallel,
|
|
81
|
+
auto_map=auto_map,
|
|
82
|
+
timeout=timeout,
|
|
83
|
+
cancel_on_timeout=cancel_on_timeout,
|
|
84
|
+
schema=output_schema,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return decorator(func) if func else decorator
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def retry(
|
|
91
|
+
*,
|
|
92
|
+
max_attempts: int = 3,
|
|
93
|
+
delay: float = 1.0,
|
|
94
|
+
backoff: float = 2.0,
|
|
95
|
+
errors: Tuple[type, ...] = (Exception,),
|
|
96
|
+
) -> Callable[[PipeStep], PipeStep]:
|
|
97
|
+
"""Add retry capability to a pipeline step."""
|
|
98
|
+
|
|
99
|
+
def decorator(step: PipeStep) -> PipeStep:
|
|
100
|
+
return step.copy(
|
|
101
|
+
retry_config=RetryConfig(
|
|
102
|
+
attempts=max_attempts, delay=delay, backoff=backoff, errors=errors,
|
|
103
|
+
),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return decorator
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def circuit_breaker(
|
|
110
|
+
*,
|
|
111
|
+
threshold: int = 5,
|
|
112
|
+
timeout: float = 60.0,
|
|
113
|
+
failure_threshold: int = None,
|
|
114
|
+
recovery_timeout: float = None,
|
|
115
|
+
) -> Callable[[PipeStep], PipeStep]:
|
|
116
|
+
"""Add circuit breaker capability to a pipeline step."""
|
|
117
|
+
if failure_threshold is not None:
|
|
118
|
+
threshold = failure_threshold
|
|
119
|
+
if recovery_timeout is not None:
|
|
120
|
+
timeout = recovery_timeout
|
|
121
|
+
|
|
122
|
+
def decorator(step: PipeStep) -> PipeStep:
|
|
123
|
+
return step.copy(
|
|
124
|
+
circuit_config=CircuitBreakerConfig(threshold=threshold, timeout=timeout),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
return decorator
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def node(
|
|
131
|
+
func: Optional[Callable] = None,
|
|
132
|
+
*,
|
|
133
|
+
setup: Optional[Callable] = None,
|
|
134
|
+
teardown: Optional[Callable] = None,
|
|
135
|
+
) -> Union[Node, Callable[[Callable], Node]]:
|
|
136
|
+
"""Decorator to create a Node from a function."""
|
|
137
|
+
|
|
138
|
+
def decorator(f: Callable) -> Node:
|
|
139
|
+
display_name = _get_func_name(f)
|
|
140
|
+
checked = apply_step_beartype(f)
|
|
141
|
+
|
|
142
|
+
class FuncNode(Node):
|
|
143
|
+
@property
|
|
144
|
+
def _func_name(self) -> str:
|
|
145
|
+
return display_name
|
|
146
|
+
|
|
147
|
+
def __repr__(self) -> str:
|
|
148
|
+
return f"{display_name}()"
|
|
149
|
+
|
|
150
|
+
def process(self, *args, **kwargs):
|
|
151
|
+
return checked(*args, **kwargs)
|
|
152
|
+
|
|
153
|
+
if setup:
|
|
154
|
+
FuncNode.setup = setup
|
|
155
|
+
if teardown:
|
|
156
|
+
FuncNode.teardown = teardown
|
|
157
|
+
|
|
158
|
+
return FuncNode()
|
|
159
|
+
|
|
160
|
+
return decorator(func) if func else decorator
|
stepcraft/exceptions.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class PipelineError(Exception):
|
|
5
|
+
"""Base pipeline exception with context."""
|
|
6
|
+
|
|
7
|
+
def __init__(self, func_name: str, original_error: Exception):
|
|
8
|
+
self.func_name = func_name
|
|
9
|
+
self.original_error = original_error
|
|
10
|
+
super().__init__(f"Pipeline step '{func_name}' failed: {original_error}")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RetryExhaustedError(PipelineError):
|
|
14
|
+
"""Raised when retry attempts are exhausted."""
|
|
15
|
+
|
|
16
|
+
def __str__(self):
|
|
17
|
+
return f"RetryExhausted: {super().__str__()}"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CircuitBreakerError(PipelineError):
|
|
21
|
+
"""Raised when circuit breaker is open."""
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GraphCycleError(Exception):
|
|
26
|
+
"""Raised when a cycle is detected in the DAG."""
|
|
27
|
+
pass
|
stepcraft/fan.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import inspect
|
|
5
|
+
import itertools
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Callable, Generic, Iterable, List, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
from .constants import R, T
|
|
10
|
+
from .context import wrap_worker
|
|
11
|
+
from .pools import _get_pool
|
|
12
|
+
from .runtime import resolve_parallel_kind
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _run_branch_on(branch: Any, value: Any) -> Any:
|
|
16
|
+
"""Module-level (picklable) helper for ProcessPoolExecutor fan-out."""
|
|
17
|
+
return branch.run(value)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class FanOutStep(Generic[T, R]):
|
|
22
|
+
"""Execute multiple branches in parallel."""
|
|
23
|
+
|
|
24
|
+
branches: Tuple[Any, ...]
|
|
25
|
+
parallel: Optional[str] = None
|
|
26
|
+
|
|
27
|
+
def __post_init__(self):
|
|
28
|
+
parallel = resolve_parallel_kind(self.parallel)
|
|
29
|
+
if parallel != self.parallel:
|
|
30
|
+
object.__setattr__(self, 'parallel', parallel)
|
|
31
|
+
|
|
32
|
+
def run(self, value: T) -> Tuple[R, ...]:
|
|
33
|
+
if self.parallel:
|
|
34
|
+
pool = _get_pool(self.parallel)
|
|
35
|
+
runner = wrap_worker(_run_branch_on)
|
|
36
|
+
return tuple(
|
|
37
|
+
pool.map(runner, self.branches, itertools.repeat(value))
|
|
38
|
+
)
|
|
39
|
+
return tuple(branch.run(value) for branch in self.branches)
|
|
40
|
+
|
|
41
|
+
async def async_run(self, value: T) -> Tuple[R, ...]:
|
|
42
|
+
if self.parallel:
|
|
43
|
+
pool = _get_pool(self.parallel)
|
|
44
|
+
loop = asyncio.get_running_loop()
|
|
45
|
+
runner = wrap_worker(_run_branch_on)
|
|
46
|
+
tasks = [
|
|
47
|
+
loop.run_in_executor(pool, runner, branch, value)
|
|
48
|
+
for branch in self.branches
|
|
49
|
+
]
|
|
50
|
+
return tuple(await asyncio.gather(*tasks))
|
|
51
|
+
|
|
52
|
+
tasks = []
|
|
53
|
+
for branch in self.branches:
|
|
54
|
+
if hasattr(branch, 'async_run'):
|
|
55
|
+
tasks.append(branch.async_run(value))
|
|
56
|
+
else:
|
|
57
|
+
loop = asyncio.get_running_loop()
|
|
58
|
+
call = wrap_worker(lambda b=branch, v=value: b.run(v))
|
|
59
|
+
tasks.append(loop.run_in_executor(None, call))
|
|
60
|
+
return tuple(await asyncio.gather(*tasks))
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def _func_name(self) -> str:
|
|
64
|
+
return "FanOutStep"
|
|
65
|
+
|
|
66
|
+
def __or__(self, other):
|
|
67
|
+
from .pipeline import Pipeline
|
|
68
|
+
return Pipeline([self]) | other
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class FanInStep(Generic[T, R]):
|
|
73
|
+
"""Combine multiple inputs into single output."""
|
|
74
|
+
|
|
75
|
+
combiner: Callable[..., R]
|
|
76
|
+
|
|
77
|
+
def run(self, values: Tuple[T, ...]) -> R:
|
|
78
|
+
return self.combiner(*values)
|
|
79
|
+
|
|
80
|
+
async def async_run(self, values: Tuple[T, ...]) -> R:
|
|
81
|
+
result = self.combiner(*values)
|
|
82
|
+
if asyncio.iscoroutine(result):
|
|
83
|
+
return await result
|
|
84
|
+
return result
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def _func_name(self) -> str:
|
|
88
|
+
return "FanInStep"
|
|
89
|
+
|
|
90
|
+
def __or__(self, other):
|
|
91
|
+
from .pipeline import Pipeline
|
|
92
|
+
return Pipeline([self]) | other
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class MapReduceStep(Generic[T, R]):
|
|
97
|
+
"""Map-reduce with optional batching."""
|
|
98
|
+
|
|
99
|
+
mapper: Callable[[T], Any]
|
|
100
|
+
reducer: Callable[[Iterable[Any]], R]
|
|
101
|
+
batch_size: int = 1
|
|
102
|
+
_mapper_is_async: bool = field(default=False, init=False)
|
|
103
|
+
|
|
104
|
+
def __post_init__(self):
|
|
105
|
+
object.__setattr__(
|
|
106
|
+
self, '_mapper_is_async', inspect.iscoroutinefunction(self.mapper),
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
async def _map_item(self, item: T) -> Any:
|
|
110
|
+
if self._mapper_is_async:
|
|
111
|
+
return await self.mapper(item)
|
|
112
|
+
result = self.mapper(item)
|
|
113
|
+
if asyncio.iscoroutine(result):
|
|
114
|
+
return await result
|
|
115
|
+
return result
|
|
116
|
+
|
|
117
|
+
async def _map_batch(self, batch: List[T]) -> List[Any]:
|
|
118
|
+
return list(await asyncio.gather(*(self._map_item(x) for x in batch)))
|
|
119
|
+
|
|
120
|
+
def run(self, items: Iterable[T]) -> R:
|
|
121
|
+
results: List[Any] = []
|
|
122
|
+
batch: List[T] = []
|
|
123
|
+
for item in items:
|
|
124
|
+
batch.append(item)
|
|
125
|
+
if len(batch) == self.batch_size:
|
|
126
|
+
results.extend(map(self.mapper, batch))
|
|
127
|
+
batch.clear()
|
|
128
|
+
if batch:
|
|
129
|
+
results.extend(map(self.mapper, batch))
|
|
130
|
+
return self.reducer(results)
|
|
131
|
+
|
|
132
|
+
async def async_run(self, items: Iterable[T]) -> R:
|
|
133
|
+
results: List[Any] = []
|
|
134
|
+
batch: List[T] = []
|
|
135
|
+
for item in items:
|
|
136
|
+
batch.append(item)
|
|
137
|
+
if len(batch) == self.batch_size:
|
|
138
|
+
results.extend(await self._map_batch(batch))
|
|
139
|
+
batch.clear()
|
|
140
|
+
if batch:
|
|
141
|
+
results.extend(await self._map_batch(batch))
|
|
142
|
+
out = self.reducer(results)
|
|
143
|
+
if asyncio.iscoroutine(out):
|
|
144
|
+
return await out
|
|
145
|
+
return out
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def _func_name(self) -> str:
|
|
149
|
+
return "MapReduceStep"
|
|
150
|
+
|
|
151
|
+
def __or__(self, other):
|
|
152
|
+
from .pipeline import Pipeline
|
|
153
|
+
return Pipeline([self]) | other
|