flyte 2.0.0b18__py3-none-any.whl → 2.0.0b20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/_bin/runtime.py +2 -1
- flyte/_initialize.py +4 -4
- flyte/_internal/controllers/__init__.py +4 -5
- flyte/_internal/controllers/_local_controller.py +5 -5
- flyte/_internal/controllers/remote/__init__.py +0 -2
- flyte/_internal/controllers/remote/_controller.py +19 -23
- flyte/_internal/controllers/remote/_core.py +120 -92
- flyte/_internal/controllers/remote/_informer.py +15 -6
- flyte/_map.py +90 -12
- flyte/_task.py +3 -0
- flyte/_version.py +3 -3
- flyte/cli/_create.py +4 -1
- flyte/cli/_deploy.py +4 -5
- flyte/cli/_params.py +18 -4
- flyte/cli/_run.py +2 -2
- flyte/config/_config.py +2 -2
- flyte/config/_reader.py +14 -8
- flyte/errors.py +12 -1
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +17 -0
- flyte/io/_dataframe/basic_dfs.py +16 -7
- flyte/io/_dataframe/dataframe.py +84 -123
- flyte/remote/_task.py +52 -22
- flyte/report/_report.py +1 -1
- flyte/types/_type_engine.py +1 -30
- {flyte-2.0.0b18.data → flyte-2.0.0b20.data}/scripts/runtime.py +2 -1
- {flyte-2.0.0b18.dist-info → flyte-2.0.0b20.dist-info}/METADATA +2 -1
- {flyte-2.0.0b18.dist-info → flyte-2.0.0b20.dist-info}/RECORD +33 -31
- {flyte-2.0.0b18.data → flyte-2.0.0b20.data}/scripts/debug.py +0 -0
- {flyte-2.0.0b18.dist-info → flyte-2.0.0b20.dist-info}/WHEEL +0 -0
- {flyte-2.0.0b18.dist-info → flyte-2.0.0b20.dist-info}/entry_points.txt +0 -0
- {flyte-2.0.0b18.dist-info → flyte-2.0.0b20.dist-info}/licenses/LICENSE +0 -0
- {flyte-2.0.0b18.dist-info → flyte-2.0.0b20.dist-info}/top_level.txt +0 -0
flyte/_bin/runtime.py
CHANGED
|
@@ -101,7 +101,6 @@ def main(
|
|
|
101
101
|
from flyte._logging import logger
|
|
102
102
|
from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
|
|
103
103
|
|
|
104
|
-
logger.warning(f"Flyte runtime started for action {name} with run name {run_name}")
|
|
105
104
|
logger.info("Registering faulthandler for SIGUSR1")
|
|
106
105
|
faulthandler.register(signal.SIGUSR1)
|
|
107
106
|
|
|
@@ -117,6 +116,8 @@ def main(
|
|
|
117
116
|
if name.startswith("{{"):
|
|
118
117
|
name = os.getenv("ACTION_NAME", "")
|
|
119
118
|
|
|
119
|
+
logger.warning(f"Flyte runtime started for action {name} with run name {run_name}")
|
|
120
|
+
|
|
120
121
|
if debug and name == "a0":
|
|
121
122
|
from flyte._debug.vscode import _start_vscode_server
|
|
122
123
|
|
flyte/_initialize.py
CHANGED
|
@@ -228,7 +228,7 @@ async def init(
|
|
|
228
228
|
|
|
229
229
|
@syncify
|
|
230
230
|
async def init_from_config(
|
|
231
|
-
path_or_config: str | Config | None = None,
|
|
231
|
+
path_or_config: str | Path | Config | None = None,
|
|
232
232
|
root_dir: Path | None = None,
|
|
233
233
|
log_level: int | None = None,
|
|
234
234
|
) -> None:
|
|
@@ -251,11 +251,11 @@ async def init_from_config(
|
|
|
251
251
|
if path_or_config is None:
|
|
252
252
|
# If no path is provided, use the default config file
|
|
253
253
|
cfg = config.auto()
|
|
254
|
-
elif isinstance(path_or_config, str):
|
|
254
|
+
elif isinstance(path_or_config, (str, Path)):
|
|
255
255
|
if root_dir:
|
|
256
|
-
cfg_path =
|
|
256
|
+
cfg_path = root_dir.expanduser() / path_or_config
|
|
257
257
|
else:
|
|
258
|
-
cfg_path = path_or_config
|
|
258
|
+
cfg_path = Path(path_or_config).expanduser()
|
|
259
259
|
if not Path(cfg_path).exists():
|
|
260
260
|
raise InitializationError(
|
|
261
261
|
"ConfigFileNotFoundError",
|
|
@@ -5,12 +5,13 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, Tu
|
|
|
5
5
|
from flyte._task import TaskTemplate
|
|
6
6
|
from flyte.models import ActionID, NativeInterface
|
|
7
7
|
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from flyte.remote._task import TaskDetails
|
|
10
|
+
|
|
8
11
|
from ._trace import TraceInfo
|
|
9
12
|
|
|
10
13
|
__all__ = ["Controller", "ControllerType", "TraceInfo", "create_controller", "get_controller"]
|
|
11
14
|
|
|
12
|
-
from ..._protos.workflow import task_definition_pb2
|
|
13
|
-
|
|
14
15
|
if TYPE_CHECKING:
|
|
15
16
|
import concurrent.futures
|
|
16
17
|
|
|
@@ -41,9 +42,7 @@ class Controller(Protocol):
|
|
|
41
42
|
"""
|
|
42
43
|
...
|
|
43
44
|
|
|
44
|
-
async def submit_task_ref(
|
|
45
|
-
self, _task: task_definition_pb2.TaskDetails, max_inline_io_bytes: int, *args, **kwargs
|
|
46
|
-
) -> Any:
|
|
45
|
+
async def submit_task_ref(self, _task: "TaskDetails", *args, **kwargs) -> Any:
|
|
47
46
|
"""
|
|
48
47
|
Submit a task reference to the controller asynchronously and wait for the result. This is async and will block
|
|
49
48
|
the current coroutine until the result is available.
|
|
@@ -11,10 +11,10 @@ from flyte._internal.controllers import TraceInfo
|
|
|
11
11
|
from flyte._internal.runtime import convert
|
|
12
12
|
from flyte._internal.runtime.entrypoints import direct_dispatch
|
|
13
13
|
from flyte._logging import log, logger
|
|
14
|
-
from flyte._protos.workflow import task_definition_pb2
|
|
15
14
|
from flyte._task import TaskTemplate
|
|
16
15
|
from flyte._utils.helpers import _selector_policy
|
|
17
16
|
from flyte.models import ActionID, NativeInterface
|
|
17
|
+
from flyte.remote._task import TaskDetails
|
|
18
18
|
|
|
19
19
|
R = TypeVar("R")
|
|
20
20
|
|
|
@@ -192,7 +192,7 @@ class LocalController:
|
|
|
192
192
|
assert info.start_time
|
|
193
193
|
assert info.end_time
|
|
194
194
|
|
|
195
|
-
async def submit_task_ref(
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
195
|
+
async def submit_task_ref(self, _task: TaskDetails, max_inline_io_bytes: int, *args, **kwargs) -> Any:
|
|
196
|
+
raise flyte.errors.ReferenceTaskError(
|
|
197
|
+
f"Reference tasks cannot be executed locally, only remotely. Found remote task {_task.name}"
|
|
198
|
+
)
|
|
@@ -12,7 +12,6 @@ from typing import Any, Awaitable, DefaultDict, Tuple, TypeVar
|
|
|
12
12
|
import flyte
|
|
13
13
|
import flyte.errors
|
|
14
14
|
import flyte.storage as storage
|
|
15
|
-
import flyte.types as types
|
|
16
15
|
from flyte._code_bundle import build_pkl_bundle
|
|
17
16
|
from flyte._context import internal_ctx
|
|
18
17
|
from flyte._internal.controllers import TraceInfo
|
|
@@ -24,10 +23,11 @@ from flyte._internal.runtime.task_serde import translate_task_to_wire
|
|
|
24
23
|
from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
|
|
25
24
|
from flyte._logging import logger
|
|
26
25
|
from flyte._protos.common import identifier_pb2
|
|
27
|
-
from flyte._protos.workflow import run_definition_pb2
|
|
26
|
+
from flyte._protos.workflow import run_definition_pb2
|
|
28
27
|
from flyte._task import TaskTemplate
|
|
29
28
|
from flyte._utils.helpers import _selector_policy
|
|
30
29
|
from flyte.models import MAX_INLINE_IO_BYTES, ActionID, NativeInterface, SerializationContext
|
|
30
|
+
from flyte.remote._task import TaskDetails
|
|
31
31
|
|
|
32
32
|
R = TypeVar("R")
|
|
33
33
|
|
|
@@ -117,9 +117,8 @@ class RemoteController(Controller):
|
|
|
117
117
|
def __init__(
|
|
118
118
|
self,
|
|
119
119
|
client_coro: Awaitable[ClientSet],
|
|
120
|
-
workers: int,
|
|
121
|
-
max_system_retries: int,
|
|
122
|
-
default_parent_concurrency: int = 100,
|
|
120
|
+
workers: int = 20,
|
|
121
|
+
max_system_retries: int = 10,
|
|
123
122
|
):
|
|
124
123
|
""" """
|
|
125
124
|
super().__init__(
|
|
@@ -127,6 +126,7 @@ class RemoteController(Controller):
|
|
|
127
126
|
workers=workers,
|
|
128
127
|
max_system_retries=max_system_retries,
|
|
129
128
|
)
|
|
129
|
+
default_parent_concurrency = int(os.getenv("_F_P_CNC", "100"))
|
|
130
130
|
self._default_parent_concurrency = default_parent_concurrency
|
|
131
131
|
self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
|
|
132
132
|
lambda: asyncio.Semaphore(default_parent_concurrency)
|
|
@@ -482,19 +482,17 @@ class RemoteController(Controller):
|
|
|
482
482
|
# If the action is cancelled, we need to cancel the action on the server as well
|
|
483
483
|
raise
|
|
484
484
|
|
|
485
|
-
async def _submit_task_ref(
|
|
486
|
-
self, invoke_seq_num: int, _task: task_definition_pb2.TaskDetails, max_inline_io_bytes: int, *args, **kwargs
|
|
487
|
-
) -> Any:
|
|
485
|
+
async def _submit_task_ref(self, invoke_seq_num: int, _task: TaskDetails, *args, **kwargs) -> Any:
|
|
488
486
|
ctx = internal_ctx()
|
|
489
487
|
tctx = ctx.data.task_context
|
|
490
488
|
if tctx is None:
|
|
491
489
|
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
492
490
|
current_action_id = tctx.action
|
|
493
|
-
task_name = _task.
|
|
491
|
+
task_name = _task.name
|
|
492
|
+
|
|
493
|
+
native_interface = _task.interface
|
|
494
|
+
pb_interface = _task.pb2.spec.task_template.interface
|
|
494
495
|
|
|
495
|
-
native_interface = types.guess_interface(
|
|
496
|
-
_task.spec.task_template.interface, default_inputs=_task.spec.default_inputs
|
|
497
|
-
)
|
|
498
496
|
inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
|
|
499
497
|
inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
|
|
500
498
|
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
@@ -503,19 +501,19 @@ class RemoteController(Controller):
|
|
|
503
501
|
|
|
504
502
|
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
505
503
|
inputs_uri = io.inputs_path(sub_action_output_path)
|
|
506
|
-
await upload_inputs_with_retry(serialized_inputs, inputs_uri, max_inline_io_bytes)
|
|
504
|
+
await upload_inputs_with_retry(serialized_inputs, inputs_uri, _task.max_inline_io_bytes)
|
|
507
505
|
# cache key - task name, task signature, inputs, cache version
|
|
508
506
|
cache_key = None
|
|
509
|
-
md = _task.spec.task_template.metadata
|
|
507
|
+
md = _task.pb2.spec.task_template.metadata
|
|
510
508
|
ignored_input_vars = []
|
|
511
509
|
if len(md.cache_ignore_input_vars) > 0:
|
|
512
510
|
ignored_input_vars = list(md.cache_ignore_input_vars)
|
|
513
|
-
if
|
|
514
|
-
discovery_version =
|
|
511
|
+
if md and md.discoverable:
|
|
512
|
+
discovery_version = md.discovery_version
|
|
515
513
|
cache_key = convert.generate_cache_key_hash(
|
|
516
514
|
task_name,
|
|
517
515
|
inputs_hash,
|
|
518
|
-
|
|
516
|
+
pb_interface,
|
|
519
517
|
discovery_version,
|
|
520
518
|
ignored_input_vars,
|
|
521
519
|
inputs.proto_inputs,
|
|
@@ -537,7 +535,7 @@ class RemoteController(Controller):
|
|
|
537
535
|
),
|
|
538
536
|
parent_action_name=current_action_id.name,
|
|
539
537
|
group_data=tctx.group_data,
|
|
540
|
-
task_spec=_task.spec,
|
|
538
|
+
task_spec=_task.pb2.spec,
|
|
541
539
|
inputs_uri=inputs_uri,
|
|
542
540
|
run_output_base=tctx.run_base_dir,
|
|
543
541
|
cache_key=cache_key,
|
|
@@ -566,12 +564,10 @@ class RemoteController(Controller):
|
|
|
566
564
|
"RuntimeError",
|
|
567
565
|
f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
|
|
568
566
|
)
|
|
569
|
-
return await load_and_convert_outputs(native_interface, n.realized_outputs_uri, max_inline_io_bytes)
|
|
567
|
+
return await load_and_convert_outputs(native_interface, n.realized_outputs_uri, _task.max_inline_io_bytes)
|
|
570
568
|
return None
|
|
571
569
|
|
|
572
|
-
async def submit_task_ref(
|
|
573
|
-
self, _task: task_definition_pb2.TaskDetails, max_inline_io_bytes: int, *args, **kwargs
|
|
574
|
-
) -> Any:
|
|
570
|
+
async def submit_task_ref(self, _task: TaskDetails, *args, **kwargs) -> Any:
|
|
575
571
|
ctx = internal_ctx()
|
|
576
572
|
tctx = ctx.data.task_context
|
|
577
573
|
if tctx is None:
|
|
@@ -579,4 +575,4 @@ class RemoteController(Controller):
|
|
|
579
575
|
current_action_id = tctx.action
|
|
580
576
|
task_call_seq = self.generate_task_call_sequence(_task, current_action_id)
|
|
581
577
|
async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
|
|
582
|
-
return await self._submit_task_ref(task_call_seq, _task,
|
|
578
|
+
return await self._submit_task_ref(task_call_seq, _task, *args, **kwargs)
|
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import os
|
|
4
5
|
import sys
|
|
5
6
|
import threading
|
|
6
7
|
from asyncio import Event
|
|
7
8
|
from typing import Awaitable, Coroutine, Optional
|
|
8
9
|
|
|
9
10
|
import grpc.aio
|
|
11
|
+
from aiolimiter import AsyncLimiter
|
|
10
12
|
from google.protobuf.wrappers_pb2 import StringValue
|
|
11
13
|
|
|
12
14
|
import flyte.errors
|
|
@@ -32,10 +34,10 @@ class Controller:
|
|
|
32
34
|
def __init__(
|
|
33
35
|
self,
|
|
34
36
|
client_coro: Awaitable[ClientSet],
|
|
35
|
-
workers: int =
|
|
36
|
-
max_system_retries: int =
|
|
37
|
+
workers: int = 20,
|
|
38
|
+
max_system_retries: int = 10,
|
|
37
39
|
resource_log_interval_sec: float = 10.0,
|
|
38
|
-
min_backoff_on_err_sec: float = 0.
|
|
40
|
+
min_backoff_on_err_sec: float = 0.5,
|
|
39
41
|
thread_wait_timeout_sec: float = 5.0,
|
|
40
42
|
enqueue_timeout_sec: float = 5.0,
|
|
41
43
|
):
|
|
@@ -53,14 +55,17 @@ class Controller:
|
|
|
53
55
|
self._running = False
|
|
54
56
|
self._resource_log_task = None
|
|
55
57
|
self._workers = workers
|
|
56
|
-
self._max_retries = max_system_retries
|
|
58
|
+
self._max_retries = int(os.getenv("_F_MAX_RETRIES", max_system_retries))
|
|
57
59
|
self._resource_log_interval = resource_log_interval_sec
|
|
58
60
|
self._min_backoff_on_err = min_backoff_on_err_sec
|
|
61
|
+
self._max_backoff_on_err = float(os.getenv("_F_MAX_BFF_ON_ERR", "10.0"))
|
|
59
62
|
self._thread_wait_timeout = thread_wait_timeout_sec
|
|
60
63
|
self._client_coro = client_coro
|
|
61
64
|
self._failure_event: Event | None = None
|
|
62
65
|
self._enqueue_timeout = enqueue_timeout_sec
|
|
63
66
|
self._informer_start_wait_timeout = thread_wait_timeout_sec
|
|
67
|
+
max_qps = int(os.getenv("_F_MAX_QPS", "100"))
|
|
68
|
+
self._rate_limiter = AsyncLimiter(max_qps, 1.0)
|
|
64
69
|
|
|
65
70
|
# Thread management
|
|
66
71
|
self._thread = None
|
|
@@ -194,15 +199,16 @@ class Controller:
|
|
|
194
199
|
# We will wait for this to signal that the thread is ready
|
|
195
200
|
# Signal the main thread that we're ready
|
|
196
201
|
logger.debug("Background thread initialization complete")
|
|
197
|
-
self._thread_ready.set()
|
|
198
202
|
if sys.version_info >= (3, 11):
|
|
199
203
|
async with asyncio.TaskGroup() as tg:
|
|
200
204
|
for i in range(self._workers):
|
|
201
|
-
tg.create_task(self._bg_run())
|
|
205
|
+
tg.create_task(self._bg_run(f"worker-{i}"))
|
|
206
|
+
self._thread_ready.set()
|
|
202
207
|
else:
|
|
203
208
|
tasks = []
|
|
204
209
|
for i in range(self._workers):
|
|
205
|
-
tasks.append(asyncio.create_task(self._bg_run()))
|
|
210
|
+
tasks.append(asyncio.create_task(self._bg_run(f"worker-{i}")))
|
|
211
|
+
self._thread_ready.set()
|
|
206
212
|
await asyncio.gather(*tasks)
|
|
207
213
|
|
|
208
214
|
def _bg_thread_target(self):
|
|
@@ -221,6 +227,7 @@ class Controller:
|
|
|
221
227
|
except Exception as e:
|
|
222
228
|
logger.error(f"Controller thread encountered an exception: {e}")
|
|
223
229
|
self._set_exception(e)
|
|
230
|
+
self._failure_event.set()
|
|
224
231
|
finally:
|
|
225
232
|
if self._loop and self._loop.is_running():
|
|
226
233
|
self._loop.close()
|
|
@@ -292,21 +299,22 @@ class Controller:
|
|
|
292
299
|
started = action.is_started()
|
|
293
300
|
action.mark_cancelled()
|
|
294
301
|
if started:
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
302
|
+
async with self._rate_limiter:
|
|
303
|
+
logger.info(f"Cancelling action: {action.name}")
|
|
304
|
+
try:
|
|
305
|
+
# TODO add support when the queue service supports aborting actions
|
|
306
|
+
# await self._queue_service.AbortQueuedAction(
|
|
307
|
+
# queue_service_pb2.AbortQueuedActionRequest(action_id=action.action_id),
|
|
308
|
+
# wait_for_ready=True,
|
|
309
|
+
# )
|
|
310
|
+
logger.info(f"Successfully cancelled action: {action.name}")
|
|
311
|
+
except grpc.aio.AioRpcError as e:
|
|
312
|
+
if e.code() in [
|
|
313
|
+
grpc.StatusCode.NOT_FOUND,
|
|
314
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
315
|
+
]:
|
|
316
|
+
logger.info(f"Action {action.name} not found, assumed completed or cancelled.")
|
|
317
|
+
return
|
|
310
318
|
else:
|
|
311
319
|
# If the action is not started, we have to ensure it does not get launched
|
|
312
320
|
logger.info(f"Action {action.name} is not started, no need to cancel.")
|
|
@@ -320,56 +328,69 @@ class Controller:
|
|
|
320
328
|
Attempt to launch an action.
|
|
321
329
|
"""
|
|
322
330
|
if not action.is_started():
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
if action.
|
|
327
|
-
|
|
328
|
-
|
|
331
|
+
async with self._rate_limiter:
|
|
332
|
+
task: queue_service_pb2.TaskAction | None = None
|
|
333
|
+
trace: queue_service_pb2.TraceAction | None = None
|
|
334
|
+
if action.type == "task":
|
|
335
|
+
if action.task is None:
|
|
336
|
+
raise flyte.errors.RuntimeSystemError(
|
|
337
|
+
"NoTaskSpec", "Task Spec not found, cannot launch Task Action."
|
|
338
|
+
)
|
|
339
|
+
cache_key = None
|
|
340
|
+
logger.info(f"Action {action.name} has cache version {action.cache_key}")
|
|
341
|
+
if action.cache_key:
|
|
342
|
+
cache_key = StringValue(value=action.cache_key)
|
|
343
|
+
|
|
344
|
+
task = queue_service_pb2.TaskAction(
|
|
345
|
+
id=task_definition_pb2.TaskIdentifier(
|
|
346
|
+
version=action.task.task_template.id.version,
|
|
347
|
+
org=action.task.task_template.id.org,
|
|
348
|
+
project=action.task.task_template.id.project,
|
|
349
|
+
domain=action.task.task_template.id.domain,
|
|
350
|
+
name=action.task.task_template.id.name,
|
|
351
|
+
),
|
|
352
|
+
spec=action.task,
|
|
353
|
+
cache_key=cache_key,
|
|
329
354
|
)
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
return
|
|
370
|
-
logger.exception(f"Failed to launch action: {action.name} backing off...")
|
|
371
|
-
logger.debug(f"Action details: {action}")
|
|
372
|
-
raise e
|
|
355
|
+
elif action.type == "trace":
|
|
356
|
+
trace = action.trace
|
|
357
|
+
|
|
358
|
+
logger.debug(f"Attempting to launch action: {action.name}")
|
|
359
|
+
try:
|
|
360
|
+
await self._queue_service.EnqueueAction(
|
|
361
|
+
queue_service_pb2.EnqueueActionRequest(
|
|
362
|
+
action_id=action.action_id,
|
|
363
|
+
parent_action_name=action.parent_action_name,
|
|
364
|
+
task=task,
|
|
365
|
+
trace=trace,
|
|
366
|
+
input_uri=action.inputs_uri,
|
|
367
|
+
run_output_base=action.run_output_base,
|
|
368
|
+
group=action.group.name if action.group else None,
|
|
369
|
+
# Subject is not used in the current implementation
|
|
370
|
+
),
|
|
371
|
+
wait_for_ready=True,
|
|
372
|
+
timeout=self._enqueue_timeout,
|
|
373
|
+
)
|
|
374
|
+
logger.info(f"Successfully launched action: {action.name}")
|
|
375
|
+
except grpc.aio.AioRpcError as e:
|
|
376
|
+
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
|
|
377
|
+
logger.info(f"Action {action.name} already exists, continuing to monitor.")
|
|
378
|
+
return
|
|
379
|
+
if e.code() in [
|
|
380
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
381
|
+
grpc.StatusCode.INVALID_ARGUMENT,
|
|
382
|
+
grpc.StatusCode.NOT_FOUND,
|
|
383
|
+
]:
|
|
384
|
+
raise flyte.errors.RuntimeSystemError(
|
|
385
|
+
e.code().name, f"Precondition failed: {e.details()}"
|
|
386
|
+
) from e
|
|
387
|
+
# For all other errors, we will retry with backoff
|
|
388
|
+
logger.exception(
|
|
389
|
+
f"Failed to launch action: {action.name}, Code: {e.code()}, "
|
|
390
|
+
f"Details {e.details()} backing off..."
|
|
391
|
+
)
|
|
392
|
+
logger.debug(f"Action details: {action}")
|
|
393
|
+
raise flyte.errors.SlowDownError(f"Failed to launch action: {e.details()}") from e
|
|
373
394
|
|
|
374
395
|
@log
|
|
375
396
|
async def _bg_process(self, action: Action):
|
|
@@ -397,35 +418,42 @@ class Controller:
|
|
|
397
418
|
await asyncio.sleep(self._resource_log_interval)
|
|
398
419
|
|
|
399
420
|
@log
|
|
400
|
-
async def _bg_run(self):
|
|
421
|
+
async def _bg_run(self, worker_id: str):
|
|
401
422
|
"""Run loop with resource status logging"""
|
|
423
|
+
logger.info(f"Worker {worker_id} started")
|
|
402
424
|
while self._running:
|
|
403
425
|
logger.debug(f"{threading.current_thread().name} Waiting for resource")
|
|
404
426
|
action = await self._shared_queue.get()
|
|
405
427
|
logger.debug(f"{threading.current_thread().name} Got resource {action.name}")
|
|
406
428
|
try:
|
|
407
429
|
await self._bg_process(action)
|
|
408
|
-
except
|
|
409
|
-
|
|
410
|
-
# TODO we need a better way of handling backoffs currently the entire worker coroutine backs off
|
|
411
|
-
await asyncio.sleep(self._min_backoff_on_err)
|
|
412
|
-
action.increment_retries()
|
|
430
|
+
except flyte.errors.SlowDownError as e:
|
|
431
|
+
action.retries += 1
|
|
413
432
|
if action.retries > self._max_retries:
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
433
|
+
raise
|
|
434
|
+
backoff = min(self._min_backoff_on_err * (2 ** (action.retries - 1)), self._max_backoff_on_err)
|
|
435
|
+
logger.warning(
|
|
436
|
+
f"[{worker_id}] Backing off for {backoff} [retry {action.retries}/{self._max_retries}] "
|
|
437
|
+
f"on action {action.name} due to error: {e}"
|
|
438
|
+
)
|
|
439
|
+
await asyncio.sleep(backoff)
|
|
440
|
+
logger.warning(f"[{worker_id}] Retrying action {action.name} after backoff")
|
|
441
|
+
await self._shared_queue.put(action)
|
|
442
|
+
except Exception as e:
|
|
443
|
+
logger.error(f"[{worker_id}] Error in controller loop: {e}")
|
|
444
|
+
err = flyte.errors.RuntimeSystemError(
|
|
445
|
+
code=type(e).__name__,
|
|
446
|
+
message=f"Controller failed, system retries {action.retries} crossed threshold {self._max_retries}",
|
|
447
|
+
worker=worker_id,
|
|
448
|
+
)
|
|
449
|
+
err.__cause__ = e
|
|
450
|
+
action.set_client_error(err)
|
|
451
|
+
informer = await self._informers.get(
|
|
452
|
+
run_name=action.run_name,
|
|
453
|
+
parent_action_name=action.parent_action_name,
|
|
454
|
+
)
|
|
455
|
+
if informer:
|
|
456
|
+
await informer.fire_completion_event(action.name)
|
|
429
457
|
finally:
|
|
430
458
|
self._shared_queue.task_done()
|
|
431
459
|
|
|
@@ -132,8 +132,10 @@ class Informer:
|
|
|
132
132
|
parent_action_name: str,
|
|
133
133
|
shared_queue: Queue,
|
|
134
134
|
client: Optional[StateService] = None,
|
|
135
|
-
|
|
135
|
+
min_watch_backoff: float = 1.0,
|
|
136
|
+
max_watch_backoff: float = 30.0,
|
|
136
137
|
watch_conn_timeout_sec: float = 5.0,
|
|
138
|
+
max_watch_retries: int = 10,
|
|
137
139
|
):
|
|
138
140
|
self.name = self.mkname(run_name=run_id.name, parent_action_name=parent_action_name)
|
|
139
141
|
self.parent_action_name = parent_action_name
|
|
@@ -144,8 +146,10 @@ class Informer:
|
|
|
144
146
|
self._running = False
|
|
145
147
|
self._watch_task: asyncio.Task | None = None
|
|
146
148
|
self._ready = asyncio.Event()
|
|
147
|
-
self.
|
|
149
|
+
self._min_watch_backoff = min_watch_backoff
|
|
150
|
+
self._max_watch_backoff = max_watch_backoff
|
|
148
151
|
self._watch_conn_timeout_sec = watch_conn_timeout_sec
|
|
152
|
+
self._max_watch_retries = max_watch_retries
|
|
149
153
|
|
|
150
154
|
@classmethod
|
|
151
155
|
def mkname(cls, *, run_name: str, parent_action_name: str) -> str:
|
|
@@ -211,13 +215,16 @@ class Informer:
|
|
|
211
215
|
"""
|
|
212
216
|
# sentinel = False
|
|
213
217
|
retries = 0
|
|
214
|
-
max_retries = 5
|
|
215
218
|
last_exc = None
|
|
216
219
|
while self._running:
|
|
217
|
-
if retries >=
|
|
218
|
-
logger.error(
|
|
220
|
+
if retries >= self._max_watch_retries:
|
|
221
|
+
logger.error(
|
|
222
|
+
f"Informer watch failure retries crossed threshold {retries}/{self._max_watch_retries}, exiting!"
|
|
223
|
+
)
|
|
219
224
|
raise last_exc
|
|
220
225
|
try:
|
|
226
|
+
if retries >= 1:
|
|
227
|
+
logger.warning(f"Informer watch retrying, attempt {retries}/{self._max_watch_retries}")
|
|
221
228
|
watcher = self._client.Watch(
|
|
222
229
|
state_service_pb2.WatchRequest(
|
|
223
230
|
parent_action_id=identifier_pb2.ActionIdentifier(
|
|
@@ -252,7 +259,9 @@ class Informer:
|
|
|
252
259
|
logger.exception(f"Watch error: {self.name}", exc_info=e)
|
|
253
260
|
last_exc = e
|
|
254
261
|
retries += 1
|
|
255
|
-
|
|
262
|
+
backoff = min(self._min_watch_backoff * (2**retries), self._max_watch_backoff)
|
|
263
|
+
logger.warning(f"Watch for {self.name} failed, retrying in {backoff} seconds...")
|
|
264
|
+
await asyncio.sleep(backoff)
|
|
256
265
|
|
|
257
266
|
@log
|
|
258
267
|
async def start(self, timeout: Optional[float] = None) -> asyncio.Task:
|