prefect-client 3.0.0rc1__py3-none-any.whl → 3.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prefect/blocks/redis.py +168 -0
- prefect/client/orchestration.py +17 -1
- prefect/client/schemas/objects.py +12 -8
- prefect/concurrency/asyncio.py +1 -1
- prefect/concurrency/services.py +1 -1
- prefect/deployments/base.py +7 -1
- prefect/events/schemas/events.py +2 -0
- prefect/flow_engine.py +2 -2
- prefect/flow_runs.py +2 -2
- prefect/flows.py +8 -1
- prefect/futures.py +44 -43
- prefect/input/run_input.py +4 -2
- prefect/records/cache_policies.py +179 -0
- prefect/settings.py +6 -3
- prefect/states.py +6 -4
- prefect/task_engine.py +169 -198
- prefect/task_runners.py +6 -2
- prefect/task_runs.py +203 -0
- prefect/{task_server.py → task_worker.py} +37 -27
- prefect/tasks.py +49 -22
- prefect/transactions.py +6 -2
- prefect/utilities/callables.py +74 -3
- prefect/utilities/importtools.py +5 -5
- prefect/variables.py +15 -10
- prefect/workers/base.py +11 -1
- {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc2.dist-info}/METADATA +2 -1
- {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc2.dist-info}/RECORD +30 -27
- {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc2.dist-info}/LICENSE +0 -0
- {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc2.dist-info}/WHEEL +0 -0
- {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc2.dist-info}/top_level.txt +0 -0
prefect/blocks/redis.py
ADDED
@@ -0,0 +1,168 @@
|
|
1
|
+
from contextlib import asynccontextmanager
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import AsyncGenerator, Optional, Union
|
4
|
+
|
5
|
+
try:
|
6
|
+
import redis.asyncio as redis
|
7
|
+
except ImportError:
|
8
|
+
raise ImportError(
|
9
|
+
"`redis-py` must be installed to use the `RedisStorageContainer` block. "
|
10
|
+
"You can install it with `pip install redis>=5.0.1"
|
11
|
+
)
|
12
|
+
|
13
|
+
from pydantic import Field
|
14
|
+
from pydantic.types import SecretStr
|
15
|
+
from typing_extensions import Self
|
16
|
+
|
17
|
+
from prefect.filesystems import WritableFileSystem
|
18
|
+
from prefect.utilities.asyncutils import sync_compatible
|
19
|
+
|
20
|
+
|
21
|
+
class RedisStorageContainer(WritableFileSystem):
|
22
|
+
"""
|
23
|
+
Block used to interact with Redis as a filesystem
|
24
|
+
|
25
|
+
Attributes:
|
26
|
+
host (str): The value to store.
|
27
|
+
port (int): The value to store.
|
28
|
+
db (int): The value to store.
|
29
|
+
username (str): The value to store.
|
30
|
+
password (str): The value to store.
|
31
|
+
connection_string (str): The value to store.
|
32
|
+
|
33
|
+
Example:
|
34
|
+
Create a new block from hostname, username and password:
|
35
|
+
```python
|
36
|
+
from prefect.blocks.redis import RedisStorageContainer
|
37
|
+
|
38
|
+
block = RedisStorageContainer.from_host(
|
39
|
+
host="myredishost.com", username="redis", password="SuperSecret")
|
40
|
+
block.save("BLOCK_NAME")
|
41
|
+
```
|
42
|
+
|
43
|
+
Create a new block from a connection string
|
44
|
+
```python
|
45
|
+
from prefect.blocks.redis import RedisStorageContainer
|
46
|
+
block = RedisStorageContainer.from_url(""redis://redis:SuperSecret@myredishost.com:6379")
|
47
|
+
block.save("BLOCK_NAME")
|
48
|
+
```
|
49
|
+
"""
|
50
|
+
|
51
|
+
_logo_url = "https://stprododpcmscdnendpoint.azureedge.net/assets/icons/redis.png"
|
52
|
+
|
53
|
+
host: Optional[str] = Field(default=None, description="Redis hostname")
|
54
|
+
port: int = Field(default=6379, description="Redis port")
|
55
|
+
db: int = Field(default=0, description="Redis DB index")
|
56
|
+
username: Optional[SecretStr] = Field(default=None, description="Redis username")
|
57
|
+
password: Optional[SecretStr] = Field(default=None, description="Redis password")
|
58
|
+
connection_string: Optional[SecretStr] = Field(
|
59
|
+
default=None, description="Redis connection string"
|
60
|
+
)
|
61
|
+
|
62
|
+
def block_initialization(self) -> None:
|
63
|
+
if self.connection_string:
|
64
|
+
return
|
65
|
+
if not self.host:
|
66
|
+
raise ValueError("Initialization error: 'host' is required but missing.")
|
67
|
+
if self.username and not self.password:
|
68
|
+
raise ValueError(
|
69
|
+
"Initialization error: 'username' is provided, but 'password' is missing. Both are required."
|
70
|
+
)
|
71
|
+
|
72
|
+
@sync_compatible
|
73
|
+
async def read_path(self, path: Union[Path, str]):
|
74
|
+
"""Read the redis content at `path`
|
75
|
+
|
76
|
+
Args:
|
77
|
+
path: Redis key to read from
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
Contents at key as bytes
|
81
|
+
"""
|
82
|
+
async with self._client() as client:
|
83
|
+
return await client.get(str(path))
|
84
|
+
|
85
|
+
@sync_compatible
|
86
|
+
async def write_path(self, path: Union[Path, str], content: bytes):
|
87
|
+
"""Write `content` to the redis at `path`
|
88
|
+
|
89
|
+
Args:
|
90
|
+
path: Redis key to write to
|
91
|
+
content: Binary object to write
|
92
|
+
"""
|
93
|
+
|
94
|
+
async with self._client() as client:
|
95
|
+
return await client.set(str(path), content)
|
96
|
+
|
97
|
+
@asynccontextmanager
|
98
|
+
async def _client(self) -> AsyncGenerator[redis.Redis, None]:
|
99
|
+
if self.connection_string:
|
100
|
+
client = redis.Redis.from_url(self.connection_string.get_secret_value())
|
101
|
+
else:
|
102
|
+
assert self.host
|
103
|
+
client = redis.Redis(
|
104
|
+
host=self.host,
|
105
|
+
port=self.port,
|
106
|
+
username=self.username.get_secret_value() if self.username else None,
|
107
|
+
password=self.password.get_secret_value() if self.password else None,
|
108
|
+
db=self.db,
|
109
|
+
)
|
110
|
+
|
111
|
+
try:
|
112
|
+
yield client
|
113
|
+
finally:
|
114
|
+
await client.aclose()
|
115
|
+
|
116
|
+
@classmethod
|
117
|
+
def from_host(
|
118
|
+
cls,
|
119
|
+
host: str,
|
120
|
+
port: int = 6379,
|
121
|
+
db: int = 0,
|
122
|
+
username: Union[None, str, SecretStr] = None,
|
123
|
+
password: Union[None, str, SecretStr] = None,
|
124
|
+
) -> Self:
|
125
|
+
"""Create block from hostname, username and password
|
126
|
+
|
127
|
+
Args:
|
128
|
+
host: Redis hostname
|
129
|
+
username: Redis username
|
130
|
+
password: Redis password
|
131
|
+
port: Redis port
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
`RedisStorageContainer` instance
|
135
|
+
"""
|
136
|
+
|
137
|
+
username = SecretStr(username) if isinstance(username, str) else username
|
138
|
+
password = SecretStr(password) if isinstance(password, str) else password
|
139
|
+
|
140
|
+
return cls(host=host, port=port, db=db, username=username, password=password)
|
141
|
+
|
142
|
+
@classmethod
|
143
|
+
def from_connection_string(cls, connection_string: Union[str, SecretStr]) -> Self:
|
144
|
+
"""Create block from a Redis connection string
|
145
|
+
|
146
|
+
Supports the following URL schemes:
|
147
|
+
- `redis://` creates a TCP socket connection
|
148
|
+
- `rediss://` creates a SSL wrapped TCP socket connection
|
149
|
+
- `unix://` creates a Unix Domain Socket connection
|
150
|
+
|
151
|
+
See [Redis docs](https://redis.readthedocs.io/en/stable/examples
|
152
|
+
/connection_examples.html#Connecting-to-Redis-instances-by-specifying-a-URL
|
153
|
+
-scheme.) for more info.
|
154
|
+
|
155
|
+
Args:
|
156
|
+
connection_string: Redis connection string
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
`RedisStorageContainer` instance
|
160
|
+
"""
|
161
|
+
|
162
|
+
connection_string = (
|
163
|
+
SecretStr(connection_string)
|
164
|
+
if isinstance(connection_string, str)
|
165
|
+
else connection_string
|
166
|
+
)
|
167
|
+
|
168
|
+
return cls(connection_string=connection_string)
|
prefect/client/orchestration.py
CHANGED
@@ -8,11 +8,13 @@ from typing import (
|
|
8
8
|
Dict,
|
9
9
|
Iterable,
|
10
10
|
List,
|
11
|
+
Literal,
|
11
12
|
Optional,
|
12
13
|
Set,
|
13
14
|
Tuple,
|
14
15
|
TypeVar,
|
15
16
|
Union,
|
17
|
+
overload,
|
16
18
|
)
|
17
19
|
from uuid import UUID, uuid4
|
18
20
|
|
@@ -156,9 +158,23 @@ class ServerType(AutoEnum):
|
|
156
158
|
CLOUD = AutoEnum.auto()
|
157
159
|
|
158
160
|
|
161
|
+
@overload
|
162
|
+
def get_client(
|
163
|
+
httpx_settings: Optional[Dict[str, Any]] = None, sync_client: Literal[False] = False
|
164
|
+
) -> "PrefectClient":
|
165
|
+
...
|
166
|
+
|
167
|
+
|
168
|
+
@overload
|
169
|
+
def get_client(
|
170
|
+
httpx_settings: Optional[Dict[str, Any]] = None, sync_client: Literal[True] = True
|
171
|
+
) -> "SyncPrefectClient":
|
172
|
+
...
|
173
|
+
|
174
|
+
|
159
175
|
def get_client(
|
160
176
|
httpx_settings: Optional[Dict[str, Any]] = None, sync_client: bool = False
|
161
|
-
)
|
177
|
+
):
|
162
178
|
"""
|
163
179
|
Retrieve a HTTP client for communicating with the Prefect REST API.
|
164
180
|
|
@@ -94,6 +94,14 @@ class StateType(AutoEnum):
|
|
94
94
|
CANCELLING = AutoEnum.auto()
|
95
95
|
|
96
96
|
|
97
|
+
TERMINAL_STATES = {
|
98
|
+
StateType.COMPLETED,
|
99
|
+
StateType.CANCELLED,
|
100
|
+
StateType.FAILED,
|
101
|
+
StateType.CRASHED,
|
102
|
+
}
|
103
|
+
|
104
|
+
|
97
105
|
class WorkPoolStatus(AutoEnum):
|
98
106
|
"""Enumeration of work pool statuses."""
|
99
107
|
|
@@ -280,7 +288,7 @@ class State(ObjectBaseModel, Generic[R]):
|
|
280
288
|
def default_scheduled_start_time(self) -> Self:
|
281
289
|
if self.type == StateType.SCHEDULED:
|
282
290
|
if not self.state_details.scheduled_time:
|
283
|
-
self.state_details.scheduled_time =
|
291
|
+
self.state_details.scheduled_time = DateTime.now("utc")
|
284
292
|
return self
|
285
293
|
|
286
294
|
def is_scheduled(self) -> bool:
|
@@ -308,12 +316,7 @@ class State(ObjectBaseModel, Generic[R]):
|
|
308
316
|
return self.type == StateType.CANCELLING
|
309
317
|
|
310
318
|
def is_final(self) -> bool:
|
311
|
-
return self.type in
|
312
|
-
StateType.CANCELLED,
|
313
|
-
StateType.FAILED,
|
314
|
-
StateType.COMPLETED,
|
315
|
-
StateType.CRASHED,
|
316
|
-
}
|
319
|
+
return self.type in TERMINAL_STATES
|
317
320
|
|
318
321
|
def is_paused(self) -> bool:
|
319
322
|
return self.type == StateType.PAUSED
|
@@ -550,7 +553,8 @@ class FlowRun(ObjectBaseModel):
|
|
550
553
|
examples=["State(type=StateType.COMPLETED)"],
|
551
554
|
)
|
552
555
|
job_variables: Optional[dict] = Field(
|
553
|
-
default=None,
|
556
|
+
default=None,
|
557
|
+
description="Job variables for the flow run.",
|
554
558
|
)
|
555
559
|
|
556
560
|
# These are server-side optimizations and should not be present on client models
|
prefect/concurrency/asyncio.py
CHANGED
@@ -11,7 +11,7 @@ except ImportError:
|
|
11
11
|
# pendulum < 3
|
12
12
|
from pendulum.period import Period as Interval # type: ignore
|
13
13
|
|
14
|
-
from prefect import get_client
|
14
|
+
from prefect.client.orchestration import get_client
|
15
15
|
from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse
|
16
16
|
from prefect.utilities.timeout import timeout_async
|
17
17
|
|
prefect/concurrency/services.py
CHANGED
@@ -10,9 +10,9 @@ from typing import (
|
|
10
10
|
import httpx
|
11
11
|
from starlette import status
|
12
12
|
|
13
|
-
from prefect import get_client
|
14
13
|
from prefect._internal.concurrency import logger
|
15
14
|
from prefect._internal.concurrency.services import QueueService
|
15
|
+
from prefect.client.orchestration import get_client
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
from prefect.client.orchestration import PrefectClient
|
prefect/deployments/base.py
CHANGED
@@ -398,7 +398,13 @@ async def _find_flow_functions_in_file(filename: str) -> List[Dict]:
|
|
398
398
|
return decorated_functions
|
399
399
|
|
400
400
|
for node in ast.walk(tree):
|
401
|
-
if isinstance(
|
401
|
+
if isinstance(
|
402
|
+
node,
|
403
|
+
(
|
404
|
+
ast.FunctionDef,
|
405
|
+
ast.AsyncFunctionDef,
|
406
|
+
),
|
407
|
+
):
|
402
408
|
for decorator in node.decorator_list:
|
403
409
|
# handles @flow
|
404
410
|
is_name_match = (
|
prefect/events/schemas/events.py
CHANGED
@@ -83,6 +83,8 @@ class RelatedResource(Resource):
|
|
83
83
|
class Event(PrefectBaseModel):
|
84
84
|
"""The client-side view of an event that has happened to a Resource"""
|
85
85
|
|
86
|
+
model_config = ConfigDict(extra="ignore")
|
87
|
+
|
86
88
|
occurred: DateTime = Field(
|
87
89
|
default_factory=lambda: pendulum.now("UTC"),
|
88
90
|
description="When the event happened from the sender's perspective",
|
prefect/flow_engine.py
CHANGED
@@ -25,9 +25,9 @@ import anyio._backends._asyncio
|
|
25
25
|
from sniffio import AsyncLibraryNotFoundError
|
26
26
|
from typing_extensions import ParamSpec
|
27
27
|
|
28
|
-
from prefect import Task
|
28
|
+
from prefect import Task
|
29
29
|
from prefect._internal.concurrency.api import create_call, from_sync
|
30
|
-
from prefect.client.orchestration import SyncPrefectClient
|
30
|
+
from prefect.client.orchestration import SyncPrefectClient, get_client
|
31
31
|
from prefect.client.schemas import FlowRun, TaskRun
|
32
32
|
from prefect.client.schemas.filters import FlowRunFilter
|
33
33
|
from prefect.client.schemas.sorting import FlowRunSort
|
prefect/flow_runs.py
CHANGED
@@ -76,7 +76,7 @@ async def wait_for_flow_run(
|
|
76
76
|
```python
|
77
77
|
import asyncio
|
78
78
|
|
79
|
-
from prefect import get_client
|
79
|
+
from prefect.client.orchestration import get_client
|
80
80
|
from prefect.flow_runs import wait_for_flow_run
|
81
81
|
|
82
82
|
async def main():
|
@@ -94,7 +94,7 @@ async def wait_for_flow_run(
|
|
94
94
|
```python
|
95
95
|
import asyncio
|
96
96
|
|
97
|
-
from prefect import get_client
|
97
|
+
from prefect.client.orchestration import get_client
|
98
98
|
from prefect.flow_runs import wait_for_flow_run
|
99
99
|
|
100
100
|
async def main(num_runs: int):
|
prefect/flows.py
CHANGED
@@ -1913,7 +1913,14 @@ def load_flow_argument_from_entrypoint(
|
|
1913
1913
|
(
|
1914
1914
|
node
|
1915
1915
|
for node in ast.walk(parsed_code)
|
1916
|
-
if isinstance(
|
1916
|
+
if isinstance(
|
1917
|
+
node,
|
1918
|
+
(
|
1919
|
+
ast.FunctionDef,
|
1920
|
+
ast.AsyncFunctionDef,
|
1921
|
+
),
|
1922
|
+
)
|
1923
|
+
and node.name == func_name
|
1917
1924
|
),
|
1918
1925
|
None,
|
1919
1926
|
)
|
prefect/futures.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
import abc
|
2
2
|
import concurrent.futures
|
3
3
|
import inspect
|
4
|
-
import time
|
5
4
|
import uuid
|
6
5
|
from functools import partial
|
7
6
|
from typing import Any, Generic, Optional, Set, Union, cast
|
@@ -11,13 +10,17 @@ from typing_extensions import TypeVar
|
|
11
10
|
from prefect.client.orchestration import get_client
|
12
11
|
from prefect.client.schemas.objects import TaskRun
|
13
12
|
from prefect.exceptions import ObjectNotFound
|
13
|
+
from prefect.logging.loggers import get_logger
|
14
14
|
from prefect.states import Pending, State
|
15
|
+
from prefect.task_runs import TaskRunWaiter
|
15
16
|
from prefect.utilities.annotations import quote
|
16
17
|
from prefect.utilities.asyncutils import run_coro_as_sync
|
17
18
|
from prefect.utilities.collections import StopVisiting, visit_collection
|
18
19
|
|
19
20
|
F = TypeVar("F")
|
20
21
|
|
22
|
+
logger = get_logger(__name__)
|
23
|
+
|
21
24
|
|
22
25
|
class PrefectFuture(abc.ABC):
|
23
26
|
"""
|
@@ -146,68 +149,66 @@ class PrefectDistributedFuture(PrefectFuture):
|
|
146
149
|
Represents the result of a computation happening anywhere.
|
147
150
|
|
148
151
|
This class is typically used to interact with the result of a task run
|
149
|
-
scheduled to run in a Prefect task
|
152
|
+
scheduled to run in a Prefect task worker but can be used to interact with
|
150
153
|
any task run scheduled in Prefect's API.
|
151
154
|
"""
|
152
155
|
|
153
|
-
def
|
154
|
-
self.
|
155
|
-
self._client = None
|
156
|
-
super().__init__(task_run_id=task_run_id)
|
157
|
-
|
158
|
-
@property
|
159
|
-
def client(self):
|
160
|
-
if self._client is None:
|
161
|
-
self._client = get_client(sync_client=True)
|
162
|
-
return self._client
|
156
|
+
def wait(self, timeout: Optional[float] = None) -> None:
|
157
|
+
return run_coro_as_sync(self.wait_async(timeout=timeout))
|
163
158
|
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
return self._task_run
|
169
|
-
|
170
|
-
@task_run.setter
|
171
|
-
def task_run(self, task_run):
|
172
|
-
self._task_run = task_run
|
173
|
-
|
174
|
-
def wait(
|
175
|
-
self, timeout: Optional[float] = None, polling_interval: Optional[float] = 0.2
|
176
|
-
) -> None:
|
177
|
-
start_time = time.time()
|
178
|
-
# TODO: Websocket implementation?
|
179
|
-
while True:
|
180
|
-
self.task_run = cast(
|
181
|
-
TaskRun, self.client.read_task_run(task_run_id=self.task_run_id)
|
159
|
+
async def wait_async(self, timeout: Optional[float] = None):
|
160
|
+
if self._final_state:
|
161
|
+
logger.debug(
|
162
|
+
"Final state already set for %s. Returning...", self.task_run_id
|
182
163
|
)
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
164
|
+
return
|
165
|
+
|
166
|
+
# Read task run to see if it is still running
|
167
|
+
async with get_client() as client:
|
168
|
+
task_run = await client.read_task_run(task_run_id=self._task_run_id)
|
169
|
+
if task_run.state.is_final():
|
170
|
+
logger.debug(
|
171
|
+
"Task run %s already finished. Returning...",
|
172
|
+
self.task_run_id,
|
173
|
+
)
|
174
|
+
self._final_state = task_run.state
|
187
175
|
return
|
188
|
-
|
176
|
+
|
177
|
+
# If still running, wait for a completed event from the server
|
178
|
+
logger.debug(
|
179
|
+
"Waiting for completed event for task run %s...",
|
180
|
+
self.task_run_id,
|
181
|
+
)
|
182
|
+
await TaskRunWaiter.wait_for_task_run(self._task_run_id, timeout=timeout)
|
183
|
+
task_run = await client.read_task_run(task_run_id=self._task_run_id)
|
184
|
+
if task_run.state.is_final():
|
185
|
+
self._final_state = task_run.state
|
186
|
+
return
|
189
187
|
|
190
188
|
def result(
|
191
189
|
self,
|
192
190
|
timeout: Optional[float] = None,
|
193
191
|
raise_on_failure: bool = True,
|
194
|
-
polling_interval: Optional[float] = 0.2,
|
195
192
|
) -> Any:
|
193
|
+
return run_coro_as_sync(
|
194
|
+
self.result_async(timeout=timeout, raise_on_failure=raise_on_failure)
|
195
|
+
)
|
196
|
+
|
197
|
+
async def result_async(
|
198
|
+
self,
|
199
|
+
timeout: Optional[float] = None,
|
200
|
+
raise_on_failure: bool = True,
|
201
|
+
):
|
196
202
|
if not self._final_state:
|
197
|
-
self.
|
203
|
+
await self.wait_async(timeout=timeout)
|
198
204
|
if not self._final_state:
|
199
205
|
raise TimeoutError(
|
200
206
|
f"Task run {self.task_run_id} did not complete within {timeout} seconds"
|
201
207
|
)
|
202
208
|
|
203
|
-
|
209
|
+
return await self._final_state.result(
|
204
210
|
raise_on_failure=raise_on_failure, fetch=True
|
205
211
|
)
|
206
|
-
# state.result is a `sync_compatible` function that may or may not return an awaitable
|
207
|
-
# depending on whether the parent frame is sync or not
|
208
|
-
if inspect.isawaitable(_result):
|
209
|
-
_result = run_coro_as_sync(_result)
|
210
|
-
return _result
|
211
212
|
|
212
213
|
def __eq__(self, other):
|
213
214
|
if not isinstance(other, PrefectDistributedFuture):
|
prefect/input/run_input.py
CHANGED
@@ -18,7 +18,8 @@ Sender flow:
|
|
18
18
|
```python
|
19
19
|
import random
|
20
20
|
from uuid import UUID
|
21
|
-
from prefect import flow
|
21
|
+
from prefect import flow
|
22
|
+
from prefect.logging import get_run_logger
|
22
23
|
from prefect.input import RunInput
|
23
24
|
|
24
25
|
class NumberData(RunInput):
|
@@ -43,7 +44,8 @@ Receiver flow:
|
|
43
44
|
```python
|
44
45
|
import random
|
45
46
|
from uuid import UUID
|
46
|
-
from prefect import flow
|
47
|
+
from prefect import flow
|
48
|
+
from prefect.logging import get_run_logger
|
47
49
|
from prefect.input import RunInput
|
48
50
|
|
49
51
|
class NumberData(RunInput):
|
@@ -0,0 +1,179 @@
|
|
1
|
+
import inspect
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import Any, Callable, Dict, Optional
|
4
|
+
|
5
|
+
from prefect.context import TaskRunContext
|
6
|
+
from prefect.utilities.hashing import hash_objects
|
7
|
+
|
8
|
+
|
9
|
+
@dataclass
|
10
|
+
class CachePolicy:
|
11
|
+
@classmethod
|
12
|
+
def from_cache_key_fn(
|
13
|
+
cls, cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
|
14
|
+
) -> "CacheKeyFnPolicy":
|
15
|
+
"""
|
16
|
+
Given a function generates a key policy.
|
17
|
+
"""
|
18
|
+
return CacheKeyFnPolicy(cache_key_fn=cache_key_fn)
|
19
|
+
|
20
|
+
def compute_key(
|
21
|
+
self,
|
22
|
+
task_ctx: TaskRunContext,
|
23
|
+
inputs: Dict[str, Any],
|
24
|
+
flow_parameters: Dict[str, Any],
|
25
|
+
**kwargs,
|
26
|
+
) -> Optional[str]:
|
27
|
+
raise NotImplementedError
|
28
|
+
|
29
|
+
def __sub__(self, other: str) -> "CompoundCachePolicy":
|
30
|
+
if not isinstance(other, str):
|
31
|
+
raise TypeError("Can only subtract strings from key policies.")
|
32
|
+
if isinstance(self, Inputs):
|
33
|
+
exclude = self.exclude or []
|
34
|
+
return Inputs(exclude=exclude + [other])
|
35
|
+
elif isinstance(self, CompoundCachePolicy):
|
36
|
+
new = Inputs(exclude=[other])
|
37
|
+
policies = self.policies or []
|
38
|
+
return CompoundCachePolicy(policies=policies + [new])
|
39
|
+
else:
|
40
|
+
new = Inputs(exclude=[other])
|
41
|
+
return CompoundCachePolicy(policies=[self, new])
|
42
|
+
|
43
|
+
def __add__(self, other: "CachePolicy") -> "CompoundCachePolicy":
|
44
|
+
# adding _None is a no-op
|
45
|
+
if isinstance(other, _None):
|
46
|
+
return self
|
47
|
+
elif isinstance(self, _None):
|
48
|
+
return other
|
49
|
+
|
50
|
+
if isinstance(self, CompoundCachePolicy):
|
51
|
+
policies = self.policies or []
|
52
|
+
return CompoundCachePolicy(policies=policies + [other])
|
53
|
+
elif isinstance(other, CompoundCachePolicy):
|
54
|
+
policies = other.policies or []
|
55
|
+
return CompoundCachePolicy(policies=policies + [self])
|
56
|
+
else:
|
57
|
+
return CompoundCachePolicy(policies=[self, other])
|
58
|
+
|
59
|
+
|
60
|
+
@dataclass
|
61
|
+
class CacheKeyFnPolicy(CachePolicy):
|
62
|
+
# making it optional for tests
|
63
|
+
cache_key_fn: Optional[
|
64
|
+
Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
|
65
|
+
] = None
|
66
|
+
|
67
|
+
def compute_key(
|
68
|
+
self,
|
69
|
+
task_ctx: TaskRunContext,
|
70
|
+
inputs: Dict[str, Any],
|
71
|
+
flow_parameters: Dict[str, Any],
|
72
|
+
**kwargs,
|
73
|
+
) -> Optional[str]:
|
74
|
+
if self.cache_key_fn:
|
75
|
+
return self.cache_key_fn(task_ctx, inputs)
|
76
|
+
|
77
|
+
|
78
|
+
@dataclass
|
79
|
+
class CompoundCachePolicy(CachePolicy):
|
80
|
+
policies: list = None
|
81
|
+
|
82
|
+
def compute_key(
|
83
|
+
self,
|
84
|
+
task_ctx: TaskRunContext,
|
85
|
+
inputs: Dict[str, Any],
|
86
|
+
flow_parameters: Dict[str, Any],
|
87
|
+
**kwargs,
|
88
|
+
) -> Optional[str]:
|
89
|
+
keys = []
|
90
|
+
for policy in self.policies:
|
91
|
+
keys.append(
|
92
|
+
policy.compute_key(
|
93
|
+
task_ctx=task_ctx,
|
94
|
+
inputs=inputs,
|
95
|
+
flow_parameters=flow_parameters,
|
96
|
+
**kwargs,
|
97
|
+
)
|
98
|
+
)
|
99
|
+
return hash_objects(*keys)
|
100
|
+
|
101
|
+
|
102
|
+
@dataclass
|
103
|
+
class Default(CachePolicy):
|
104
|
+
"Execution run ID only"
|
105
|
+
|
106
|
+
def compute_key(
|
107
|
+
self,
|
108
|
+
task_ctx: TaskRunContext,
|
109
|
+
inputs: Dict[str, Any],
|
110
|
+
flow_parameters: Dict[str, Any],
|
111
|
+
**kwargs,
|
112
|
+
) -> Optional[str]:
|
113
|
+
return str(task_ctx.task_run.id)
|
114
|
+
|
115
|
+
|
116
|
+
@dataclass
|
117
|
+
class _None(CachePolicy):
|
118
|
+
"ignore key policies altogether, always run - prevents persistence"
|
119
|
+
|
120
|
+
def compute_key(
|
121
|
+
self,
|
122
|
+
task_ctx: TaskRunContext,
|
123
|
+
inputs: Dict[str, Any],
|
124
|
+
flow_parameters: Dict[str, Any],
|
125
|
+
**kwargs,
|
126
|
+
) -> Optional[str]:
|
127
|
+
return None
|
128
|
+
|
129
|
+
|
130
|
+
@dataclass
|
131
|
+
class TaskDef(CachePolicy):
|
132
|
+
def compute_key(
|
133
|
+
self,
|
134
|
+
task_ctx: TaskRunContext,
|
135
|
+
inputs: Dict[str, Any],
|
136
|
+
flow_parameters: Dict[str, Any],
|
137
|
+
**kwargs,
|
138
|
+
) -> Optional[str]:
|
139
|
+
lines = inspect.getsource(task_ctx.task)
|
140
|
+
return hash_objects(lines)
|
141
|
+
|
142
|
+
|
143
|
+
@dataclass
|
144
|
+
class FlowParameters(CachePolicy):
|
145
|
+
pass
|
146
|
+
|
147
|
+
|
148
|
+
@dataclass
|
149
|
+
class Inputs(CachePolicy):
|
150
|
+
"""
|
151
|
+
Exposes flag for whether to include flow parameters as well.
|
152
|
+
|
153
|
+
And exclude/include config.
|
154
|
+
"""
|
155
|
+
|
156
|
+
exclude: list = None
|
157
|
+
|
158
|
+
def compute_key(
|
159
|
+
self,
|
160
|
+
task_ctx: TaskRunContext,
|
161
|
+
inputs: Dict[str, Any],
|
162
|
+
flow_parameters: Dict[str, Any],
|
163
|
+
**kwargs,
|
164
|
+
) -> Optional[str]:
|
165
|
+
hashed_inputs = {}
|
166
|
+
inputs = inputs or {}
|
167
|
+
exclude = self.exclude or []
|
168
|
+
|
169
|
+
for key, val in inputs.items():
|
170
|
+
if key not in exclude:
|
171
|
+
hashed_inputs[key] = val
|
172
|
+
|
173
|
+
return hash_objects(hashed_inputs)
|
174
|
+
|
175
|
+
|
176
|
+
DEFAULT = Default()
|
177
|
+
INPUTS = Inputs()
|
178
|
+
NONE = _None()
|
179
|
+
TASKDEF = TaskDef()
|