prefect-client 3.0.0rc13__py3-none-any.whl → 3.0.0rc15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prefect/_internal/compatibility/deprecated.py +0 -53
- prefect/blocks/core.py +132 -4
- prefect/blocks/notifications.py +26 -3
- prefect/client/base.py +30 -24
- prefect/client/orchestration.py +121 -47
- prefect/client/utilities.py +4 -4
- prefect/concurrency/asyncio.py +48 -7
- prefect/concurrency/context.py +24 -0
- prefect/concurrency/services.py +24 -8
- prefect/concurrency/sync.py +30 -3
- prefect/context.py +85 -24
- prefect/events/clients.py +93 -60
- prefect/events/utilities.py +0 -2
- prefect/events/worker.py +9 -2
- prefect/flow_engine.py +6 -3
- prefect/flows.py +176 -12
- prefect/futures.py +84 -7
- prefect/profiles.toml +16 -2
- prefect/runner/runner.py +6 -1
- prefect/runner/storage.py +4 -0
- prefect/settings.py +108 -14
- prefect/task_engine.py +901 -285
- prefect/task_runs.py +24 -1
- prefect/task_worker.py +7 -1
- prefect/tasks.py +9 -5
- prefect/utilities/asyncutils.py +0 -6
- prefect/utilities/callables.py +5 -3
- prefect/utilities/engine.py +3 -0
- prefect/utilities/importtools.py +138 -58
- prefect/utilities/schema_tools/validation.py +30 -0
- prefect/utilities/services.py +32 -0
- {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/METADATA +39 -39
- {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/RECORD +36 -35
- {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/WHEEL +1 -1
- {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/LICENSE +0 -0
- {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/top_level.txt +0 -0
prefect/events/clients.py
CHANGED
@@ -15,10 +15,10 @@ from typing import (
|
|
15
15
|
)
|
16
16
|
from uuid import UUID
|
17
17
|
|
18
|
-
import httpx
|
19
18
|
import orjson
|
20
19
|
import pendulum
|
21
20
|
from cachetools import TTLCache
|
21
|
+
from prometheus_client import Counter
|
22
22
|
from typing_extensions import Self
|
23
23
|
from websockets import Subprotocol
|
24
24
|
from websockets.client import WebSocketClientProtocol, connect
|
@@ -28,20 +28,48 @@ from websockets.exceptions import (
|
|
28
28
|
ConnectionClosedOK,
|
29
29
|
)
|
30
30
|
|
31
|
-
from prefect.client.base import PrefectHttpxAsyncClient
|
32
31
|
from prefect.events import Event
|
33
32
|
from prefect.logging import get_logger
|
34
|
-
from prefect.settings import
|
33
|
+
from prefect.settings import (
|
34
|
+
PREFECT_API_KEY,
|
35
|
+
PREFECT_API_URL,
|
36
|
+
PREFECT_CLOUD_API_URL,
|
37
|
+
PREFECT_SERVER_ALLOW_EPHEMERAL_MODE,
|
38
|
+
)
|
35
39
|
|
36
40
|
if TYPE_CHECKING:
|
37
41
|
from prefect.events.filters import EventFilter
|
38
42
|
|
43
|
+
EVENTS_EMITTED = Counter(
|
44
|
+
"prefect_events_emitted",
|
45
|
+
"The number of events emitted by Prefect event clients",
|
46
|
+
labelnames=["client"],
|
47
|
+
)
|
48
|
+
EVENTS_OBSERVED = Counter(
|
49
|
+
"prefect_events_observed",
|
50
|
+
"The number of events observed by Prefect event subscribers",
|
51
|
+
labelnames=["client"],
|
52
|
+
)
|
53
|
+
EVENT_WEBSOCKET_CONNECTIONS = Counter(
|
54
|
+
"prefect_event_websocket_connections",
|
55
|
+
(
|
56
|
+
"The number of times Prefect event clients have connected to an event stream, "
|
57
|
+
"broken down by direction (in/out) and connection (initial/reconnect)"
|
58
|
+
),
|
59
|
+
labelnames=["client", "direction", "connection"],
|
60
|
+
)
|
61
|
+
EVENT_WEBSOCKET_CHECKPOINTS = Counter(
|
62
|
+
"prefect_event_websocket_checkpoints",
|
63
|
+
"The number of checkpoints performed by Prefect event clients",
|
64
|
+
labelnames=["client"],
|
65
|
+
)
|
66
|
+
|
39
67
|
logger = get_logger(__name__)
|
40
68
|
|
41
69
|
|
42
70
|
def get_events_client(
|
43
71
|
reconnection_attempts: int = 10,
|
44
|
-
checkpoint_every: int =
|
72
|
+
checkpoint_every: int = 700,
|
45
73
|
) -> "EventsClient":
|
46
74
|
api_url = PREFECT_API_URL.value()
|
47
75
|
if isinstance(api_url, str) and api_url.startswith(PREFECT_CLOUD_API_URL.value()):
|
@@ -49,13 +77,25 @@ def get_events_client(
|
|
49
77
|
reconnection_attempts=reconnection_attempts,
|
50
78
|
checkpoint_every=checkpoint_every,
|
51
79
|
)
|
52
|
-
elif
|
80
|
+
elif api_url:
|
81
|
+
return PrefectEventsClient(
|
82
|
+
reconnection_attempts=reconnection_attempts,
|
83
|
+
checkpoint_every=checkpoint_every,
|
84
|
+
)
|
85
|
+
elif PREFECT_SERVER_ALLOW_EPHEMERAL_MODE:
|
86
|
+
from prefect.server.api.server import SubprocessASGIServer
|
87
|
+
|
88
|
+
server = SubprocessASGIServer()
|
89
|
+
server.start()
|
53
90
|
return PrefectEventsClient(
|
91
|
+
api_url=server.api_url,
|
54
92
|
reconnection_attempts=reconnection_attempts,
|
55
93
|
checkpoint_every=checkpoint_every,
|
56
94
|
)
|
57
95
|
else:
|
58
|
-
|
96
|
+
raise ValueError(
|
97
|
+
"No Prefect API URL provided. Please set PREFECT_API_URL to the address of a running Prefect server."
|
98
|
+
)
|
59
99
|
|
60
100
|
|
61
101
|
def get_events_subscriber(
|
@@ -63,25 +103,38 @@ def get_events_subscriber(
|
|
63
103
|
reconnection_attempts: int = 10,
|
64
104
|
) -> "PrefectEventSubscriber":
|
65
105
|
api_url = PREFECT_API_URL.value()
|
66
|
-
if not api_url:
|
67
|
-
raise ValueError(
|
68
|
-
"A Prefect server or Prefect Cloud is required to start an event "
|
69
|
-
"subscriber. Please check the PREFECT_API_URL setting in your profile."
|
70
|
-
)
|
71
106
|
|
72
107
|
if isinstance(api_url, str) and api_url.startswith(PREFECT_CLOUD_API_URL.value()):
|
73
108
|
return PrefectCloudEventSubscriber(
|
74
109
|
filter=filter, reconnection_attempts=reconnection_attempts
|
75
110
|
)
|
76
|
-
|
111
|
+
elif api_url:
|
77
112
|
return PrefectEventSubscriber(
|
78
113
|
filter=filter, reconnection_attempts=reconnection_attempts
|
79
114
|
)
|
115
|
+
elif PREFECT_SERVER_ALLOW_EPHEMERAL_MODE:
|
116
|
+
from prefect.server.api.server import SubprocessASGIServer
|
117
|
+
|
118
|
+
server = SubprocessASGIServer()
|
119
|
+
server.start()
|
120
|
+
return PrefectEventSubscriber(
|
121
|
+
api_url=server.api_url,
|
122
|
+
filter=filter,
|
123
|
+
reconnection_attempts=reconnection_attempts,
|
124
|
+
)
|
125
|
+
else:
|
126
|
+
raise ValueError(
|
127
|
+
"No Prefect API URL provided. Please set PREFECT_API_URL to the address of a running Prefect server."
|
128
|
+
)
|
80
129
|
|
81
130
|
|
82
131
|
class EventsClient(abc.ABC):
|
83
132
|
"""The abstract interface for all Prefect Events clients"""
|
84
133
|
|
134
|
+
@property
|
135
|
+
def client_name(self) -> str:
|
136
|
+
return self.__class__.__name__
|
137
|
+
|
85
138
|
async def emit(self, event: Event) -> None:
|
86
139
|
"""Emit a single event"""
|
87
140
|
if not hasattr(self, "_in_context"):
|
@@ -89,7 +142,11 @@ class EventsClient(abc.ABC):
|
|
89
142
|
"Events may only be emitted while this client is being used as a "
|
90
143
|
"context manager"
|
91
144
|
)
|
92
|
-
|
145
|
+
|
146
|
+
try:
|
147
|
+
return await self._emit(event)
|
148
|
+
finally:
|
149
|
+
EVENTS_EMITTED.labels(self.client_name).inc()
|
93
150
|
|
94
151
|
@abc.abstractmethod
|
95
152
|
async def _emit(self, event: Event) -> None: # pragma: no cover
|
@@ -168,47 +225,6 @@ def _get_api_url_and_key(
|
|
168
225
|
return api_url, api_key
|
169
226
|
|
170
227
|
|
171
|
-
class PrefectEphemeralEventsClient(EventsClient):
|
172
|
-
"""A Prefect Events client that sends events to an ephemeral Prefect server"""
|
173
|
-
|
174
|
-
def __init__(self):
|
175
|
-
if PREFECT_API_KEY.value():
|
176
|
-
raise ValueError(
|
177
|
-
"PrefectEphemeralEventsClient cannot be used when PREFECT_API_KEY is set."
|
178
|
-
" Please use PrefectEventsClient or PrefectCloudEventsClient instead."
|
179
|
-
)
|
180
|
-
from prefect.server.api.server import create_app
|
181
|
-
|
182
|
-
app = create_app(ephemeral=True)
|
183
|
-
|
184
|
-
self._http_client = PrefectHttpxAsyncClient(
|
185
|
-
transport=httpx.ASGITransport(app=app, raise_app_exceptions=False),
|
186
|
-
base_url="http://ephemeral-prefect/api",
|
187
|
-
enable_csrf_support=False,
|
188
|
-
)
|
189
|
-
|
190
|
-
async def __aenter__(self) -> Self:
|
191
|
-
await super().__aenter__()
|
192
|
-
await self._http_client.__aenter__()
|
193
|
-
return self
|
194
|
-
|
195
|
-
async def __aexit__(
|
196
|
-
self,
|
197
|
-
exc_type: Optional[Type[Exception]],
|
198
|
-
exc_val: Optional[Exception],
|
199
|
-
exc_tb: Optional[TracebackType],
|
200
|
-
) -> None:
|
201
|
-
self._websocket = None
|
202
|
-
await self._http_client.__aexit__(exc_type, exc_val, exc_tb)
|
203
|
-
return await super().__aexit__(exc_type, exc_val, exc_tb)
|
204
|
-
|
205
|
-
async def _emit(self, event: Event) -> None:
|
206
|
-
await self._http_client.post(
|
207
|
-
"/events",
|
208
|
-
json=[event.model_dump(mode="json")],
|
209
|
-
)
|
210
|
-
|
211
|
-
|
212
228
|
class PrefectEventsClient(EventsClient):
|
213
229
|
"""A Prefect Events client that streams events to a Prefect server"""
|
214
230
|
|
@@ -219,7 +235,7 @@ class PrefectEventsClient(EventsClient):
|
|
219
235
|
self,
|
220
236
|
api_url: Optional[str] = None,
|
221
237
|
reconnection_attempts: int = 10,
|
222
|
-
checkpoint_every: int =
|
238
|
+
checkpoint_every: int = 700,
|
223
239
|
):
|
224
240
|
"""
|
225
241
|
Args:
|
@@ -299,6 +315,8 @@ class PrefectEventsClient(EventsClient):
|
|
299
315
|
# don't clear the list, just the ones that we are sure of.
|
300
316
|
self._unconfirmed_events = self._unconfirmed_events[unconfirmed_count:]
|
301
317
|
|
318
|
+
EVENT_WEBSOCKET_CHECKPOINTS.labels(self.client_name).inc()
|
319
|
+
|
302
320
|
async def _emit(self, event: Event) -> None:
|
303
321
|
for i in range(self._reconnection_attempts + 1):
|
304
322
|
try:
|
@@ -336,7 +354,7 @@ class PrefectCloudEventsClient(PrefectEventsClient):
|
|
336
354
|
api_url: Optional[str] = None,
|
337
355
|
api_key: Optional[str] = None,
|
338
356
|
reconnection_attempts: int = 10,
|
339
|
-
checkpoint_every: int =
|
357
|
+
checkpoint_every: int = 700,
|
340
358
|
):
|
341
359
|
"""
|
342
360
|
Args:
|
@@ -400,9 +418,9 @@ class PrefectEventSubscriber:
|
|
400
418
|
reconnection_attempts: When the client is disconnected, how many times
|
401
419
|
the client should attempt to reconnect
|
402
420
|
"""
|
421
|
+
self._api_key = None
|
403
422
|
if not api_url:
|
404
423
|
api_url = cast(str, PREFECT_API_URL.value())
|
405
|
-
self._api_key = None
|
406
424
|
|
407
425
|
from prefect.events.filters import EventFilter
|
408
426
|
|
@@ -426,10 +444,17 @@ class PrefectEventSubscriber:
|
|
426
444
|
if self._reconnection_attempts < 0:
|
427
445
|
raise ValueError("reconnection_attempts must be a non-negative integer")
|
428
446
|
|
447
|
+
@property
|
448
|
+
def client_name(self) -> str:
|
449
|
+
return self.__class__.__name__
|
450
|
+
|
429
451
|
async def __aenter__(self) -> Self:
|
430
452
|
# Don't handle any errors in the initial connection, because these are most
|
431
453
|
# likely a permission or configuration issue that should propagate
|
432
|
-
|
454
|
+
try:
|
455
|
+
await self._reconnect()
|
456
|
+
finally:
|
457
|
+
EVENT_WEBSOCKET_CONNECTIONS.labels(self.client_name, "out", "initial")
|
433
458
|
return self
|
434
459
|
|
435
460
|
async def _reconnect(self) -> None:
|
@@ -503,7 +528,12 @@ class PrefectEventSubscriber:
|
|
503
528
|
# Otherwise, after the first time through this loop, we're recovering
|
504
529
|
# from a ConnectionClosed, so reconnect now.
|
505
530
|
if not self._websocket or i > 0:
|
506
|
-
|
531
|
+
try:
|
532
|
+
await self._reconnect()
|
533
|
+
finally:
|
534
|
+
EVENT_WEBSOCKET_CONNECTIONS.labels(
|
535
|
+
self.client_name, "out", "reconnect"
|
536
|
+
)
|
507
537
|
assert self._websocket
|
508
538
|
|
509
539
|
while True:
|
@@ -514,7 +544,10 @@ class PrefectEventSubscriber:
|
|
514
544
|
continue
|
515
545
|
self._seen_events[event.id] = True
|
516
546
|
|
517
|
-
|
547
|
+
try:
|
548
|
+
return event
|
549
|
+
finally:
|
550
|
+
EVENTS_OBSERVED.labels(self.client_name).inc()
|
518
551
|
except ConnectionClosedOK:
|
519
552
|
logger.debug('Connection closed with "OK" status')
|
520
553
|
raise StopAsyncIteration
|
prefect/events/utilities.py
CHANGED
@@ -8,7 +8,6 @@ from pydantic_extra_types.pendulum_dt import DateTime
|
|
8
8
|
from .clients import (
|
9
9
|
AssertingEventsClient,
|
10
10
|
PrefectCloudEventsClient,
|
11
|
-
PrefectEphemeralEventsClient,
|
12
11
|
PrefectEventsClient,
|
13
12
|
)
|
14
13
|
from .schemas.events import Event, RelatedResource
|
@@ -53,7 +52,6 @@ def emit_event(
|
|
53
52
|
AssertingEventsClient,
|
54
53
|
PrefectCloudEventsClient,
|
55
54
|
PrefectEventsClient,
|
56
|
-
PrefectEphemeralEventsClient,
|
57
55
|
]
|
58
56
|
worker_instance = EventsWorker.instance()
|
59
57
|
|
prefect/events/worker.py
CHANGED
@@ -17,7 +17,6 @@ from .clients import (
|
|
17
17
|
EventsClient,
|
18
18
|
NullEventsClient,
|
19
19
|
PrefectCloudEventsClient,
|
20
|
-
PrefectEphemeralEventsClient,
|
21
20
|
PrefectEventsClient,
|
22
21
|
)
|
23
22
|
from .related import related_resources_from_run_context
|
@@ -97,7 +96,15 @@ class EventsWorker(QueueService[Event]):
|
|
97
96
|
elif should_emit_events_to_running_server():
|
98
97
|
client_type = PrefectEventsClient
|
99
98
|
elif should_emit_events_to_ephemeral_server():
|
100
|
-
|
99
|
+
# create an ephemeral API if none was provided
|
100
|
+
from prefect.server.api.server import SubprocessASGIServer
|
101
|
+
|
102
|
+
server = SubprocessASGIServer()
|
103
|
+
server.start()
|
104
|
+
assert server.server_process is not None, "Server process did not start"
|
105
|
+
|
106
|
+
client_kwargs = {"api_url": server.api_url}
|
107
|
+
client_type = PrefectEventsClient
|
101
108
|
else:
|
102
109
|
client_type = NullEventsClient
|
103
110
|
|
prefect/flow_engine.py
CHANGED
@@ -29,7 +29,8 @@ from prefect.client.orchestration import SyncPrefectClient, get_client
|
|
29
29
|
from prefect.client.schemas import FlowRun, TaskRun
|
30
30
|
from prefect.client.schemas.filters import FlowRunFilter
|
31
31
|
from prefect.client.schemas.sorting import FlowRunSort
|
32
|
-
from prefect.context import
|
32
|
+
from prefect.concurrency.context import ConcurrencyContext
|
33
|
+
from prefect.context import FlowRunContext, SyncClientContext, TagsContext
|
33
34
|
from prefect.exceptions import (
|
34
35
|
Abort,
|
35
36
|
Pause,
|
@@ -505,6 +506,8 @@ class FlowRunEngine(Generic[P, R]):
|
|
505
506
|
task_runner=task_runner,
|
506
507
|
)
|
507
508
|
)
|
509
|
+
stack.enter_context(ConcurrencyContext())
|
510
|
+
|
508
511
|
# set the logger to the flow run logger
|
509
512
|
self.logger = flow_run_logger(flow_run=self.flow_run, flow=self.flow)
|
510
513
|
|
@@ -529,8 +532,8 @@ class FlowRunEngine(Generic[P, R]):
|
|
529
532
|
"""
|
530
533
|
Enters a client context and creates a flow run if needed.
|
531
534
|
"""
|
532
|
-
with
|
533
|
-
self._client = client_ctx.
|
535
|
+
with SyncClientContext.get_or_create() as client_ctx:
|
536
|
+
self._client = client_ctx.client
|
534
537
|
self._is_started = True
|
535
538
|
|
536
539
|
if not self.flow_run:
|
prefect/flows.py
CHANGED
@@ -66,6 +66,7 @@ from prefect.exceptions import (
|
|
66
66
|
ObjectNotFound,
|
67
67
|
ParameterTypeError,
|
68
68
|
ScriptError,
|
69
|
+
TerminationSignal,
|
69
70
|
UnspecifiedFlowError,
|
70
71
|
)
|
71
72
|
from prefect.filesystems import LocalFileSystem, ReadableDeploymentStorage
|
@@ -95,7 +96,7 @@ from prefect.utilities.callables import (
|
|
95
96
|
parameters_to_args_kwargs,
|
96
97
|
raise_for_reserved_arguments,
|
97
98
|
)
|
98
|
-
from prefect.utilities.collections import listrepr
|
99
|
+
from prefect.utilities.collections import listrepr, visit_collection
|
99
100
|
from prefect.utilities.filesystem import relative_path_to_current_platform
|
100
101
|
from prefect.utilities.hashing import file_hash
|
101
102
|
from prefect.utilities.importtools import import_object, safe_load_namespace
|
@@ -535,6 +536,21 @@ class Flow(Generic[P, R]):
|
|
535
536
|
Raises:
|
536
537
|
ParameterTypeError: if the provided parameters are not valid
|
537
538
|
"""
|
539
|
+
|
540
|
+
def resolve_block_reference(data: Any) -> Any:
|
541
|
+
if isinstance(data, dict) and "$ref" in data:
|
542
|
+
return Block.load_from_ref(data["$ref"])
|
543
|
+
return data
|
544
|
+
|
545
|
+
try:
|
546
|
+
parameters = visit_collection(
|
547
|
+
parameters, resolve_block_reference, return_data=True
|
548
|
+
)
|
549
|
+
except (ValueError, RuntimeError) as exc:
|
550
|
+
raise ParameterTypeError(
|
551
|
+
"Failed to resolve block references in parameters."
|
552
|
+
) from exc
|
553
|
+
|
538
554
|
args, kwargs = parameters_to_args_kwargs(self.fn, parameters)
|
539
555
|
|
540
556
|
with warnings.catch_warnings():
|
@@ -931,10 +947,15 @@ class Flow(Generic[P, R]):
|
|
931
947
|
else:
|
932
948
|
raise
|
933
949
|
|
934
|
-
|
935
|
-
loop
|
936
|
-
|
937
|
-
|
950
|
+
try:
|
951
|
+
if loop is not None:
|
952
|
+
loop.run_until_complete(runner.start(webserver=webserver))
|
953
|
+
else:
|
954
|
+
asyncio.run(runner.start(webserver=webserver))
|
955
|
+
except (KeyboardInterrupt, TerminationSignal) as exc:
|
956
|
+
logger.info(f"Received {type(exc).__name__}, shutting down...")
|
957
|
+
if loop is not None:
|
958
|
+
loop.stop()
|
938
959
|
|
939
960
|
@classmethod
|
940
961
|
@sync_compatible
|
@@ -1734,14 +1755,13 @@ def load_flow_from_entrypoint(
|
|
1734
1755
|
raise MissingFlowError(
|
1735
1756
|
f"Flow function with name {func_name!r} not found in {path!r}. "
|
1736
1757
|
) from exc
|
1737
|
-
except ScriptError
|
1758
|
+
except ScriptError:
|
1738
1759
|
# If the flow has dependencies that are not installed in the current
|
1739
|
-
# environment, fallback to loading the flow via AST parsing.
|
1740
|
-
# drawback of this approach is that we're unable to actually load the
|
1741
|
-
# function, so we create a placeholder flow that will re-raise this
|
1742
|
-
# exception when called.
|
1760
|
+
# environment, fallback to loading the flow via AST parsing.
|
1743
1761
|
if use_placeholder_flow:
|
1744
|
-
flow =
|
1762
|
+
flow = safe_load_flow_from_entrypoint(entrypoint)
|
1763
|
+
if flow is None:
|
1764
|
+
raise
|
1745
1765
|
else:
|
1746
1766
|
raise
|
1747
1767
|
|
@@ -1976,6 +1996,147 @@ def load_placeholder_flow(entrypoint: str, raises: Exception):
|
|
1976
1996
|
return Flow(**arguments)
|
1977
1997
|
|
1978
1998
|
|
1999
|
+
def safe_load_flow_from_entrypoint(entrypoint: str) -> Optional[Flow]:
|
2000
|
+
"""
|
2001
|
+
Load a flow from an entrypoint and return None if an exception is raised.
|
2002
|
+
|
2003
|
+
Args:
|
2004
|
+
entrypoint: a string in the format `<path_to_script>:<flow_func_name>`
|
2005
|
+
or a module path to a flow function
|
2006
|
+
"""
|
2007
|
+
func_def, source_code = _entrypoint_definition_and_source(entrypoint)
|
2008
|
+
path = None
|
2009
|
+
if ":" in entrypoint:
|
2010
|
+
path = entrypoint.rsplit(":")[0]
|
2011
|
+
namespace = safe_load_namespace(source_code, filepath=path)
|
2012
|
+
if func_def.name in namespace:
|
2013
|
+
return namespace[func_def.name]
|
2014
|
+
else:
|
2015
|
+
# If the function is not in the namespace, if may be due to missing dependencies
|
2016
|
+
# for the function. We will attempt to compile each annotation and default value
|
2017
|
+
# and remove them from the function definition to see if the function can be
|
2018
|
+
# compiled without them.
|
2019
|
+
|
2020
|
+
return _sanitize_and_load_flow(func_def, namespace)
|
2021
|
+
|
2022
|
+
|
2023
|
+
def _sanitize_and_load_flow(
|
2024
|
+
func_def: Union[ast.FunctionDef, ast.AsyncFunctionDef], namespace: Dict[str, Any]
|
2025
|
+
) -> Optional[Flow]:
|
2026
|
+
"""
|
2027
|
+
Attempt to load a flow from the function definition after sanitizing the annotations
|
2028
|
+
and defaults that can't be compiled.
|
2029
|
+
|
2030
|
+
Args:
|
2031
|
+
func_def: the function definition
|
2032
|
+
namespace: the namespace to load the function into
|
2033
|
+
|
2034
|
+
Returns:
|
2035
|
+
The loaded function or None if the function can't be loaded
|
2036
|
+
after sanitizing the annotations and defaults.
|
2037
|
+
"""
|
2038
|
+
args = func_def.args.posonlyargs + func_def.args.args + func_def.args.kwonlyargs
|
2039
|
+
if func_def.args.vararg:
|
2040
|
+
args.append(func_def.args.vararg)
|
2041
|
+
if func_def.args.kwarg:
|
2042
|
+
args.append(func_def.args.kwarg)
|
2043
|
+
# Remove annotations that can't be compiled
|
2044
|
+
for arg in args:
|
2045
|
+
if arg.annotation is not None:
|
2046
|
+
try:
|
2047
|
+
code = compile(
|
2048
|
+
ast.Expression(arg.annotation),
|
2049
|
+
filename="<ast>",
|
2050
|
+
mode="eval",
|
2051
|
+
)
|
2052
|
+
exec(code, namespace)
|
2053
|
+
except Exception as e:
|
2054
|
+
logger.debug(
|
2055
|
+
"Failed to evaluate annotation for argument %s due to the following error. Ignoring annotation.",
|
2056
|
+
arg.arg,
|
2057
|
+
exc_info=e,
|
2058
|
+
)
|
2059
|
+
arg.annotation = None
|
2060
|
+
|
2061
|
+
# Remove defaults that can't be compiled
|
2062
|
+
new_defaults = []
|
2063
|
+
for default in func_def.args.defaults:
|
2064
|
+
try:
|
2065
|
+
code = compile(ast.Expression(default), "<ast>", "eval")
|
2066
|
+
exec(code, namespace)
|
2067
|
+
new_defaults.append(default)
|
2068
|
+
except Exception as e:
|
2069
|
+
logger.debug(
|
2070
|
+
"Failed to evaluate default value %s due to the following error. Ignoring default.",
|
2071
|
+
default,
|
2072
|
+
exc_info=e,
|
2073
|
+
)
|
2074
|
+
new_defaults.append(
|
2075
|
+
ast.Constant(
|
2076
|
+
value=None, lineno=default.lineno, col_offset=default.col_offset
|
2077
|
+
)
|
2078
|
+
)
|
2079
|
+
func_def.args.defaults = new_defaults
|
2080
|
+
|
2081
|
+
# Remove kw_defaults that can't be compiled
|
2082
|
+
new_kw_defaults = []
|
2083
|
+
for default in func_def.args.kw_defaults:
|
2084
|
+
if default is not None:
|
2085
|
+
try:
|
2086
|
+
code = compile(ast.Expression(default), "<ast>", "eval")
|
2087
|
+
exec(code, namespace)
|
2088
|
+
new_kw_defaults.append(default)
|
2089
|
+
except Exception as e:
|
2090
|
+
logger.debug(
|
2091
|
+
"Failed to evaluate default value %s due to the following error. Ignoring default.",
|
2092
|
+
default,
|
2093
|
+
exc_info=e,
|
2094
|
+
)
|
2095
|
+
new_kw_defaults.append(
|
2096
|
+
ast.Constant(
|
2097
|
+
value=None,
|
2098
|
+
lineno=default.lineno,
|
2099
|
+
col_offset=default.col_offset,
|
2100
|
+
)
|
2101
|
+
)
|
2102
|
+
else:
|
2103
|
+
new_kw_defaults.append(
|
2104
|
+
ast.Constant(
|
2105
|
+
value=None,
|
2106
|
+
lineno=func_def.lineno,
|
2107
|
+
col_offset=func_def.col_offset,
|
2108
|
+
)
|
2109
|
+
)
|
2110
|
+
func_def.args.kw_defaults = new_kw_defaults
|
2111
|
+
|
2112
|
+
if func_def.returns is not None:
|
2113
|
+
try:
|
2114
|
+
code = compile(
|
2115
|
+
ast.Expression(func_def.returns), filename="<ast>", mode="eval"
|
2116
|
+
)
|
2117
|
+
exec(code, namespace)
|
2118
|
+
except Exception as e:
|
2119
|
+
logger.debug(
|
2120
|
+
"Failed to evaluate return annotation due to the following error. Ignoring annotation.",
|
2121
|
+
exc_info=e,
|
2122
|
+
)
|
2123
|
+
func_def.returns = None
|
2124
|
+
|
2125
|
+
# Attempt to compile the function without annotations and defaults that
|
2126
|
+
# can't be compiled
|
2127
|
+
try:
|
2128
|
+
code = compile(
|
2129
|
+
ast.Module(body=[func_def], type_ignores=[]),
|
2130
|
+
filename="<ast>",
|
2131
|
+
mode="exec",
|
2132
|
+
)
|
2133
|
+
exec(code, namespace)
|
2134
|
+
except Exception as e:
|
2135
|
+
logger.debug("Failed to compile: %s", e)
|
2136
|
+
else:
|
2137
|
+
return namespace.get(func_def.name)
|
2138
|
+
|
2139
|
+
|
1979
2140
|
def load_flow_arguments_from_entrypoint(
|
1980
2141
|
entrypoint: str, arguments: Optional[Union[List[str], Set[str]]] = None
|
1981
2142
|
) -> dict[str, Any]:
|
@@ -1991,6 +2152,9 @@ def load_flow_arguments_from_entrypoint(
|
|
1991
2152
|
"""
|
1992
2153
|
|
1993
2154
|
func_def, source_code = _entrypoint_definition_and_source(entrypoint)
|
2155
|
+
path = None
|
2156
|
+
if ":" in entrypoint:
|
2157
|
+
path = entrypoint.rsplit(":")[0]
|
1994
2158
|
|
1995
2159
|
if arguments is None:
|
1996
2160
|
# If no arguments are provided default to known arguments that are of
|
@@ -2026,7 +2190,7 @@ def load_flow_arguments_from_entrypoint(
|
|
2026
2190
|
|
2027
2191
|
# if the arg value is not a raw str (i.e. a variable or expression),
|
2028
2192
|
# then attempt to evaluate it
|
2029
|
-
namespace = safe_load_namespace(source_code)
|
2193
|
+
namespace = safe_load_namespace(source_code, filepath=path)
|
2030
2194
|
literal_arg_value = ast.get_source_segment(source_code, keyword.value)
|
2031
2195
|
cleaned_value = (
|
2032
2196
|
literal_arg_value.replace("\n", "") if literal_arg_value else ""
|