flyte 2.0.0b19__py3-none-any.whl → 2.0.0b21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (37) hide show
  1. flyte/_bin/runtime.py +2 -2
  2. flyte/_code_bundle/_ignore.py +11 -3
  3. flyte/_code_bundle/_packaging.py +9 -5
  4. flyte/_code_bundle/_utils.py +2 -2
  5. flyte/_initialize.py +14 -7
  6. flyte/_interface.py +35 -2
  7. flyte/_internal/controllers/remote/__init__.py +0 -2
  8. flyte/_internal/controllers/remote/_controller.py +3 -3
  9. flyte/_internal/controllers/remote/_core.py +120 -92
  10. flyte/_internal/controllers/remote/_informer.py +15 -6
  11. flyte/_internal/imagebuild/docker_builder.py +58 -8
  12. flyte/_internal/imagebuild/remote_builder.py +9 -26
  13. flyte/_keyring/__init__.py +0 -0
  14. flyte/_keyring/file.py +85 -0
  15. flyte/_logging.py +19 -8
  16. flyte/_utils/coro_management.py +2 -1
  17. flyte/_version.py +3 -3
  18. flyte/config/_config.py +6 -4
  19. flyte/config/_reader.py +19 -4
  20. flyte/errors.py +9 -0
  21. flyte/git/_config.py +2 -0
  22. flyte/io/_dataframe/dataframe.py +3 -2
  23. flyte/io/_dir.py +72 -72
  24. flyte/models.py +6 -2
  25. flyte/remote/_client/auth/_authenticators/device_code.py +3 -4
  26. flyte/remote/_data.py +2 -3
  27. flyte/remote/_run.py +17 -1
  28. flyte/storage/_config.py +5 -1
  29. flyte/types/_type_engine.py +7 -0
  30. {flyte-2.0.0b19.data → flyte-2.0.0b21.data}/scripts/runtime.py +2 -2
  31. {flyte-2.0.0b19.dist-info → flyte-2.0.0b21.dist-info}/METADATA +2 -1
  32. {flyte-2.0.0b19.dist-info → flyte-2.0.0b21.dist-info}/RECORD +37 -35
  33. {flyte-2.0.0b19.dist-info → flyte-2.0.0b21.dist-info}/entry_points.txt +3 -0
  34. {flyte-2.0.0b19.data → flyte-2.0.0b21.data}/scripts/debug.py +0 -0
  35. {flyte-2.0.0b19.dist-info → flyte-2.0.0b21.dist-info}/WHEEL +0 -0
  36. {flyte-2.0.0b19.dist-info → flyte-2.0.0b21.dist-info}/licenses/LICENSE +0 -0
  37. {flyte-2.0.0b19.dist-info → flyte-2.0.0b21.dist-info}/top_level.txt +0 -0
flyte/_bin/runtime.py CHANGED
@@ -21,8 +21,8 @@ import click
21
21
 
22
22
  ACTION_NAME = "ACTION_NAME"
23
23
  RUN_NAME = "RUN_NAME"
24
- PROJECT_NAME = "FLYTE_INTERNAL_TASK_PROJECT"
25
- DOMAIN_NAME = "FLYTE_INTERNAL_TASK_DOMAIN"
24
+ PROJECT_NAME = "FLYTE_INTERNAL_EXECUTION_PROJECT"
25
+ DOMAIN_NAME = "FLYTE_INTERNAL_EXECUTION_DOMAIN"
26
26
  ORG_NAME = "_U_ORG_NAME"
27
27
  ENDPOINT_OVERRIDE = "_U_EP_OVERRIDE"
28
28
  RUN_OUTPUT_BASE_DIR = "_U_RUN_BASE"
@@ -83,8 +83,15 @@ class StandardIgnore(Ignore):
83
83
  self.patterns = patterns if patterns else STANDARD_IGNORE_PATTERNS
84
84
 
85
85
  def _is_ignored(self, path: pathlib.Path) -> bool:
86
+ # Convert to relative path for pattern matching
87
+ try:
88
+ rel_path = path.relative_to(self.root)
89
+ except ValueError:
90
+ # If path is not under root, don't ignore it
91
+ return False
92
+
86
93
  for pattern in self.patterns:
87
- if fnmatch(str(path), pattern):
94
+ if fnmatch(str(rel_path), pattern):
88
95
  return True
89
96
  return False
90
97
 
