flyte 0.2.0b9__py3-none-any.whl → 0.2.0b11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +4 -2
- flyte/_bin/runtime.py +6 -3
- flyte/_deploy.py +3 -0
- flyte/_initialize.py +30 -6
- flyte/_internal/controllers/_local_controller.py +4 -3
- flyte/_internal/controllers/_trace.py +1 -0
- flyte/_internal/controllers/remote/_action.py +1 -1
- flyte/_internal/controllers/remote/_informer.py +1 -1
- flyte/_internal/runtime/convert.py +7 -4
- flyte/_internal/runtime/task_serde.py +80 -10
- flyte/_internal/runtime/taskrunner.py +1 -1
- flyte/_logging.py +1 -1
- flyte/_pod.py +19 -0
- flyte/_run.py +84 -39
- flyte/_task.py +2 -13
- flyte/_utils/org_discovery.py +31 -0
- flyte/_version.py +2 -2
- flyte/cli/_common.py +6 -6
- flyte/cli/_create.py +16 -8
- flyte/cli/_params.py +2 -2
- flyte/cli/_run.py +1 -1
- flyte/cli/main.py +4 -8
- flyte/errors.py +11 -0
- flyte/extras/_container.py +29 -39
- flyte/io/__init__.py +17 -1
- flyte/io/_file.py +2 -0
- flyte/io/{structured_dataset → _structured_dataset}/basic_dfs.py +1 -1
- flyte/io/{structured_dataset → _structured_dataset}/structured_dataset.py +1 -1
- flyte/models.py +1 -0
- flyte/remote/_data.py +2 -1
- flyte/types/__init__.py +23 -0
- flyte/{io/pickle/transformer.py → types/_pickle.py} +2 -1
- flyte/types/_type_engine.py +7 -5
- {flyte-0.2.0b9.dist-info → flyte-0.2.0b11.dist-info}/METADATA +5 -6
- {flyte-0.2.0b9.dist-info → flyte-0.2.0b11.dist-info}/RECORD +39 -39
- flyte/io/_dataframe.py +0 -0
- flyte/io/pickle/__init__.py +0 -0
- /flyte/io/{structured_dataset → _structured_dataset}/__init__.py +0 -0
- {flyte-0.2.0b9.dist-info → flyte-0.2.0b11.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b9.dist-info → flyte-0.2.0b11.dist-info}/entry_points.txt +0 -0
- {flyte-0.2.0b9.dist-info → flyte-0.2.0b11.dist-info}/top_level.txt +0 -0
flyte/__init__.py
CHANGED
|
@@ -25,6 +25,7 @@ __all__ = [
|
|
|
25
25
|
"Device",
|
|
26
26
|
"Environment",
|
|
27
27
|
"Image",
|
|
28
|
+
"PodTemplate",
|
|
28
29
|
"Resources",
|
|
29
30
|
"RetryStrategy",
|
|
30
31
|
"ReusePolicy",
|
|
@@ -38,7 +39,7 @@ __all__ = [
|
|
|
38
39
|
"deploy",
|
|
39
40
|
"group",
|
|
40
41
|
"init",
|
|
41
|
-
"
|
|
42
|
+
"init_from_config",
|
|
42
43
|
"map",
|
|
43
44
|
"run",
|
|
44
45
|
"trace",
|
|
@@ -51,8 +52,9 @@ from ._deploy import deploy
|
|
|
51
52
|
from ._environment import Environment
|
|
52
53
|
from ._group import group
|
|
53
54
|
from ._image import Image
|
|
54
|
-
from ._initialize import init,
|
|
55
|
+
from ._initialize import init, init_from_config
|
|
55
56
|
from ._map import map
|
|
57
|
+
from ._pod import PodTemplate
|
|
56
58
|
from ._resources import GPU, TPU, Device, Resources
|
|
57
59
|
from ._retry import RetryStrategy
|
|
58
60
|
from ._reusable_environment import ReusePolicy
|
flyte/_bin/runtime.py
CHANGED
|
@@ -76,13 +76,17 @@ def main(
|
|
|
76
76
|
):
|
|
77
77
|
sys.path.insert(0, ".")
|
|
78
78
|
|
|
79
|
+
import flyte
|
|
79
80
|
import flyte._utils as utils
|
|
80
81
|
from flyte._initialize import initialize_in_cluster
|
|
81
82
|
from flyte._internal.controllers import create_controller
|
|
82
83
|
from flyte._internal.imagebuild.image_builder import ImageCache
|
|
83
84
|
from flyte._internal.runtime.entrypoints import load_and_run_task
|
|
85
|
+
from flyte._logging import logger
|
|
84
86
|
from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
|
|
85
87
|
|
|
88
|
+
logger.info(f"Initializing flyte runtime - version {flyte.__version__}")
|
|
89
|
+
|
|
86
90
|
assert org, "Org is required for now"
|
|
87
91
|
assert project, "Project is required"
|
|
88
92
|
assert domain, "Domain is required"
|
|
@@ -98,15 +102,14 @@ def main(
|
|
|
98
102
|
# This detection of api key is a hack for now.
|
|
99
103
|
controller_kwargs: dict[str, Any] = {"insecure": False}
|
|
100
104
|
if api_key := os.getenv(_UNION_EAGER_API_KEY_ENV_VAR):
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
logger.warning(f"Using api key {api_key}")
|
|
105
|
+
logger.info("Using api key from environment")
|
|
104
106
|
controller_kwargs["api_key"] = api_key
|
|
105
107
|
else:
|
|
106
108
|
ep = os.environ.get(ENDPOINT_OVERRIDE, "host.docker.internal:8090")
|
|
107
109
|
controller_kwargs["endpoint"] = ep
|
|
108
110
|
if "localhost" in ep or "docker" in ep:
|
|
109
111
|
controller_kwargs["insecure"] = True
|
|
112
|
+
logger.debug(f"Using controller endpoint: {ep} with kwargs: {controller_kwargs}")
|
|
110
113
|
|
|
111
114
|
bundle = CodeBundle(tgz=tgz, pkl=pkl, destination=dest, computed_version=version)
|
|
112
115
|
initialize_in_cluster()
|
flyte/_deploy.py
CHANGED
|
@@ -128,6 +128,9 @@ async def apply(deployment: DeploymentPlan, copy_style: CopyFiles, dryrun: bool
|
|
|
128
128
|
else:
|
|
129
129
|
code_bundle = await build_code_bundle(from_dir=cfg.root_dir, dryrun=dryrun, copy_style=copy_style)
|
|
130
130
|
deployment.version = code_bundle.computed_version
|
|
131
|
+
# TODO we should update the version to include the image cache digest and code bundle digest. This is
|
|
132
|
+
# to ensure that changes in image dependencies, cause an update to the deployment version.
|
|
133
|
+
# TODO Also hash the environment and tasks to ensure that changes in the environment or tasks
|
|
131
134
|
|
|
132
135
|
sc = SerializationContext(
|
|
133
136
|
project=cfg.project,
|
flyte/_initialize.py
CHANGED
|
@@ -138,14 +138,13 @@ async def init(
|
|
|
138
138
|
:param project: Optional project name (not used in this implementation)
|
|
139
139
|
:param domain: Optional domain name (not used in this implementation)
|
|
140
140
|
:param root_dir: Optional root directory from which to determine how to load files, and find paths to files.
|
|
141
|
+
This is useful for determining the root directory for the current project, and for locating files like config etc.
|
|
142
|
+
also use to determine all the code that needs to be copied to the remote location.
|
|
141
143
|
defaults to the editable install directory if the cwd is in a Python editable install, else just the cwd.
|
|
142
144
|
:param log_level: Optional logging level for the logger, default is set using the default initialization policies
|
|
143
145
|
:param api_key: Optional API key for authentication
|
|
144
146
|
:param endpoint: Optional API endpoint URL
|
|
145
147
|
:param headless: Optional Whether to run in headless mode
|
|
146
|
-
:param mode: Optional execution model (local, remote). Default is local. When local is used,
|
|
147
|
-
the execution will be done locally. When remote is used, the execution will be sent to a remote server,
|
|
148
|
-
In the remote case, the endpoint or api_key must be set.
|
|
149
148
|
:param insecure_skip_verify: Whether to skip SSL certificate verification
|
|
150
149
|
:param auth_client_config: Optional client configuration for authentication
|
|
151
150
|
:param auth_type: The authentication type to use (Pkce, ClientSecret, ExternalCommand, DeviceFlow)
|
|
@@ -169,6 +168,8 @@ async def init(
|
|
|
169
168
|
"""
|
|
170
169
|
from flyte._utils import get_cwd_editable_install
|
|
171
170
|
|
|
171
|
+
from ._utils.org_discovery import org_from_endpoint
|
|
172
|
+
|
|
172
173
|
interactive_mode = ipython_check()
|
|
173
174
|
|
|
174
175
|
initialize_logger(enable_rich=interactive_mode)
|
|
@@ -177,6 +178,9 @@ async def init(
|
|
|
177
178
|
|
|
178
179
|
global _init_config # noqa: PLW0603
|
|
179
180
|
|
|
181
|
+
if endpoint and "://" not in endpoint:
|
|
182
|
+
endpoint = f"dns:///{endpoint}"
|
|
183
|
+
|
|
180
184
|
with _init_lock:
|
|
181
185
|
client = None
|
|
182
186
|
if endpoint or api_key:
|
|
@@ -204,17 +208,25 @@ async def init(
|
|
|
204
208
|
domain=domain,
|
|
205
209
|
client=client,
|
|
206
210
|
storage=storage,
|
|
207
|
-
org=org,
|
|
211
|
+
org=org or org_from_endpoint(endpoint),
|
|
208
212
|
)
|
|
209
213
|
|
|
210
214
|
|
|
211
215
|
@syncify
|
|
212
|
-
async def
|
|
216
|
+
async def init_from_config(
|
|
217
|
+
path_or_config: str | Config | None = None, root_dir: Path | None = None, log_level: int | None = None
|
|
218
|
+
) -> None:
|
|
213
219
|
"""
|
|
214
220
|
Initialize the Flyte system using a configuration file or Config object. This method should be called before any
|
|
215
221
|
other Flyte remote API methods are called. Thread-safe implementation.
|
|
216
222
|
|
|
217
223
|
:param path_or_config: Path to the configuration file or Config object
|
|
224
|
+
:param root_dir: Optional root directory from which to determine how to load files, and find paths to
|
|
225
|
+
files like config etc. For example if one uses the copy-style=="all", it is essential to determine the
|
|
226
|
+
root directory for the current project. If not provided, it defaults to the editable install directory or
|
|
227
|
+
if not available, the current working directory.
|
|
228
|
+
:param log_level: Optional logging level for the framework logger,
|
|
229
|
+
default is set using the default initialization policies
|
|
218
230
|
:return: None
|
|
219
231
|
"""
|
|
220
232
|
import flyte.config as config
|
|
@@ -222,7 +234,17 @@ async def init_auto_from_config(path_or_config: str | Config | None = None) -> N
|
|
|
222
234
|
cfg: config.Config
|
|
223
235
|
if path_or_config is None or isinstance(path_or_config, str):
|
|
224
236
|
# If a string is passed, treat it as a path to the config file
|
|
225
|
-
|
|
237
|
+
if path_or_config:
|
|
238
|
+
if not Path(path_or_config).exists():
|
|
239
|
+
raise InitializationError(
|
|
240
|
+
"ConfigFileNotFoundError",
|
|
241
|
+
"user",
|
|
242
|
+
f"Configuration file '{path_or_config}' does not exist., current working directory is {Path.cwd()}",
|
|
243
|
+
)
|
|
244
|
+
if root_dir and path_or_config:
|
|
245
|
+
cfg = config.auto(str(root_dir / path_or_config))
|
|
246
|
+
else:
|
|
247
|
+
cfg = config.auto(path_or_config)
|
|
226
248
|
else:
|
|
227
249
|
# If a Config object is passed, use it directly
|
|
228
250
|
cfg = path_or_config
|
|
@@ -241,6 +263,8 @@ async def init_auto_from_config(path_or_config: str | Config | None = None) -> N
|
|
|
241
263
|
proxy_command=cfg.platform.proxy_command,
|
|
242
264
|
client_id=cfg.platform.client_id,
|
|
243
265
|
client_credentials_secret=cfg.platform.client_credentials_secret,
|
|
266
|
+
root_dir=root_dir,
|
|
267
|
+
log_level=log_level,
|
|
244
268
|
)
|
|
245
269
|
|
|
246
270
|
|
|
@@ -14,7 +14,7 @@ from flyte._logging import log, logger
|
|
|
14
14
|
from flyte._protos.workflow import task_definition_pb2
|
|
15
15
|
from flyte._task import TaskTemplate
|
|
16
16
|
from flyte._utils.helpers import _selector_policy
|
|
17
|
-
from flyte.models import ActionID, NativeInterface
|
|
17
|
+
from flyte.models import ActionID, NativeInterface
|
|
18
18
|
|
|
19
19
|
R = TypeVar("R")
|
|
20
20
|
|
|
@@ -86,7 +86,7 @@ class LocalController:
|
|
|
86
86
|
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
87
87
|
tctx, _task.name, serialized_inputs, 0
|
|
88
88
|
)
|
|
89
|
-
sub_action_raw_data_path =
|
|
89
|
+
sub_action_raw_data_path = tctx.raw_data_path
|
|
90
90
|
|
|
91
91
|
out, err = await direct_dispatch(
|
|
92
92
|
_task,
|
|
@@ -162,6 +162,7 @@ class LocalController:
|
|
|
162
162
|
action=action_id,
|
|
163
163
|
interface=_interface,
|
|
164
164
|
inputs_path=action_output_path,
|
|
165
|
+
name=_func.__name__,
|
|
165
166
|
),
|
|
166
167
|
True,
|
|
167
168
|
)
|
|
@@ -179,7 +180,7 @@ class LocalController:
|
|
|
179
180
|
|
|
180
181
|
if info.interface.outputs and info.output:
|
|
181
182
|
# If the result is not an AsyncGenerator, convert it directly
|
|
182
|
-
converted_outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
|
|
183
|
+
converted_outputs = await convert.convert_from_native_to_outputs(info.output, info.interface, info.name)
|
|
183
184
|
assert converted_outputs
|
|
184
185
|
elif info.error:
|
|
185
186
|
# If there is an error, convert it to a native error
|
|
@@ -130,7 +130,7 @@ class Action:
|
|
|
130
130
|
"""
|
|
131
131
|
from flyte._logging import logger
|
|
132
132
|
|
|
133
|
-
logger.
|
|
133
|
+
logger.debug(f"In Action from_state {obj.action_id} {obj.phase} {obj.output_uri}")
|
|
134
134
|
return cls(
|
|
135
135
|
action_id=obj.action_id,
|
|
136
136
|
parent_action_name=parent_action_name,
|
|
@@ -235,7 +235,7 @@ class Informer:
|
|
|
235
235
|
await self._shared_queue.put(node)
|
|
236
236
|
# hack to work in the absence of sentinel
|
|
237
237
|
except asyncio.CancelledError:
|
|
238
|
-
logger.
|
|
238
|
+
logger.info(f"Watch cancelled: {self.name}")
|
|
239
239
|
return
|
|
240
240
|
except asyncio.TimeoutError as e:
|
|
241
241
|
logger.error(f"Watch timeout: {self.name}", exc_info=e)
|
|
@@ -11,7 +11,7 @@ import flyte.errors
|
|
|
11
11
|
import flyte.storage as storage
|
|
12
12
|
from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
|
|
13
13
|
from flyte.models import ActionID, NativeInterface, TaskContext
|
|
14
|
-
from flyte.types import TypeEngine
|
|
14
|
+
from flyte.types import TypeEngine, TypeTransformerFailedError
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@dataclass(frozen=True)
|
|
@@ -80,7 +80,7 @@ async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwa
|
|
|
80
80
|
)
|
|
81
81
|
|
|
82
82
|
|
|
83
|
-
async def convert_from_native_to_outputs(o: Any, interface: NativeInterface) -> Outputs:
|
|
83
|
+
async def convert_from_native_to_outputs(o: Any, interface: NativeInterface, task_name: str = "") -> Outputs:
|
|
84
84
|
# Always make it a tuple even if it's just one item to simplify logic below
|
|
85
85
|
if not isinstance(o, tuple):
|
|
86
86
|
o = (o,)
|
|
@@ -90,8 +90,11 @@ async def convert_from_native_to_outputs(o: Any, interface: NativeInterface) ->
|
|
|
90
90
|
)
|
|
91
91
|
named = []
|
|
92
92
|
for (output_name, python_type), v in zip(interface.outputs.items(), o):
|
|
93
|
-
|
|
94
|
-
|
|
93
|
+
try:
|
|
94
|
+
lit = await TypeEngine.to_literal(v, python_type, TypeEngine.to_literal_type(python_type))
|
|
95
|
+
named.append(run_definition_pb2.NamedLiteral(name=output_name, value=lit))
|
|
96
|
+
except TypeTransformerFailedError as e:
|
|
97
|
+
raise flyte.errors.RuntimeDataValidationError(output_name, e, task_name)
|
|
95
98
|
|
|
96
99
|
return Outputs(proto_outputs=run_definition_pb2.Outputs(literals=named))
|
|
97
100
|
|
|
@@ -3,9 +3,11 @@ This module provides functionality to serialize and deserialize tasks to and fro
|
|
|
3
3
|
It includes a Resolver interface for loading tasks, and functions to load classes and tasks.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
import copy
|
|
6
7
|
import importlib
|
|
8
|
+
import typing
|
|
7
9
|
from datetime import timedelta
|
|
8
|
-
from typing import Optional, Type
|
|
10
|
+
from typing import Optional, Type, cast
|
|
9
11
|
|
|
10
12
|
from flyteidl.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
|
|
11
13
|
from google.protobuf import duration_pb2, wrappers_pb2
|
|
@@ -13,6 +15,7 @@ from google.protobuf import duration_pb2, wrappers_pb2
|
|
|
13
15
|
import flyte.errors
|
|
14
16
|
from flyte._cache.cache import VersionParameters, cache_from_request
|
|
15
17
|
from flyte._logging import logger
|
|
18
|
+
from flyte._pod import _PRIMARY_CONTAINER_NAME_FIELD, PodTemplate
|
|
16
19
|
from flyte._protos.workflow import task_definition_pb2
|
|
17
20
|
from flyte._secret import SecretRequest, secrets_from_request
|
|
18
21
|
from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
|
|
@@ -121,17 +124,18 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
|
|
|
121
124
|
# if task.parent_env is None:
|
|
122
125
|
# raise ValueError(f"Task {task.name} must have a parent environment")
|
|
123
126
|
|
|
124
|
-
#
|
|
125
|
-
|
|
126
|
-
#
|
|
127
|
-
container = _get_urun_container(serialize_context, task)
|
|
127
|
+
# TODO Add support for SQL, extra_config, custom
|
|
128
|
+
extra_config: typing.Dict[str, str] = {}
|
|
129
|
+
custom = {} # type: ignore
|
|
128
130
|
|
|
129
|
-
# TODO Add support for SQL, Pod, extra_config, custom
|
|
130
|
-
pod = None
|
|
131
131
|
sql = None
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
132
|
+
if task.pod_template and not isinstance(task.pod_template, str):
|
|
133
|
+
container = None
|
|
134
|
+
pod = _get_k8s_pod(_get_urun_container(serialize_context, task), task.pod_template)
|
|
135
|
+
extra_config[_PRIMARY_CONTAINER_NAME_FIELD] = task.pod_template.primary_container_name
|
|
136
|
+
else:
|
|
137
|
+
container = _get_urun_container(serialize_context, task)
|
|
138
|
+
pod = None
|
|
135
139
|
|
|
136
140
|
# -------------- CACHE HANDLING ----------------------
|
|
137
141
|
task_cache = cache_from_request(task.cache)
|
|
@@ -210,6 +214,72 @@ def _get_urun_container(
|
|
|
210
214
|
)
|
|
211
215
|
|
|
212
216
|
|
|
217
|
+
def _sanitize_resource_name(resource: tasks_pb2.Resources.ResourceEntry) -> str:
|
|
218
|
+
return tasks_pb2.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _get_k8s_pod(primary_container: tasks_pb2.Container, pod_template: PodTemplate) -> Optional[tasks_pb2.K8sPod]:
|
|
222
|
+
"""
|
|
223
|
+
Get the K8sPod representation of the task template.
|
|
224
|
+
:param task: The task to convert.
|
|
225
|
+
:return: The K8sPod representation of the task template.
|
|
226
|
+
"""
|
|
227
|
+
from kubernetes.client import ApiClient, V1PodSpec
|
|
228
|
+
from kubernetes.client.models import V1EnvVar, V1ResourceRequirements
|
|
229
|
+
|
|
230
|
+
pod_template = copy.deepcopy(pod_template)
|
|
231
|
+
containers = cast(V1PodSpec, pod_template.pod_spec).containers
|
|
232
|
+
primary_exists = False
|
|
233
|
+
|
|
234
|
+
for container in containers:
|
|
235
|
+
if container.name == pod_template.primary_container_name:
|
|
236
|
+
primary_exists = True
|
|
237
|
+
break
|
|
238
|
+
|
|
239
|
+
if not primary_exists:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
"No primary container defined in the pod spec."
|
|
242
|
+
f" You must define a primary container with the name '{pod_template.primary_container_name}'."
|
|
243
|
+
)
|
|
244
|
+
final_containers = []
|
|
245
|
+
|
|
246
|
+
for container in containers:
|
|
247
|
+
# We overwrite the primary container attributes with the values given to ContainerTask.
|
|
248
|
+
# The attributes include: image, command, args, resource, and env (env is unioned)
|
|
249
|
+
|
|
250
|
+
if container.name == pod_template.primary_container_name:
|
|
251
|
+
if container.image is None:
|
|
252
|
+
# Copy the image from primary_container only if the image is not specified in the pod spec.
|
|
253
|
+
container.image = primary_container.image
|
|
254
|
+
|
|
255
|
+
container.command = list(primary_container.command)
|
|
256
|
+
container.args = list(primary_container.args)
|
|
257
|
+
|
|
258
|
+
limits, requests = {}, {}
|
|
259
|
+
for resource in primary_container.resources.limits:
|
|
260
|
+
limits[_sanitize_resource_name(resource)] = resource.value
|
|
261
|
+
for resource in primary_container.resources.requests:
|
|
262
|
+
requests[_sanitize_resource_name(resource)] = resource.value
|
|
263
|
+
|
|
264
|
+
resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
|
|
265
|
+
if len(limits) > 0 or len(requests) > 0:
|
|
266
|
+
# Important! Only copy over resource requirements if they are non-empty.
|
|
267
|
+
container.resources = resource_requirements
|
|
268
|
+
|
|
269
|
+
if primary_container.env is not None:
|
|
270
|
+
container.env = [V1EnvVar(name=e.key, value=e.value) for e in primary_container.env] + (
|
|
271
|
+
container.env or []
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
final_containers.append(container)
|
|
275
|
+
|
|
276
|
+
cast(V1PodSpec, pod_template.pod_spec).containers = final_containers
|
|
277
|
+
pod_spec = ApiClient().sanitize_for_serialization(pod_template.pod_spec)
|
|
278
|
+
|
|
279
|
+
metadata = tasks_pb2.K8sObjectMetadata(labels=pod_template.labels, annotations=pod_template.annotations)
|
|
280
|
+
return tasks_pb2.K8sPod(pod_spec=pod_spec, metadata=metadata)
|
|
281
|
+
|
|
282
|
+
|
|
213
283
|
def extract_code_bundle(task_spec: task_definition_pb2.TaskSpec) -> Optional[CodeBundle]:
|
|
214
284
|
"""
|
|
215
285
|
Extract the code bundle from the task spec.
|
|
@@ -145,7 +145,7 @@ async def convert_and_run(
|
|
|
145
145
|
return None, convert_from_native_to_error(err)
|
|
146
146
|
if task.report:
|
|
147
147
|
await flyte.report.flush.aio()
|
|
148
|
-
return await convert_from_native_to_outputs(out, task.native_interface), None
|
|
148
|
+
return await convert_from_native_to_outputs(out, task.native_interface, task.name), None
|
|
149
149
|
|
|
150
150
|
|
|
151
151
|
async def extract_download_run_upload(
|
flyte/_logging.py
CHANGED
flyte/_pod.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, Optional
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from kubernetes.client import V1PodSpec
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
|
|
9
|
+
_PRIMARY_CONTAINER_DEFAULT_NAME = "primary"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(init=True, repr=True, eq=True, frozen=False)
|
|
13
|
+
class PodTemplate(object):
|
|
14
|
+
"""Custom PodTemplate specification for a Task."""
|
|
15
|
+
|
|
16
|
+
pod_spec: Optional["V1PodSpec"] = field(default_factory=lambda: V1PodSpec())
|
|
17
|
+
primary_container_name: str = _PRIMARY_CONTAINER_DEFAULT_NAME
|
|
18
|
+
labels: Optional[Dict[str, str]] = None
|
|
19
|
+
annotations: Optional[Dict[str, str]] = None
|
flyte/_run.py
CHANGED
|
@@ -5,6 +5,7 @@ import pathlib
|
|
|
5
5
|
import uuid
|
|
6
6
|
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union, cast
|
|
7
7
|
|
|
8
|
+
import flyte.errors
|
|
8
9
|
from flyte.errors import InitializationError
|
|
9
10
|
from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath, SerializationContext, TaskContext
|
|
10
11
|
from flyte.syncify import syncify
|
|
@@ -25,6 +26,7 @@ from ._tools import ipython_check
|
|
|
25
26
|
|
|
26
27
|
if TYPE_CHECKING:
|
|
27
28
|
from flyte.remote import Run
|
|
29
|
+
from flyte.remote._task import LazyEntity
|
|
28
30
|
|
|
29
31
|
from ._code_bundle import CopyFiles
|
|
30
32
|
|
|
@@ -81,8 +83,11 @@ class _Runner:
|
|
|
81
83
|
self._run_base_dir = run_base_dir or "/tmp/base"
|
|
82
84
|
|
|
83
85
|
@requires_initialization
|
|
84
|
-
async def _run_remote(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> Run:
|
|
86
|
+
async def _run_remote(self, obj: TaskTemplate[P, R] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run:
|
|
87
|
+
import grpc
|
|
88
|
+
|
|
85
89
|
from flyte.remote import Run
|
|
90
|
+
from flyte.remote._task import LazyEntity
|
|
86
91
|
|
|
87
92
|
from ._code_bundle import build_code_bundle, build_pkl_bundle
|
|
88
93
|
from ._deploy import build_images, plan_deploy
|
|
@@ -93,37 +98,47 @@ class _Runner:
|
|
|
93
98
|
|
|
94
99
|
cfg = get_common_config()
|
|
95
100
|
|
|
96
|
-
if obj
|
|
97
|
-
|
|
101
|
+
if isinstance(obj, LazyEntity):
|
|
102
|
+
task = await obj.fetch.aio()
|
|
103
|
+
task_spec = task.pb2.spec
|
|
104
|
+
inputs = await convert_from_native_to_inputs(task.interface, *args, **kwargs)
|
|
105
|
+
version = task.pb2.task_id.version
|
|
106
|
+
code_bundle = None
|
|
107
|
+
else:
|
|
108
|
+
if obj.parent_env is None:
|
|
109
|
+
raise ValueError("Task is not attached to an environment. Please attach the task to an environment")
|
|
98
110
|
|
|
99
|
-
|
|
100
|
-
|
|
111
|
+
deploy_plan = plan_deploy(cast(Environment, obj.parent_env()))
|
|
112
|
+
image_cache = await build_images(deploy_plan)
|
|
101
113
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
)
|
|
106
|
-
else:
|
|
107
|
-
if self._copy_files != "none":
|
|
108
|
-
code_bundle = await build_code_bundle(
|
|
109
|
-
from_dir=cfg.root_dir, dryrun=self._dry_run, copy_bundle_to=self._copy_bundle_to
|
|
114
|
+
if self._interactive_mode:
|
|
115
|
+
code_bundle = await build_pkl_bundle(
|
|
116
|
+
obj, upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to
|
|
110
117
|
)
|
|
111
118
|
else:
|
|
112
|
-
|
|
119
|
+
if self._copy_files != "none":
|
|
120
|
+
code_bundle = await build_code_bundle(
|
|
121
|
+
from_dir=cfg.root_dir,
|
|
122
|
+
dryrun=self._dry_run,
|
|
123
|
+
copy_bundle_to=self._copy_bundle_to,
|
|
124
|
+
copy_style=self._copy_files,
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
code_bundle = None
|
|
113
128
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
129
|
+
version = self._version or (
|
|
130
|
+
code_bundle.computed_version if code_bundle and code_bundle.computed_version else None
|
|
131
|
+
)
|
|
132
|
+
if not version:
|
|
133
|
+
raise ValueError("Version is required when running a task")
|
|
134
|
+
s_ctx = SerializationContext(
|
|
135
|
+
code_bundle=code_bundle,
|
|
136
|
+
version=version,
|
|
137
|
+
image_cache=image_cache,
|
|
138
|
+
root_dir=cfg.root_dir,
|
|
139
|
+
)
|
|
140
|
+
task_spec = translate_task_to_wire(obj, s_ctx)
|
|
141
|
+
inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
|
|
127
142
|
|
|
128
143
|
if not self._dry_run:
|
|
129
144
|
if get_client() is None:
|
|
@@ -160,15 +175,35 @@ class _Runner:
|
|
|
160
175
|
if task_spec.task_template.id.version == "":
|
|
161
176
|
task_spec.task_template.id.version = version
|
|
162
177
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
178
|
+
try:
|
|
179
|
+
resp = await get_client().run_service.CreateRun(
|
|
180
|
+
run_service_pb2.CreateRunRequest(
|
|
181
|
+
run_id=run_id,
|
|
182
|
+
project_id=project_id,
|
|
183
|
+
task_spec=task_spec,
|
|
184
|
+
inputs=inputs.proto_inputs,
|
|
185
|
+
),
|
|
186
|
+
)
|
|
187
|
+
return Run(pb2=resp.run)
|
|
188
|
+
except grpc.aio.AioRpcError as e:
|
|
189
|
+
if e.code() == grpc.StatusCode.UNAVAILABLE:
|
|
190
|
+
raise flyte.errors.RuntimeSystemError(
|
|
191
|
+
"SystemUnavailableError",
|
|
192
|
+
"Flyte system is currently unavailable. check your configuration, or the service status.",
|
|
193
|
+
) from e
|
|
194
|
+
elif e.code() == grpc.StatusCode.INVALID_ARGUMENT:
|
|
195
|
+
raise flyte.errors.RuntimeUserError("InvalidArgumentError", e.details())
|
|
196
|
+
elif e.code() == grpc.StatusCode.ALREADY_EXISTS:
|
|
197
|
+
# TODO maybe this should be a pass and return existing run?
|
|
198
|
+
raise flyte.errors.RuntimeUserError(
|
|
199
|
+
"RunAlreadyExistsError",
|
|
200
|
+
f"A run with the name '{self._name}' already exists. Please choose a different name.",
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
raise flyte.errors.RuntimeSystemError(
|
|
204
|
+
"RunCreationError",
|
|
205
|
+
f"Failed to create run: {e.details()}",
|
|
206
|
+
) from e
|
|
172
207
|
|
|
173
208
|
class DryRun(Run):
|
|
174
209
|
def __init__(self, _task_spec, _inputs, _code_bundle):
|
|
@@ -225,7 +260,10 @@ class _Runner:
|
|
|
225
260
|
else:
|
|
226
261
|
if self._copy_files != "none":
|
|
227
262
|
code_bundle = await build_code_bundle(
|
|
228
|
-
from_dir=cfg.root_dir,
|
|
263
|
+
from_dir=cfg.root_dir,
|
|
264
|
+
dryrun=self._dry_run,
|
|
265
|
+
copy_bundle_to=self._copy_bundle_to,
|
|
266
|
+
copy_style=self._copy_files,
|
|
229
267
|
)
|
|
230
268
|
else:
|
|
231
269
|
code_bundle = None
|
|
@@ -313,7 +351,7 @@ class _Runner:
|
|
|
313
351
|
report=Report(name=action.name),
|
|
314
352
|
mode="local",
|
|
315
353
|
)
|
|
316
|
-
|
|
354
|
+
with ctx.replace_task_context(tctx):
|
|
317
355
|
# make the local version always runs on a different thread, returns a wrapped future.
|
|
318
356
|
if obj._call_as_synchronous:
|
|
319
357
|
fut = controller.submit_sync(obj, *args, **kwargs)
|
|
@@ -323,7 +361,9 @@ class _Runner:
|
|
|
323
361
|
return await controller.submit(obj, *args, **kwargs)
|
|
324
362
|
|
|
325
363
|
@syncify
|
|
326
|
-
async def run(
|
|
364
|
+
async def run(
|
|
365
|
+
self, task: TaskTemplate[P, Union[R, Run]] | LazyEntity, *args: P.args, **kwargs: P.kwargs
|
|
366
|
+
) -> Union[R, Run]:
|
|
327
367
|
"""
|
|
328
368
|
Run an async `@env.task` or `TaskTemplate` instance. The existing async context will be used.
|
|
329
369
|
|
|
@@ -345,8 +385,13 @@ class _Runner:
|
|
|
345
385
|
:param kwargs: Keyword arguments to pass to the Task
|
|
346
386
|
:return: Run instance or the result of the task
|
|
347
387
|
"""
|
|
388
|
+
from flyte.remote._task import LazyEntity
|
|
389
|
+
|
|
390
|
+
if isinstance(task, LazyEntity) and self._mode != "remote":
|
|
391
|
+
raise ValueError("Remote task can only be run in remote mode.")
|
|
348
392
|
if self._mode == "remote":
|
|
349
393
|
return await self._run_remote(task, *args, **kwargs)
|
|
394
|
+
task = cast(TaskTemplate, task)
|
|
350
395
|
if self._mode == "hybrid":
|
|
351
396
|
return await self._run_hybrid(task, *args, **kwargs)
|
|
352
397
|
|
flyte/_task.py
CHANGED
|
@@ -23,6 +23,7 @@ from typing import (
|
|
|
23
23
|
|
|
24
24
|
from flyteidl.core.tasks_pb2 import DataLoadingConfig
|
|
25
25
|
|
|
26
|
+
from flyte._pod import PodTemplate
|
|
26
27
|
from flyte.errors import RuntimeSystemError, RuntimeUserError
|
|
27
28
|
|
|
28
29
|
from ._cache import Cache, CacheRequest
|
|
@@ -37,8 +38,6 @@ from ._timeout import TimeoutType
|
|
|
37
38
|
from .models import NativeInterface, SerializationContext
|
|
38
39
|
|
|
39
40
|
if TYPE_CHECKING:
|
|
40
|
-
from kubernetes.client import V1PodTemplate
|
|
41
|
-
|
|
42
41
|
from ._task_environment import TaskEnvironment
|
|
43
42
|
|
|
44
43
|
P = ParamSpec("P") # capture the function's parameters
|
|
@@ -96,8 +95,7 @@ class TaskTemplate(Generic[P, R]):
|
|
|
96
95
|
env: Optional[Dict[str, str]] = None
|
|
97
96
|
secrets: Optional[SecretRequest] = None
|
|
98
97
|
timeout: Optional[TimeoutType] = None
|
|
99
|
-
|
|
100
|
-
pod_template: Optional[Union[str, V1PodTemplate]] = None
|
|
98
|
+
pod_template: Optional[Union[str, PodTemplate]] = None
|
|
101
99
|
report: bool = False
|
|
102
100
|
|
|
103
101
|
parent_env: Optional[weakref.ReferenceType[TaskEnvironment]] = None
|
|
@@ -108,15 +106,6 @@ class TaskTemplate(Generic[P, R]):
|
|
|
108
106
|
_call_as_synchronous: bool = False
|
|
109
107
|
|
|
110
108
|
def __post_init__(self):
|
|
111
|
-
# If pod_template is set to a pod, verify
|
|
112
|
-
if self.pod_template is not None and not isinstance(self.pod_template, str):
|
|
113
|
-
try:
|
|
114
|
-
from kubernetes.client import V1PodTemplate # noqa: F401
|
|
115
|
-
except ImportError as e:
|
|
116
|
-
raise ImportError(
|
|
117
|
-
"kubernetes is not installed, please install kubernetes package to use pod_template"
|
|
118
|
-
) from e
|
|
119
|
-
|
|
120
109
|
# Auto set the image based on the image request
|
|
121
110
|
if self.image == "auto":
|
|
122
111
|
self.image = Image.auto()
|