flyte 2.0.0b22__py3-none-any.whl → 2.0.0b30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyte/__init__.py +18 -2
- flyte/_bin/runtime.py +43 -5
- flyte/_cache/cache.py +4 -2
- flyte/_cache/local_cache.py +216 -0
- flyte/_code_bundle/_ignore.py +1 -1
- flyte/_code_bundle/_packaging.py +4 -4
- flyte/_code_bundle/_utils.py +14 -8
- flyte/_code_bundle/bundle.py +13 -5
- flyte/_constants.py +1 -0
- flyte/_context.py +4 -1
- flyte/_custom_context.py +73 -0
- flyte/_debug/constants.py +0 -1
- flyte/_debug/vscode.py +6 -1
- flyte/_deploy.py +223 -59
- flyte/_environment.py +5 -0
- flyte/_excepthook.py +1 -1
- flyte/_image.py +144 -82
- flyte/_initialize.py +95 -12
- flyte/_interface.py +2 -0
- flyte/_internal/controllers/_local_controller.py +65 -24
- flyte/_internal/controllers/_trace.py +1 -1
- flyte/_internal/controllers/remote/_action.py +13 -11
- flyte/_internal/controllers/remote/_client.py +1 -1
- flyte/_internal/controllers/remote/_controller.py +9 -4
- flyte/_internal/controllers/remote/_core.py +16 -16
- flyte/_internal/controllers/remote/_informer.py +4 -4
- flyte/_internal/controllers/remote/_service_protocol.py +7 -7
- flyte/_internal/imagebuild/docker_builder.py +139 -84
- flyte/_internal/imagebuild/image_builder.py +7 -13
- flyte/_internal/imagebuild/remote_builder.py +65 -13
- flyte/_internal/imagebuild/utils.py +51 -3
- flyte/_internal/resolvers/_task_module.py +5 -38
- flyte/_internal/resolvers/default.py +2 -2
- flyte/_internal/runtime/convert.py +42 -20
- flyte/_internal/runtime/entrypoints.py +24 -1
- flyte/_internal/runtime/io.py +21 -8
- flyte/_internal/runtime/resources_serde.py +20 -6
- flyte/_internal/runtime/reuse.py +1 -1
- flyte/_internal/runtime/rusty.py +20 -5
- flyte/_internal/runtime/task_serde.py +33 -27
- flyte/_internal/runtime/taskrunner.py +10 -1
- flyte/_internal/runtime/trigger_serde.py +160 -0
- flyte/_internal/runtime/types_serde.py +1 -1
- flyte/_keyring/file.py +39 -9
- flyte/_logging.py +79 -12
- flyte/_map.py +31 -12
- flyte/_module.py +70 -0
- flyte/_pod.py +2 -2
- flyte/_resources.py +213 -31
- flyte/_run.py +107 -41
- flyte/_task.py +66 -10
- flyte/_task_environment.py +96 -24
- flyte/_task_plugins.py +4 -2
- flyte/_trigger.py +1000 -0
- flyte/_utils/__init__.py +2 -1
- flyte/_utils/asyn.py +3 -1
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/module_loader.py +17 -2
- flyte/_version.py +3 -3
- flyte/cli/_abort.py +3 -3
- flyte/cli/_build.py +1 -3
- flyte/cli/_common.py +78 -7
- flyte/cli/_create.py +178 -3
- flyte/cli/_delete.py +23 -1
- flyte/cli/_deploy.py +49 -11
- flyte/cli/_get.py +79 -34
- flyte/cli/_params.py +8 -6
- flyte/cli/_plugins.py +209 -0
- flyte/cli/_run.py +127 -11
- flyte/cli/_serve.py +64 -0
- flyte/cli/_update.py +37 -0
- flyte/cli/_user.py +17 -0
- flyte/cli/main.py +30 -4
- flyte/config/_config.py +2 -0
- flyte/config/_internal.py +1 -0
- flyte/config/_reader.py +3 -3
- flyte/connectors/__init__.py +11 -0
- flyte/connectors/_connector.py +270 -0
- flyte/connectors/_server.py +197 -0
- flyte/connectors/utils.py +135 -0
- flyte/errors.py +10 -1
- flyte/extend.py +8 -1
- flyte/extras/_container.py +6 -1
- flyte/git/_config.py +11 -9
- flyte/io/__init__.py +2 -0
- flyte/io/_dataframe/__init__.py +2 -0
- flyte/io/_dataframe/basic_dfs.py +1 -1
- flyte/io/_dataframe/dataframe.py +12 -8
- flyte/io/_dir.py +551 -120
- flyte/io/_file.py +538 -141
- flyte/models.py +57 -12
- flyte/remote/__init__.py +6 -1
- flyte/remote/_action.py +18 -16
- flyte/remote/_client/_protocols.py +39 -4
- flyte/remote/_client/auth/_channel.py +10 -6
- flyte/remote/_client/controlplane.py +17 -5
- flyte/remote/_console.py +3 -2
- flyte/remote/_data.py +4 -3
- flyte/remote/_logs.py +3 -3
- flyte/remote/_run.py +47 -7
- flyte/remote/_secret.py +26 -17
- flyte/remote/_task.py +21 -9
- flyte/remote/_trigger.py +306 -0
- flyte/remote/_user.py +33 -0
- flyte/storage/__init__.py +6 -1
- flyte/storage/_parallel_reader.py +274 -0
- flyte/storage/_storage.py +185 -103
- flyte/types/__init__.py +16 -0
- flyte/types/_interface.py +2 -2
- flyte/types/_pickle.py +17 -4
- flyte/types/_string_literals.py +8 -9
- flyte/types/_type_engine.py +26 -19
- flyte/types/_utils.py +1 -1
- {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/runtime.py +43 -5
- {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/METADATA +8 -1
- flyte-2.0.0b30.dist-info/RECORD +192 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +0 -66
- flyte/_protos/common/authorization_pb2.pyi +0 -108
- flyte/_protos/common/authorization_pb2_grpc.py +0 -4
- flyte/_protos/common/identifier_pb2.py +0 -99
- flyte/_protos/common/identifier_pb2.pyi +0 -120
- flyte/_protos/common/identifier_pb2_grpc.py +0 -4
- flyte/_protos/common/identity_pb2.py +0 -48
- flyte/_protos/common/identity_pb2.pyi +0 -72
- flyte/_protos/common/identity_pb2_grpc.py +0 -4
- flyte/_protos/common/list_pb2.py +0 -36
- flyte/_protos/common/list_pb2.pyi +0 -71
- flyte/_protos/common/list_pb2_grpc.py +0 -4
- flyte/_protos/common/policy_pb2.py +0 -37
- flyte/_protos/common/policy_pb2.pyi +0 -27
- flyte/_protos/common/policy_pb2_grpc.py +0 -4
- flyte/_protos/common/role_pb2.py +0 -37
- flyte/_protos/common/role_pb2.pyi +0 -53
- flyte/_protos/common/role_pb2_grpc.py +0 -4
- flyte/_protos/common/runtime_version_pb2.py +0 -28
- flyte/_protos/common/runtime_version_pb2.pyi +0 -24
- flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/definition_pb2.py +0 -60
- flyte/_protos/imagebuilder/definition_pb2.pyi +0 -153
- flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/payload_pb2.py +0 -32
- flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
- flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/service_pb2.py +0 -29
- flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
- flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
- flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
- flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/definition_pb2.py +0 -49
- flyte/_protos/secret/definition_pb2.pyi +0 -93
- flyte/_protos/secret/definition_pb2_grpc.py +0 -4
- flyte/_protos/secret/payload_pb2.py +0 -62
- flyte/_protos/secret/payload_pb2.pyi +0 -94
- flyte/_protos/secret/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/secret_pb2.py +0 -38
- flyte/_protos/secret/secret_pb2.pyi +0 -6
- flyte/_protos/secret/secret_pb2_grpc.py +0 -198
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
- flyte/_protos/validate/validate/validate_pb2.py +0 -76
- flyte/_protos/workflow/common_pb2.py +0 -27
- flyte/_protos/workflow/common_pb2.pyi +0 -14
- flyte/_protos/workflow/common_pb2_grpc.py +0 -4
- flyte/_protos/workflow/environment_pb2.py +0 -29
- flyte/_protos/workflow/environment_pb2.pyi +0 -12
- flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
- flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
- flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- flyte/_protos/workflow/queue_service_pb2.py +0 -111
- flyte/_protos/workflow/queue_service_pb2.pyi +0 -168
- flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
- flyte/_protos/workflow/run_definition_pb2.py +0 -123
- flyte/_protos/workflow/run_definition_pb2.pyi +0 -352
- flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
- flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- flyte/_protos/workflow/run_service_pb2.py +0 -137
- flyte/_protos/workflow/run_service_pb2.pyi +0 -185
- flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
- flyte/_protos/workflow/state_service_pb2.py +0 -67
- flyte/_protos/workflow/state_service_pb2.pyi +0 -76
- flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
- flyte/_protos/workflow/task_definition_pb2.py +0 -82
- flyte/_protos/workflow/task_definition_pb2.pyi +0 -88
- flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/task_service_pb2.py +0 -60
- flyte/_protos/workflow/task_service_pb2.pyi +0 -59
- flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
- flyte-2.0.0b22.dist-info/RECORD +0 -250
- {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/debug.py +0 -0
- {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
- {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +0 -0
- {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
- {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,11 @@
|
|
|
1
|
-
import inspect
|
|
2
|
-
import os
|
|
3
1
|
import pathlib
|
|
4
|
-
import sys
|
|
5
2
|
from typing import Tuple
|
|
6
3
|
|
|
4
|
+
from flyte._module import extract_obj_module
|
|
7
5
|
from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
|
|
8
6
|
|
|
9
7
|
|
|
10
|
-
def extract_task_module(task: TaskTemplate, /, source_dir: pathlib.Path
|
|
8
|
+
def extract_task_module(task: TaskTemplate, /, source_dir: pathlib.Path) -> Tuple[str, str]:
|
|
11
9
|
"""
|
|
12
10
|
Extract the task module from the task template.
|
|
13
11
|
|
|
@@ -15,40 +13,9 @@ def extract_task_module(task: TaskTemplate, /, source_dir: pathlib.Path | None =
|
|
|
15
13
|
:param source_dir: The source directory to use for relative paths.
|
|
16
14
|
:return: A tuple containing the entity name, module
|
|
17
15
|
"""
|
|
18
|
-
entity_name = task.name
|
|
19
16
|
if isinstance(task, AsyncFunctionTaskTemplate):
|
|
20
|
-
entity_module = inspect.getmodule(task.func)
|
|
21
|
-
if entity_module is None:
|
|
22
|
-
raise ValueError(f"Task {entity_name} has no module.")
|
|
23
|
-
|
|
24
|
-
fp = entity_module.__file__
|
|
25
|
-
if fp is None:
|
|
26
|
-
raise ValueError(f"Task {entity_name} has no module.")
|
|
27
|
-
|
|
28
|
-
file_path = pathlib.Path(fp)
|
|
29
|
-
# Get the relative path to the current directory
|
|
30
|
-
# Will raise ValueError if the file is not in the source directory
|
|
31
|
-
relative_path = file_path.relative_to(str(source_dir))
|
|
32
|
-
|
|
33
|
-
if relative_path == pathlib.Path("."):
|
|
34
|
-
entity_module_name = entity_module.__name__
|
|
35
|
-
else:
|
|
36
|
-
# Replace file separators with dots and remove the '.py' extension
|
|
37
|
-
dotted_path = os.path.splitext(str(relative_path))[0].replace(os.sep, ".")
|
|
38
|
-
entity_module_name = dotted_path
|
|
39
|
-
|
|
40
17
|
entity_name = task.func.__name__
|
|
18
|
+
entity_module_name = extract_obj_module(task.func, source_dir)
|
|
19
|
+
return entity_name, entity_module_name
|
|
41
20
|
else:
|
|
42
|
-
raise NotImplementedError(f"Task module {
|
|
43
|
-
|
|
44
|
-
if entity_module_name == "__main__":
|
|
45
|
-
"""
|
|
46
|
-
This case is for the case in which the task is run from the main module.
|
|
47
|
-
"""
|
|
48
|
-
fp = sys.modules["__main__"].__file__
|
|
49
|
-
if fp is None:
|
|
50
|
-
raise ValueError(f"Task {entity_name} has no module.")
|
|
51
|
-
main_path = pathlib.Path(fp)
|
|
52
|
-
entity_module_name = main_path.stem
|
|
53
|
-
|
|
54
|
-
return entity_name, entity_module_name
|
|
21
|
+
raise NotImplementedError(f"Task module {task.name} not implemented.")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import importlib
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import List
|
|
3
|
+
from typing import List
|
|
4
4
|
|
|
5
5
|
from flyte._internal.resolvers._task_module import extract_task_module
|
|
6
6
|
from flyte._internal.resolvers.common import Resolver
|
|
@@ -23,6 +23,6 @@ class DefaultTaskResolver(Resolver):
|
|
|
23
23
|
task_def = getattr(task_module, task_name)
|
|
24
24
|
return task_def
|
|
25
25
|
|
|
26
|
-
def loader_args(self, task: TaskTemplate, root_dir:
|
|
26
|
+
def loader_args(self, task: TaskTemplate, root_dir: Path) -> List[str]: # type:ignore
|
|
27
27
|
t, m = extract_task_module(task, root_dir)
|
|
28
28
|
return ["mod", m, "instance", t]
|
|
@@ -8,27 +8,33 @@ from dataclasses import dataclass
|
|
|
8
8
|
from types import NoneType
|
|
9
9
|
from typing import Any, Dict, List, Tuple, Union, get_args
|
|
10
10
|
|
|
11
|
-
from
|
|
11
|
+
from flyteidl2.core import execution_pb2, interface_pb2, literals_pb2
|
|
12
|
+
from flyteidl2.task import common_pb2, task_definition_pb2
|
|
12
13
|
|
|
13
14
|
import flyte.errors
|
|
14
15
|
import flyte.storage as storage
|
|
15
|
-
from flyte.
|
|
16
|
+
from flyte._context import ctx
|
|
16
17
|
from flyte.models import ActionID, NativeInterface, TaskContext
|
|
17
18
|
from flyte.types import TypeEngine, TypeTransformerFailedError
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
@dataclass(frozen=True)
|
|
21
22
|
class Inputs:
|
|
22
|
-
proto_inputs:
|
|
23
|
+
proto_inputs: common_pb2.Inputs
|
|
23
24
|
|
|
24
25
|
@classmethod
|
|
25
26
|
def empty(cls) -> "Inputs":
|
|
26
|
-
return cls(proto_inputs=
|
|
27
|
+
return cls(proto_inputs=common_pb2.Inputs())
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def context(self) -> Dict[str, str]:
|
|
31
|
+
"""Get the context as a dictionary."""
|
|
32
|
+
return {kv.key: kv.value for kv in self.proto_inputs.context}
|
|
27
33
|
|
|
28
34
|
|
|
29
35
|
@dataclass(frozen=True)
|
|
30
36
|
class Outputs:
|
|
31
|
-
proto_outputs:
|
|
37
|
+
proto_outputs: common_pb2.Outputs
|
|
32
38
|
|
|
33
39
|
|
|
34
40
|
@dataclass
|
|
@@ -102,15 +108,30 @@ def is_optional_type(tp) -> bool:
|
|
|
102
108
|
return NoneType in get_args(tp) # fastest check
|
|
103
109
|
|
|
104
110
|
|
|
105
|
-
async def convert_from_native_to_inputs(
|
|
111
|
+
async def convert_from_native_to_inputs(
|
|
112
|
+
interface: NativeInterface, *args, custom_context: Dict[str, str] | None = None, **kwargs
|
|
113
|
+
) -> Inputs:
|
|
106
114
|
kwargs = interface.convert_to_kwargs(*args, **kwargs)
|
|
107
115
|
|
|
108
116
|
missing = [key for key in interface.required_inputs() if key not in kwargs]
|
|
109
117
|
if missing:
|
|
110
118
|
raise ValueError(f"Missing required inputs: {', '.join(missing)}")
|
|
111
119
|
|
|
120
|
+
# Read custom_context from TaskContext if available (inside task execution)
|
|
121
|
+
# Otherwise use the passed parameter (for remote run initiation)
|
|
122
|
+
context_kvs = None
|
|
123
|
+
tctx = ctx()
|
|
124
|
+
if tctx and tctx.custom_context:
|
|
125
|
+
# Inside a task - read from TaskContext
|
|
126
|
+
context_to_use = tctx.custom_context
|
|
127
|
+
context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in context_to_use.items()]
|
|
128
|
+
elif custom_context:
|
|
129
|
+
# Remote run initiation
|
|
130
|
+
context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in custom_context.items()]
|
|
131
|
+
|
|
112
132
|
if len(interface.inputs) == 0:
|
|
113
|
-
|
|
133
|
+
# Handle context even for empty inputs
|
|
134
|
+
return Inputs(proto_inputs=common_pb2.Inputs(context=context_kvs))
|
|
114
135
|
|
|
115
136
|
# fill in defaults if missing
|
|
116
137
|
type_hints: Dict[str, type] = {}
|
|
@@ -122,13 +143,14 @@ async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwa
|
|
|
122
143
|
(default_value is not None and default_value is not inspect.Signature.empty)
|
|
123
144
|
or (default_value is None and is_optional_type(input_type))
|
|
124
145
|
or input_type is None
|
|
146
|
+
or input_type is type(None)
|
|
125
147
|
):
|
|
126
148
|
if default_value == NativeInterface.has_default:
|
|
127
149
|
if interface._remote_defaults is None or input_name not in interface._remote_defaults:
|
|
128
150
|
raise ValueError(f"Input '{input_name}' has a default value but it is not set in the interface.")
|
|
129
151
|
already_converted_kwargs[input_name] = interface._remote_defaults[input_name]
|
|
130
|
-
elif input_type is None:
|
|
131
|
-
# If the type is None, we assume it's a placeholder for no type
|
|
152
|
+
elif input_type is None or input_type is type(None):
|
|
153
|
+
# If the type is 'None' or 'class<None>', we assume it's a placeholder for no type
|
|
132
154
|
kwargs[input_name] = None
|
|
133
155
|
type_hints[input_name] = NoneType
|
|
134
156
|
else:
|
|
@@ -144,12 +166,12 @@ async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwa
|
|
|
144
166
|
for k, v in already_converted_kwargs.items():
|
|
145
167
|
copied_literals[k] = v
|
|
146
168
|
literal_map = literals_pb2.LiteralMap(literals=copied_literals)
|
|
169
|
+
|
|
147
170
|
# Make sure we the interface, not literal_map or kwargs, because those may have a different order
|
|
148
171
|
return Inputs(
|
|
149
|
-
proto_inputs=
|
|
150
|
-
literals=[
|
|
151
|
-
|
|
152
|
-
]
|
|
172
|
+
proto_inputs=common_pb2.Inputs(
|
|
173
|
+
literals=[common_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()],
|
|
174
|
+
context=context_kvs,
|
|
153
175
|
)
|
|
154
176
|
)
|
|
155
177
|
|
|
@@ -191,11 +213,11 @@ async def convert_from_native_to_outputs(o: Any, interface: NativeInterface, tas
|
|
|
191
213
|
for (output_name, python_type), v in zip(interface.outputs.items(), o):
|
|
192
214
|
try:
|
|
193
215
|
lit = await TypeEngine.to_literal(v, python_type, TypeEngine.to_literal_type(python_type))
|
|
194
|
-
named.append(
|
|
216
|
+
named.append(common_pb2.NamedLiteral(name=output_name, value=lit))
|
|
195
217
|
except TypeTransformerFailedError as e:
|
|
196
218
|
raise flyte.errors.RuntimeDataValidationError(output_name, e, task_name)
|
|
197
219
|
|
|
198
|
-
return Outputs(proto_outputs=
|
|
220
|
+
return Outputs(proto_outputs=common_pb2.Outputs(literals=named))
|
|
199
221
|
|
|
200
222
|
|
|
201
223
|
async def convert_outputs_to_native(interface: NativeInterface, outputs: Outputs) -> Union[Any, Tuple[Any, ...]]:
|
|
@@ -222,7 +244,7 @@ def convert_error_to_native(err: execution_pb2.ExecutionError | Exception | Erro
|
|
|
222
244
|
if isinstance(err, Error):
|
|
223
245
|
err = err.err
|
|
224
246
|
|
|
225
|
-
user_code,
|
|
247
|
+
user_code, _server_code = _clean_error_code(err.code)
|
|
226
248
|
match err.kind:
|
|
227
249
|
case execution_pb2.ExecutionError.UNKNOWN:
|
|
228
250
|
return flyte.errors.RuntimeUnknownError(code=user_code, message=err.message, worker=err.worker)
|
|
@@ -351,7 +373,7 @@ def generate_inputs_repr_for_literal(literal: literals_pb2.Literal) -> bytes:
|
|
|
351
373
|
return literal.SerializeToString(deterministic=True)
|
|
352
374
|
|
|
353
375
|
|
|
354
|
-
def generate_inputs_hash_for_named_literals(inputs: list[
|
|
376
|
+
def generate_inputs_hash_for_named_literals(inputs: list[common_pb2.NamedLiteral]) -> str:
|
|
355
377
|
"""
|
|
356
378
|
Generate a hash for the inputs using the new literal representation approach that respects
|
|
357
379
|
hash values already present in literals. This is used to uniquely identify the inputs for a task
|
|
@@ -375,7 +397,7 @@ def generate_inputs_hash_for_named_literals(inputs: list[run_definition_pb2.Name
|
|
|
375
397
|
return hash_data(combined_bytes)
|
|
376
398
|
|
|
377
399
|
|
|
378
|
-
def generate_inputs_hash_from_proto(inputs:
|
|
400
|
+
def generate_inputs_hash_from_proto(inputs: common_pb2.Inputs) -> str:
|
|
379
401
|
"""
|
|
380
402
|
Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
|
|
381
403
|
:param inputs: The inputs to hash.
|
|
@@ -404,7 +426,7 @@ def generate_cache_key_hash(
|
|
|
404
426
|
task_interface: interface_pb2.TypedInterface,
|
|
405
427
|
cache_version: str,
|
|
406
428
|
ignored_input_vars: List[str],
|
|
407
|
-
proto_inputs:
|
|
429
|
+
proto_inputs: common_pb2.Inputs,
|
|
408
430
|
) -> str:
|
|
409
431
|
"""
|
|
410
432
|
Generate a cache key hash based on the inputs hash, task name, task interface, and cache version.
|
|
@@ -420,7 +442,7 @@ def generate_cache_key_hash(
|
|
|
420
442
|
"""
|
|
421
443
|
if ignored_input_vars:
|
|
422
444
|
filtered = [named_lit for named_lit in proto_inputs.literals if named_lit.name not in ignored_input_vars]
|
|
423
|
-
final =
|
|
445
|
+
final = common_pb2.Inputs(literals=filtered)
|
|
424
446
|
final_inputs = generate_inputs_hash_from_proto(final)
|
|
425
447
|
else:
|
|
426
448
|
final_inputs = inputs_hash
|
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import importlib
|
|
2
|
+
import os
|
|
3
|
+
import traceback
|
|
2
4
|
from typing import List, Optional, Tuple, Type
|
|
3
5
|
|
|
4
6
|
import flyte.errors
|
|
@@ -10,6 +12,7 @@ from flyte._logging import log, logger
|
|
|
10
12
|
from flyte._task import TaskTemplate
|
|
11
13
|
from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
|
|
12
14
|
|
|
15
|
+
from ..._utils import adjust_sys_path
|
|
13
16
|
from .convert import Error, Inputs, Outputs
|
|
14
17
|
from .taskrunner import (
|
|
15
18
|
convert_and_run,
|
|
@@ -72,7 +75,26 @@ def load_task(resolver: str, *resolver_args: str) -> TaskTemplate:
|
|
|
72
75
|
"""
|
|
73
76
|
resolver_class = load_class(resolver)
|
|
74
77
|
resolver_instance = resolver_class()
|
|
75
|
-
|
|
78
|
+
try:
|
|
79
|
+
return resolver_instance.load_task(resolver_args)
|
|
80
|
+
except ModuleNotFoundError as e:
|
|
81
|
+
cwd = os.getcwd()
|
|
82
|
+
files = []
|
|
83
|
+
try:
|
|
84
|
+
for root, dirs, filenames in os.walk(cwd):
|
|
85
|
+
for name in dirs + filenames:
|
|
86
|
+
rel_path = os.path.relpath(os.path.join(root, name), cwd)
|
|
87
|
+
files.append(rel_path)
|
|
88
|
+
except Exception as list_err:
|
|
89
|
+
files = [f"(Failed to list directory: {list_err})"]
|
|
90
|
+
|
|
91
|
+
msg = (
|
|
92
|
+
"\n\nFull traceback:\n" + "".join(traceback.format_exc()) + f"\n[ImportError Diagnostics]\n"
|
|
93
|
+
f"Module '{e.name}' not found in either the Python virtual environment or the current working directory.\n"
|
|
94
|
+
f"Current working directory: {cwd}\n"
|
|
95
|
+
f"Files found under current directory:\n" + "\n".join(f" - {f}" for f in files)
|
|
96
|
+
)
|
|
97
|
+
raise ModuleNotFoundError(msg) from e
|
|
76
98
|
|
|
77
99
|
|
|
78
100
|
def load_pkl_task(code_bundle: CodeBundle) -> TaskTemplate:
|
|
@@ -100,6 +122,7 @@ async def download_code_bundle(code_bundle: CodeBundle) -> CodeBundle:
|
|
|
100
122
|
:param code_bundle: The code bundle to download.
|
|
101
123
|
:return: The code bundle with the downloaded path.
|
|
102
124
|
"""
|
|
125
|
+
adjust_sys_path()
|
|
103
126
|
logger.debug(f"Downloading {code_bundle}")
|
|
104
127
|
downloaded_path = await download_bundle(code_bundle)
|
|
105
128
|
return code_bundle.with_downloaded_path(downloaded_path)
|
flyte/_internal/runtime/io.py
CHANGED
|
@@ -5,10 +5,12 @@ It uses the storage module to handle the actual uploading and downloading of fil
|
|
|
5
5
|
TODO: Convert to use streaming apis
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from flyteidl.core import errors_pb2
|
|
8
|
+
from flyteidl.core import errors_pb2
|
|
9
|
+
from flyteidl2.core import execution_pb2
|
|
10
|
+
from flyteidl2.task import common_pb2
|
|
9
11
|
|
|
10
12
|
import flyte.storage as storage
|
|
11
|
-
from flyte.
|
|
13
|
+
from flyte.models import PathRewrite
|
|
12
14
|
|
|
13
15
|
from .convert import Inputs, Outputs, _clean_error_code
|
|
14
16
|
|
|
@@ -69,7 +71,7 @@ async def upload_outputs(outputs: Outputs, output_path: str, max_bytes: int = -1
|
|
|
69
71
|
await storage.put_stream(data_iterable=outputs.proto_outputs.SerializeToString(), to_path=output_uri)
|
|
70
72
|
|
|
71
73
|
|
|
72
|
-
async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str):
|
|
74
|
+
async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str) -> str:
|
|
73
75
|
"""
|
|
74
76
|
:param err: execution_pb2.ExecutionError
|
|
75
77
|
:param output_prefix: The output prefix of the remote uri.
|
|
@@ -86,17 +88,18 @@ async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str):
|
|
|
86
88
|
)
|
|
87
89
|
)
|
|
88
90
|
error_uri = error_path(output_prefix)
|
|
89
|
-
await storage.put_stream(data_iterable=error_document.SerializeToString(), to_path=error_uri)
|
|
91
|
+
return await storage.put_stream(data_iterable=error_document.SerializeToString(), to_path=error_uri)
|
|
90
92
|
|
|
91
93
|
|
|
92
94
|
# ------------------------------- DOWNLOAD Methods ------------------------------- #
|
|
93
|
-
async def load_inputs(path: str, max_bytes: int = -1) -> Inputs:
|
|
95
|
+
async def load_inputs(path: str, max_bytes: int = -1, path_rewrite_config: PathRewrite | None = None) -> Inputs:
|
|
94
96
|
"""
|
|
95
97
|
:param path: Input file to be downloaded
|
|
96
98
|
:param max_bytes: Maximum number of bytes to read from the input file. Default is -1, which means no limit.
|
|
99
|
+
:param path_rewrite_config: If provided, rewrites paths in the input blobs according to the configuration.
|
|
97
100
|
:return: Inputs object
|
|
98
101
|
"""
|
|
99
|
-
lm =
|
|
102
|
+
lm = common_pb2.Inputs()
|
|
100
103
|
|
|
101
104
|
if max_bytes == -1:
|
|
102
105
|
proto_str = b"".join([c async for c in storage.get_stream(path=path)])
|
|
@@ -115,6 +118,16 @@ async def load_inputs(path: str, max_bytes: int = -1) -> Inputs:
|
|
|
115
118
|
proto_str = b"".join(proto_bytes)
|
|
116
119
|
|
|
117
120
|
lm.ParseFromString(proto_str)
|
|
121
|
+
|
|
122
|
+
if path_rewrite_config is not None:
|
|
123
|
+
for inp in lm.literals:
|
|
124
|
+
if inp.value.HasField("scalar") and inp.value.scalar.HasField("blob"):
|
|
125
|
+
scalar_blob = inp.value.scalar.blob
|
|
126
|
+
if scalar_blob.uri.startswith(path_rewrite_config.old_prefix):
|
|
127
|
+
scalar_blob.uri = scalar_blob.uri.replace(
|
|
128
|
+
path_rewrite_config.old_prefix, path_rewrite_config.new_prefix, 1
|
|
129
|
+
)
|
|
130
|
+
|
|
118
131
|
return Inputs(proto_inputs=lm)
|
|
119
132
|
|
|
120
133
|
|
|
@@ -125,7 +138,7 @@ async def load_outputs(path: str, max_bytes: int = -1) -> Outputs:
|
|
|
125
138
|
If -1, reads the entire file.
|
|
126
139
|
:return: Outputs object
|
|
127
140
|
"""
|
|
128
|
-
lm =
|
|
141
|
+
lm = common_pb2.Outputs()
|
|
129
142
|
|
|
130
143
|
if max_bytes == -1:
|
|
131
144
|
proto_str = b"".join([c async for c in storage.get_stream(path=path)])
|
|
@@ -157,7 +170,7 @@ async def load_error(path: str) -> execution_pb2.ExecutionError:
|
|
|
157
170
|
err.ParseFromString(proto_str)
|
|
158
171
|
|
|
159
172
|
if err.error is not None:
|
|
160
|
-
user_code,
|
|
173
|
+
user_code, _server_code = _clean_error_code(err.error.code)
|
|
161
174
|
return execution_pb2.ExecutionError(
|
|
162
175
|
code=user_code,
|
|
163
176
|
message=err.error.message,
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
from typing import List, Optional, Tuple
|
|
1
|
+
from typing import Dict, List, Optional, Tuple
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from flyteidl2.core import tasks_pb2
|
|
4
4
|
|
|
5
|
-
from flyte._resources import CPUBaseType, Resources
|
|
5
|
+
from flyte._resources import CPUBaseType, DeviceClass, Resources
|
|
6
6
|
|
|
7
7
|
ACCELERATOR_DEVICE_MAP = {
|
|
8
8
|
"A100": "nvidia-tesla-a100",
|
|
@@ -24,6 +24,14 @@ ACCELERATOR_DEVICE_MAP = {
|
|
|
24
24
|
"V6E": "tpu-v6e-slice",
|
|
25
25
|
}
|
|
26
26
|
|
|
27
|
+
_DeviceClassToProto: Dict[DeviceClass, "tasks_pb2.GPUAccelerator.DeviceClass"] = {
|
|
28
|
+
"GPU": tasks_pb2.GPUAccelerator.NVIDIA_GPU,
|
|
29
|
+
"TPU": tasks_pb2.GPUAccelerator.GOOGLE_TPU,
|
|
30
|
+
"NEURON": tasks_pb2.GPUAccelerator.AMAZON_NEURON,
|
|
31
|
+
"AMD_GPU": tasks_pb2.GPUAccelerator.AMD_GPU,
|
|
32
|
+
"HABANA_GAUDI": tasks_pb2.GPUAccelerator.HABANA_GAUDI,
|
|
33
|
+
}
|
|
34
|
+
|
|
27
35
|
|
|
28
36
|
def _get_cpu_resource_entry(cpu: CPUBaseType) -> tasks_pb2.Resources.ResourceEntry:
|
|
29
37
|
return tasks_pb2.Resources.ResourceEntry(
|
|
@@ -54,11 +62,17 @@ def _get_gpu_extended_resource_entry(resources: Resources) -> Optional[tasks_pb2
|
|
|
54
62
|
device = resources.get_device()
|
|
55
63
|
if device is None:
|
|
56
64
|
return None
|
|
57
|
-
|
|
58
|
-
|
|
65
|
+
|
|
66
|
+
device_class = _DeviceClassToProto.get(device.device_class, tasks_pb2.GPUAccelerator.NVIDIA_GPU)
|
|
67
|
+
if device.device is None:
|
|
68
|
+
raise RuntimeError("Device type must be specified for GPU string.")
|
|
69
|
+
else:
|
|
70
|
+
device_type = device.device
|
|
71
|
+
device_type = ACCELERATOR_DEVICE_MAP.get(device_type, device_type)
|
|
59
72
|
return tasks_pb2.GPUAccelerator(
|
|
60
|
-
device=
|
|
73
|
+
device=device_type,
|
|
61
74
|
partition_size=device.partition if device.partition else None,
|
|
75
|
+
device_class=device_class,
|
|
62
76
|
)
|
|
63
77
|
|
|
64
78
|
|
flyte/_internal/runtime/reuse.py
CHANGED
flyte/_internal/runtime/rusty.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import sys
|
|
3
2
|
import time
|
|
4
3
|
from typing import Any, List, Tuple
|
|
5
4
|
|
|
@@ -11,7 +10,8 @@ from flyte._internal.runtime.entrypoints import download_code_bundle, load_pkl_t
|
|
|
11
10
|
from flyte._internal.runtime.taskrunner import extract_download_run_upload
|
|
12
11
|
from flyte._logging import logger
|
|
13
12
|
from flyte._task import TaskTemplate
|
|
14
|
-
from flyte.
|
|
13
|
+
from flyte._utils import adjust_sys_path
|
|
14
|
+
from flyte.models import ActionID, Checkpoints, CodeBundle, PathRewrite, RawDataPath
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
async def download_tgz(destination: str, version: str, tgz: str) -> CodeBundle:
|
|
@@ -23,7 +23,7 @@ async def download_tgz(destination: str, version: str, tgz: str) -> CodeBundle:
|
|
|
23
23
|
:return: The CodeBundle object.
|
|
24
24
|
"""
|
|
25
25
|
logger.info(f"[rusty] Downloading tgz code bundle from {tgz} to {destination} with version {version}")
|
|
26
|
-
|
|
26
|
+
adjust_sys_path()
|
|
27
27
|
|
|
28
28
|
code_bundle = CodeBundle(
|
|
29
29
|
tgz=tgz,
|
|
@@ -42,7 +42,7 @@ async def download_load_pkl(destination: str, version: str, pkl: str) -> Tuple[C
|
|
|
42
42
|
:return: The CodeBundle object.
|
|
43
43
|
"""
|
|
44
44
|
logger.info(f"[rusty] Downloading pkl code bundle from {pkl} to {destination} with version {version}")
|
|
45
|
-
|
|
45
|
+
adjust_sys_path()
|
|
46
46
|
|
|
47
47
|
code_bundle = CodeBundle(
|
|
48
48
|
pkl=pkl,
|
|
@@ -115,6 +115,7 @@ async def run_task(
|
|
|
115
115
|
prev_checkpoint: str | None = None,
|
|
116
116
|
code_bundle: CodeBundle | None = None,
|
|
117
117
|
input_path: str | None = None,
|
|
118
|
+
path_rewrite_cfg: str | None = None,
|
|
118
119
|
):
|
|
119
120
|
"""
|
|
120
121
|
Runs the task with the provided parameters.
|
|
@@ -134,6 +135,7 @@ async def run_task(
|
|
|
134
135
|
:param controller: The controller to use for the task.
|
|
135
136
|
:param code_bundle: Optional code bundle for the task.
|
|
136
137
|
:param input_path: Optional input path for the task.
|
|
138
|
+
:param path_rewrite_cfg: Optional path rewrite configuration.
|
|
137
139
|
:return: The loaded task template.
|
|
138
140
|
"""
|
|
139
141
|
start_time = time.time()
|
|
@@ -144,6 +146,19 @@ async def run_task(
|
|
|
144
146
|
f" at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}"
|
|
145
147
|
)
|
|
146
148
|
|
|
149
|
+
path_rewrite = PathRewrite.from_str(path_rewrite_cfg) if path_rewrite_cfg else None
|
|
150
|
+
if path_rewrite:
|
|
151
|
+
import flyte.storage as storage
|
|
152
|
+
|
|
153
|
+
if not await storage.exists(path_rewrite.new_prefix):
|
|
154
|
+
logger.error(
|
|
155
|
+
f"[rusty] Path rewrite failed for path {path_rewrite.new_prefix}, "
|
|
156
|
+
f"not found, reverting to original path {path_rewrite.old_prefix}"
|
|
157
|
+
)
|
|
158
|
+
path_rewrite = None
|
|
159
|
+
else:
|
|
160
|
+
logger.info(f"[rusty] Using path rewrite: {path_rewrite}")
|
|
161
|
+
|
|
147
162
|
try:
|
|
148
163
|
await contextual_run(
|
|
149
164
|
extract_download_run_upload,
|
|
@@ -151,7 +166,7 @@ async def run_task(
|
|
|
151
166
|
action=ActionID(name=name, org=org, project=project, domain=domain, run_name=run_name),
|
|
152
167
|
version=version,
|
|
153
168
|
controller=controller,
|
|
154
|
-
raw_data_path=RawDataPath(path=raw_data_path),
|
|
169
|
+
raw_data_path=RawDataPath(path=raw_data_path, path_rewrite=path_rewrite),
|
|
155
170
|
output_path=output_path,
|
|
156
171
|
run_base_dir=run_base_dir,
|
|
157
172
|
checkpoints=Checkpoints(prev_checkpoint_path=prev_checkpoint, checkpoint_path=checkpoint_path),
|
|
@@ -4,19 +4,18 @@ It includes a Resolver interface for loading tasks, and functions to load classe
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
import copy
|
|
7
|
-
import sys
|
|
8
7
|
import typing
|
|
9
8
|
from datetime import timedelta
|
|
10
9
|
from typing import Optional, cast
|
|
11
10
|
|
|
12
|
-
from
|
|
11
|
+
from flyteidl2.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
|
|
12
|
+
from flyteidl2.task import common_pb2, environment_pb2, task_definition_pb2
|
|
13
13
|
from google.protobuf import duration_pb2, wrappers_pb2
|
|
14
14
|
|
|
15
15
|
import flyte.errors
|
|
16
16
|
from flyte._cache.cache import VersionParameters, cache_from_request
|
|
17
17
|
from flyte._logging import logger
|
|
18
18
|
from flyte._pod import _PRIMARY_CONTAINER_NAME_FIELD, PodTemplate
|
|
19
|
-
from flyte._protos.workflow import common_pb2, environment_pb2, task_definition_pb2
|
|
20
19
|
from flyte._secret import SecretRequest, secrets_from_request
|
|
21
20
|
from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
|
|
22
21
|
from flyte.models import CodeBundle, SerializationContext
|
|
@@ -120,7 +119,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
|
|
|
120
119
|
version=serialize_context.version,
|
|
121
120
|
)
|
|
122
121
|
|
|
123
|
-
# TODO Add support for
|
|
122
|
+
# TODO Add support for extra_config, custom
|
|
124
123
|
extra_config: typing.Dict[str, str] = {}
|
|
125
124
|
|
|
126
125
|
if task.pod_template and not isinstance(task.pod_template, str):
|
|
@@ -133,7 +132,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
|
|
|
133
132
|
|
|
134
133
|
custom = task.custom_config(serialize_context)
|
|
135
134
|
|
|
136
|
-
sql =
|
|
135
|
+
sql = task.sql(serialize_context)
|
|
137
136
|
|
|
138
137
|
# -------------- CACHE HANDLING ----------------------
|
|
139
138
|
task_cache = cache_from_request(task.cache)
|
|
@@ -171,8 +170,9 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
|
|
|
171
170
|
retries=get_proto_retry_strategy(task.retries),
|
|
172
171
|
timeout=get_proto_timeout(task.timeout),
|
|
173
172
|
pod_template_name=(task.pod_template if task.pod_template and isinstance(task.pod_template, str) else None),
|
|
174
|
-
interruptible=task.
|
|
173
|
+
interruptible=task.interruptible,
|
|
175
174
|
generates_deck=wrappers_pb2.BoolValue(value=task.report),
|
|
175
|
+
debuggable=task.debuggable,
|
|
176
176
|
),
|
|
177
177
|
interface=transform_native_to_typed_interface(task.native_interface),
|
|
178
178
|
custom=custom if len(custom) > 0 else None,
|
|
@@ -209,34 +209,40 @@ def _get_urun_container(
|
|
|
209
209
|
else None
|
|
210
210
|
)
|
|
211
211
|
resources = get_proto_resources(task_template.resources)
|
|
212
|
-
|
|
212
|
+
|
|
213
213
|
if isinstance(task_template.image, str):
|
|
214
214
|
raise flyte.errors.RuntimeSystemError("BadConfig", "Image is not a valid image")
|
|
215
|
-
|
|
215
|
+
|
|
216
|
+
env_name = task_template.parent_env_name
|
|
217
|
+
if env_name is None:
|
|
218
|
+
raise flyte.errors.RuntimeSystemError("BadConfig", f"Task {task_template.name} has no parent environment name")
|
|
219
|
+
|
|
216
220
|
if not serialize_context.image_cache:
|
|
217
221
|
# This computes the image uri, computing hashes as necessary so can fail if done remotely.
|
|
218
222
|
img_uri = task_template.image.uri
|
|
219
|
-
elif serialize_context.image_cache and
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
223
|
+
elif serialize_context.image_cache and env_name not in serialize_context.image_cache.image_lookup:
|
|
224
|
+
raise flyte.errors.RuntimeUserError(
|
|
225
|
+
"MissingEnvironment",
|
|
226
|
+
f"Environment '{env_name}' not found in image cache.\n\n"
|
|
227
|
+
"💡 To fix this:\n"
|
|
228
|
+
" 1. If your parent environment calls a task in another environment,"
|
|
229
|
+
" declare that dependency using 'depends_on=[...]'.\n"
|
|
230
|
+
" Example:\n"
|
|
231
|
+
" env1 = flyte.TaskEnvironment(\n"
|
|
232
|
+
" name='outer',\n"
|
|
233
|
+
" image=flyte.Image.from_debian_base().with_pip_packages('requests'),\n"
|
|
234
|
+
" depends_on=[env2, env3],\n"
|
|
235
|
+
" )\n"
|
|
236
|
+
" 2. If you're using os.getenv() to set the environment name,"
|
|
237
|
+
" make sure the runtime environment has the same environment variable defined.\n"
|
|
238
|
+
" Example:\n"
|
|
239
|
+
" env = flyte.TaskEnvironment(\n"
|
|
240
|
+
' name=os.getenv("my-name"),\n'
|
|
241
|
+
' env_vars={"my-name": os.getenv("my-name")},\n'
|
|
242
|
+
" )\n",
|
|
224
243
|
)
|
|
225
244
|
else:
|
|
226
|
-
|
|
227
|
-
version_lookup = serialize_context.image_cache.image_lookup[image_id]
|
|
228
|
-
if python_version_str in version_lookup:
|
|
229
|
-
img_uri = version_lookup[python_version_str]
|
|
230
|
-
elif version_lookup:
|
|
231
|
-
# Fallback: try to get any available version
|
|
232
|
-
fallback_py_version, img_uri = next(iter(version_lookup.items()))
|
|
233
|
-
logger.warning(
|
|
234
|
-
f"Image {task_template.image} for python version {python_version_str} "
|
|
235
|
-
f"not found in the image cache: {serialize_context.image_cache.image_lookup}.\n"
|
|
236
|
-
f"Fall back using image {img_uri} for python version {fallback_py_version} ."
|
|
237
|
-
)
|
|
238
|
-
else:
|
|
239
|
-
img_uri = task_template.image.uri
|
|
245
|
+
img_uri = serialize_context.image_cache.image_lookup[env_name]
|
|
240
246
|
|
|
241
247
|
return tasks_pb2.Container(
|
|
242
248
|
image=img_uri,
|
|
@@ -129,6 +129,14 @@ async def convert_and_run(
|
|
|
129
129
|
in a context tree.
|
|
130
130
|
"""
|
|
131
131
|
ctx = internal_ctx()
|
|
132
|
+
|
|
133
|
+
# Load inputs first to get context
|
|
134
|
+
if input_path:
|
|
135
|
+
inputs = await load_inputs(input_path, path_rewrite_config=raw_data_path.path_rewrite)
|
|
136
|
+
|
|
137
|
+
# Extract context from inputs
|
|
138
|
+
custom_context = inputs.context if inputs else {}
|
|
139
|
+
|
|
132
140
|
tctx = TaskContext(
|
|
133
141
|
action=action,
|
|
134
142
|
checkpoints=checkpoints,
|
|
@@ -142,9 +150,10 @@ async def convert_and_run(
|
|
|
142
150
|
report=flyte.report.Report(name=action.name),
|
|
143
151
|
mode="remote" if not ctx.data.task_context else ctx.data.task_context.mode,
|
|
144
152
|
interactive_mode=interactive_mode,
|
|
153
|
+
custom_context=custom_context,
|
|
145
154
|
)
|
|
155
|
+
|
|
146
156
|
with ctx.replace_task_context(tctx):
|
|
147
|
-
inputs = await load_inputs(input_path) if input_path else inputs
|
|
148
157
|
inputs_kwargs = await convert_inputs_to_native(inputs, task.native_interface)
|
|
149
158
|
out, err = await run_task(tctx=tctx, controller=controller, task=task, inputs=inputs_kwargs)
|
|
150
159
|
if err is not None:
|