@@ -105,9 +112,10 @@ class IgnoreGroup(Ignore):
105
112
 
106
113
  def list_ignored(self) -> List[str]:
107
114
  ignored = []
108
- for dir, _, files in self.root.walk():
115
+ for dir, _, files in os.walk(self.root):
116
+ dir_path = Path(dir)
109
117
  for file in files:
110
- abs_path = dir / file
118
+ abs_path = dir_path / file
111
119
  if self.is_ignored(abs_path):
112
120
  ignored.append(str(abs_path.relative_to(self.root)))
113
121
  return ignored
@@ -14,10 +14,9 @@ import typing
14
14
  from typing import List, Optional, Tuple, Union
15
15
 
16
16
  import click
17
- from rich import print as rich_print
18
17
  from rich.tree import Tree
19
18
 
20
- from flyte._logging import logger
19
+ from flyte._logging import _get_console, logger
21
20
 
22
21
  from ._ignore import Ignore, IgnoreGroup
23
22
  from ._utils import CopyFiles, _filehash_update, _pathhash_update, ls_files, tar_strip_file_attributes
@@ -27,10 +26,10 @@ FAST_FILEENDING = ".tar.gz"
27
26
 
28
27
 
29
28
  def print_ls_tree(source: os.PathLike, ls: typing.List[str]):
30
- click.secho("Files to be copied for fast registration...", fg="bright_blue")
29
+ logger.info("Files to be copied for fast registration...")
31
30
 
32
31
  tree_root = Tree(
33
- f":open_file_folder: [link file://{source}]{source} (detected source root)",
32
+ f"File structure:\n:open_file_folder: {source}",
34
33
  guide_style="bold bright_blue",
35
34
  )
36
35
  trees = {pathlib.Path(source): tree_root}
@@ -49,7 +48,12 @@ def print_ls_tree(source: os.PathLike, ls: typing.List[str]):
49
48
  else:
50
49
  current = trees[current_path]
51
50
  trees[fpp.parent].add(f"{fpp.name}", guide_style="bold bright_blue")
52
- rich_print(tree_root)
51
+
52
+ console = _get_console()
53
+ with console.capture() as capture:
54
+ console.print(tree_root, overflow="ignore", no_wrap=True, crop=False)
55
+ logger.info(f"Root directory: [link=file://{source}]{source}[/link]")
56
+ logger.info(capture.get(), extra={"console": console})
53
57
 
54
58
 
55
59
  def _compress_tarball(source: pathlib.Path, output: pathlib.Path) -> None:
