flyte 0.2.0b5__py3-none-any.whl → 0.2.0b8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +2 -1
- flyte/_code_bundle/_utils.py +0 -16
- flyte/_code_bundle/bundle.py +1 -1
- flyte/_environment.py +42 -1
- flyte/_image.py +1 -2
- flyte/_initialize.py +52 -23
- flyte/_internal/controllers/__init__.py +2 -0
- flyte/_internal/controllers/_local_controller.py +3 -0
- flyte/_internal/controllers/remote/_controller.py +3 -0
- flyte/_internal/controllers/remote/_core.py +1 -1
- flyte/_internal/controllers/remote/_informer.py +3 -3
- flyte/_task.py +51 -12
- flyte/_task_environment.py +48 -66
- flyte/_utils/coro_management.py +0 -2
- flyte/_version.py +2 -2
- flyte/cli/_common.py +24 -15
- flyte/cli/_create.py +39 -8
- flyte/cli/_delete.py +2 -2
- flyte/cli/_deploy.py +4 -1
- flyte/cli/_gen.py +155 -0
- flyte/cli/_get.py +53 -7
- flyte/cli/_run.py +34 -8
- flyte/cli/main.py +69 -16
- flyte/config/__init__.py +2 -189
- flyte/config/_config.py +181 -172
- flyte/config/_internal.py +1 -1
- flyte/config/_reader.py +207 -0
- flyte/extras/_container.py +1 -1
- flyte/remote/_logs.py +9 -2
- flyte/remote/_run.py +26 -17
- flyte/syncify/__init__.py +51 -0
- flyte/syncify/_api.py +5 -6
- flyte-0.2.0b8.dist-info/METADATA +180 -0
- {flyte-0.2.0b5.dist-info → flyte-0.2.0b8.dist-info}/RECORD +37 -35
- flyte-0.2.0b5.dist-info/METADATA +0 -178
- {flyte-0.2.0b5.dist-info → flyte-0.2.0b8.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b5.dist-info → flyte-0.2.0b8.dist-info}/entry_points.txt +0 -0
- {flyte-0.2.0b5.dist-info → flyte-0.2.0b8.dist-info}/top_level.txt +0 -0
flyte/__init__.py
CHANGED
|
@@ -38,6 +38,7 @@ __all__ = [
|
|
|
38
38
|
"deploy",
|
|
39
39
|
"group",
|
|
40
40
|
"init",
|
|
41
|
+
"init_auto_from_config",
|
|
41
42
|
"run",
|
|
42
43
|
"trace",
|
|
43
44
|
"with_runcontext",
|
|
@@ -49,7 +50,7 @@ from ._deploy import deploy
|
|
|
49
50
|
from ._environment import Environment
|
|
50
51
|
from ._group import group
|
|
51
52
|
from ._image import Image
|
|
52
|
-
from ._initialize import init
|
|
53
|
+
from ._initialize import init, init_auto_from_config
|
|
53
54
|
from ._resources import GPU, TPU, Device, Resources
|
|
54
55
|
from ._retry import RetryStrategy
|
|
55
56
|
from ._reusable_environment import ReusePolicy
|
flyte/_code_bundle/_utils.py
CHANGED
|
@@ -14,7 +14,6 @@ import tempfile
|
|
|
14
14
|
import typing
|
|
15
15
|
from datetime import datetime, timezone
|
|
16
16
|
from functools import lru_cache
|
|
17
|
-
from pathlib import Path
|
|
18
17
|
from types import ModuleType
|
|
19
18
|
from typing import List, Literal, Optional, Tuple, Union
|
|
20
19
|
|
|
@@ -322,18 +321,3 @@ def hash_file(file_path: typing.Union[os.PathLike, str]) -> Tuple[bytes, str, in
|
|
|
322
321
|
size += len(chunk)
|
|
323
322
|
|
|
324
323
|
return h.digest(), h.hexdigest(), size
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
def _find_project_root(source_path) -> str:
|
|
328
|
-
"""
|
|
329
|
-
Find the root of the project.
|
|
330
|
-
The root of the project is considered to be the first ancestor from source_path that does
|
|
331
|
-
not contain a __init__.py file.
|
|
332
|
-
|
|
333
|
-
N.B.: This assumption only holds for regular packages (as opposed to namespace packages)
|
|
334
|
-
"""
|
|
335
|
-
# Start from the directory right above source_path
|
|
336
|
-
path = Path(source_path).parent.resolve()
|
|
337
|
-
while os.path.exists(os.path.join(path, "__init__.py")):
|
|
338
|
-
path = path.parent
|
|
339
|
-
return str(path)
|
flyte/_code_bundle/bundle.py
CHANGED
|
@@ -146,7 +146,7 @@ async def build_code_bundle(
|
|
|
146
146
|
logger.info(f"Code bundle created at {bundle_path}, size: {tar_size} MB, archive size: {archive_size} MB")
|
|
147
147
|
if not dryrun:
|
|
148
148
|
hash_digest, remote_path = await upload_file(bundle_path)
|
|
149
|
-
logger.
|
|
149
|
+
logger.debug(f"Code bundle uploaded to {remote_path}")
|
|
150
150
|
else:
|
|
151
151
|
remote_path = "na"
|
|
152
152
|
if copy_bundle_to:
|
flyte/_environment.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import re
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
|
-
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
|
5
6
|
|
|
6
7
|
import rich.repr
|
|
7
8
|
|
|
@@ -14,6 +15,10 @@ if TYPE_CHECKING:
|
|
|
14
15
|
from kubernetes.client import V1PodTemplate
|
|
15
16
|
|
|
16
17
|
|
|
18
|
+
def is_snake_or_kebab_with_numbers(s: str) -> bool:
|
|
19
|
+
return re.fullmatch(r"^[a-z0-9]+([_-][a-z0-9]+)*$", s) is not None
|
|
20
|
+
|
|
21
|
+
|
|
17
22
|
@rich.repr.auto
|
|
18
23
|
@dataclass(init=True, repr=True)
|
|
19
24
|
class Environment:
|
|
@@ -36,8 +41,44 @@ class Environment:
|
|
|
36
41
|
resources: Optional[Resources] = None
|
|
37
42
|
image: Union[str, Image, Literal["auto"]] = "auto"
|
|
38
43
|
|
|
44
|
+
def __post_init__(self):
|
|
45
|
+
if not is_snake_or_kebab_with_numbers(self.name):
|
|
46
|
+
raise ValueError(f"Environment name '{self.name}' must be in snake_case or kebab-case format.")
|
|
47
|
+
|
|
39
48
|
def add_dependency(self, *env: Environment):
|
|
40
49
|
"""
|
|
41
50
|
Add a dependency to the environment.
|
|
42
51
|
"""
|
|
43
52
|
self.env_dep_hints.extend(env)
|
|
53
|
+
|
|
54
|
+
def clone_with(
|
|
55
|
+
self,
|
|
56
|
+
name: str,
|
|
57
|
+
image: Optional[Union[str, Image, Literal["auto"]]] = None,
|
|
58
|
+
resources: Optional[Resources] = None,
|
|
59
|
+
env: Optional[Dict[str, str]] = None,
|
|
60
|
+
secrets: Optional[SecretRequest] = None,
|
|
61
|
+
env_dep_hints: Optional[List[Environment]] = None,
|
|
62
|
+
**kwargs: Any,
|
|
63
|
+
) -> Environment:
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
66
|
+
def _get_kwargs(self) -> Dict[str, Any]:
|
|
67
|
+
"""
|
|
68
|
+
Get the keyword arguments for the environment.
|
|
69
|
+
"""
|
|
70
|
+
kwargs: Dict[str, Any] = {
|
|
71
|
+
"env_dep_hints": self.env_dep_hints,
|
|
72
|
+
"image": self.image,
|
|
73
|
+
}
|
|
74
|
+
if self.resources is not None:
|
|
75
|
+
kwargs["resources"] = self.resources
|
|
76
|
+
if self.secrets is not None:
|
|
77
|
+
kwargs["secrets"] = self.secrets
|
|
78
|
+
if self.env is not None:
|
|
79
|
+
kwargs["env"] = self.env
|
|
80
|
+
if self.pod_template is not None:
|
|
81
|
+
kwargs["pod_template"] = self.pod_template
|
|
82
|
+
if self.description is not None:
|
|
83
|
+
kwargs["description"] = self.description
|
|
84
|
+
return kwargs
|
flyte/_image.py
CHANGED
|
@@ -444,8 +444,7 @@ class Image:
|
|
|
444
444
|
```
|
|
445
445
|
|
|
446
446
|
For more information on the uv script format, see the documentation:
|
|
447
|
-
|
|
448
|
-
UV: Declaring script dependencies</href>
|
|
447
|
+
[UV: Declaring script dependencies](https://docs.astral.sh/uv/guides/scripts/#declaring-script-dependencies)
|
|
449
448
|
|
|
450
449
|
:param name: name of the image
|
|
451
450
|
:param registry: registry to use for the image
|
flyte/_initialize.py
CHANGED
|
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Callable, List, Literal, Optional, TypeVar
|
|
|
10
10
|
from flyte.errors import InitializationError
|
|
11
11
|
from flyte.syncify import syncify
|
|
12
12
|
|
|
13
|
-
from ._logging import initialize_logger
|
|
13
|
+
from ._logging import initialize_logger, logger
|
|
14
14
|
from ._tools import ipython_check
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
@@ -130,7 +130,6 @@ async def init(
|
|
|
130
130
|
rpc_retries: int = 3,
|
|
131
131
|
http_proxy_url: str | None = None,
|
|
132
132
|
storage: Storage | None = None,
|
|
133
|
-
config: Config | None = None,
|
|
134
133
|
) -> None:
|
|
135
134
|
"""
|
|
136
135
|
Initialize the Flyte system with the given configuration. This method should be called before any other Flyte
|
|
@@ -179,42 +178,72 @@ async def init(
|
|
|
179
178
|
global _init_config # noqa: PLW0603
|
|
180
179
|
|
|
181
180
|
with _init_lock:
|
|
182
|
-
if config is None:
|
|
183
|
-
import flyte.config as _f_cfg
|
|
184
|
-
|
|
185
|
-
config = _f_cfg.Config()
|
|
186
|
-
platform_cfg = config.platform
|
|
187
|
-
task_cfg = config.task
|
|
188
181
|
client = None
|
|
189
|
-
if endpoint or
|
|
182
|
+
if endpoint or api_key:
|
|
190
183
|
client = await _initialize_client(
|
|
191
184
|
api_key=api_key,
|
|
192
|
-
auth_type=auth_type
|
|
193
|
-
endpoint=endpoint
|
|
185
|
+
auth_type=auth_type,
|
|
186
|
+
endpoint=endpoint,
|
|
194
187
|
headless=headless,
|
|
195
|
-
insecure=insecure
|
|
196
|
-
insecure_skip_verify=insecure_skip_verify
|
|
197
|
-
ca_cert_file_path=ca_cert_file_path
|
|
198
|
-
command=command
|
|
199
|
-
proxy_command=proxy_command
|
|
200
|
-
client_id=client_id
|
|
201
|
-
client_credentials_secret=client_credentials_secret
|
|
188
|
+
insecure=insecure,
|
|
189
|
+
insecure_skip_verify=insecure_skip_verify,
|
|
190
|
+
ca_cert_file_path=ca_cert_file_path,
|
|
191
|
+
command=command,
|
|
192
|
+
proxy_command=proxy_command,
|
|
193
|
+
client_id=client_id,
|
|
194
|
+
client_credentials_secret=client_credentials_secret,
|
|
202
195
|
client_config=auth_client_config,
|
|
203
|
-
rpc_retries=rpc_retries
|
|
204
|
-
http_proxy_url=http_proxy_url
|
|
196
|
+
rpc_retries=rpc_retries,
|
|
197
|
+
http_proxy_url=http_proxy_url,
|
|
205
198
|
)
|
|
206
199
|
|
|
207
200
|
root_dir = root_dir or get_cwd_editable_install() or Path.cwd()
|
|
208
201
|
_init_config = _InitConfig(
|
|
209
202
|
root_dir=root_dir,
|
|
210
|
-
project=project
|
|
211
|
-
domain=domain
|
|
203
|
+
project=project,
|
|
204
|
+
domain=domain,
|
|
212
205
|
client=client,
|
|
213
206
|
storage=storage,
|
|
214
|
-
org=org
|
|
207
|
+
org=org,
|
|
215
208
|
)
|
|
216
209
|
|
|
217
210
|
|
|
211
|
+
@syncify
|
|
212
|
+
async def init_auto_from_config(path_or_config: str | Config | None = None) -> None:
|
|
213
|
+
"""
|
|
214
|
+
Initialize the Flyte system using a configuration file or Config object. This method should be called before any
|
|
215
|
+
other Flyte remote API methods are called. Thread-safe implementation.
|
|
216
|
+
|
|
217
|
+
:param path_or_config: Path to the configuration file or Config object
|
|
218
|
+
:return: None
|
|
219
|
+
"""
|
|
220
|
+
import flyte.config as config
|
|
221
|
+
|
|
222
|
+
cfg: config.Config
|
|
223
|
+
if path_or_config is None or isinstance(path_or_config, str):
|
|
224
|
+
# If a string is passed, treat it as a path to the config file
|
|
225
|
+
cfg = config.auto(path_or_config)
|
|
226
|
+
else:
|
|
227
|
+
# If a Config object is passed, use it directly
|
|
228
|
+
cfg = path_or_config
|
|
229
|
+
|
|
230
|
+
logger.debug(f"Flyte config initialized as {cfg}")
|
|
231
|
+
await init.aio(
|
|
232
|
+
org=cfg.task.org,
|
|
233
|
+
project=cfg.task.project,
|
|
234
|
+
domain=cfg.task.domain,
|
|
235
|
+
endpoint=cfg.platform.endpoint,
|
|
236
|
+
insecure=cfg.platform.insecure,
|
|
237
|
+
insecure_skip_verify=cfg.platform.insecure_skip_verify,
|
|
238
|
+
ca_cert_file_path=cfg.platform.ca_cert_file_path,
|
|
239
|
+
auth_type=cfg.platform.auth_mode,
|
|
240
|
+
command=cfg.platform.command,
|
|
241
|
+
proxy_command=cfg.platform.proxy_command,
|
|
242
|
+
client_id=cfg.platform.client_id,
|
|
243
|
+
client_credentials_secret=cfg.platform.client_credentials_secret,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
|
|
218
247
|
def _get_init_config() -> Optional[_InitConfig]:
|
|
219
248
|
"""
|
|
220
249
|
Get the current initialization configuration. Thread-safe implementation.
|
|
@@ -28,6 +28,8 @@ class Controller(Protocol):
|
|
|
28
28
|
"""
|
|
29
29
|
...
|
|
30
30
|
|
|
31
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> Any: ...
|
|
32
|
+
|
|
31
33
|
async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
|
|
32
34
|
"""
|
|
33
35
|
Submit a task reference to the controller asynchronously and wait for the result. This is async and will block
|
|
@@ -8,6 +8,7 @@ from flyte._internal.runtime.entrypoints import direct_dispatch
|
|
|
8
8
|
from flyte._logging import log, logger
|
|
9
9
|
from flyte._protos.workflow import task_definition_pb2
|
|
10
10
|
from flyte._task import TaskTemplate
|
|
11
|
+
from flyte._utils.asyn import loop_manager
|
|
11
12
|
from flyte.models import ActionID, NativeInterface, RawDataPath
|
|
12
13
|
|
|
13
14
|
R = TypeVar("R")
|
|
@@ -58,6 +59,8 @@ class LocalController:
|
|
|
58
59
|
return result
|
|
59
60
|
return out
|
|
60
61
|
|
|
62
|
+
submit_sync = loop_manager.synced(submit)
|
|
63
|
+
|
|
61
64
|
async def finalize_parent_action(self, action: ActionID):
|
|
62
65
|
pass
|
|
63
66
|
|
|
@@ -21,6 +21,7 @@ from flyte._internal.runtime.task_serde import translate_task_to_wire
|
|
|
21
21
|
from flyte._logging import logger
|
|
22
22
|
from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
|
|
23
23
|
from flyte._task import TaskTemplate
|
|
24
|
+
from flyte._utils.asyn import loop_manager
|
|
24
25
|
from flyte.models import ActionID, NativeInterface, SerializationContext
|
|
25
26
|
|
|
26
27
|
R = TypeVar("R")
|
|
@@ -235,6 +236,8 @@ class RemoteController(Controller):
|
|
|
235
236
|
async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
|
|
236
237
|
return await self._submit(task_call_seq, _task, *args, **kwargs)
|
|
237
238
|
|
|
239
|
+
submit_sync = loop_manager.synced(submit)
|
|
240
|
+
|
|
238
241
|
async def finalize_parent_action(self, action_id: ActionID):
|
|
239
242
|
"""
|
|
240
243
|
This method is invoked when the parent action is finished. It will finalize the run and upload the outputs
|
|
@@ -32,7 +32,7 @@ class Controller:
|
|
|
32
32
|
max_system_retries: int = 5,
|
|
33
33
|
resource_log_interval_sec: float = 10.0,
|
|
34
34
|
min_backoff_on_err_sec: float = 0.1,
|
|
35
|
-
thread_wait_timeout_sec: float = 0
|
|
35
|
+
thread_wait_timeout_sec: float = 5.0,
|
|
36
36
|
enqueue_timeout_sec: float = 5.0,
|
|
37
37
|
):
|
|
38
38
|
"""
|
|
@@ -38,11 +38,11 @@ class ActionCache:
|
|
|
38
38
|
"""
|
|
39
39
|
Add an action to the cache if it doesn't exist. This is invoked by the watch.
|
|
40
40
|
"""
|
|
41
|
-
logger.
|
|
41
|
+
logger.debug(f"Observing phase {run_definition_pb2.Phase.Name(state.phase)} for {state.action_id.name}")
|
|
42
42
|
if state.output_uri:
|
|
43
|
-
logger.
|
|
43
|
+
logger.debug(f"Output URI: {state.output_uri}")
|
|
44
44
|
else:
|
|
45
|
-
logger.
|
|
45
|
+
logger.warning(f"{state.action_id.name} has no output URI")
|
|
46
46
|
if state.phase == run_definition_pb2.Phase.PHASE_FAILED:
|
|
47
47
|
logger.error(
|
|
48
48
|
f"Action {state.action_id.name} failed with error (msg):"
|
flyte/_task.py
CHANGED
|
@@ -3,10 +3,10 @@ from __future__ import annotations
|
|
|
3
3
|
import weakref
|
|
4
4
|
from dataclasses import dataclass, field, replace
|
|
5
5
|
from functools import cached_property
|
|
6
|
+
from inspect import iscoroutinefunction
|
|
6
7
|
from typing import (
|
|
7
8
|
TYPE_CHECKING,
|
|
8
9
|
Any,
|
|
9
|
-
Awaitable,
|
|
10
10
|
Callable,
|
|
11
11
|
Coroutine,
|
|
12
12
|
Dict,
|
|
@@ -15,6 +15,7 @@ from typing import (
|
|
|
15
15
|
Literal,
|
|
16
16
|
Optional,
|
|
17
17
|
ParamSpec,
|
|
18
|
+
TypeAlias,
|
|
18
19
|
TypeVar,
|
|
19
20
|
Union,
|
|
20
21
|
)
|
|
@@ -42,6 +43,10 @@ if TYPE_CHECKING:
|
|
|
42
43
|
P = ParamSpec("P") # capture the function's parameters
|
|
43
44
|
R = TypeVar("R") # return type
|
|
44
45
|
|
|
46
|
+
AsyncFunctionType: TypeAlias = Callable[P, Coroutine[Any, Any, R]]
|
|
47
|
+
SyncFunctionType: TypeAlias = Callable[P, R]
|
|
48
|
+
FunctionTypes: TypeAlias = Union[AsyncFunctionType, SyncFunctionType]
|
|
49
|
+
|
|
45
50
|
|
|
46
51
|
@dataclass(kw_only=True)
|
|
47
52
|
class TaskTemplate(Generic[P, R]):
|
|
@@ -98,6 +103,9 @@ class TaskTemplate(Generic[P, R]):
|
|
|
98
103
|
local: bool = field(default=False, init=False)
|
|
99
104
|
ref: bool = field(default=False, init=False, repr=False, compare=False)
|
|
100
105
|
|
|
106
|
+
# Only used in python 3.10 and 3.11, where we cannot use markcoroutinefunction
|
|
107
|
+
_call_as_synchronous: bool = False
|
|
108
|
+
|
|
101
109
|
def __post_init__(self):
|
|
102
110
|
# If pod_template is set to a pod, verify
|
|
103
111
|
if self.pod_template is not None and not isinstance(self.pod_template, str):
|
|
@@ -208,13 +216,15 @@ class TaskTemplate(Generic[P, R]):
|
|
|
208
216
|
def native_interface(self) -> NativeInterface:
|
|
209
217
|
return self.interface
|
|
210
218
|
|
|
211
|
-
|
|
219
|
+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
|
|
212
220
|
"""
|
|
213
221
|
This is the entrypoint for an async function task at runtime. It will be called during an execution.
|
|
214
222
|
Please do not override this method, if you simply want to modify the execution behavior, override the
|
|
215
223
|
execute method.
|
|
216
224
|
|
|
217
|
-
|
|
225
|
+
This needs to be overridable to maybe be async.
|
|
226
|
+
The returned thing from here needs to be an awaitable if the underlying task is async, and a regular object
|
|
227
|
+
if the task is not.
|
|
218
228
|
"""
|
|
219
229
|
try:
|
|
220
230
|
ctx = internal_ctx()
|
|
@@ -226,8 +236,14 @@ class TaskTemplate(Generic[P, R]):
|
|
|
226
236
|
|
|
227
237
|
controller = get_controller()
|
|
228
238
|
if controller:
|
|
229
|
-
|
|
230
|
-
|
|
239
|
+
if self._call_as_synchronous:
|
|
240
|
+
return controller.submit_sync(self, *args, **kwargs)
|
|
241
|
+
else:
|
|
242
|
+
return controller.submit(self, *args, **kwargs)
|
|
243
|
+
else:
|
|
244
|
+
raise RuntimeSystemError("BadContext", "Controller is not initialized.")
|
|
245
|
+
|
|
246
|
+
return self.forward(*args, **kwargs)
|
|
231
247
|
except RuntimeSystemError:
|
|
232
248
|
raise
|
|
233
249
|
except RuntimeUserError:
|
|
@@ -235,6 +251,17 @@ class TaskTemplate(Generic[P, R]):
|
|
|
235
251
|
except Exception as e:
|
|
236
252
|
raise RuntimeUserError(type(e).__name__, str(e)) from e
|
|
237
253
|
|
|
254
|
+
def forward(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
|
|
255
|
+
"""
|
|
256
|
+
Think of this as a local execute method for your task. This function will be invoked by the __call__ method
|
|
257
|
+
when not in a Flyte task execution context. See the implementation below for an example.
|
|
258
|
+
|
|
259
|
+
:param args:
|
|
260
|
+
:param kwargs:
|
|
261
|
+
:return:
|
|
262
|
+
"""
|
|
263
|
+
raise NotImplementedError
|
|
264
|
+
|
|
238
265
|
def override(
|
|
239
266
|
self,
|
|
240
267
|
*,
|
|
@@ -290,26 +317,38 @@ class AsyncFunctionTaskTemplate(TaskTemplate[P, R]):
|
|
|
290
317
|
is decorated with the task decorator.
|
|
291
318
|
"""
|
|
292
319
|
|
|
293
|
-
func:
|
|
320
|
+
func: FunctionTypes
|
|
321
|
+
|
|
322
|
+
def __post_init__(self):
|
|
323
|
+
super().__post_init__()
|
|
324
|
+
if not iscoroutinefunction(self.func):
|
|
325
|
+
self._call_as_synchronous = True
|
|
294
326
|
|
|
295
327
|
@cached_property
|
|
296
328
|
def native_interface(self) -> NativeInterface:
|
|
297
329
|
return NativeInterface.from_callable(self.func)
|
|
298
330
|
|
|
331
|
+
def forward(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
|
|
332
|
+
# In local execution, we want to just call the function. Note we're not awaiting anything here.
|
|
333
|
+
# If the function was a coroutine function, the coroutine is returned and the await that the caller has
|
|
334
|
+
# in front of the task invocation will handle the awaiting.
|
|
335
|
+
return self.func(*args, **kwargs)
|
|
336
|
+
|
|
299
337
|
async def execute(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
|
300
338
|
"""
|
|
301
339
|
This is the execute method that will be called when the task is invoked. It will call the actual function.
|
|
302
340
|
# TODO We may need to keep this as the bare func execute, and need a pre and post execute some other func.
|
|
303
341
|
"""
|
|
342
|
+
|
|
304
343
|
ctx = internal_ctx()
|
|
344
|
+
assert ctx.data.task_context is not None, "Function should have already returned if not in a task context"
|
|
305
345
|
ctx_data = await self.pre(*args, **kwargs)
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
346
|
+
tctx = ctx.data.task_context.replace(data=ctx_data)
|
|
347
|
+
with ctx.replace_task_context(tctx):
|
|
348
|
+
if iscoroutinefunction(self.func):
|
|
309
349
|
v = await self.func(*args, **kwargs)
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
v = await self.func(*args, **kwargs)
|
|
350
|
+
else:
|
|
351
|
+
v = self.func(*args, **kwargs)
|
|
313
352
|
await self.post(v)
|
|
314
353
|
return v
|
|
315
354
|
|
flyte/_task_environment.py
CHANGED
|
@@ -1,11 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import asyncio
|
|
4
3
|
import weakref
|
|
5
4
|
from dataclasses import dataclass, field, replace
|
|
6
5
|
from datetime import timedelta
|
|
7
|
-
from
|
|
8
|
-
|
|
6
|
+
from typing import (
|
|
7
|
+
TYPE_CHECKING,
|
|
8
|
+
Any,
|
|
9
|
+
Callable,
|
|
10
|
+
Dict,
|
|
11
|
+
List,
|
|
12
|
+
Literal,
|
|
13
|
+
Optional,
|
|
14
|
+
Union,
|
|
15
|
+
cast,
|
|
16
|
+
)
|
|
9
17
|
|
|
10
18
|
import rich.repr
|
|
11
19
|
|
|
@@ -23,8 +31,7 @@ from .models import NativeInterface
|
|
|
23
31
|
if TYPE_CHECKING:
|
|
24
32
|
from kubernetes.client import V1PodTemplate
|
|
25
33
|
|
|
26
|
-
|
|
27
|
-
R = TypeVar("R") # return type
|
|
34
|
+
from ._task import FunctionTypes, P, R
|
|
28
35
|
|
|
29
36
|
|
|
30
37
|
@rich.repr.auto
|
|
@@ -54,7 +61,7 @@ class TaskEnvironment(Environment):
|
|
|
54
61
|
"""
|
|
55
62
|
|
|
56
63
|
cache: Union[CacheRequest] = "auto"
|
|
57
|
-
reusable:
|
|
64
|
+
reusable: ReusePolicy | None = None
|
|
58
65
|
# TODO Shall we make this union of string or env? This way we can lookup the env by module/file:name
|
|
59
66
|
# TODO also we could add list of files that are used by this environment
|
|
60
67
|
|
|
@@ -65,32 +72,42 @@ class TaskEnvironment(Environment):
|
|
|
65
72
|
name: str,
|
|
66
73
|
image: Optional[Union[str, Image, Literal["auto"]]] = None,
|
|
67
74
|
resources: Optional[Resources] = None,
|
|
68
|
-
cache: Union[CacheRequest, None] = None,
|
|
69
75
|
env: Optional[Dict[str, str]] = None,
|
|
70
|
-
reusable: Union[ReusePolicy, None] = None,
|
|
71
76
|
secrets: Optional[SecretRequest] = None,
|
|
72
77
|
env_dep_hints: Optional[List[Environment]] = None,
|
|
78
|
+
**kwargs: Any,
|
|
73
79
|
) -> TaskEnvironment:
|
|
74
80
|
"""
|
|
75
|
-
Clone the
|
|
81
|
+
Clone the TaskEnvironment with new parameters.
|
|
82
|
+
besides the base environment parameters, you can override, kwargs like `cache`, `reusable`, etc.
|
|
83
|
+
|
|
76
84
|
"""
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
85
|
+
cache = kwargs.pop("cache", None)
|
|
86
|
+
reusable = kwargs.pop("reusable", None)
|
|
87
|
+
|
|
88
|
+
# validate unknown kwargs if needed
|
|
89
|
+
if kwargs:
|
|
90
|
+
raise TypeError(f"Unexpected keyword arguments: {list(kwargs.keys())}")
|
|
91
|
+
|
|
92
|
+
kwargs = self._get_kwargs()
|
|
93
|
+
kwargs["name"] = name
|
|
94
|
+
if image is not None:
|
|
95
|
+
kwargs["image"] = image
|
|
96
|
+
if resources is not None:
|
|
97
|
+
kwargs["resources"] = resources
|
|
98
|
+
if cache is not None:
|
|
99
|
+
kwargs["cache"] = cache
|
|
100
|
+
if env is not None:
|
|
101
|
+
kwargs["env"] = env
|
|
102
|
+
if reusable is not None:
|
|
103
|
+
kwargs["reusable"] = reusable
|
|
104
|
+
if secrets is not None:
|
|
105
|
+
kwargs["secrets"] = secrets
|
|
106
|
+
if env_dep_hints is not None:
|
|
107
|
+
kwargs["env_dep_hints"] = env_dep_hints
|
|
108
|
+
return replace(self, **kwargs)
|
|
109
|
+
|
|
110
|
+
def task(
|
|
94
111
|
self,
|
|
95
112
|
_func=None,
|
|
96
113
|
*,
|
|
@@ -104,6 +121,7 @@ class TaskEnvironment(Environment):
|
|
|
104
121
|
report: bool = False,
|
|
105
122
|
) -> Union[AsyncFunctionTaskTemplate, Callable[P, R]]:
|
|
106
123
|
"""
|
|
124
|
+
:param _func: Optional The function to decorate. If not provided, the decorator will return a callable that
|
|
107
125
|
:param name: Optional The name of the task (defaults to the function name)
|
|
108
126
|
:param cache: Optional The cache policy for the task, defaults to auto, which will cache the results of the
|
|
109
127
|
task.
|
|
@@ -119,25 +137,12 @@ class TaskEnvironment(Environment):
|
|
|
119
137
|
if pod_template is not None:
|
|
120
138
|
raise ValueError("Cannot set pod_template when environment is reusable.")
|
|
121
139
|
|
|
122
|
-
def decorator(func:
|
|
140
|
+
def decorator(func: FunctionTypes) -> AsyncFunctionTaskTemplate[P, R]:
|
|
123
141
|
task_name = name or func.__name__
|
|
124
142
|
task_name = self.name + "." + task_name
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
@wraps(func)
|
|
130
|
-
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
131
|
-
return await func(*args, **kwargs)
|
|
132
|
-
|
|
133
|
-
if not asyncio.iscoroutinefunction(func):
|
|
134
|
-
raise TypeError(
|
|
135
|
-
f"Function {func.__name__} is not a coroutine function. Use @env.task decorator for async tasks."
|
|
136
|
-
f"You can simply mark your function as async def {func.__name__} to make it a coroutine function, "
|
|
137
|
-
f"it is ok to write sync code in async functions, but not the other way around."
|
|
138
|
-
)
|
|
139
|
-
tmpl = AsyncFunctionTaskTemplate(
|
|
140
|
-
func=wrapper,
|
|
143
|
+
|
|
144
|
+
tmpl: AsyncFunctionTaskTemplate = AsyncFunctionTaskTemplate(
|
|
145
|
+
func=func,
|
|
141
146
|
name=task_name,
|
|
142
147
|
image=self.image,
|
|
143
148
|
resources=self.resources,
|
|
@@ -160,29 +165,6 @@ class TaskEnvironment(Environment):
|
|
|
160
165
|
return cast(AsyncFunctionTaskTemplate, decorator)
|
|
161
166
|
return cast(AsyncFunctionTaskTemplate, decorator(_func))
|
|
162
167
|
|
|
163
|
-
@property
|
|
164
|
-
def task(self) -> Callable:
|
|
165
|
-
"""
|
|
166
|
-
Decorator to create a new task with the environment settings.
|
|
167
|
-
The task will be executed in its own container with the specified image, resources, and environment variables,
|
|
168
|
-
unless reusePolicy is set, in which case the same container will be reused for all tasks with the same
|
|
169
|
-
environment settings.
|
|
170
|
-
|
|
171
|
-
:param name: Optional The name of the task (defaults to the function name)
|
|
172
|
-
:param cache: Optional The cache policy for the task, defaults to auto, which will cache the results of the
|
|
173
|
-
task.
|
|
174
|
-
:param retries: Optional The number of retries for the task, defaults to 0, which means no retries.
|
|
175
|
-
:param docs: Optional The documentation for the task, if not provided the function docstring will be used.
|
|
176
|
-
:param secrets: Optional The secrets that will be injected into the task at runtime.
|
|
177
|
-
:param timeout: Optional The timeout for the task.
|
|
178
|
-
:param pod_template: Optional The pod template for the task, if not provided the default pod template will be
|
|
179
|
-
used.
|
|
180
|
-
:param report: Optional Whether to generate the html report for the task, defaults to False.
|
|
181
|
-
|
|
182
|
-
:return: New Task instance or Task decorator
|
|
183
|
-
"""
|
|
184
|
-
return self._task
|
|
185
|
-
|
|
186
168
|
@property
|
|
187
169
|
def tasks(self) -> Dict[str, TaskTemplate]:
|
|
188
170
|
"""
|
flyte/_utils/coro_management.py
CHANGED
flyte/_version.py
CHANGED
|
@@ -17,5 +17,5 @@ __version__: str
|
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
|
18
18
|
version_tuple: VERSION_TUPLE
|
|
19
19
|
|
|
20
|
-
__version__ = version = '0.2.
|
|
21
|
-
__version_tuple__ = version_tuple = (0, 2, 0, '
|
|
20
|
+
__version__ = version = '0.2.0b8'
|
|
21
|
+
__version_tuple__ = version_tuple = (0, 2, 0, 'b8')
|