flyte 0.2.0b23__py3-none-any.whl → 0.2.0b25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +1 -15
- flyte/_bin/runtime.py +10 -1
- flyte/_environment.py +7 -0
- flyte/_image.py +37 -22
- flyte/_internal/controllers/remote/_action.py +29 -0
- flyte/_internal/controllers/remote/_controller.py +33 -5
- flyte/_internal/controllers/remote/_core.py +28 -14
- flyte/_internal/imagebuild/__init__.py +1 -13
- flyte/_internal/imagebuild/docker_builder.py +5 -1
- flyte/_internal/imagebuild/remote_builder.py +11 -3
- flyte/_internal/runtime/convert.py +40 -5
- flyte/_internal/runtime/entrypoints.py +58 -15
- flyte/_internal/runtime/reuse.py +121 -0
- flyte/_internal/runtime/rusty.py +165 -0
- flyte/_internal/runtime/task_serde.py +26 -35
- flyte/_pod.py +11 -1
- flyte/_resources.py +67 -2
- flyte/_reusable_environment.py +57 -2
- flyte/_run.py +35 -16
- flyte/_secret.py +30 -0
- flyte/_task.py +1 -5
- flyte/_task_environment.py +34 -2
- flyte/_task_plugins.py +45 -0
- flyte/_version.py +2 -2
- flyte/errors.py +2 -2
- flyte/extend.py +12 -0
- flyte/models.py +10 -1
- flyte/types/_type_engine.py +16 -2
- flyte-0.2.0b25.data/scripts/runtime.py +169 -0
- {flyte-0.2.0b23.dist-info → flyte-0.2.0b25.dist-info}/METADATA +1 -1
- {flyte-0.2.0b23.dist-info → flyte-0.2.0b25.dist-info}/RECORD +34 -29
- {flyte-0.2.0b23.dist-info → flyte-0.2.0b25.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b23.dist-info → flyte-0.2.0b25.dist-info}/entry_points.txt +0 -0
- {flyte-0.2.0b23.dist-info → flyte-0.2.0b25.dist-info}/top_level.txt +0 -0
flyte/_run.py
CHANGED
|
@@ -6,13 +6,9 @@ import uuid
|
|
|
6
6
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast
|
|
7
7
|
|
|
8
8
|
import flyte.errors
|
|
9
|
-
from flyte.
|
|
10
|
-
from flyte.
|
|
11
|
-
from flyte.
|
|
12
|
-
|
|
13
|
-
from ._context import contextual_run, internal_ctx
|
|
14
|
-
from ._environment import Environment
|
|
15
|
-
from ._initialize import (
|
|
9
|
+
from flyte._context import contextual_run, internal_ctx
|
|
10
|
+
from flyte._environment import Environment
|
|
11
|
+
from flyte._initialize import (
|
|
16
12
|
_get_init_config,
|
|
17
13
|
get_client,
|
|
18
14
|
get_common_config,
|
|
@@ -20,9 +16,19 @@ from ._initialize import (
|
|
|
20
16
|
requires_initialization,
|
|
21
17
|
requires_storage,
|
|
22
18
|
)
|
|
23
|
-
from ._logging import logger
|
|
24
|
-
from ._task import P, R, TaskTemplate
|
|
25
|
-
from ._tools import ipython_check
|
|
19
|
+
from flyte._logging import logger
|
|
20
|
+
from flyte._task import P, R, TaskTemplate
|
|
21
|
+
from flyte._tools import ipython_check
|
|
22
|
+
from flyte.errors import InitializationError
|
|
23
|
+
from flyte.models import (
|
|
24
|
+
ActionID,
|
|
25
|
+
Checkpoints,
|
|
26
|
+
CodeBundle,
|
|
27
|
+
RawDataPath,
|
|
28
|
+
SerializationContext,
|
|
29
|
+
TaskContext,
|
|
30
|
+
)
|
|
31
|
+
from flyte.syncify import syncify
|
|
26
32
|
|
|
27
33
|
if TYPE_CHECKING:
|
|
28
34
|
from flyte.remote import Run
|
|
@@ -132,7 +138,9 @@ class _Runner:
|
|
|
132
138
|
|
|
133
139
|
if self._interactive_mode:
|
|
134
140
|
code_bundle = await build_pkl_bundle(
|
|
135
|
-
obj,
|
|
141
|
+
obj,
|
|
142
|
+
upload_to_controlplane=not self._dry_run,
|
|
143
|
+
copy_bundle_to=self._copy_bundle_to,
|
|
136
144
|
)
|
|
137
145
|
else:
|
|
138
146
|
if self._copy_files != "none":
|
|
@@ -253,7 +261,8 @@ class _Runner:
|
|
|
253
261
|
pb2=run_definition_pb2.Run(
|
|
254
262
|
action=run_definition_pb2.Action(
|
|
255
263
|
id=run_definition_pb2.ActionIdentifier(
|
|
256
|
-
name="a0",
|
|
264
|
+
name="a0",
|
|
265
|
+
run=run_definition_pb2.RunIdentifier(name="dry-run"),
|
|
257
266
|
)
|
|
258
267
|
)
|
|
259
268
|
)
|
|
@@ -286,7 +295,7 @@ class _Runner:
|
|
|
286
295
|
if obj.parent_env is None:
|
|
287
296
|
raise ValueError("Task is not attached to an environment. Please attach the task to an environment.")
|
|
288
297
|
|
|
289
|
-
image_cache = build_images.aio(cast(Environment, obj.parent_env()))
|
|
298
|
+
image_cache = await build_images.aio(cast(Environment, obj.parent_env()))
|
|
290
299
|
|
|
291
300
|
code_bundle = None
|
|
292
301
|
if self._name is not None:
|
|
@@ -296,7 +305,9 @@ class _Runner:
|
|
|
296
305
|
if not code_bundle:
|
|
297
306
|
if self._interactive_mode:
|
|
298
307
|
code_bundle = await build_pkl_bundle(
|
|
299
|
-
obj,
|
|
308
|
+
obj,
|
|
309
|
+
upload_to_controlplane=not self._dry_run,
|
|
310
|
+
copy_bundle_to=self._copy_bundle_to,
|
|
300
311
|
)
|
|
301
312
|
else:
|
|
302
313
|
if self._copy_files != "none":
|
|
@@ -381,7 +392,8 @@ class _Runner:
|
|
|
381
392
|
tctx = TaskContext(
|
|
382
393
|
action=action,
|
|
383
394
|
checkpoints=Checkpoints(
|
|
384
|
-
prev_checkpoint_path=internal_ctx().raw_data.path,
|
|
395
|
+
prev_checkpoint_path=internal_ctx().raw_data.path,
|
|
396
|
+
checkpoint_path=internal_ctx().raw_data.path,
|
|
385
397
|
),
|
|
386
398
|
code_bundle=None,
|
|
387
399
|
output_path=self._metadata_path,
|
|
@@ -403,7 +415,10 @@ class _Runner:
|
|
|
403
415
|
|
|
404
416
|
@syncify
|
|
405
417
|
async def run(
|
|
406
|
-
self,
|
|
418
|
+
self,
|
|
419
|
+
task: TaskTemplate[P, Union[R, Run]] | LazyEntity,
|
|
420
|
+
*args: P.args,
|
|
421
|
+
**kwargs: P.kwargs,
|
|
407
422
|
) -> Union[R, Run]:
|
|
408
423
|
"""
|
|
409
424
|
Run an async `@env.task` or `TaskTemplate` instance. The existing async context will be used.
|
|
@@ -430,6 +445,10 @@ class _Runner:
|
|
|
430
445
|
|
|
431
446
|
if isinstance(task, LazyEntity) and self._mode != "remote":
|
|
432
447
|
raise ValueError("Remote task can only be run in remote mode.")
|
|
448
|
+
|
|
449
|
+
if not isinstance(task, TaskTemplate) and not isinstance(task, LazyEntity):
|
|
450
|
+
raise TypeError("On Flyte tasks can be run, not generic functions or methods.")
|
|
451
|
+
|
|
433
452
|
if self._mode == "remote":
|
|
434
453
|
return await self._run_remote(task, *args, **kwargs)
|
|
435
454
|
task = cast(TaskTemplate, task)
|
flyte/_secret.py
CHANGED
|
@@ -45,6 +45,27 @@ class Secret:
|
|
|
45
45
|
if not re.match(pattern, self.as_env_var):
|
|
46
46
|
raise ValueError(f"Invalid environment variable name: {self.as_env_var}, must match {pattern}")
|
|
47
47
|
|
|
48
|
+
def stable_hash(self) -> str:
|
|
49
|
+
"""
|
|
50
|
+
Deterministic, process-independent hash (as hex string).
|
|
51
|
+
"""
|
|
52
|
+
import hashlib
|
|
53
|
+
|
|
54
|
+
data = (
|
|
55
|
+
self.key,
|
|
56
|
+
self.group or "",
|
|
57
|
+
str(self.mount) if self.mount else "",
|
|
58
|
+
self.as_env_var or "",
|
|
59
|
+
)
|
|
60
|
+
joined = "|".join(data)
|
|
61
|
+
return hashlib.sha256(joined.encode("utf-8")).hexdigest()
|
|
62
|
+
|
|
63
|
+
def __hash__(self) -> int:
|
|
64
|
+
"""
|
|
65
|
+
Deterministic hash function for the Secret class.
|
|
66
|
+
"""
|
|
67
|
+
return int(self.stable_hash()[:16], 16)
|
|
68
|
+
|
|
48
69
|
|
|
49
70
|
SecretRequest = Union[str, Secret, List[str | Secret]]
|
|
50
71
|
|
|
@@ -59,3 +80,12 @@ def secrets_from_request(secrets: SecretRequest) -> List[Secret]:
|
|
|
59
80
|
return [secrets]
|
|
60
81
|
else:
|
|
61
82
|
return [Secret(key=s) if isinstance(s, str) else s for s in secrets]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
if __name__ == "__main__":
|
|
86
|
+
# Example usage
|
|
87
|
+
secret1 = Secret(key="MY_SECRET", mount=pathlib.Path("/path/to/secret"), as_env_var="MY_SECRET_ENV")
|
|
88
|
+
secret2 = Secret(
|
|
89
|
+
key="ANOTHER_SECRET",
|
|
90
|
+
)
|
|
91
|
+
print(hash(secret1), hash(secret2))
|
flyte/_task.py
CHANGED
|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import weakref
|
|
5
5
|
from dataclasses import dataclass, field, replace
|
|
6
|
-
from functools import cached_property
|
|
7
6
|
from inspect import iscoroutinefunction
|
|
8
7
|
from typing import (
|
|
9
8
|
TYPE_CHECKING,
|
|
@@ -361,16 +360,13 @@ class AsyncFunctionTaskTemplate(TaskTemplate[P, R]):
|
|
|
361
360
|
"""
|
|
362
361
|
|
|
363
362
|
func: FunctionTypes
|
|
363
|
+
plugin_config: Optional[Any] = None # This is used to pass plugin specific configuration
|
|
364
364
|
|
|
365
365
|
def __post_init__(self):
|
|
366
366
|
super().__post_init__()
|
|
367
367
|
if not iscoroutinefunction(self.func):
|
|
368
368
|
self._call_as_synchronous = True
|
|
369
369
|
|
|
370
|
-
@cached_property
|
|
371
|
-
def native_interface(self) -> NativeInterface:
|
|
372
|
-
return NativeInterface.from_callable(self.func)
|
|
373
|
-
|
|
374
370
|
def forward(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
|
|
375
371
|
# In local execution, we want to just call the function. Note we're not awaiting anything here.
|
|
376
372
|
# If the function was a coroutine function, the coroutine is returned and the await that the caller has
|
flyte/_task_environment.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import inspect
|
|
3
4
|
import weakref
|
|
4
5
|
from dataclasses import dataclass, field, replace
|
|
5
6
|
from datetime import timedelta
|
|
@@ -60,13 +61,19 @@ class TaskEnvironment(Environment):
|
|
|
60
61
|
:param reusable: Reuse policy for the environment, if set, a python process may be reused for multiple tasks.
|
|
61
62
|
"""
|
|
62
63
|
|
|
63
|
-
cache: Union[CacheRequest] = "
|
|
64
|
+
cache: Union[CacheRequest] = "disable"
|
|
64
65
|
reusable: ReusePolicy | None = None
|
|
66
|
+
plugin_config: Optional[Any] = None
|
|
65
67
|
# TODO Shall we make this union of string or env? This way we can lookup the env by module/file:name
|
|
66
68
|
# TODO also we could add list of files that are used by this environment
|
|
67
69
|
|
|
68
70
|
_tasks: Dict[str, TaskTemplate] = field(default_factory=dict, init=False)
|
|
69
71
|
|
|
72
|
+
def __post_init__(self) -> None:
|
|
73
|
+
super().__post_init__()
|
|
74
|
+
if self.reusable is not None and self.plugin_config is not None:
|
|
75
|
+
raise ValueError("Cannot set plugin_config when environment is reusable.")
|
|
76
|
+
|
|
70
77
|
def clone_with(
|
|
71
78
|
self,
|
|
72
79
|
name: str,
|
|
@@ -133,6 +140,8 @@ class TaskEnvironment(Environment):
|
|
|
133
140
|
used.
|
|
134
141
|
:param report: Optional Whether to generate the html report for the task, defaults to False.
|
|
135
142
|
"""
|
|
143
|
+
from ._task import P, R
|
|
144
|
+
|
|
136
145
|
if self.reusable is not None:
|
|
137
146
|
if pod_template is not None:
|
|
138
147
|
raise ValueError("Cannot set pod_template when environment is reusable.")
|
|
@@ -141,7 +150,29 @@ class TaskEnvironment(Environment):
|
|
|
141
150
|
friendly_name = name or func.__name__
|
|
142
151
|
task_name = self.name + "." + func.__name__
|
|
143
152
|
|
|
144
|
-
|
|
153
|
+
if not inspect.iscoroutinefunction(func) and self.reusable is not None:
|
|
154
|
+
if self.reusable.concurrency > 1:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
"Reusable environments with concurrency greater than 1 are only supported for async tasks. "
|
|
157
|
+
"Please use an async function or set concurrency to 1."
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if self.plugin_config is not None:
|
|
161
|
+
from flyte.extend import TaskPluginRegistry
|
|
162
|
+
|
|
163
|
+
task_template_class: type[AsyncFunctionTaskTemplate[P, R]] | None = TaskPluginRegistry.find(
|
|
164
|
+
config_type=type(self.plugin_config)
|
|
165
|
+
)
|
|
166
|
+
if task_template_class is None:
|
|
167
|
+
raise ValueError(
|
|
168
|
+
f"No task plugin found for config type {type(self.plugin_config)}. "
|
|
169
|
+
f"Please register a plugin using flyte.extend.TaskPluginRegistry.register() api."
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
task_template_class = AsyncFunctionTaskTemplate[P, R]
|
|
173
|
+
|
|
174
|
+
task_template_class = cast(type[AsyncFunctionTaskTemplate[P, R]], task_template_class)
|
|
175
|
+
tmpl = task_template_class(
|
|
145
176
|
func=func,
|
|
146
177
|
name=task_name,
|
|
147
178
|
image=self.image,
|
|
@@ -158,6 +189,7 @@ class TaskEnvironment(Environment):
|
|
|
158
189
|
interface=NativeInterface.from_callable(func),
|
|
159
190
|
report=report,
|
|
160
191
|
friendly_name=friendly_name,
|
|
192
|
+
plugin_config=self.plugin_config,
|
|
161
193
|
)
|
|
162
194
|
self._tasks[task_name] = tmpl
|
|
163
195
|
return tmpl
|
flyte/_task_plugins.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
from typing import Type
|
|
3
|
+
|
|
4
|
+
import rich.repr
|
|
5
|
+
|
|
6
|
+
if typing.TYPE_CHECKING:
|
|
7
|
+
from ._task import AsyncFunctionTaskTemplate
|
|
8
|
+
|
|
9
|
+
T = typing.TypeVar("T", bound="AsyncFunctionTaskTemplate")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class _Registry:
|
|
13
|
+
"""
|
|
14
|
+
A registry for task plugins.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self):
|
|
18
|
+
self._plugins: typing.Dict[Type, Type[T]] = {}
|
|
19
|
+
|
|
20
|
+
def register(self, config_type: Type, plugin: Type[T]):
|
|
21
|
+
"""
|
|
22
|
+
Register a plugin.
|
|
23
|
+
"""
|
|
24
|
+
self._plugins[config_type] = plugin
|
|
25
|
+
|
|
26
|
+
def find(self, config_type: Type) -> typing.Optional[Type[T]]:
|
|
27
|
+
"""
|
|
28
|
+
Get a plugin by name.
|
|
29
|
+
"""
|
|
30
|
+
return self._plugins.get(config_type)
|
|
31
|
+
|
|
32
|
+
def list_plugins(self):
|
|
33
|
+
"""
|
|
34
|
+
List all registered plugins.
|
|
35
|
+
"""
|
|
36
|
+
return list(self._plugins.keys())
|
|
37
|
+
|
|
38
|
+
def __rich_repr__(self) -> "rich.repr.Result":
|
|
39
|
+
yield from (("Name", i) for i in self.list_plugins())
|
|
40
|
+
|
|
41
|
+
def __repr__(self):
|
|
42
|
+
return f"TaskPluginRegistry(plugins={self.list_plugins()})"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
TaskPluginRegistry = _Registry()
|
flyte/_version.py
CHANGED
|
@@ -17,5 +17,5 @@ __version__: str
|
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
|
18
18
|
version_tuple: VERSION_TUPLE
|
|
19
19
|
|
|
20
|
-
__version__ = version = '0.2.
|
|
21
|
-
__version_tuple__ = version_tuple = (0, 2, 0, '
|
|
20
|
+
__version__ = version = '0.2.0b25'
|
|
21
|
+
__version_tuple__ = version_tuple = (0, 2, 0, 'b25')
|
flyte/errors.py
CHANGED
|
@@ -157,9 +157,9 @@ class RuntimeDataValidationError(RuntimeUserError):
|
|
|
157
157
|
This error is raised when the user tries to access a resource that does not exist or is invalid.
|
|
158
158
|
"""
|
|
159
159
|
|
|
160
|
-
def __init__(self, var: str, e: Exception, task_name: str = ""):
|
|
160
|
+
def __init__(self, var: str, e: Exception | str, task_name: str = ""):
|
|
161
161
|
super().__init__(
|
|
162
|
-
"DataValiationError", f"In task {task_name} variable {var}, failed to serialize/deserialize because {e}"
|
|
162
|
+
"DataValiationError", f"In task {task_name} variable {var}, failed to serialize/deserialize because of {e}"
|
|
163
163
|
)
|
|
164
164
|
|
|
165
165
|
|
flyte/extend.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from ._initialize import is_initialized
|
|
2
|
+
from ._resources import PRIMARY_CONTAINER_DEFAULT_NAME, pod_spec_from_resources
|
|
3
|
+
from ._task import AsyncFunctionTaskTemplate
|
|
4
|
+
from ._task_plugins import TaskPluginRegistry
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"PRIMARY_CONTAINER_DEFAULT_NAME",
|
|
8
|
+
"AsyncFunctionTaskTemplate",
|
|
9
|
+
"TaskPluginRegistry",
|
|
10
|
+
"is_initialized",
|
|
11
|
+
"pod_spec_from_resources",
|
|
12
|
+
]
|
flyte/models.py
CHANGED
|
@@ -256,6 +256,7 @@ class NativeInterface:
|
|
|
256
256
|
_remote_defaults: Optional[Dict[str, literals_pb2.Literal]] = field(default=None, repr=False)
|
|
257
257
|
|
|
258
258
|
has_default: ClassVar[Type[_has_default]] = _has_default # This can be used to indicate if a specific input
|
|
259
|
+
|
|
259
260
|
# has a default value or not, in the case when the default value is not known. An example would be remote tasks.
|
|
260
261
|
|
|
261
262
|
def has_outputs(self) -> bool:
|
|
@@ -298,7 +299,15 @@ class NativeInterface:
|
|
|
298
299
|
sig = inspect.signature(func)
|
|
299
300
|
|
|
300
301
|
# Extract parameter details (name, type, default value)
|
|
301
|
-
param_info = {
|
|
302
|
+
param_info = {}
|
|
303
|
+
for name, param in sig.parameters.items():
|
|
304
|
+
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
|
|
305
|
+
raise ValueError(f"Function {func.__name__} cannot have variable positional or keyword arguments.")
|
|
306
|
+
if param.annotation is inspect.Parameter.empty:
|
|
307
|
+
logger.warning(
|
|
308
|
+
f"Function {func.__name__} has parameter {name} without type annotation. Data will be pickled."
|
|
309
|
+
)
|
|
310
|
+
param_info[name] = (param.annotation, param.default)
|
|
302
311
|
|
|
303
312
|
# Get return type
|
|
304
313
|
outputs = extract_return_annotation(sig.return_annotation)
|
flyte/types/_type_engine.py
CHANGED
|
@@ -1315,7 +1315,8 @@ class TypeEngine(typing.Generic[T]):
|
|
|
1315
1315
|
try:
|
|
1316
1316
|
return transformer.guess_python_type(flyte_type)
|
|
1317
1317
|
except ValueError:
|
|
1318
|
-
|
|
1318
|
+
# Skipping transformer
|
|
1319
|
+
continue
|
|
1319
1320
|
|
|
1320
1321
|
# Because the dataclass transformer is handled explicitly in the get_transformer code, we have to handle it
|
|
1321
1322
|
# separately here too.
|
|
@@ -1438,6 +1439,19 @@ def _type_essence(x: types_pb2.LiteralType) -> types_pb2.LiteralType:
|
|
|
1438
1439
|
|
|
1439
1440
|
|
|
1440
1441
|
def _are_types_castable(upstream: types_pb2.LiteralType, downstream: types_pb2.LiteralType) -> bool:
|
|
1442
|
+
if upstream.union_type is not None:
|
|
1443
|
+
# for each upstream variant, there must be a compatible type downstream
|
|
1444
|
+
for v in upstream.union_type.variants:
|
|
1445
|
+
if not _are_types_castable(v, downstream):
|
|
1446
|
+
return False
|
|
1447
|
+
return True
|
|
1448
|
+
|
|
1449
|
+
if downstream.union_type is not None:
|
|
1450
|
+
# there must be a compatible downstream type
|
|
1451
|
+
for v in downstream.union_type.variants:
|
|
1452
|
+
if _are_types_castable(upstream, v):
|
|
1453
|
+
return True
|
|
1454
|
+
|
|
1441
1455
|
if upstream.HasField("collection_type"):
|
|
1442
1456
|
if not downstream.HasField("collection_type"):
|
|
1443
1457
|
return False
|
|
@@ -1483,7 +1497,7 @@ def _are_types_castable(upstream: types_pb2.LiteralType, downstream: types_pb2.L
|
|
|
1483
1497
|
|
|
1484
1498
|
return True
|
|
1485
1499
|
|
|
1486
|
-
if upstream.HasField("union_type"):
|
|
1500
|
+
if upstream.HasField("union_type") and upstream.union_type is not None:
|
|
1487
1501
|
# for each upstream variant, there must be a compatible type downstream
|
|
1488
1502
|
for v in upstream.union_type.variants:
|
|
1489
1503
|
if not _are_types_castable(v, downstream):
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Flyte runtime module, this is the entrypoint script for the Flyte runtime.
|
|
3
|
+
|
|
4
|
+
Caution: Startup time for this module is very important, as it is the entrypoint for the Flyte runtime.
|
|
5
|
+
Refrain from importing any modules here. If you need to import any modules, do it inside the main function.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
from typing import Any, List
|
|
12
|
+
|
|
13
|
+
import click
|
|
14
|
+
|
|
15
|
+
# Todo: work with pvditt to make these the names
|
|
16
|
+
# ACTION_NAME = "_U_ACTION_NAME"
|
|
17
|
+
# RUN_NAME = "_U_RUN_NAME"
|
|
18
|
+
# PROJECT_NAME = "_U_PROJECT_NAME"
|
|
19
|
+
# DOMAIN_NAME = "_U_DOMAIN_NAME"
|
|
20
|
+
# ORG_NAME = "_U_ORG_NAME"
|
|
21
|
+
|
|
22
|
+
ACTION_NAME = "ACTION_NAME"
|
|
23
|
+
RUN_NAME = "RUN_NAME"
|
|
24
|
+
PROJECT_NAME = "FLYTE_INTERNAL_TASK_PROJECT"
|
|
25
|
+
DOMAIN_NAME = "FLYTE_INTERNAL_TASK_DOMAIN"
|
|
26
|
+
ORG_NAME = "_U_ORG_NAME"
|
|
27
|
+
ENDPOINT_OVERRIDE = "_U_EP_OVERRIDE"
|
|
28
|
+
RUN_OUTPUT_BASE_DIR = "_U_RUN_BASE"
|
|
29
|
+
ENABLE_REF_TASKS = "_REF_TASKS" # This is a temporary flag to enable reference tasks in the runtime.
|
|
30
|
+
|
|
31
|
+
# TODO: Remove this after proper auth is implemented
|
|
32
|
+
_UNION_EAGER_API_KEY_ENV_VAR = "_UNION_EAGER_API_KEY"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@click.group()
|
|
36
|
+
def _pass_through():
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@_pass_through.command("a0")
|
|
41
|
+
@click.option("--inputs", "-i", required=True)
|
|
42
|
+
@click.option("--outputs-path", "-o", required=True)
|
|
43
|
+
@click.option("--version", "-v", required=True)
|
|
44
|
+
@click.option("--run-base-dir", envvar=RUN_OUTPUT_BASE_DIR, required=True)
|
|
45
|
+
@click.option("--raw-data-path", "-r", required=False)
|
|
46
|
+
@click.option("--checkpoint-path", "-c", required=False)
|
|
47
|
+
@click.option("--prev-checkpoint", "-p", required=False)
|
|
48
|
+
@click.option("--name", envvar=ACTION_NAME, required=False)
|
|
49
|
+
@click.option("--run-name", envvar=RUN_NAME, required=False)
|
|
50
|
+
@click.option("--project", envvar=PROJECT_NAME, required=False)
|
|
51
|
+
@click.option("--domain", envvar=DOMAIN_NAME, required=False)
|
|
52
|
+
@click.option("--org", envvar=ORG_NAME, required=False)
|
|
53
|
+
@click.option("--image-cache", required=False)
|
|
54
|
+
@click.option("--tgz", required=False)
|
|
55
|
+
@click.option("--pkl", required=False)
|
|
56
|
+
@click.option("--dest", required=False)
|
|
57
|
+
@click.option("--resolver", required=False)
|
|
58
|
+
@click.argument(
|
|
59
|
+
"resolver-args",
|
|
60
|
+
type=click.UNPROCESSED,
|
|
61
|
+
nargs=-1,
|
|
62
|
+
)
|
|
63
|
+
def main(
|
|
64
|
+
run_name: str,
|
|
65
|
+
name: str,
|
|
66
|
+
project: str,
|
|
67
|
+
domain: str,
|
|
68
|
+
org: str,
|
|
69
|
+
image_cache: str,
|
|
70
|
+
version: str,
|
|
71
|
+
inputs: str,
|
|
72
|
+
run_base_dir: str,
|
|
73
|
+
outputs_path: str,
|
|
74
|
+
raw_data_path: str,
|
|
75
|
+
checkpoint_path: str,
|
|
76
|
+
prev_checkpoint: str,
|
|
77
|
+
tgz: str,
|
|
78
|
+
pkl: str,
|
|
79
|
+
dest: str,
|
|
80
|
+
resolver: str,
|
|
81
|
+
resolver_args: List[str],
|
|
82
|
+
):
|
|
83
|
+
sys.path.insert(0, ".")
|
|
84
|
+
|
|
85
|
+
import flyte
|
|
86
|
+
import flyte._utils as utils
|
|
87
|
+
from flyte._initialize import init
|
|
88
|
+
from flyte._internal.controllers import create_controller
|
|
89
|
+
from flyte._internal.imagebuild.image_builder import ImageCache
|
|
90
|
+
from flyte._internal.runtime.entrypoints import load_and_run_task
|
|
91
|
+
from flyte._logging import logger
|
|
92
|
+
from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
|
|
93
|
+
|
|
94
|
+
logger.info(f"Initializing flyte runtime - version {flyte.__version__}")
|
|
95
|
+
|
|
96
|
+
assert org, "Org is required for now"
|
|
97
|
+
assert project, "Project is required"
|
|
98
|
+
assert domain, "Domain is required"
|
|
99
|
+
assert run_name, f"Run name is required {run_name}"
|
|
100
|
+
assert name, f"Action name is required {name}"
|
|
101
|
+
|
|
102
|
+
if run_name.startswith("{{"):
|
|
103
|
+
run_name = os.getenv("RUN_NAME", "")
|
|
104
|
+
if name.startswith("{{"):
|
|
105
|
+
name = os.getenv("ACTION_NAME", "")
|
|
106
|
+
|
|
107
|
+
# Figure out how to connect
|
|
108
|
+
# This detection of api key is a hack for now.
|
|
109
|
+
controller_kwargs: dict[str, Any] = {"insecure": False}
|
|
110
|
+
if api_key := os.getenv(_UNION_EAGER_API_KEY_ENV_VAR):
|
|
111
|
+
logger.info("Using api key from environment")
|
|
112
|
+
controller_kwargs["api_key"] = api_key
|
|
113
|
+
else:
|
|
114
|
+
ep = os.environ.get(ENDPOINT_OVERRIDE, "host.docker.internal:8090")
|
|
115
|
+
controller_kwargs["endpoint"] = ep
|
|
116
|
+
if "localhost" in ep or "docker" in ep:
|
|
117
|
+
controller_kwargs["insecure"] = True
|
|
118
|
+
logger.debug(f"Using controller endpoint: {ep} with kwargs: {controller_kwargs}")
|
|
119
|
+
|
|
120
|
+
bundle = CodeBundle(tgz=tgz, pkl=pkl, destination=dest, computed_version=version)
|
|
121
|
+
enable_ref_tasks = os.getenv(ENABLE_REF_TASKS, "false").lower() in ("true", "1", "yes")
|
|
122
|
+
# We init regular client here so that reference tasks can work
|
|
123
|
+
# Current reference tasks will not work with remote controller, because we create 2 different
|
|
124
|
+
# channels on different threads and this is not supported by grpcio or the auth system. It ends up leading
|
|
125
|
+
# File "src/python/grpcio/grpc/_cython/_cygrpc/aio/completion_queue.pyx.pxi", line 147,
|
|
126
|
+
# in grpc._cython.cygrpc.PollerCompletionQueue._handle_events
|
|
127
|
+
# BlockingIOError: [Errno 11] Resource temporarily unavailable
|
|
128
|
+
# TODO solution is to use a single channel for both controller and reference tasks, but this requires a refactor
|
|
129
|
+
if enable_ref_tasks:
|
|
130
|
+
logger.warning(
|
|
131
|
+
"Reference tasks are enabled. This will initialize client and you will see a BlockIOError. "
|
|
132
|
+
"This is harmless, but a nuisance. We are working on a fix."
|
|
133
|
+
)
|
|
134
|
+
init(org=org, project=project, domain=domain, **controller_kwargs)
|
|
135
|
+
else:
|
|
136
|
+
init()
|
|
137
|
+
# Controller is created with the same kwargs as init, so that it can be used to run tasks
|
|
138
|
+
controller = create_controller(ct="remote", **controller_kwargs)
|
|
139
|
+
|
|
140
|
+
ic = ImageCache.from_transport(image_cache) if image_cache else None
|
|
141
|
+
|
|
142
|
+
# Create a coroutine to load the task and run it
|
|
143
|
+
task_coroutine = load_and_run_task(
|
|
144
|
+
resolver=resolver,
|
|
145
|
+
resolver_args=resolver_args,
|
|
146
|
+
action=ActionID(name=name, run_name=run_name, project=project, domain=domain, org=org),
|
|
147
|
+
raw_data_path=RawDataPath(path=raw_data_path),
|
|
148
|
+
checkpoints=Checkpoints(checkpoint_path, prev_checkpoint),
|
|
149
|
+
code_bundle=bundle,
|
|
150
|
+
input_path=inputs,
|
|
151
|
+
output_path=outputs_path,
|
|
152
|
+
run_base_dir=run_base_dir,
|
|
153
|
+
version=version,
|
|
154
|
+
controller=controller,
|
|
155
|
+
image_cache=ic,
|
|
156
|
+
)
|
|
157
|
+
# Create a coroutine to watch for errors
|
|
158
|
+
controller_failure = controller.watch_for_errors()
|
|
159
|
+
|
|
160
|
+
# Run both coroutines concurrently and wait for first to finish and cancel the other
|
|
161
|
+
async def _run_and_stop():
|
|
162
|
+
await utils.run_coros(controller_failure, task_coroutine)
|
|
163
|
+
await controller.stop()
|
|
164
|
+
|
|
165
|
+
asyncio.run(_run_and_stop())
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
if __name__ == "__main__":
|
|
169
|
+
_pass_through()
|