@@ -156,7 +156,7 @@ def list_all_files(source_path: pathlib.Path, deref_symlinks, ignore_group: Opti
156
156
 
157
157
  # This is needed to prevent infinite recursion when walking with followlinks
158
158
  visited_inodes = set()
159
- for root, dirnames, files in source_path.walk(top_down=True, follow_symlinks=deref_symlinks):
159
+ for root, dirnames, files in os.walk(source_path, topdown=True, followlinks=deref_symlinks):
160
160
  dirnames[:] = [d for d in dirnames if d not in EXCLUDE_DIRS]
161
161
  if deref_symlinks:
162
162
  inode = os.stat(root).st_ino
@@ -167,7 +167,7 @@ def list_all_files(source_path: pathlib.Path, deref_symlinks, ignore_group: Opti
167
167
  ff = []
168
168
  files.sort()
169
169
  for fname in files:
170
- abspath = (root / fname).absolute()
170
+ abspath = (pathlib.Path(root) / fname).absolute()
171
171
  # Only consider files that exist (e.g. disregard symlinks that point to non-existent files)
172
172
  if not os.path.exists(abspath):
173
173
  logger.info(f"Skipping non-existent file {abspath}")
flyte/_initialize.py CHANGED
@@ -110,6 +110,12 @@ async def _initialize_client(
110
110
  )
111
111
 
112
112
 
113
+ def _initialize_logger(log_level: int | None = None):
114
+ initialize_logger(enable_rich=True)
115
+ if log_level:
116
+ initialize_logger(log_level=log_level, enable_rich=True)
117
+
118
+
113
119
  @syncify
114
120
  async def init(
115
121
  org: str | None = None,
@@ -172,14 +178,9 @@ async def init(
172
178
 
173
179
  :return: None
174
180
  """
175
- from flyte._tools import ipython_check
176
181
  from flyte._utils import get_cwd_editable_install, org_from_endpoint, sanitize_endpoint
177
182
 
178
- interactive_mode = ipython_check()
179
-
180
- initialize_logger(enable_rich=interactive_mode)
181
- if log_level:
182
- initialize_logger(log_level=log_level, enable_rich=interactive_mode)
183
+ _initialize_logger(log_level=log_level)
183
184
 
184
185
  global _init_config # noqa: PLW0603
185
186
 
@@ -231,6 +232,7 @@ async def init_from_config(
231
232
  path_or_config: str | Path | Config | None = None,
232
233
  root_dir: Path | None = None,
233
234
  log_level: int | None = None,
235
+ storage: Storage | None = None,
234
236
  ) -> None:
235
237
  """
236
238
  Initialize the Flyte system using a configuration file or Config object. This method should be called before any
@@ -245,6 +247,8 @@ async def init_from_config(
245
247
  default is set using the default initialization policies
246
248
  :return: None
247
249
  """
250
+ from rich.highlighter import ReprHighlighter
251
+
248
252
  import flyte.config as config
249
253
 
250
254
  cfg: config.Config
@@ -266,7 +270,9 @@ async def init_from_config(
266
270
  else:
267
271
  cfg = path_or_config
268
272
 
269
- logger.debug(f"Flyte config initialized as {cfg}")
273
+ _initialize_logger(log_level=log_level)
274
+
275
+ logger.info(f"Flyte config initialized as {cfg}", extra={"highlighter": ReprHighlighter()})
270
276
  await init.aio(
271
277
  org=cfg.task.org,
272
278
  project=cfg.task.project,
@@ -283,6 +289,7 @@ async def init_from_config(
283
289
  root_dir=root_dir,
284
290
  log_level=log_level,
285
291
  image_builder=cfg.image.builder,
292
+ storage=storage,
286
293
  )
287
294
 
288
295
 
flyte/_interface.py CHANGED
@@ -1,7 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- from typing import Dict, Generator, Tuple, Type, TypeVar, Union, cast, get_args, get_type_hints
4
+ import typing
5
+ from enum import Enum
6
+ from typing import Dict, Generator, Literal, Tuple, Type, TypeVar, Union, cast, get_args, get_origin, get_type_hints
7
+
8
+ from flyte._logging import logger
5
9
 
6
10
 
7
11
  def default_output_name(index: int = 0) -> str:
@@ -69,7 +73,15 @@ def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Di
69
73
  if len(return_annotation.__args__) == 1: # type: ignore
70
74
  raise TypeError("Tuples should be used to indicate multiple return values, found only one return variable.")
71
75
  ra = get_args(return_annotation)
72
- return dict(zip(list(output_name_generator(len(ra))), ra))
76
+ annotations = {}
77
+ for i, r in enumerate(ra):
78
+ if r is Ellipsis:
79
+ raise TypeError("Variable length tuples are not supported as return types.")
80
+ if get_origin(r) is Literal:
81
+ annotations[default_output_name(i)] = literal_to_enum(cast(Type, r))
82
+ else:
83
+ annotations[default_output_name(i)] = r
84
+ return annotations
73
85
 
74
86
  elif isinstance(return_annotation, tuple):
75
87
  if len(return_annotation) == 1:
@@ -79,4 +91,25 @@ def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Di
79
91
  else:
80
92
  # Handle all other single return types
81
93
  # Task returns unnamed native tuple
94
+ if get_origin(return_annotation) is Literal:
95
+ return {default_output_name(): literal_to_enum(cast(Type, return_annotation))}
82
96
  return {default_output_name(): cast(Type, return_annotation)}
97
+
98
+
99
+ def literal_to_enum(literal_type: Type) -> Type[Enum | typing.Any]:
100
+ """Convert a Literal[...] into Union[str, Enum]."""
101
+
102
+ if get_origin(literal_type) is not Literal:
103
+ raise TypeError(f"{literal_type} is not a Literal")
104
+
105
+ values = get_args(literal_type)
106
+ if not all(isinstance(v, str) for v in values):
107
+ logger.warning(f"Literal type {literal_type} contains non-string values, using Any instead of Enum")
108
+ return typing.Any
109
+ # Deduplicate & keep order
110
+ enum_dict = {str(v).upper(): v for v in values}
111
+
112
+ # Dynamically create an Enum
113
+ literal_enum = Enum("LiteralEnum", enum_dict) # type: ignore
114
+
115
+ return literal_enum # type: ignore
@@ -54,7 +54,5 @@ def create_remote_controller(
54
54
 
55
55
  controller = RemoteController(
56
56
  client_coro=client_coro,
57
- workers=10,
58
- max_system_retries=5,
59
57
  )
60
58
  return controller
@@ -117,9 +117,8 @@ class RemoteController(Controller):
117
117
  def __init__(
118
118
  self,
119
119
  client_coro: Awaitable[ClientSet],
120
- workers: int,
121
- max_system_retries: int,
122
- default_parent_concurrency: int = 100,
120
+ workers: int = 20,
121
+ max_system_retries: int = 10,
123
122
  ):
124
123
  """ """
125
124
  super().__init__(
@@ -127,6 +126,7 @@ class RemoteController(Controller):
127
126
  workers=workers,
128
127
  max_system_retries=max_system_retries,
129
128
  )
129
+ default_parent_concurrency = int(os.getenv("_F_P_CNC", "100"))
130
130
  self._default_parent_concurrency = default_parent_concurrency
131
131
  self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
132
132
  lambda: asyncio.Semaphore(default_parent_concurrency)
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import os
4
5
  import sys
5
6
  import threading
6
7
  from asyncio import Event
7
8
  from typing import Awaitable, Coroutine, Optional
8
9
 
9
10
  import grpc.aio
11
+ from aiolimiter import AsyncLimiter
10
12
  from google.protobuf.wrappers_pb2 import StringValue
11
13
 
12
14
  import flyte.errors
@@ -32,10 +34,10 @@ class Controller:
32
34
  def __init__(
33
35
  self,
34
36
  client_coro: Awaitable[ClientSet],
35
- workers: int = 2,
36
- max_system_retries: int = 5,
37
+ workers: int = 20,
38
+ max_system_retries: int = 10,
37
39
  resource_log_interval_sec: float = 10.0,
38
- min_backoff_on_err_sec: float = 0.1,
40
+ min_backoff_on_err_sec: float = 0.5,
39
41
  thread_wait_timeout_sec: float = 5.0,
40
42
  enqueue_timeout_sec: float = 5.0,
41
43
  ):
@@ -53,14 +55,17 @@ class Controller:
53
55
  self._running = False
54
56
  self._resource_log_task = None
55
57
  self._workers = workers
56
- self._max_retries = max_system_retries
58
+ self._max_retries = int(os.getenv("_F_MAX_RETRIES", max_system_retries))
57
59
  self._resource_log_interval = resource_log_interval_sec
58
60
  self._min_backoff_on_err = min_backoff_on_err_sec
61
+ self._max_backoff_on_err = float(os.getenv("_F_MAX_BFF_ON_ERR", "10.0"))
59
62
  self._thread_wait_timeout = thread_wait_timeout_sec
60
63
  self._client_coro = client_coro
61
64
  self._failure_event: Event | None = None
62
65
  self._enqueue_timeout = enqueue_timeout_sec
63
66
  self._informer_start_wait_timeout = thread_wait_timeout_sec
67
+ max_qps = int(os.getenv("_F_MAX_QPS", "100"))
68
+ self._rate_limiter = AsyncLimiter(max_qps, 1.0)
64
69
 
65
70
  # Thread management
66
71
  self._thread = None
@@ -194,15 +199,16 @@ class Controller:
194
199
  # We will wait for this to signal that the thread is ready
195
200
  # Signal the main thread that we're ready
196
201
  logger.debug("Background thread initialization complete")
197
- self._thread_ready.set()
198
202
  if sys.version_info >= (3, 11):
199
203
  async with asyncio.TaskGroup() as tg:
200
204
  for i in range(self._workers):
201
- tg.create_task(self._bg_run())
205
+ tg.create_task(self._bg_run(f"worker-{i}"))
206
+ self._thread_ready.set()
202
207
  else:
203
208
  tasks = []
204
209
  for i in range(self._workers):
205
- tasks.append(asyncio.create_task(self._bg_run()))
210
+ tasks.append(asyncio.create_task(self._bg_run(f"worker-{i}")))
211
+ self._thread_ready.set()
206
212
  await asyncio.gather(*tasks)
207
213
 
208
214
  def _bg_thread_target(self):
@@ -221,6 +227,7 @@ class Controller:
221
227
  except Exception as e:
222
228
  logger.error(f"Controller thread encountered an exception: {e}")
223
229
  self._set_exception(e)
230
+ self._failure_event.set()
224
231
  finally:
225
232
  if self._loop and self._loop.is_running():
226
233
  self._loop.close()
@@ -292,21 +299,22 @@ class Controller:
292
299
  started = action.is_started()
293
300
  action.mark_cancelled()
294
301
  if started:
295
- logger.info(f"Cancelling action: {action.name}")
296
- try:
297
- # TODO add support when the queue service supports aborting actions
298
- # await self._queue_service.AbortQueuedAction(
299
- # queue_service_pb2.AbortQueuedActionRequest(action_id=action.action_id),
300
- # wait_for_ready=True,
301
- # )
302
- logger.info(f"Successfully cancelled action: {action.name}")
303
- except grpc.aio.AioRpcError as e:
304
- if e.code() in [
305
- grpc.StatusCode.NOT_FOUND,
306
- grpc.StatusCode.FAILED_PRECONDITION,
307
- ]:
308
- logger.info(f"Action {action.name} not found, assumed completed or cancelled.")
309
- return
302
+ async with self._rate_limiter:
303
+ logger.info(f"Cancelling action: {action.name}")
304
+ try:
305
+ # TODO add support when the queue service supports aborting actions
306
+ # await self._queue_service.AbortQueuedAction(
307
+ # queue_service_pb2.AbortQueuedActionRequest(action_id=action.action_id),
308
+ # wait_for_ready=True,
309
+ # )
310
+ logger.info(f"Successfully cancelled action: {action.name}")
311
+ except grpc.aio.AioRpcError as e:
312
+ if e.code() in [
313
+ grpc.StatusCode.NOT_FOUND,
314
+ grpc.StatusCode.FAILED_PRECONDITION,
315
+ ]:
316
+ logger.info(f"Action {action.name} not found, assumed completed or cancelled.")
317
+ return
310
318
  else:
311
319
  # If the action is not started, we have to ensure it does not get launched
312
320
  logger.info(f"Action {action.name} is not started, no need to cancel.")
@@ -320,56 +328,69 @@ class Controller:
320
328
  Attempt to launch an action.
321
329
  """
322
330
  if not action.is_started():
323
- task: queue_service_pb2.TaskAction | None = None
324
- trace: queue_service_pb2.TraceAction | None = None
325
- if action.type == "task":
326
- if action.task is None:
327
- raise flyte.errors.RuntimeSystemError(
328
- "NoTaskSpec", "Task Spec not found, cannot launch Task Action."
331
+ async with self._rate_limiter:
332
+ task: queue_service_pb2.TaskAction | None = None
333
+ trace: queue_service_pb2.TraceAction | None = None
334
+ if action.type == "task":
335
+ if action.task is None:
336
+ raise flyte.errors.RuntimeSystemError(
337
+ "NoTaskSpec", "Task Spec not found, cannot launch Task Action."
338
+ )
339
+ cache_key = None
340
+ logger.info(f"Action {action.name} has cache version {action.cache_key}")
341
+ if action.cache_key:
342
+ cache_key = StringValue(value=action.cache_key)
343
+
344
+ task = queue_service_pb2.TaskAction(
345
+ id=task_definition_pb2.TaskIdentifier(
346
+ version=action.task.task_template.id.version,
347
+ org=action.task.task_template.id.org,
348
+ project=action.task.task_template.id.project,
349
+ domain=action.task.task_template.id.domain,
350
+ name=action.task.task_template.id.name,
351
+ ),
352
+ spec=action.task,
353
+ cache_key=cache_key,
329
354
  )
330
- cache_key = None
331
- logger.info(f"Action {action.name} has cache version {action.cache_key}")
332
- if action.cache_key:
333
- cache_key = StringValue(value=action.cache_key)
334
-
335
- task = queue_service_pb2.TaskAction(
336
- id=task_definition_pb2.TaskIdentifier(
337
- version=action.task.task_template.id.version,
338
- org=action.task.task_template.id.org,
339
- project=action.task.task_template.id.project,
340
- domain=action.task.task_template.id.domain,
341
- name=action.task.task_template.id.name,
342
- ),
343
- spec=action.task,
344
- cache_key=cache_key,
345
- )
346
- elif action.type == "trace":
347
- trace = action.trace
348
-
349
- logger.debug(f"Attempting to launch action: {action.name}")
350
- try:
351
- await self._queue_service.EnqueueAction(
352
- queue_service_pb2.EnqueueActionRequest(
353
- action_id=action.action_id,
354
- parent_action_name=action.parent_action_name,
355
- task=task,
356
- trace=trace,
357
- input_uri=action.inputs_uri,
358
- run_output_base=action.run_output_base,
359
- group=action.group.name if action.group else None,
360
- # Subject is not used in the current implementation
361
- ),
362
- wait_for_ready=True,
363
- timeout=self._enqueue_timeout,
364
- )
365
- logger.info(f"Successfully launched action: {action.name}")
366
- except grpc.aio.AioRpcError as e:
367
- if e.code() == grpc.StatusCode.ALREADY_EXISTS:
368
- logger.info(f"Action {action.name} already exists, continuing to monitor.")
369
- return
370
- logger.exception(f"Failed to launch action: {action.name} backing off...")
371
- logger.debug(f"Action details: {action}")
372
- raise e
355
+ elif action.type == "trace":
356
+ trace = action.trace
357
+
358
+ logger.debug(f"Attempting to launch action: {action.name}")
359
+ try:
360
+ await self._queue_service.EnqueueAction(
361
+ queue_service_pb2.EnqueueActionRequest(
362
+ action_id=action.action_id,
363
+ parent_action_name=action.parent_action_name,
364
+ task=task,
365
+ trace=trace,
366
+ input_uri=action.inputs_uri,
367
+ run_output_base=action.run_output_base,
368
+ group=action.group.name if action.group else None,
369
+ # Subject is not used in the current implementation
370
+ ),
371
+ wait_for_ready=True,
372
+ timeout=self._enqueue_timeout,
373
+ )
374
+ logger.info(f"Successfully launched action: {action.name}")
375
+ except grpc.aio.AioRpcError as e:
376
+ if e.code() == grpc.StatusCode.ALREADY_EXISTS:
377
+ logger.info(f"Action {action.name} already exists, continuing to monitor.")
378
+ return
379
+ if e.code() in [
380
+ grpc.StatusCode.FAILED_PRECONDITION,
381
+ grpc.StatusCode.INVALID_ARGUMENT,
382
+ grpc.StatusCode.NOT_FOUND,
383
+ ]:
384
+ raise flyte.errors.RuntimeSystemError(
385
+ e.code().name, f"Precondition failed: {e.details()}"
386
+ ) from e
387
+ # For all other errors, we will retry with backoff
388
+ logger.exception(
389
+ f"Failed to launch action: {action.name}, Code: {e.code()}, "
390
+ f"Details {e.details()} backing off..."
391
+ )
392
+ logger.debug(f"Action details: {action}")
393
+ raise flyte.errors.SlowDownError(f"Failed to launch action: {e.details()}") from e
373
394
 
374
395
  @log
375
396
  async def _bg_process(self, action: Action):
@@ -397,35 +418,42 @@ class Controller:
397
418
  await asyncio.sleep(self._resource_log_interval)
398
419
 
399
420
  @log
400
- async def _bg_run(self):
421
+ async def _bg_run(self, worker_id: str):
401
422
  """Run loop with resource status logging"""
423
+ logger.info(f"Worker {worker_id} started")
402
424
  while self._running:
403
425
  logger.debug(f"{threading.current_thread().name} Waiting for resource")
404
426
  action = await self._shared_queue.get()
405
427
  logger.debug(f"{threading.current_thread().name} Got resource {action.name}")
406
428
  try:
407
429
  await self._bg_process(action)
408
- except Exception as e:
409
- logger.error(f"Error in controller loop: {e}")
410
- # TODO we need a better way of handling backoffs currently the entire worker coroutine backs off
411
- await asyncio.sleep(self._min_backoff_on_err)
412
- action.increment_retries()
430
+ except flyte.errors.SlowDownError as e:
431
+ action.retries += 1
413
432
  if action.retries > self._max_retries:
414
- err = flyte.errors.RuntimeSystemError(
415
- code=type(e).__name__,
416
- message=f"Controller failed, system retries {action.retries}"
417
- f" crossed threshold {self._max_retries}",
418
- )
419
- err.__cause__ = e
420
- action.set_client_error(err)
421
- informer = await self._informers.get(
422
- run_name=action.run_name,
423
- parent_action_name=action.parent_action_name,
424
- )
425
- if informer:
426
- await informer.fire_completion_event(action.name)
427
- else:
428
- await self._shared_queue.put(action)
433
+ raise
434
+ backoff = min(self._min_backoff_on_err * (2 ** (action.retries - 1)), self._max_backoff_on_err)
435
+ logger.warning(
436
+ f"[{worker_id}] Backing off for {backoff} [retry {action.retries}/{self._max_retries}] "
437
+ f"on action {action.name} due to error: {e}"
438
+ )
439
+ await asyncio.sleep(backoff)
440
+ logger.warning(f"[{worker_id}] Retrying action {action.name} after backoff")
441
+ await self._shared_queue.put(action)
442
+ except Exception as e:
443
+ logger.error(f"[{worker_id}] Error in controller loop: {e}")
444
+ err = flyte.errors.RuntimeSystemError(
445
+ code=type(e).__name__,
446
+ message=f"Controller failed, system retries {action.retries} crossed threshold {self._max_retries}",
447
+ worker=worker_id,
448
+ )
449
+ err.__cause__ = e
450
+ action.set_client_error(err)
451
+ informer = await self._informers.get(
452
+ run_name=action.run_name,
453
+ parent_action_name=action.parent_action_name,
454
+ )
455
+ if informer:
456
+ await informer.fire_completion_event(action.name)
429
457
  finally:
430
458
  self._shared_queue.task_done()
431
459
 
@@ -132,8 +132,10 @@ class Informer:
132
132
  parent_action_name: str,
133
133
  shared_queue: Queue,
134
134
  client: Optional[StateService] = None,
135
- watch_backoff_interval_sec: float = 1.0,
135
+ min_watch_backoff: float = 1.0,
136
+ max_watch_backoff: float = 30.0,
136
137
  watch_conn_timeout_sec: float = 5.0,
138
+ max_watch_retries: int = 10,
137
139
  ):
138
140
  self.name = self.mkname(run_name=run_id.name, parent_action_name=parent_action_name)
139
141
  self.parent_action_name = parent_action_name
@@ -144,8 +146,10 @@ class Informer:
144
146
  self._running = False
145
147
  self._watch_task: asyncio.Task | None = None
146
148
  self._ready = asyncio.Event()
147
- self._watch_backoff_interval_sec = watch_backoff_interval_sec
149
+ self._min_watch_backoff = min_watch_backoff
150
+ self._max_watch_backoff = max_watch_backoff
148
151
  self._watch_conn_timeout_sec = watch_conn_timeout_sec
152
+ self._max_watch_retries = max_watch_retries
149
153
 
150
154
  @classmethod
151
155
  def mkname(cls, *, run_name: str, parent_action_name: str) -> str:
@@ -211,13 +215,16 @@ class Informer:
211
215
  """
212
216
  # sentinel = False
213
217
  retries = 0
214
- max_retries = 5
215
218
  last_exc = None
216
219
  while self._running:
217
- if retries >= max_retries:
218
- logger.error(f"Informer watch failure retries crossed threshold {retries}/{max_retries}, exiting!")
220
+ if retries >= self._max_watch_retries:
221
+ logger.error(
222
+ f"Informer watch failure retries crossed threshold {retries}/{self._max_watch_retries}, exiting!"
223
+ )
219
224
  raise last_exc
220
225
  try:
226
+ if retries >= 1:
227
+ logger.warning(f"Informer watch retrying, attempt {retries}/{self._max_watch_retries}")
221
228
  watcher = self._client.Watch(
222
229
  state_service_pb2.WatchRequest(
223
230
  parent_action_id=identifier_pb2.ActionIdentifier(
@@ -252,7 +259,9 @@ class Informer:
252
259
  logger.exception(f"Watch error: {self.name}", exc_info=e)
253
260
  last_exc = e
254
261
  retries += 1
255
- await asyncio.sleep(self._watch_backoff_interval_sec)
262
+ backoff = min(self._min_watch_backoff * (2**retries), self._max_watch_backoff)
263
+ logger.warning(f"Watch for {self.name} failed, retrying in {backoff} seconds...")
264
+ await asyncio.sleep(backoff)
256
265
 
257
266
  @log
258
267
  async def start(self, timeout: Optional[float] = None) -> asyncio.Task: