flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyte/__init__.py +18 -2
- flyte/_bin/debug.py +38 -0
- flyte/_bin/runtime.py +62 -8
- flyte/_cache/cache.py +4 -2
- flyte/_cache/local_cache.py +216 -0
- flyte/_code_bundle/_ignore.py +12 -4
- flyte/_code_bundle/_packaging.py +13 -9
- flyte/_code_bundle/_utils.py +18 -10
- flyte/_code_bundle/bundle.py +17 -9
- flyte/_constants.py +1 -0
- flyte/_context.py +4 -1
- flyte/_custom_context.py +73 -0
- flyte/_debug/constants.py +38 -0
- flyte/_debug/utils.py +17 -0
- flyte/_debug/vscode.py +307 -0
- flyte/_deploy.py +235 -61
- flyte/_environment.py +20 -6
- flyte/_excepthook.py +1 -1
- flyte/_hash.py +1 -16
- flyte/_image.py +178 -81
- flyte/_initialize.py +132 -51
- flyte/_interface.py +39 -2
- flyte/_internal/controllers/__init__.py +4 -5
- flyte/_internal/controllers/_local_controller.py +70 -29
- flyte/_internal/controllers/_trace.py +1 -1
- flyte/_internal/controllers/remote/__init__.py +0 -2
- flyte/_internal/controllers/remote/_action.py +14 -16
- flyte/_internal/controllers/remote/_client.py +1 -1
- flyte/_internal/controllers/remote/_controller.py +68 -70
- flyte/_internal/controllers/remote/_core.py +127 -99
- flyte/_internal/controllers/remote/_informer.py +19 -10
- flyte/_internal/controllers/remote/_service_protocol.py +7 -7
- flyte/_internal/imagebuild/docker_builder.py +181 -69
- flyte/_internal/imagebuild/image_builder.py +0 -5
- flyte/_internal/imagebuild/remote_builder.py +155 -64
- flyte/_internal/imagebuild/utils.py +51 -2
- flyte/_internal/resolvers/_task_module.py +5 -38
- flyte/_internal/resolvers/default.py +2 -2
- flyte/_internal/runtime/convert.py +110 -21
- flyte/_internal/runtime/entrypoints.py +27 -1
- flyte/_internal/runtime/io.py +21 -8
- flyte/_internal/runtime/resources_serde.py +20 -6
- flyte/_internal/runtime/reuse.py +1 -1
- flyte/_internal/runtime/rusty.py +20 -5
- flyte/_internal/runtime/task_serde.py +34 -19
- flyte/_internal/runtime/taskrunner.py +22 -4
- flyte/_internal/runtime/trigger_serde.py +160 -0
- flyte/_internal/runtime/types_serde.py +1 -1
- flyte/_keyring/__init__.py +0 -0
- flyte/_keyring/file.py +115 -0
- flyte/_logging.py +201 -39
- flyte/_map.py +111 -14
- flyte/_module.py +70 -0
- flyte/_pod.py +4 -3
- flyte/_resources.py +213 -31
- flyte/_run.py +110 -39
- flyte/_task.py +75 -16
- flyte/_task_environment.py +105 -29
- flyte/_task_plugins.py +4 -2
- flyte/_trace.py +5 -0
- flyte/_trigger.py +1000 -0
- flyte/_utils/__init__.py +2 -1
- flyte/_utils/asyn.py +3 -1
- flyte/_utils/coro_management.py +2 -1
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/module_loader.py +17 -2
- flyte/_version.py +3 -3
- flyte/cli/_abort.py +3 -3
- flyte/cli/_build.py +3 -6
- flyte/cli/_common.py +78 -7
- flyte/cli/_create.py +182 -4
- flyte/cli/_delete.py +23 -1
- flyte/cli/_deploy.py +63 -16
- flyte/cli/_get.py +79 -34
- flyte/cli/_params.py +26 -10
- flyte/cli/_plugins.py +209 -0
- flyte/cli/_run.py +151 -26
- flyte/cli/_serve.py +64 -0
- flyte/cli/_update.py +37 -0
- flyte/cli/_user.py +17 -0
- flyte/cli/main.py +30 -4
- flyte/config/_config.py +10 -6
- flyte/config/_internal.py +1 -0
- flyte/config/_reader.py +29 -8
- flyte/connectors/__init__.py +11 -0
- flyte/connectors/_connector.py +270 -0
- flyte/connectors/_server.py +197 -0
- flyte/connectors/utils.py +135 -0
- flyte/errors.py +22 -2
- flyte/extend.py +8 -1
- flyte/extras/_container.py +6 -1
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +21 -0
- flyte/io/__init__.py +2 -0
- flyte/io/_dataframe/__init__.py +2 -0
- flyte/io/_dataframe/basic_dfs.py +17 -8
- flyte/io/_dataframe/dataframe.py +98 -132
- flyte/io/_dir.py +575 -113
- flyte/io/_file.py +582 -139
- flyte/io/_hashing_io.py +342 -0
- flyte/models.py +74 -15
- flyte/remote/__init__.py +6 -1
- flyte/remote/_action.py +34 -26
- flyte/remote/_client/_protocols.py +39 -4
- flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
- flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
- flyte/remote/_client/auth/_channel.py +10 -6
- flyte/remote/_client/controlplane.py +17 -5
- flyte/remote/_console.py +3 -2
- flyte/remote/_data.py +6 -6
- flyte/remote/_logs.py +3 -3
- flyte/remote/_run.py +64 -8
- flyte/remote/_secret.py +26 -17
- flyte/remote/_task.py +75 -33
- flyte/remote/_trigger.py +306 -0
- flyte/remote/_user.py +33 -0
- flyte/report/_report.py +1 -1
- flyte/storage/__init__.py +6 -1
- flyte/storage/_config.py +5 -1
- flyte/storage/_parallel_reader.py +274 -0
- flyte/storage/_storage.py +200 -103
- flyte/types/__init__.py +16 -0
- flyte/types/_interface.py +2 -2
- flyte/types/_pickle.py +35 -8
- flyte/types/_string_literals.py +8 -9
- flyte/types/_type_engine.py +40 -70
- flyte/types/_utils.py +1 -1
- flyte-2.0.0b30.data/scripts/debug.py +38 -0
- {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
- flyte-2.0.0b30.dist-info/RECORD +192 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
- flyte/_protos/common/authorization_pb2.py +0 -66
- flyte/_protos/common/authorization_pb2.pyi +0 -108
- flyte/_protos/common/authorization_pb2_grpc.py +0 -4
- flyte/_protos/common/identifier_pb2.py +0 -93
- flyte/_protos/common/identifier_pb2.pyi +0 -110
- flyte/_protos/common/identifier_pb2_grpc.py +0 -4
- flyte/_protos/common/identity_pb2.py +0 -48
- flyte/_protos/common/identity_pb2.pyi +0 -72
- flyte/_protos/common/identity_pb2_grpc.py +0 -4
- flyte/_protos/common/list_pb2.py +0 -36
- flyte/_protos/common/list_pb2.pyi +0 -71
- flyte/_protos/common/list_pb2_grpc.py +0 -4
- flyte/_protos/common/policy_pb2.py +0 -37
- flyte/_protos/common/policy_pb2.pyi +0 -27
- flyte/_protos/common/policy_pb2_grpc.py +0 -4
- flyte/_protos/common/role_pb2.py +0 -37
- flyte/_protos/common/role_pb2.pyi +0 -53
- flyte/_protos/common/role_pb2_grpc.py +0 -4
- flyte/_protos/common/runtime_version_pb2.py +0 -28
- flyte/_protos/common/runtime_version_pb2.pyi +0 -24
- flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/definition_pb2.py +0 -59
- flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
- flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/payload_pb2.py +0 -32
- flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
- flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/service_pb2.py +0 -29
- flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
- flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
- flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
- flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/definition_pb2.py +0 -49
- flyte/_protos/secret/definition_pb2.pyi +0 -93
- flyte/_protos/secret/definition_pb2_grpc.py +0 -4
- flyte/_protos/secret/payload_pb2.py +0 -62
- flyte/_protos/secret/payload_pb2.pyi +0 -94
- flyte/_protos/secret/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/secret_pb2.py +0 -38
- flyte/_protos/secret/secret_pb2.pyi +0 -6
- flyte/_protos/secret/secret_pb2_grpc.py +0 -198
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
- flyte/_protos/validate/validate/validate_pb2.py +0 -76
- flyte/_protos/workflow/common_pb2.py +0 -27
- flyte/_protos/workflow/common_pb2.pyi +0 -14
- flyte/_protos/workflow/common_pb2_grpc.py +0 -4
- flyte/_protos/workflow/environment_pb2.py +0 -29
- flyte/_protos/workflow/environment_pb2.pyi +0 -12
- flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
- flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
- flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- flyte/_protos/workflow/queue_service_pb2.py +0 -109
- flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
- flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
- flyte/_protos/workflow/run_definition_pb2.py +0 -121
- flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
- flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
- flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- flyte/_protos/workflow/run_service_pb2.py +0 -137
- flyte/_protos/workflow/run_service_pb2.pyi +0 -185
- flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
- flyte/_protos/workflow/state_service_pb2.py +0 -67
- flyte/_protos/workflow/state_service_pb2.pyi +0 -76
- flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
- flyte/_protos/workflow/task_definition_pb2.py +0 -79
- flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
- flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/task_service_pb2.py +0 -60
- flyte/_protos/workflow/task_service_pb2.pyi +0 -59
- flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
- flyte-2.0.0b13.dist-info/RECORD +0 -239
- /flyte/{_protos → _debug}/__init__.py +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
|
@@ -8,27 +8,33 @@ from dataclasses import dataclass
|
|
|
8
8
|
from types import NoneType
|
|
9
9
|
from typing import Any, Dict, List, Tuple, Union, get_args
|
|
10
10
|
|
|
11
|
-
from
|
|
11
|
+
from flyteidl2.core import execution_pb2, interface_pb2, literals_pb2
|
|
12
|
+
from flyteidl2.task import common_pb2, task_definition_pb2
|
|
12
13
|
|
|
13
14
|
import flyte.errors
|
|
14
15
|
import flyte.storage as storage
|
|
15
|
-
from flyte.
|
|
16
|
+
from flyte._context import ctx
|
|
16
17
|
from flyte.models import ActionID, NativeInterface, TaskContext
|
|
17
18
|
from flyte.types import TypeEngine, TypeTransformerFailedError
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
@dataclass(frozen=True)
|
|
21
22
|
class Inputs:
|
|
22
|
-
proto_inputs:
|
|
23
|
+
proto_inputs: common_pb2.Inputs
|
|
23
24
|
|
|
24
25
|
@classmethod
|
|
25
26
|
def empty(cls) -> "Inputs":
|
|
26
|
-
return cls(proto_inputs=
|
|
27
|
+
return cls(proto_inputs=common_pb2.Inputs())
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def context(self) -> Dict[str, str]:
|
|
31
|
+
"""Get the context as a dictionary."""
|
|
32
|
+
return {kv.key: kv.value for kv in self.proto_inputs.context}
|
|
27
33
|
|
|
28
34
|
|
|
29
35
|
@dataclass(frozen=True)
|
|
30
36
|
class Outputs:
|
|
31
|
-
proto_outputs:
|
|
37
|
+
proto_outputs: common_pb2.Outputs
|
|
32
38
|
|
|
33
39
|
|
|
34
40
|
@dataclass
|
|
@@ -102,15 +108,30 @@ def is_optional_type(tp) -> bool:
|
|
|
102
108
|
return NoneType in get_args(tp) # fastest check
|
|
103
109
|
|
|
104
110
|
|
|
105
|
-
async def convert_from_native_to_inputs(
|
|
111
|
+
async def convert_from_native_to_inputs(
|
|
112
|
+
interface: NativeInterface, *args, custom_context: Dict[str, str] | None = None, **kwargs
|
|
113
|
+
) -> Inputs:
|
|
106
114
|
kwargs = interface.convert_to_kwargs(*args, **kwargs)
|
|
107
115
|
|
|
108
116
|
missing = [key for key in interface.required_inputs() if key not in kwargs]
|
|
109
117
|
if missing:
|
|
110
118
|
raise ValueError(f"Missing required inputs: {', '.join(missing)}")
|
|
111
119
|
|
|
120
|
+
# Read custom_context from TaskContext if available (inside task execution)
|
|
121
|
+
# Otherwise use the passed parameter (for remote run initiation)
|
|
122
|
+
context_kvs = None
|
|
123
|
+
tctx = ctx()
|
|
124
|
+
if tctx and tctx.custom_context:
|
|
125
|
+
# Inside a task - read from TaskContext
|
|
126
|
+
context_to_use = tctx.custom_context
|
|
127
|
+
context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in context_to_use.items()]
|
|
128
|
+
elif custom_context:
|
|
129
|
+
# Remote run initiation
|
|
130
|
+
context_kvs = [literals_pb2.KeyValuePair(key=k, value=v) for k, v in custom_context.items()]
|
|
131
|
+
|
|
112
132
|
if len(interface.inputs) == 0:
|
|
113
|
-
|
|
133
|
+
# Handle context even for empty inputs
|
|
134
|
+
return Inputs(proto_inputs=common_pb2.Inputs(context=context_kvs))
|
|
114
135
|
|
|
115
136
|
# fill in defaults if missing
|
|
116
137
|
type_hints: Dict[str, type] = {}
|
|
@@ -122,13 +143,14 @@ async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwa
|
|
|
122
143
|
(default_value is not None and default_value is not inspect.Signature.empty)
|
|
123
144
|
or (default_value is None and is_optional_type(input_type))
|
|
124
145
|
or input_type is None
|
|
146
|
+
or input_type is type(None)
|
|
125
147
|
):
|
|
126
148
|
if default_value == NativeInterface.has_default:
|
|
127
149
|
if interface._remote_defaults is None or input_name not in interface._remote_defaults:
|
|
128
150
|
raise ValueError(f"Input '{input_name}' has a default value but it is not set in the interface.")
|
|
129
151
|
already_converted_kwargs[input_name] = interface._remote_defaults[input_name]
|
|
130
|
-
elif input_type is None:
|
|
131
|
-
# If the type is None, we assume it's a placeholder for no type
|
|
152
|
+
elif input_type is None or input_type is type(None):
|
|
153
|
+
# If the type is 'None' or 'class<None>', we assume it's a placeholder for no type
|
|
132
154
|
kwargs[input_name] = None
|
|
133
155
|
type_hints[input_name] = NoneType
|
|
134
156
|
else:
|
|
@@ -144,12 +166,12 @@ async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwa
|
|
|
144
166
|
for k, v in already_converted_kwargs.items():
|
|
145
167
|
copied_literals[k] = v
|
|
146
168
|
literal_map = literals_pb2.LiteralMap(literals=copied_literals)
|
|
169
|
+
|
|
147
170
|
# Make sure we the interface, not literal_map or kwargs, because those may have a different order
|
|
148
171
|
return Inputs(
|
|
149
|
-
proto_inputs=
|
|
150
|
-
literals=[
|
|
151
|
-
|
|
152
|
-
]
|
|
172
|
+
proto_inputs=common_pb2.Inputs(
|
|
173
|
+
literals=[common_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()],
|
|
174
|
+
context=context_kvs,
|
|
153
175
|
)
|
|
154
176
|
)
|
|
155
177
|
|
|
@@ -191,11 +213,11 @@ async def convert_from_native_to_outputs(o: Any, interface: NativeInterface, tas
|
|
|
191
213
|
for (output_name, python_type), v in zip(interface.outputs.items(), o):
|
|
192
214
|
try:
|
|
193
215
|
lit = await TypeEngine.to_literal(v, python_type, TypeEngine.to_literal_type(python_type))
|
|
194
|
-
named.append(
|
|
216
|
+
named.append(common_pb2.NamedLiteral(name=output_name, value=lit))
|
|
195
217
|
except TypeTransformerFailedError as e:
|
|
196
218
|
raise flyte.errors.RuntimeDataValidationError(output_name, e, task_name)
|
|
197
219
|
|
|
198
|
-
return Outputs(proto_outputs=
|
|
220
|
+
return Outputs(proto_outputs=common_pb2.Outputs(literals=named))
|
|
199
221
|
|
|
200
222
|
|
|
201
223
|
async def convert_outputs_to_native(interface: NativeInterface, outputs: Outputs) -> Union[Any, Tuple[Any, ...]]:
|
|
@@ -222,7 +244,7 @@ def convert_error_to_native(err: execution_pb2.ExecutionError | Exception | Erro
|
|
|
222
244
|
if isinstance(err, Error):
|
|
223
245
|
err = err.err
|
|
224
246
|
|
|
225
|
-
user_code,
|
|
247
|
+
user_code, _server_code = _clean_error_code(err.code)
|
|
226
248
|
match err.kind:
|
|
227
249
|
case execution_pb2.ExecutionError.UNKNOWN:
|
|
228
250
|
return flyte.errors.RuntimeUnknownError(code=user_code, message=err.message, worker=err.worker)
|
|
@@ -308,15 +330,82 @@ def generate_inputs_hash(serialized_inputs: str | bytes) -> str:
|
|
|
308
330
|
return hash_data(serialized_inputs)
|
|
309
331
|
|
|
310
332
|
|
|
311
|
-
def
|
|
333
|
+
def generate_inputs_repr_for_literal(literal: literals_pb2.Literal) -> bytes:
|
|
334
|
+
"""
|
|
335
|
+
Generate a byte representation for a single literal that is meant to be hashed as part of the cache key
|
|
336
|
+
computation for an Action. This function should just serialize the literal deterministically, but will
|
|
337
|
+
use an existing hash value if present in the Literal. This is trivial, except we need to handle nested literals
|
|
338
|
+
(inside collections and maps), that may have the hash property set.
|
|
339
|
+
|
|
340
|
+
:param literal: The literal to get a hashable representation for.
|
|
341
|
+
:return: byte representation of the literal that can be fed into a hash function.
|
|
342
|
+
"""
|
|
343
|
+
# If the literal has a hash value, use that instead of serializing the full literal
|
|
344
|
+
if literal.hash:
|
|
345
|
+
return literal.hash.encode("utf-8")
|
|
346
|
+
|
|
347
|
+
if literal.HasField("collection"):
|
|
348
|
+
buf = bytearray()
|
|
349
|
+
for nested_literal in literal.collection.literals:
|
|
350
|
+
if nested_literal.hash:
|
|
351
|
+
buf += nested_literal.hash.encode("utf-8")
|
|
352
|
+
else:
|
|
353
|
+
buf += generate_inputs_repr_for_literal(nested_literal)
|
|
354
|
+
|
|
355
|
+
b = bytes(buf)
|
|
356
|
+
return b
|
|
357
|
+
|
|
358
|
+
elif literal.HasField("map"):
|
|
359
|
+
buf = bytearray()
|
|
360
|
+
# Sort keys to ensure deterministic ordering
|
|
361
|
+
for key in sorted(literal.map.literals.keys()):
|
|
362
|
+
nested_literal = literal.map.literals[key]
|
|
363
|
+
buf += key.encode("utf-8")
|
|
364
|
+
if nested_literal.hash:
|
|
365
|
+
buf += nested_literal.hash.encode("utf-8")
|
|
366
|
+
else:
|
|
367
|
+
buf += generate_inputs_repr_for_literal(nested_literal)
|
|
368
|
+
|
|
369
|
+
b = bytes(buf)
|
|
370
|
+
return b
|
|
371
|
+
|
|
372
|
+
# For all other cases (scalars, etc.), just serialize the literal normally
|
|
373
|
+
return literal.SerializeToString(deterministic=True)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def generate_inputs_hash_for_named_literals(inputs: list[common_pb2.NamedLiteral]) -> str:
|
|
377
|
+
"""
|
|
378
|
+
Generate a hash for the inputs using the new literal representation approach that respects
|
|
379
|
+
hash values already present in literals. This is used to uniquely identify the inputs for a task
|
|
380
|
+
when some literals may have precomputed hash values.
|
|
381
|
+
|
|
382
|
+
:param inputs: List of NamedLiteral inputs to hash.
|
|
383
|
+
:return: A base64-encoded string representation of the hash.
|
|
384
|
+
"""
|
|
385
|
+
if not inputs:
|
|
386
|
+
return ""
|
|
387
|
+
|
|
388
|
+
# Build the byte representation by concatenating each literal's representation
|
|
389
|
+
combined_bytes = b""
|
|
390
|
+
for named_literal in inputs:
|
|
391
|
+
# Add the name to ensure order matters
|
|
392
|
+
name_bytes = named_literal.name.encode("utf-8")
|
|
393
|
+
literal_bytes = generate_inputs_repr_for_literal(named_literal.value)
|
|
394
|
+
# Combine name and literal bytes with a separator to avoid collisions
|
|
395
|
+
combined_bytes += name_bytes + b":" + literal_bytes + b";"
|
|
396
|
+
|
|
397
|
+
return hash_data(combined_bytes)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def generate_inputs_hash_from_proto(inputs: common_pb2.Inputs) -> str:
|
|
312
401
|
"""
|
|
313
402
|
Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
|
|
314
403
|
:param inputs: The inputs to hash.
|
|
315
404
|
:return: A hexadecimal string representation of the hash.
|
|
316
405
|
"""
|
|
317
|
-
if not inputs:
|
|
406
|
+
if not inputs or not inputs.literals:
|
|
318
407
|
return ""
|
|
319
|
-
return
|
|
408
|
+
return generate_inputs_hash_for_named_literals(list(inputs.literals))
|
|
320
409
|
|
|
321
410
|
|
|
322
411
|
def generate_interface_hash(task_interface: interface_pb2.TypedInterface) -> str:
|
|
@@ -337,7 +426,7 @@ def generate_cache_key_hash(
|
|
|
337
426
|
task_interface: interface_pb2.TypedInterface,
|
|
338
427
|
cache_version: str,
|
|
339
428
|
ignored_input_vars: List[str],
|
|
340
|
-
proto_inputs:
|
|
429
|
+
proto_inputs: common_pb2.Inputs,
|
|
341
430
|
) -> str:
|
|
342
431
|
"""
|
|
343
432
|
Generate a cache key hash based on the inputs hash, task name, task interface, and cache version.
|
|
@@ -353,7 +442,7 @@ def generate_cache_key_hash(
|
|
|
353
442
|
"""
|
|
354
443
|
if ignored_input_vars:
|
|
355
444
|
filtered = [named_lit for named_lit in proto_inputs.literals if named_lit.name not in ignored_input_vars]
|
|
356
|
-
final =
|
|
445
|
+
final = common_pb2.Inputs(literals=filtered)
|
|
357
446
|
final_inputs = generate_inputs_hash_from_proto(final)
|
|
358
447
|
else:
|
|
359
448
|
final_inputs = inputs_hash
|
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import importlib
|
|
2
|
+
import os
|
|
3
|
+
import traceback
|
|
2
4
|
from typing import List, Optional, Tuple, Type
|
|
3
5
|
|
|
4
6
|
import flyte.errors
|
|
@@ -10,6 +12,7 @@ from flyte._logging import log, logger
|
|
|
10
12
|
from flyte._task import TaskTemplate
|
|
11
13
|
from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
|
|
12
14
|
|
|
15
|
+
from ..._utils import adjust_sys_path
|
|
13
16
|
from .convert import Error, Inputs, Outputs
|
|
14
17
|
from .taskrunner import (
|
|
15
18
|
convert_and_run,
|
|
@@ -72,7 +75,26 @@ def load_task(resolver: str, *resolver_args: str) -> TaskTemplate:
|
|
|
72
75
|
"""
|
|
73
76
|
resolver_class = load_class(resolver)
|
|
74
77
|
resolver_instance = resolver_class()
|
|
75
|
-
|
|
78
|
+
try:
|
|
79
|
+
return resolver_instance.load_task(resolver_args)
|
|
80
|
+
except ModuleNotFoundError as e:
|
|
81
|
+
cwd = os.getcwd()
|
|
82
|
+
files = []
|
|
83
|
+
try:
|
|
84
|
+
for root, dirs, filenames in os.walk(cwd):
|
|
85
|
+
for name in dirs + filenames:
|
|
86
|
+
rel_path = os.path.relpath(os.path.join(root, name), cwd)
|
|
87
|
+
files.append(rel_path)
|
|
88
|
+
except Exception as list_err:
|
|
89
|
+
files = [f"(Failed to list directory: {list_err})"]
|
|
90
|
+
|
|
91
|
+
msg = (
|
|
92
|
+
"\n\nFull traceback:\n" + "".join(traceback.format_exc()) + f"\n[ImportError Diagnostics]\n"
|
|
93
|
+
f"Module '{e.name}' not found in either the Python virtual environment or the current working directory.\n"
|
|
94
|
+
f"Current working directory: {cwd}\n"
|
|
95
|
+
f"Files found under current directory:\n" + "\n".join(f" - {f}" for f in files)
|
|
96
|
+
)
|
|
97
|
+
raise ModuleNotFoundError(msg) from e
|
|
76
98
|
|
|
77
99
|
|
|
78
100
|
def load_pkl_task(code_bundle: CodeBundle) -> TaskTemplate:
|
|
@@ -100,6 +122,7 @@ async def download_code_bundle(code_bundle: CodeBundle) -> CodeBundle:
|
|
|
100
122
|
:param code_bundle: The code bundle to download.
|
|
101
123
|
:return: The code bundle with the downloaded path.
|
|
102
124
|
"""
|
|
125
|
+
adjust_sys_path()
|
|
103
126
|
logger.debug(f"Downloading {code_bundle}")
|
|
104
127
|
downloaded_path = await download_bundle(code_bundle)
|
|
105
128
|
return code_bundle.with_downloaded_path(downloaded_path)
|
|
@@ -142,6 +165,7 @@ async def load_and_run_task(
|
|
|
142
165
|
code_bundle: CodeBundle | None = None,
|
|
143
166
|
input_path: str | None = None,
|
|
144
167
|
image_cache: ImageCache | None = None,
|
|
168
|
+
interactive_mode: bool = False,
|
|
145
169
|
):
|
|
146
170
|
"""
|
|
147
171
|
This method is invoked from the runtime/CLI and is used to run a task. This creates the context tree,
|
|
@@ -159,6 +183,7 @@ async def load_and_run_task(
|
|
|
159
183
|
:param code_bundle: The code bundle to use for the task.
|
|
160
184
|
:param input_path: The input path to use for the task.
|
|
161
185
|
:param image_cache: Mappings of Image identifiers to image URIs.
|
|
186
|
+
:param interactive_mode: Whether to run the task in interactive mode.
|
|
162
187
|
"""
|
|
163
188
|
task = await _download_and_load_task(code_bundle, resolver, resolver_args)
|
|
164
189
|
|
|
@@ -175,4 +200,5 @@ async def load_and_run_task(
|
|
|
175
200
|
code_bundle=code_bundle,
|
|
176
201
|
input_path=input_path,
|
|
177
202
|
image_cache=image_cache,
|
|
203
|
+
interactive_mode=interactive_mode,
|
|
178
204
|
)
|
flyte/_internal/runtime/io.py
CHANGED
|
@@ -5,10 +5,12 @@ It uses the storage module to handle the actual uploading and downloading of fil
|
|
|
5
5
|
TODO: Convert to use streaming apis
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from flyteidl.core import errors_pb2
|
|
8
|
+
from flyteidl.core import errors_pb2
|
|
9
|
+
from flyteidl2.core import execution_pb2
|
|
10
|
+
from flyteidl2.task import common_pb2
|
|
9
11
|
|
|
10
12
|
import flyte.storage as storage
|
|
11
|
-
from flyte.
|
|
13
|
+
from flyte.models import PathRewrite
|
|
12
14
|
|
|
13
15
|
from .convert import Inputs, Outputs, _clean_error_code
|
|
14
16
|
|
|
@@ -69,7 +71,7 @@ async def upload_outputs(outputs: Outputs, output_path: str, max_bytes: int = -1
|
|
|
69
71
|
await storage.put_stream(data_iterable=outputs.proto_outputs.SerializeToString(), to_path=output_uri)
|
|
70
72
|
|
|
71
73
|
|
|
72
|
-
async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str):
|
|
74
|
+
async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str) -> str:
|
|
73
75
|
"""
|
|
74
76
|
:param err: execution_pb2.ExecutionError
|
|
75
77
|
:param output_prefix: The output prefix of the remote uri.
|
|
@@ -86,17 +88,18 @@ async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str):
|
|
|
86
88
|
)
|
|
87
89
|
)
|
|
88
90
|
error_uri = error_path(output_prefix)
|
|
89
|
-
await storage.put_stream(data_iterable=error_document.SerializeToString(), to_path=error_uri)
|
|
91
|
+
return await storage.put_stream(data_iterable=error_document.SerializeToString(), to_path=error_uri)
|
|
90
92
|
|
|
91
93
|
|
|
92
94
|
# ------------------------------- DOWNLOAD Methods ------------------------------- #
|
|
93
|
-
async def load_inputs(path: str, max_bytes: int = -1) -> Inputs:
|
|
95
|
+
async def load_inputs(path: str, max_bytes: int = -1, path_rewrite_config: PathRewrite | None = None) -> Inputs:
|
|
94
96
|
"""
|
|
95
97
|
:param path: Input file to be downloaded
|
|
96
98
|
:param max_bytes: Maximum number of bytes to read from the input file. Default is -1, which means no limit.
|
|
99
|
+
:param path_rewrite_config: If provided, rewrites paths in the input blobs according to the configuration.
|
|
97
100
|
:return: Inputs object
|
|
98
101
|
"""
|
|
99
|
-
lm =
|
|
102
|
+
lm = common_pb2.Inputs()
|
|
100
103
|
|
|
101
104
|
if max_bytes == -1:
|
|
102
105
|
proto_str = b"".join([c async for c in storage.get_stream(path=path)])
|
|
@@ -115,6 +118,16 @@ async def load_inputs(path: str, max_bytes: int = -1) -> Inputs:
|
|
|
115
118
|
proto_str = b"".join(proto_bytes)
|
|
116
119
|
|
|
117
120
|
lm.ParseFromString(proto_str)
|
|
121
|
+
|
|
122
|
+
if path_rewrite_config is not None:
|
|
123
|
+
for inp in lm.literals:
|
|
124
|
+
if inp.value.HasField("scalar") and inp.value.scalar.HasField("blob"):
|
|
125
|
+
scalar_blob = inp.value.scalar.blob
|
|
126
|
+
if scalar_blob.uri.startswith(path_rewrite_config.old_prefix):
|
|
127
|
+
scalar_blob.uri = scalar_blob.uri.replace(
|
|
128
|
+
path_rewrite_config.old_prefix, path_rewrite_config.new_prefix, 1
|
|
129
|
+
)
|
|
130
|
+
|
|
118
131
|
return Inputs(proto_inputs=lm)
|
|
119
132
|
|
|
120
133
|
|
|
@@ -125,7 +138,7 @@ async def load_outputs(path: str, max_bytes: int = -1) -> Outputs:
|
|
|
125
138
|
If -1, reads the entire file.
|
|
126
139
|
:return: Outputs object
|
|
127
140
|
"""
|
|
128
|
-
lm =
|
|
141
|
+
lm = common_pb2.Outputs()
|
|
129
142
|
|
|
130
143
|
if max_bytes == -1:
|
|
131
144
|
proto_str = b"".join([c async for c in storage.get_stream(path=path)])
|
|
@@ -157,7 +170,7 @@ async def load_error(path: str) -> execution_pb2.ExecutionError:
|
|
|
157
170
|
err.ParseFromString(proto_str)
|
|
158
171
|
|
|
159
172
|
if err.error is not None:
|
|
160
|
-
user_code,
|
|
173
|
+
user_code, _server_code = _clean_error_code(err.error.code)
|
|
161
174
|
return execution_pb2.ExecutionError(
|
|
162
175
|
code=user_code,
|
|
163
176
|
message=err.error.message,
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
from typing import List, Optional, Tuple
|
|
1
|
+
from typing import Dict, List, Optional, Tuple
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from flyteidl2.core import tasks_pb2
|
|
4
4
|
|
|
5
|
-
from flyte._resources import CPUBaseType, Resources
|
|
5
|
+
from flyte._resources import CPUBaseType, DeviceClass, Resources
|
|
6
6
|
|
|
7
7
|
ACCELERATOR_DEVICE_MAP = {
|
|
8
8
|
"A100": "nvidia-tesla-a100",
|
|
@@ -24,6 +24,14 @@ ACCELERATOR_DEVICE_MAP = {
|
|
|
24
24
|
"V6E": "tpu-v6e-slice",
|
|
25
25
|
}
|
|
26
26
|
|
|
27
|
+
_DeviceClassToProto: Dict[DeviceClass, "tasks_pb2.GPUAccelerator.DeviceClass"] = {
|
|
28
|
+
"GPU": tasks_pb2.GPUAccelerator.NVIDIA_GPU,
|
|
29
|
+
"TPU": tasks_pb2.GPUAccelerator.GOOGLE_TPU,
|
|
30
|
+
"NEURON": tasks_pb2.GPUAccelerator.AMAZON_NEURON,
|
|
31
|
+
"AMD_GPU": tasks_pb2.GPUAccelerator.AMD_GPU,
|
|
32
|
+
"HABANA_GAUDI": tasks_pb2.GPUAccelerator.HABANA_GAUDI,
|
|
33
|
+
}
|
|
34
|
+
|
|
27
35
|
|
|
28
36
|
def _get_cpu_resource_entry(cpu: CPUBaseType) -> tasks_pb2.Resources.ResourceEntry:
|
|
29
37
|
return tasks_pb2.Resources.ResourceEntry(
|
|
@@ -54,11 +62,17 @@ def _get_gpu_extended_resource_entry(resources: Resources) -> Optional[tasks_pb2
|
|
|
54
62
|
device = resources.get_device()
|
|
55
63
|
if device is None:
|
|
56
64
|
return None
|
|
57
|
-
|
|
58
|
-
|
|
65
|
+
|
|
66
|
+
device_class = _DeviceClassToProto.get(device.device_class, tasks_pb2.GPUAccelerator.NVIDIA_GPU)
|
|
67
|
+
if device.device is None:
|
|
68
|
+
raise RuntimeError("Device type must be specified for GPU string.")
|
|
69
|
+
else:
|
|
70
|
+
device_type = device.device
|
|
71
|
+
device_type = ACCELERATOR_DEVICE_MAP.get(device_type, device_type)
|
|
59
72
|
return tasks_pb2.GPUAccelerator(
|
|
60
|
-
device=
|
|
73
|
+
device=device_type,
|
|
61
74
|
partition_size=device.partition if device.partition else None,
|
|
75
|
+
device_class=device_class,
|
|
62
76
|
)
|
|
63
77
|
|
|
64
78
|
|
flyte/_internal/runtime/reuse.py
CHANGED
flyte/_internal/runtime/rusty.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import sys
|
|
3
2
|
import time
|
|
4
3
|
from typing import Any, List, Tuple
|
|
5
4
|
|
|
@@ -11,7 +10,8 @@ from flyte._internal.runtime.entrypoints import download_code_bundle, load_pkl_t
|
|
|
11
10
|
from flyte._internal.runtime.taskrunner import extract_download_run_upload
|
|
12
11
|
from flyte._logging import logger
|
|
13
12
|
from flyte._task import TaskTemplate
|
|
14
|
-
from flyte.
|
|
13
|
+
from flyte._utils import adjust_sys_path
|
|
14
|
+
from flyte.models import ActionID, Checkpoints, CodeBundle, PathRewrite, RawDataPath
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
async def download_tgz(destination: str, version: str, tgz: str) -> CodeBundle:
|
|
@@ -23,7 +23,7 @@ async def download_tgz(destination: str, version: str, tgz: str) -> CodeBundle:
|
|
|
23
23
|
:return: The CodeBundle object.
|
|
24
24
|
"""
|
|
25
25
|
logger.info(f"[rusty] Downloading tgz code bundle from {tgz} to {destination} with version {version}")
|
|
26
|
-
|
|
26
|
+
adjust_sys_path()
|
|
27
27
|
|
|
28
28
|
code_bundle = CodeBundle(
|
|
29
29
|
tgz=tgz,
|
|
@@ -42,7 +42,7 @@ async def download_load_pkl(destination: str, version: str, pkl: str) -> Tuple[C
|
|
|
42
42
|
:return: The CodeBundle object.
|
|
43
43
|
"""
|
|
44
44
|
logger.info(f"[rusty] Downloading pkl code bundle from {pkl} to {destination} with version {version}")
|
|
45
|
-
|
|
45
|
+
adjust_sys_path()
|
|
46
46
|
|
|
47
47
|
code_bundle = CodeBundle(
|
|
48
48
|
pkl=pkl,
|
|
@@ -115,6 +115,7 @@ async def run_task(
|
|
|
115
115
|
prev_checkpoint: str | None = None,
|
|
116
116
|
code_bundle: CodeBundle | None = None,
|
|
117
117
|
input_path: str | None = None,
|
|
118
|
+
path_rewrite_cfg: str | None = None,
|
|
118
119
|
):
|
|
119
120
|
"""
|
|
120
121
|
Runs the task with the provided parameters.
|
|
@@ -134,6 +135,7 @@ async def run_task(
|
|
|
134
135
|
:param controller: The controller to use for the task.
|
|
135
136
|
:param code_bundle: Optional code bundle for the task.
|
|
136
137
|
:param input_path: Optional input path for the task.
|
|
138
|
+
:param path_rewrite_cfg: Optional path rewrite configuration.
|
|
137
139
|
:return: The loaded task template.
|
|
138
140
|
"""
|
|
139
141
|
start_time = time.time()
|
|
@@ -144,6 +146,19 @@ async def run_task(
|
|
|
144
146
|
f" at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}"
|
|
145
147
|
)
|
|
146
148
|
|
|
149
|
+
path_rewrite = PathRewrite.from_str(path_rewrite_cfg) if path_rewrite_cfg else None
|
|
150
|
+
if path_rewrite:
|
|
151
|
+
import flyte.storage as storage
|
|
152
|
+
|
|
153
|
+
if not await storage.exists(path_rewrite.new_prefix):
|
|
154
|
+
logger.error(
|
|
155
|
+
f"[rusty] Path rewrite failed for path {path_rewrite.new_prefix}, "
|
|
156
|
+
f"not found, reverting to original path {path_rewrite.old_prefix}"
|
|
157
|
+
)
|
|
158
|
+
path_rewrite = None
|
|
159
|
+
else:
|
|
160
|
+
logger.info(f"[rusty] Using path rewrite: {path_rewrite}")
|
|
161
|
+
|
|
147
162
|
try:
|
|
148
163
|
await contextual_run(
|
|
149
164
|
extract_download_run_upload,
|
|
@@ -151,7 +166,7 @@ async def run_task(
|
|
|
151
166
|
action=ActionID(name=name, org=org, project=project, domain=domain, run_name=run_name),
|
|
152
167
|
version=version,
|
|
153
168
|
controller=controller,
|
|
154
|
-
raw_data_path=RawDataPath(path=raw_data_path),
|
|
169
|
+
raw_data_path=RawDataPath(path=raw_data_path, path_rewrite=path_rewrite),
|
|
155
170
|
output_path=output_path,
|
|
156
171
|
run_base_dir=run_base_dir,
|
|
157
172
|
checkpoints=Checkpoints(prev_checkpoint_path=prev_checkpoint, checkpoint_path=checkpoint_path),
|
|
@@ -8,14 +8,14 @@ import typing
|
|
|
8
8
|
from datetime import timedelta
|
|
9
9
|
from typing import Optional, cast
|
|
10
10
|
|
|
11
|
-
from
|
|
11
|
+
from flyteidl2.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
|
|
12
|
+
from flyteidl2.task import common_pb2, environment_pb2, task_definition_pb2
|
|
12
13
|
from google.protobuf import duration_pb2, wrappers_pb2
|
|
13
14
|
|
|
14
15
|
import flyte.errors
|
|
15
16
|
from flyte._cache.cache import VersionParameters, cache_from_request
|
|
16
17
|
from flyte._logging import logger
|
|
17
18
|
from flyte._pod import _PRIMARY_CONTAINER_NAME_FIELD, PodTemplate
|
|
18
|
-
from flyte._protos.workflow import common_pb2, environment_pb2, task_definition_pb2
|
|
19
19
|
from flyte._secret import SecretRequest, secrets_from_request
|
|
20
20
|
from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
|
|
21
21
|
from flyte.models import CodeBundle, SerializationContext
|
|
@@ -54,7 +54,7 @@ def translate_task_to_wire(
|
|
|
54
54
|
return task_definition_pb2.TaskSpec(
|
|
55
55
|
task_template=tt,
|
|
56
56
|
default_inputs=default_inputs,
|
|
57
|
-
short_name=task.
|
|
57
|
+
short_name=task.short_name[:_MAX_TASK_SHORT_NAME_LENGTH],
|
|
58
58
|
environment=env,
|
|
59
59
|
)
|
|
60
60
|
|
|
@@ -119,7 +119,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
|
|
|
119
119
|
version=serialize_context.version,
|
|
120
120
|
)
|
|
121
121
|
|
|
122
|
-
# TODO Add support for
|
|
122
|
+
# TODO Add support for extra_config, custom
|
|
123
123
|
extra_config: typing.Dict[str, str] = {}
|
|
124
124
|
|
|
125
125
|
if task.pod_template and not isinstance(task.pod_template, str):
|
|
@@ -132,7 +132,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
|
|
|
132
132
|
|
|
133
133
|
custom = task.custom_config(serialize_context)
|
|
134
134
|
|
|
135
|
-
sql =
|
|
135
|
+
sql = task.sql(serialize_context)
|
|
136
136
|
|
|
137
137
|
# -------------- CACHE HANDLING ----------------------
|
|
138
138
|
task_cache = cache_from_request(task.cache)
|
|
@@ -145,7 +145,6 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
|
|
|
145
145
|
logger.debug(f"Detected pkl bundle for task {task.name}, using computed version as cache version")
|
|
146
146
|
cache_version = serialize_context.code_bundle.computed_version
|
|
147
147
|
else:
|
|
148
|
-
version_parameters = None
|
|
149
148
|
if isinstance(task, AsyncFunctionTaskTemplate):
|
|
150
149
|
version_parameters = VersionParameters(func=task.func, image=task.image)
|
|
151
150
|
else:
|
|
@@ -171,8 +170,9 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
|
|
|
171
170
|
retries=get_proto_retry_strategy(task.retries),
|
|
172
171
|
timeout=get_proto_timeout(task.timeout),
|
|
173
172
|
pod_template_name=(task.pod_template if task.pod_template and isinstance(task.pod_template, str) else None),
|
|
174
|
-
interruptible=task.
|
|
173
|
+
interruptible=task.interruptible,
|
|
175
174
|
generates_deck=wrappers_pb2.BoolValue(value=task.report),
|
|
175
|
+
debuggable=task.debuggable,
|
|
176
176
|
),
|
|
177
177
|
interface=transform_native_to_typed_interface(task.native_interface),
|
|
178
178
|
custom=custom if len(custom) > 0 else None,
|
|
@@ -209,25 +209,40 @@ def _get_urun_container(
|
|
|
209
209
|
else None
|
|
210
210
|
)
|
|
211
211
|
resources = get_proto_resources(task_template.resources)
|
|
212
|
-
|
|
212
|
+
|
|
213
213
|
if isinstance(task_template.image, str):
|
|
214
214
|
raise flyte.errors.RuntimeSystemError("BadConfig", "Image is not a valid image")
|
|
215
|
-
|
|
215
|
+
|
|
216
|
+
env_name = task_template.parent_env_name
|
|
217
|
+
if env_name is None:
|
|
218
|
+
raise flyte.errors.RuntimeSystemError("BadConfig", f"Task {task_template.name} has no parent environment name")
|
|
219
|
+
|
|
216
220
|
if not serialize_context.image_cache:
|
|
217
221
|
# This computes the image uri, computing hashes as necessary so can fail if done remotely.
|
|
218
222
|
img_uri = task_template.image.uri
|
|
219
|
-
elif serialize_context.image_cache and
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
223
|
+
elif serialize_context.image_cache and env_name not in serialize_context.image_cache.image_lookup:
|
|
224
|
+
raise flyte.errors.RuntimeUserError(
|
|
225
|
+
"MissingEnvironment",
|
|
226
|
+
f"Environment '{env_name}' not found in image cache.\n\n"
|
|
227
|
+
"💡 To fix this:\n"
|
|
228
|
+
" 1. If your parent environment calls a task in another environment,"
|
|
229
|
+
" declare that dependency using 'depends_on=[...]'.\n"
|
|
230
|
+
" Example:\n"
|
|
231
|
+
" env1 = flyte.TaskEnvironment(\n"
|
|
232
|
+
" name='outer',\n"
|
|
233
|
+
" image=flyte.Image.from_debian_base().with_pip_packages('requests'),\n"
|
|
234
|
+
" depends_on=[env2, env3],\n"
|
|
235
|
+
" )\n"
|
|
236
|
+
" 2. If you're using os.getenv() to set the environment name,"
|
|
237
|
+
" make sure the runtime environment has the same environment variable defined.\n"
|
|
238
|
+
" Example:\n"
|
|
239
|
+
" env = flyte.TaskEnvironment(\n"
|
|
240
|
+
' name=os.getenv("my-name"),\n'
|
|
241
|
+
' env_vars={"my-name": os.getenv("my-name")},\n'
|
|
242
|
+
" )\n",
|
|
228
243
|
)
|
|
229
244
|
else:
|
|
230
|
-
img_uri = serialize_context.image_cache.image_lookup[
|
|
245
|
+
img_uri = serialize_context.image_cache.image_lookup[env_name]
|
|
231
246
|
|
|
232
247
|
return tasks_pb2.Container(
|
|
233
248
|
image=img_uri,
|