flyte 0.2.0b1__py3-none-any.whl → 0.2.0b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +3 -4
- flyte/_bin/runtime.py +21 -7
- flyte/_cache/cache.py +1 -2
- flyte/_cli/_common.py +26 -4
- flyte/_cli/_create.py +48 -0
- flyte/_cli/_deploy.py +4 -2
- flyte/_cli/_get.py +18 -7
- flyte/_cli/_run.py +1 -0
- flyte/_cli/main.py +11 -5
- flyte/_code_bundle/bundle.py +42 -11
- flyte/_context.py +1 -1
- flyte/_deploy.py +3 -1
- flyte/_group.py +1 -1
- flyte/_initialize.py +28 -247
- flyte/_internal/controllers/__init__.py +6 -6
- flyte/_internal/controllers/_local_controller.py +14 -5
- flyte/_internal/controllers/_trace.py +1 -1
- flyte/_internal/controllers/remote/__init__.py +27 -7
- flyte/_internal/controllers/remote/_action.py +1 -1
- flyte/_internal/controllers/remote/_client.py +5 -1
- flyte/_internal/controllers/remote/_controller.py +68 -24
- flyte/_internal/controllers/remote/_core.py +1 -1
- flyte/_internal/runtime/convert.py +34 -8
- flyte/_internal/runtime/entrypoints.py +1 -1
- flyte/_internal/runtime/io.py +3 -3
- flyte/_internal/runtime/task_serde.py +31 -1
- flyte/_internal/runtime/taskrunner.py +1 -1
- flyte/_internal/runtime/types_serde.py +1 -1
- flyte/_run.py +47 -28
- flyte/_task.py +2 -2
- flyte/_task_environment.py +1 -1
- flyte/_trace.py +5 -6
- flyte/_utils/__init__.py +2 -0
- flyte/_utils/async_cache.py +139 -0
- flyte/_version.py +2 -2
- flyte/config/__init__.py +26 -4
- flyte/config/_config.py +13 -4
- flyte/extras/_container.py +3 -3
- flyte/{_datastructures.py → models.py} +3 -2
- flyte/remote/_client/auth/_auth_utils.py +14 -0
- flyte/remote/_client/auth/_channel.py +28 -3
- flyte/remote/_client/auth/_token_client.py +3 -3
- flyte/remote/_client/controlplane.py +13 -13
- flyte/remote/_logs.py +1 -1
- flyte/remote/_run.py +4 -8
- flyte/remote/_task.py +2 -2
- flyte/storage/__init__.py +5 -0
- flyte/storage/_config.py +233 -0
- flyte/storage/_storage.py +23 -3
- flyte/types/_interface.py +1 -1
- flyte/types/_type_engine.py +1 -1
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/METADATA +2 -2
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/RECORD +56 -54
- flyte/_internal/controllers/pbhash.py +0 -39
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/entry_points.txt +0 -0
- {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/top_level.txt +0 -0
|
@@ -2,8 +2,9 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
from collections import defaultdict
|
|
5
|
+
from collections.abc import Callable
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Awaitable, DefaultDict, Tuple, TypeVar
|
|
7
|
+
from typing import Any, AsyncIterable, Awaitable, DefaultDict, Tuple, TypeVar
|
|
7
8
|
|
|
8
9
|
import flyte
|
|
9
10
|
import flyte.errors
|
|
@@ -11,7 +12,6 @@ import flyte.storage as storage
|
|
|
11
12
|
import flyte.types as types
|
|
12
13
|
from flyte._code_bundle import build_pkl_bundle
|
|
13
14
|
from flyte._context import internal_ctx
|
|
14
|
-
from flyte._datastructures import ActionID, NativeInterface, SerializationContext
|
|
15
15
|
from flyte._internal.controllers import TraceInfo
|
|
16
16
|
from flyte._internal.controllers.remote._action import Action
|
|
17
17
|
from flyte._internal.controllers.remote._core import Controller
|
|
@@ -21,16 +21,17 @@ from flyte._internal.runtime.task_serde import translate_task_to_wire
|
|
|
21
21
|
from flyte._logging import logger
|
|
22
22
|
from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
|
|
23
23
|
from flyte._task import TaskTemplate
|
|
24
|
+
from flyte.models import ActionID, NativeInterface, SerializationContext
|
|
24
25
|
|
|
25
26
|
R = TypeVar("R")
|
|
26
27
|
|
|
27
28
|
|
|
28
|
-
async def upload_inputs_with_retry(
|
|
29
|
+
async def upload_inputs_with_retry(serialized_inputs: AsyncIterable[bytes] | bytes, inputs_uri: str) -> None:
|
|
29
30
|
"""
|
|
30
31
|
Upload inputs to the specified URI with error handling.
|
|
31
32
|
|
|
32
33
|
Args:
|
|
33
|
-
|
|
34
|
+
serialized_inputs: The serialized inputs to upload
|
|
34
35
|
inputs_uri: The destination URI
|
|
35
36
|
|
|
36
37
|
Raises:
|
|
@@ -38,9 +39,9 @@ async def upload_inputs_with_retry(inputs: convert.Inputs, inputs_uri: str) -> N
|
|
|
38
39
|
"""
|
|
39
40
|
try:
|
|
40
41
|
# TODO Add retry decorator to this
|
|
41
|
-
await
|
|
42
|
+
await storage.put_stream(serialized_inputs, to_path=inputs_uri)
|
|
42
43
|
except Exception as e:
|
|
43
|
-
logger.exception("Failed to upload inputs"
|
|
44
|
+
logger.exception("Failed to upload inputs")
|
|
44
45
|
raise flyte.errors.RuntimeSystemError(type(e).__name__, str(e)) from e
|
|
45
46
|
|
|
46
47
|
|
|
@@ -89,6 +90,10 @@ async def load_and_convert_outputs(iface: NativeInterface, realized_outputs_uri:
|
|
|
89
90
|
return await convert.convert_outputs_to_native(iface, outputs)
|
|
90
91
|
|
|
91
92
|
|
|
93
|
+
def unique_action_name(action_id: ActionID) -> str:
|
|
94
|
+
return f"{action_id.name}_{action_id.run_name}"
|
|
95
|
+
|
|
96
|
+
|
|
92
97
|
class RemoteController(Controller):
|
|
93
98
|
"""
|
|
94
99
|
This a specialized controller that wraps the core controller and performs IO, serialization and deserialization
|
|
@@ -111,31 +116,42 @@ class RemoteController(Controller):
|
|
|
111
116
|
self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
|
|
112
117
|
lambda: asyncio.Semaphore(default_parent_concurrency)
|
|
113
118
|
)
|
|
119
|
+
self._parent_action_task_call_sequence: DefaultDict[str, DefaultDict[int, int]] = defaultdict(
|
|
120
|
+
lambda: defaultdict(int)
|
|
121
|
+
)
|
|
114
122
|
|
|
115
|
-
|
|
123
|
+
def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> int:
|
|
124
|
+
"""
|
|
125
|
+
Generate a task call sequence for the given task object and action ID.
|
|
126
|
+
This is used to track the number of times a task is called within an action.
|
|
127
|
+
"""
|
|
128
|
+
current_action_sequencer = self._parent_action_task_call_sequence[unique_action_name(action_id)]
|
|
129
|
+
current_task_id = id(task_obj)
|
|
130
|
+
v = current_action_sequencer[current_task_id]
|
|
131
|
+
new_seq = v + 1
|
|
132
|
+
current_action_sequencer[current_task_id] = new_seq
|
|
133
|
+
return new_seq
|
|
134
|
+
|
|
135
|
+
async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
116
136
|
ctx = internal_ctx()
|
|
117
137
|
tctx = ctx.data.task_context
|
|
118
138
|
if tctx is None:
|
|
119
139
|
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
120
140
|
current_action_id = tctx.action
|
|
121
141
|
|
|
122
|
-
inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
|
|
123
|
-
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _task.name, inputs)
|
|
124
|
-
|
|
125
142
|
# In the case of a regular code bundle, we will just pass it down as it is to the downstream tasks
|
|
126
143
|
# It is not allowed to change the code bundle (for regular code bundles) in the middle of a run.
|
|
127
144
|
code_bundle = tctx.code_bundle
|
|
128
145
|
|
|
129
146
|
if code_bundle and code_bundle.pkl:
|
|
130
|
-
logger.debug(f"Building new pkl bundle for task {
|
|
147
|
+
logger.debug(f"Building new pkl bundle for task {_task.name}")
|
|
131
148
|
code_bundle = await build_pkl_bundle(
|
|
132
149
|
_task,
|
|
133
150
|
upload_to_controlplane=False,
|
|
134
|
-
|
|
151
|
+
upload_from_dataplane_base_path=tctx.run_base_dir,
|
|
135
152
|
)
|
|
136
153
|
|
|
137
|
-
|
|
138
|
-
await upload_inputs_with_retry(inputs, inputs_uri)
|
|
154
|
+
inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
|
|
139
155
|
|
|
140
156
|
root_dir = Path(code_bundle.destination).absolute() if code_bundle else Path.cwd()
|
|
141
157
|
# Don't set output path in sec context because node executor will set it
|
|
@@ -146,13 +162,23 @@ class RemoteController(Controller):
|
|
|
146
162
|
code_bundle=code_bundle,
|
|
147
163
|
version=tctx.version,
|
|
148
164
|
# supplied version.
|
|
149
|
-
input_path=inputs_uri,
|
|
165
|
+
# input_path=inputs_uri,
|
|
150
166
|
image_cache=tctx.compiled_image_cache,
|
|
151
167
|
root_dir=root_dir,
|
|
152
168
|
)
|
|
153
169
|
|
|
154
170
|
task_spec = translate_task_to_wire(_task, new_serialization_context)
|
|
155
171
|
|
|
172
|
+
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
173
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
174
|
+
tctx, task_spec, serialized_inputs, _task_call_seq
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
inputs_uri = io.inputs_path(sub_action_output_path)
|
|
178
|
+
await upload_inputs_with_retry(serialized_inputs, inputs_uri)
|
|
179
|
+
# Clear to free memory
|
|
180
|
+
serialized_inputs = None # type: ignore
|
|
181
|
+
|
|
156
182
|
action = Action.from_task(
|
|
157
183
|
sub_action_id=run_definition_pb2.ActionIdentifier(
|
|
158
184
|
name=sub_action_id.name,
|
|
@@ -205,8 +231,9 @@ class RemoteController(Controller):
|
|
|
205
231
|
if tctx is None:
|
|
206
232
|
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
207
233
|
current_action_id = tctx.action
|
|
208
|
-
|
|
209
|
-
|
|
234
|
+
task_call_seq = self.generate_task_call_sequence(_task, current_action_id)
|
|
235
|
+
async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
|
|
236
|
+
return await self._submit(task_call_seq, _task, *args, **kwargs)
|
|
210
237
|
|
|
211
238
|
async def finalize_parent_action(self, action_id: ActionID):
|
|
212
239
|
"""
|
|
@@ -220,16 +247,17 @@ class RemoteController(Controller):
|
|
|
220
247
|
org=action_id.org,
|
|
221
248
|
)
|
|
222
249
|
await super()._finalize_parent_action(run_id=run_id, parent_action_name=action_id.name)
|
|
223
|
-
self._parent_action_semaphore.pop(action_id
|
|
250
|
+
self._parent_action_semaphore.pop(unique_action_name(action_id), None)
|
|
251
|
+
self._parent_action_task_call_sequence.pop(unique_action_name(action_id), None)
|
|
224
252
|
|
|
225
253
|
async def get_action_outputs(
|
|
226
|
-
self, _interface: NativeInterface,
|
|
254
|
+
self, _interface: NativeInterface, _func: Callable, *args, **kwargs
|
|
227
255
|
) -> Tuple[TraceInfo, bool]:
|
|
228
256
|
"""
|
|
229
257
|
This method returns the outputs of the action, if it is available.
|
|
230
258
|
If not available it raises a NotFoundError.
|
|
231
259
|
:param _interface: NativeInterface
|
|
232
|
-
:param
|
|
260
|
+
:param _func: Function name
|
|
233
261
|
:param args: Arguments
|
|
234
262
|
:param kwargs: Keyword arguments
|
|
235
263
|
:return:
|
|
@@ -240,11 +268,19 @@ class RemoteController(Controller):
|
|
|
240
268
|
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
241
269
|
current_action_id = tctx.action
|
|
242
270
|
|
|
271
|
+
func_name = _func.__name__
|
|
272
|
+
invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id)
|
|
243
273
|
inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
|
|
244
|
-
|
|
274
|
+
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
275
|
+
|
|
276
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
277
|
+
tctx, func_name, serialized_inputs, invoke_seq_num
|
|
278
|
+
)
|
|
245
279
|
|
|
246
280
|
inputs_uri = io.inputs_path(sub_action_output_path)
|
|
247
|
-
await upload_inputs_with_retry(
|
|
281
|
+
await upload_inputs_with_retry(serialized_inputs, inputs_uri)
|
|
282
|
+
# Clear to free memory
|
|
283
|
+
serialized_inputs = None # type: ignore
|
|
248
284
|
|
|
249
285
|
prev_action = await self.get_action(
|
|
250
286
|
run_definition_pb2.ActionIdentifier(
|
|
@@ -310,12 +346,20 @@ class RemoteController(Controller):
|
|
|
310
346
|
current_action_id = tctx.action
|
|
311
347
|
task_name = _task.spec.task_template.id.name
|
|
312
348
|
|
|
349
|
+
invoke_seq_num = self.generate_task_call_sequence(_task, current_action_id)
|
|
350
|
+
|
|
313
351
|
native_interface = types.guess_interface(_task.spec.task_template.interface)
|
|
314
352
|
inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
|
|
315
|
-
|
|
353
|
+
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
354
|
+
|
|
355
|
+
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
356
|
+
tctx, task_name, serialized_inputs, invoke_seq_num
|
|
357
|
+
)
|
|
316
358
|
|
|
317
359
|
inputs_uri = io.inputs_path(sub_action_output_path)
|
|
318
|
-
await upload_inputs_with_retry(
|
|
360
|
+
await upload_inputs_with_retry(serialized_inputs, inputs_uri)
|
|
361
|
+
# Clear to free memory
|
|
362
|
+
serialized_inputs = None # type: ignore
|
|
319
363
|
|
|
320
364
|
action = Action.from_task(
|
|
321
365
|
sub_action_id=run_definition_pb2.ActionIdentifier(
|
|
@@ -32,7 +32,7 @@ class Controller:
|
|
|
32
32
|
max_system_retries: int = 5,
|
|
33
33
|
resource_log_interval_sec: float = 10.0,
|
|
34
34
|
min_backoff_on_err_sec: float = 0.1,
|
|
35
|
-
thread_wait_timeout_sec: float =
|
|
35
|
+
thread_wait_timeout_sec: float = 0.5,
|
|
36
36
|
enqueue_timeout_sec: float = 5.0,
|
|
37
37
|
):
|
|
38
38
|
"""
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
4
|
+
import hashlib
|
|
3
5
|
from dataclasses import dataclass
|
|
4
6
|
from typing import Any, Dict, Tuple, Union
|
|
5
7
|
|
|
@@ -7,9 +9,8 @@ from flyteidl.core import execution_pb2, literals_pb2
|
|
|
7
9
|
|
|
8
10
|
import flyte.errors
|
|
9
11
|
import flyte.storage as storage
|
|
10
|
-
from flyte.
|
|
11
|
-
from flyte.
|
|
12
|
-
from flyte._protos.workflow import run_definition_pb2
|
|
12
|
+
from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
|
|
13
|
+
from flyte.models import ActionID, NativeInterface, TaskContext
|
|
13
14
|
from flyte.types import TypeEngine
|
|
14
15
|
|
|
15
16
|
|
|
@@ -185,21 +186,46 @@ def convert_from_native_to_error(err: BaseException) -> Error:
|
|
|
185
186
|
)
|
|
186
187
|
|
|
187
188
|
|
|
188
|
-
def
|
|
189
|
+
def hash_data(data: Union[str, bytes]) -> str:
|
|
190
|
+
"""
|
|
191
|
+
Generate a hash for the given data. If the data is a string, it will be encoded to bytes before hashing.
|
|
192
|
+
:param data: The data to hash, can be a string or bytes.
|
|
193
|
+
:return: A hexadecimal string representation of the hash.
|
|
194
|
+
"""
|
|
195
|
+
if isinstance(data, str):
|
|
196
|
+
data = data.encode("utf-8")
|
|
197
|
+
digest = hashlib.sha256(data).digest()
|
|
198
|
+
return base64.b64encode(digest).decode("utf-8")
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def generate_sub_action_id_and_output_path(
|
|
202
|
+
tctx: TaskContext,
|
|
203
|
+
task_spec_or_name: task_definition_pb2.TaskSpec | str,
|
|
204
|
+
serialized_inputs: str | bytes,
|
|
205
|
+
invoke_seq: int,
|
|
206
|
+
) -> Tuple[ActionID, str]:
|
|
189
207
|
"""
|
|
190
208
|
Generate a sub-action ID and output path based on the current task context, task name, and inputs.
|
|
209
|
+
|
|
210
|
+
action name = current action name + task name + input hash + group name (if available)
|
|
191
211
|
:param tctx:
|
|
192
|
-
:param
|
|
193
|
-
:param
|
|
212
|
+
:param task_spec_or_name: task specification or task name. Task name is only used in case of trace actions.
|
|
213
|
+
:param serialized_inputs:
|
|
214
|
+
:param invoke_seq: The sequence number of the invocation, used to differentiate between multiple invocations.
|
|
194
215
|
:return:
|
|
195
216
|
"""
|
|
196
217
|
current_action_id = tctx.action
|
|
197
218
|
current_output_path = tctx.run_base_dir
|
|
198
|
-
inputs_hash =
|
|
219
|
+
inputs_hash = hash_data(serialized_inputs)
|
|
220
|
+
if isinstance(task_spec_or_name, task_definition_pb2.TaskSpec):
|
|
221
|
+
task_hash = hash_data(task_spec_or_name.SerializeToString(deterministic=True))
|
|
222
|
+
else:
|
|
223
|
+
task_hash = task_spec_or_name
|
|
199
224
|
sub_action_id = current_action_id.new_sub_action_from(
|
|
200
|
-
|
|
225
|
+
task_hash=task_hash,
|
|
201
226
|
input_hash=inputs_hash,
|
|
202
227
|
group=tctx.group_data.name if tctx.group_data else None,
|
|
228
|
+
task_call_seq=invoke_seq,
|
|
203
229
|
)
|
|
204
230
|
sub_run_output_path = storage.join(current_output_path, sub_action_id.name)
|
|
205
231
|
return sub_action_id, sub_run_output_path
|
|
@@ -3,11 +3,11 @@ from typing import List, Optional, Tuple
|
|
|
3
3
|
import flyte.errors
|
|
4
4
|
from flyte._code_bundle import download_bundle
|
|
5
5
|
from flyte._context import contextual_run
|
|
6
|
-
from flyte._datastructures import ActionID, Checkpoints, CodeBundle, RawDataPath
|
|
7
6
|
from flyte._internal import Controller
|
|
8
7
|
from flyte._internal.imagebuild.image_builder import ImageCache
|
|
9
8
|
from flyte._logging import log, logger
|
|
10
9
|
from flyte._task import TaskTemplate
|
|
10
|
+
from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
|
|
11
11
|
|
|
12
12
|
from .convert import Error, Inputs, Outputs
|
|
13
13
|
from .task_serde import load_task
|
flyte/_internal/runtime/io.py
CHANGED
|
@@ -21,11 +21,11 @@ _OUTPUTS_FILE_NAME = "outputs.pb"
|
|
|
21
21
|
_CHECKPOINT_FILE_NAME = "_flytecheckpoints"
|
|
22
22
|
_ERROR_FILE_NAME = "error.pb"
|
|
23
23
|
_REPORT_FILE_NAME = "report.html"
|
|
24
|
-
|
|
24
|
+
_PKL_EXT = ".pkl.gz"
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
def pkl_path(base_path: str) -> str:
|
|
28
|
-
return storage.join(base_path,
|
|
27
|
+
def pkl_path(base_path: str, pkl_name: str) -> str:
|
|
28
|
+
return storage.join(base_path, f"{pkl_name}{_PKL_EXT}")
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
def inputs_path(base_path: str) -> str:
|
|
@@ -12,11 +12,11 @@ from google.protobuf import duration_pb2, wrappers_pb2
|
|
|
12
12
|
|
|
13
13
|
import flyte.errors
|
|
14
14
|
from flyte._cache.cache import VersionParameters, cache_from_request
|
|
15
|
-
from flyte._datastructures import SerializationContext
|
|
16
15
|
from flyte._logging import logger
|
|
17
16
|
from flyte._protos.workflow import task_definition_pb2
|
|
18
17
|
from flyte._secret import SecretRequest, secrets_from_request
|
|
19
18
|
from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
|
|
19
|
+
from flyte.models import CodeBundle, SerializationContext
|
|
20
20
|
|
|
21
21
|
from ..._retry import RetryStrategy
|
|
22
22
|
from ..._timeout import TimeoutType, timeout_from_request
|
|
@@ -208,3 +208,33 @@ def _get_urun_container(
|
|
|
208
208
|
data_config=task_template.data_loading_config(serialize_context),
|
|
209
209
|
config=task_template.config(serialize_context),
|
|
210
210
|
)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def extract_code_bundle(task_spec: task_definition_pb2.TaskSpec) -> Optional[CodeBundle]:
|
|
214
|
+
"""
|
|
215
|
+
Extract the code bundle from the task spec.
|
|
216
|
+
:param task_spec: The task spec to extract the code bundle from.
|
|
217
|
+
:return: The extracted code bundle or None if not present.
|
|
218
|
+
"""
|
|
219
|
+
container = task_spec.task_template.container
|
|
220
|
+
if container and container.args:
|
|
221
|
+
pkl_path = None
|
|
222
|
+
tgz_path = None
|
|
223
|
+
dest_path: str = "."
|
|
224
|
+
version = ""
|
|
225
|
+
for i, v in enumerate(container.args):
|
|
226
|
+
if v == "--pkl":
|
|
227
|
+
# Extract the code bundle path from the argument
|
|
228
|
+
pkl_path = container.args[i + 1] if i + 1 < len(container.args) else None
|
|
229
|
+
elif v == "--tgz":
|
|
230
|
+
# Extract the code bundle path from the argument
|
|
231
|
+
tgz_path = container.args[i + 1] if i + 1 < len(container.args) else None
|
|
232
|
+
elif v == "--dest":
|
|
233
|
+
# Extract the destination path from the argument
|
|
234
|
+
dest_path = container.args[i + 1] if i + 1 < len(container.args) else "."
|
|
235
|
+
elif v == "--version":
|
|
236
|
+
# Extract the version from the argument
|
|
237
|
+
version = container.args[i + 1] if i + 1 < len(container.args) else ""
|
|
238
|
+
if pkl_path or tgz_path:
|
|
239
|
+
return CodeBundle(destination=dest_path, tgz=tgz_path, pkl=pkl_path, computed_version=version)
|
|
240
|
+
return None
|
|
@@ -8,11 +8,11 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
|
8
8
|
|
|
9
9
|
import flyte.report
|
|
10
10
|
from flyte._context import internal_ctx
|
|
11
|
-
from flyte._datastructures import ActionID, Checkpoints, CodeBundle, RawDataPath, TaskContext
|
|
12
11
|
from flyte._internal.imagebuild.image_builder import ImageCache
|
|
13
12
|
from flyte._logging import log, logger
|
|
14
13
|
from flyte._task import TaskTemplate
|
|
15
14
|
from flyte.errors import CustomError, RuntimeSystemError, RuntimeUnknownError, RuntimeUserError
|
|
15
|
+
from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath, TaskContext
|
|
16
16
|
|
|
17
17
|
from .. import Controller
|
|
18
18
|
from .convert import (
|
flyte/_run.py
CHANGED
|
@@ -4,17 +4,12 @@ import pathlib
|
|
|
4
4
|
import uuid
|
|
5
5
|
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union, cast
|
|
6
6
|
|
|
7
|
-
import
|
|
8
|
-
import flyte.report
|
|
9
|
-
from flyte import S3
|
|
7
|
+
from flyte.errors import InitializationError
|
|
10
8
|
|
|
11
9
|
from ._api_commons import syncer
|
|
12
10
|
from ._context import contextual_run, internal_ctx
|
|
13
|
-
from ._datastructures import ActionID, Checkpoints, RawDataPath, SerializationContext, TaskContext
|
|
14
11
|
from ._environment import Environment
|
|
15
12
|
from ._initialize import (
|
|
16
|
-
ABFS,
|
|
17
|
-
GCS,
|
|
18
13
|
_get_init_config,
|
|
19
14
|
get_client,
|
|
20
15
|
get_common_config,
|
|
@@ -22,14 +17,10 @@ from ._initialize import (
|
|
|
22
17
|
requires_initialization,
|
|
23
18
|
requires_storage,
|
|
24
19
|
)
|
|
25
|
-
from ._internal import create_controller
|
|
26
|
-
from ._internal.runtime.io import _CHECKPOINT_FILE_NAME
|
|
27
|
-
from ._internal.runtime.taskrunner import run_task
|
|
28
20
|
from ._logging import logger
|
|
29
|
-
from ._protos.common import identifier_pb2
|
|
30
21
|
from ._task import P, R, TaskTemplate
|
|
31
22
|
from ._tools import ipython_check
|
|
32
|
-
from .
|
|
23
|
+
from .models import ActionID, Checkpoints, CodeBundle, RawDataPath, SerializationContext, TaskContext
|
|
33
24
|
|
|
34
25
|
if TYPE_CHECKING:
|
|
35
26
|
from flyte.remote import Run
|
|
@@ -39,6 +30,22 @@ if TYPE_CHECKING:
|
|
|
39
30
|
Mode = Literal["local", "remote", "hybrid"]
|
|
40
31
|
|
|
41
32
|
|
|
33
|
+
async def _get_code_bundle_for_run(name: str) -> CodeBundle | None:
|
|
34
|
+
"""
|
|
35
|
+
Get the code bundle for the run with the given name.
|
|
36
|
+
This is used to get the code bundle for the run when running in hybrid mode.
|
|
37
|
+
"""
|
|
38
|
+
from flyte._internal.runtime.task_serde import extract_code_bundle
|
|
39
|
+
from flyte.remote import Run
|
|
40
|
+
|
|
41
|
+
run = await Run.get.aio(Run, name=name)
|
|
42
|
+
if run:
|
|
43
|
+
run_details = await run.details()
|
|
44
|
+
spec = run_details.action_details.pb2.resolved_task_spec
|
|
45
|
+
return extract_code_bundle(spec)
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
|
|
42
49
|
@syncer.wrap
|
|
43
50
|
class _Runner:
|
|
44
51
|
def __init__(
|
|
@@ -81,6 +88,7 @@ class _Runner:
|
|
|
81
88
|
from ._deploy import build_images, plan_deploy
|
|
82
89
|
from ._internal.runtime.convert import convert_from_native_to_inputs
|
|
83
90
|
from ._internal.runtime.task_serde import translate_task_to_wire
|
|
91
|
+
from ._protos.common import identifier_pb2
|
|
84
92
|
from ._protos.workflow import run_definition_pb2, run_service_pb2
|
|
85
93
|
|
|
86
94
|
cfg = get_common_config()
|
|
@@ -187,9 +195,14 @@ class _Runner:
|
|
|
187
195
|
run in the cluster remotely. This is currently only used for testing,
|
|
188
196
|
over the longer term we will productize this.
|
|
189
197
|
"""
|
|
198
|
+
import flyte.report
|
|
190
199
|
from flyte._code_bundle import build_code_bundle, build_pkl_bundle
|
|
191
|
-
from flyte._datastructures import RawDataPath
|
|
192
200
|
from flyte._deploy import build_images, plan_deploy
|
|
201
|
+
from flyte.models import RawDataPath
|
|
202
|
+
from flyte.storage import ABFS, GCS, S3
|
|
203
|
+
|
|
204
|
+
from ._internal import create_controller
|
|
205
|
+
from ._internal.runtime.taskrunner import run_task
|
|
193
206
|
|
|
194
207
|
cfg = get_common_config()
|
|
195
208
|
|
|
@@ -199,17 +212,23 @@ class _Runner:
|
|
|
199
212
|
deploy_plan = plan_deploy(cast(Environment, obj.parent_env()))
|
|
200
213
|
image_cache = await build_images(deploy_plan)
|
|
201
214
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
)
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
215
|
+
code_bundle = None
|
|
216
|
+
if self._name is not None:
|
|
217
|
+
# Check if remote run service has this run name already and if exists, then extract the code bundle from it.
|
|
218
|
+
code_bundle = await _get_code_bundle_for_run(name=self._name)
|
|
219
|
+
|
|
220
|
+
if not code_bundle:
|
|
221
|
+
if self._interactive_mode:
|
|
222
|
+
code_bundle = await build_pkl_bundle(
|
|
223
|
+
obj, upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to
|
|
210
224
|
)
|
|
211
225
|
else:
|
|
212
|
-
|
|
226
|
+
if self._copy_files != "none":
|
|
227
|
+
code_bundle = await build_code_bundle(
|
|
228
|
+
from_dir=cfg.root_dir, dryrun=self._dry_run, copy_bundle_to=self._copy_bundle_to
|
|
229
|
+
)
|
|
230
|
+
else:
|
|
231
|
+
code_bundle = None
|
|
213
232
|
|
|
214
233
|
version = self._version or (
|
|
215
234
|
code_bundle.computed_version if code_bundle and code_bundle.computed_version else None
|
|
@@ -217,14 +236,14 @@ class _Runner:
|
|
|
217
236
|
if not version:
|
|
218
237
|
raise ValueError("Version is required when running a task")
|
|
219
238
|
|
|
220
|
-
project = cfg.project
|
|
221
|
-
domain = cfg.domain
|
|
222
|
-
org = cfg.org
|
|
239
|
+
project = cfg.project
|
|
240
|
+
domain = cfg.domain
|
|
241
|
+
org = cfg.org
|
|
223
242
|
action_name = "a0"
|
|
224
243
|
run_name = self._name
|
|
225
244
|
random_id = str(uuid.uuid4())[:6]
|
|
226
245
|
|
|
227
|
-
controller = create_controller(
|
|
246
|
+
controller = create_controller("remote", endpoint="localhost:8090", insecure=True)
|
|
228
247
|
action = ActionID(name=action_name, run_name=run_name, project=project, domain=domain, org=org)
|
|
229
248
|
|
|
230
249
|
inputs = obj.native_interface.convert_to_kwargs(*args, **kwargs)
|
|
@@ -242,7 +261,7 @@ class _Runner:
|
|
|
242
261
|
output_path = self._run_base_dir
|
|
243
262
|
raw_data_path = f"{output_path}/rd/{random_id}"
|
|
244
263
|
raw_data_path_obj = RawDataPath(path=raw_data_path)
|
|
245
|
-
checkpoint_path = f"{raw_data_path}/
|
|
264
|
+
checkpoint_path = f"{raw_data_path}/checkpoint"
|
|
246
265
|
prev_checkpoint = f"{raw_data_path}/prev_checkpoint"
|
|
247
266
|
checkpoints = Checkpoints(checkpoint_path, prev_checkpoint)
|
|
248
267
|
|
|
@@ -253,7 +272,7 @@ class _Runner:
|
|
|
253
272
|
checkpoints=checkpoints,
|
|
254
273
|
code_bundle=code_bundle,
|
|
255
274
|
output_path=output_path,
|
|
256
|
-
version=version,
|
|
275
|
+
version=version if version else "na",
|
|
257
276
|
raw_data_path=raw_data_path_obj,
|
|
258
277
|
compiled_image_cache=image_cache,
|
|
259
278
|
run_base_dir=self._run_base_dir,
|
|
@@ -276,7 +295,7 @@ class _Runner:
|
|
|
276
295
|
)
|
|
277
296
|
from flyte._internal.runtime.entrypoints import direct_dispatch
|
|
278
297
|
|
|
279
|
-
controller = create_controller(
|
|
298
|
+
controller = create_controller("local")
|
|
280
299
|
inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
|
|
281
300
|
if self._name is None:
|
|
282
301
|
action = ActionID.create_random()
|
flyte/_task.py
CHANGED
|
@@ -25,7 +25,6 @@ from flyte.errors import RuntimeSystemError, RuntimeUserError
|
|
|
25
25
|
|
|
26
26
|
from ._cache import Cache, CacheRequest
|
|
27
27
|
from ._context import internal_ctx
|
|
28
|
-
from ._datastructures import NativeInterface, SerializationContext
|
|
29
28
|
from ._doc import Documentation
|
|
30
29
|
from ._image import Image
|
|
31
30
|
from ._resources import Resources
|
|
@@ -33,6 +32,7 @@ from ._retry import RetryStrategy
|
|
|
33
32
|
from ._reusable_environment import ReusePolicy
|
|
34
33
|
from ._secret import SecretRequest
|
|
35
34
|
from ._timeout import TimeoutType
|
|
35
|
+
from .models import NativeInterface, SerializationContext
|
|
36
36
|
|
|
37
37
|
if TYPE_CHECKING:
|
|
38
38
|
from kubernetes.client import V1PodTemplate
|
|
@@ -224,7 +224,7 @@ class TaskTemplate(Generic[P, R]):
|
|
|
224
224
|
# We will also check if we are not initialized, It is not expected to be not initialized
|
|
225
225
|
from ._internal.controllers import get_controller
|
|
226
226
|
|
|
227
|
-
controller =
|
|
227
|
+
controller = get_controller()
|
|
228
228
|
if controller:
|
|
229
229
|
return await controller.submit(self, *args, **kwargs)
|
|
230
230
|
return await self.execute(*args, **kwargs)
|
flyte/_task_environment.py
CHANGED
|
@@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Literal, Opti
|
|
|
10
10
|
import rich.repr
|
|
11
11
|
|
|
12
12
|
from ._cache import CacheRequest
|
|
13
|
-
from ._datastructures import NativeInterface
|
|
14
13
|
from ._doc import Documentation
|
|
15
14
|
from ._environment import Environment
|
|
16
15
|
from ._image import Image
|
|
@@ -19,6 +18,7 @@ from ._retry import RetryStrategy
|
|
|
19
18
|
from ._reusable_environment import ReusePolicy
|
|
20
19
|
from ._secret import SecretRequest
|
|
21
20
|
from ._task import AsyncFunctionTaskTemplate, TaskTemplate
|
|
21
|
+
from .models import NativeInterface
|
|
22
22
|
|
|
23
23
|
if TYPE_CHECKING:
|
|
24
24
|
from kubernetes.client import V1PodTemplate
|
flyte/_trace.py
CHANGED
|
@@ -4,7 +4,7 @@ import time
|
|
|
4
4
|
from datetime import timedelta
|
|
5
5
|
from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Callable, TypeGuard, TypeVar, Union, cast
|
|
6
6
|
|
|
7
|
-
from flyte.
|
|
7
|
+
from flyte.models import NativeInterface
|
|
8
8
|
|
|
9
9
|
T = TypeVar("T")
|
|
10
10
|
|
|
@@ -14,7 +14,6 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
|
|
|
14
14
|
A decorator that traces function execution with timing information.
|
|
15
15
|
Works with regular functions, async functions, and async generators/iterators.
|
|
16
16
|
"""
|
|
17
|
-
func_name = func.__name__
|
|
18
17
|
|
|
19
18
|
@functools.wraps(func)
|
|
20
19
|
def wrapper_sync(*args: Any, **kwargs: Any) -> Any:
|
|
@@ -31,9 +30,9 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
|
|
|
31
30
|
# We will also check if we are not initialized, It is not expected to be not initialized
|
|
32
31
|
from ._internal.controllers import get_controller
|
|
33
32
|
|
|
34
|
-
controller =
|
|
33
|
+
controller = get_controller()
|
|
35
34
|
iface = NativeInterface.from_callable(func)
|
|
36
|
-
info, ok = await controller.get_action_outputs(iface,
|
|
35
|
+
info, ok = await controller.get_action_outputs(iface, func, *args, **kwargs)
|
|
37
36
|
if ok:
|
|
38
37
|
if info.output:
|
|
39
38
|
return info.output
|
|
@@ -74,9 +73,9 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
|
|
|
74
73
|
# We will also check if we are not initialized, It is not expected to be not initialized
|
|
75
74
|
from ._internal.controllers import get_controller
|
|
76
75
|
|
|
77
|
-
controller =
|
|
76
|
+
controller = get_controller()
|
|
78
77
|
iface = NativeInterface.from_callable(func)
|
|
79
|
-
info, ok = await controller.get_action_outputs(iface,
|
|
78
|
+
info, ok = await controller.get_action_outputs(iface, func, *args, **kwargs)
|
|
80
79
|
if ok:
|
|
81
80
|
if info.output:
|
|
82
81
|
for item in info.output:
|
flyte/_utils/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ Internal utility functions.
|
|
|
4
4
|
Except for logging, modules in this package should not depend on any other part of the repo.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
from .async_cache import AsyncLRUCache
|
|
7
8
|
from .coro_management import run_coros
|
|
8
9
|
from .file_handling import filehash_update, update_hasher_for_source
|
|
9
10
|
from .helpers import get_cwd_editable_install
|
|
@@ -11,6 +12,7 @@ from .lazy_module import lazy_module
|
|
|
11
12
|
from .uv_script_parser import parse_uv_script_file
|
|
12
13
|
|
|
13
14
|
__all__ = [
|
|
15
|
+
"AsyncLRUCache",
|
|
14
16
|
"filehash_update",
|
|
15
17
|
"get_cwd_editable_install",
|
|
16
18
|
"lazy_module",
|