flyte 2.0.0b18__py3-none-any.whl → 2.0.0b20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

flyte/_bin/runtime.py CHANGED
@@ -101,7 +101,6 @@ def main(
101
101
  from flyte._logging import logger
102
102
  from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
103
103
 
104
- logger.warning(f"Flyte runtime started for action {name} with run name {run_name}")
105
104
  logger.info("Registering faulthandler for SIGUSR1")
106
105
  faulthandler.register(signal.SIGUSR1)
107
106
 
@@ -117,6 +116,8 @@ def main(
117
116
  if name.startswith("{{"):
118
117
  name = os.getenv("ACTION_NAME", "")
119
118
 
119
+ logger.warning(f"Flyte runtime started for action {name} with run name {run_name}")
120
+
120
121
  if debug and name == "a0":
121
122
  from flyte._debug.vscode import _start_vscode_server
122
123
 
flyte/_initialize.py CHANGED
@@ -228,7 +228,7 @@ async def init(
228
228
 
229
229
  @syncify
230
230
  async def init_from_config(
231
- path_or_config: str | Config | None = None,
231
+ path_or_config: str | Path | Config | None = None,
232
232
  root_dir: Path | None = None,
233
233
  log_level: int | None = None,
234
234
  ) -> None:
@@ -251,11 +251,11 @@ async def init_from_config(
251
251
  if path_or_config is None:
252
252
  # If no path is provided, use the default config file
253
253
  cfg = config.auto()
254
- elif isinstance(path_or_config, str):
254
+ elif isinstance(path_or_config, (str, Path)):
255
255
  if root_dir:
256
- cfg_path = str(root_dir / path_or_config)
256
+ cfg_path = root_dir.expanduser() / path_or_config
257
257
  else:
258
- cfg_path = path_or_config
258
+ cfg_path = Path(path_or_config).expanduser()
259
259
  if not Path(cfg_path).exists():
260
260
  raise InitializationError(
261
261
  "ConfigFileNotFoundError",
@@ -5,12 +5,13 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, Tu
5
5
  from flyte._task import TaskTemplate
6
6
  from flyte.models import ActionID, NativeInterface
7
7
 
8
+ if TYPE_CHECKING:
9
+ from flyte.remote._task import TaskDetails
10
+
8
11
  from ._trace import TraceInfo
9
12
 
10
13
  __all__ = ["Controller", "ControllerType", "TraceInfo", "create_controller", "get_controller"]
11
14
 
12
- from ..._protos.workflow import task_definition_pb2
13
-
14
15
  if TYPE_CHECKING:
15
16
  import concurrent.futures
16
17
 
@@ -41,9 +42,7 @@ class Controller(Protocol):
41
42
  """
42
43
  ...
43
44
 
44
- async def submit_task_ref(
45
- self, _task: task_definition_pb2.TaskDetails, max_inline_io_bytes: int, *args, **kwargs
46
- ) -> Any:
45
+ async def submit_task_ref(self, _task: "TaskDetails", *args, **kwargs) -> Any:
47
46
  """
48
47
  Submit a task reference to the controller asynchronously and wait for the result. This is async and will block
49
48
  the current coroutine until the result is available.
@@ -11,10 +11,10 @@ from flyte._internal.controllers import TraceInfo
11
11
  from flyte._internal.runtime import convert
12
12
  from flyte._internal.runtime.entrypoints import direct_dispatch
13
13
  from flyte._logging import log, logger
14
- from flyte._protos.workflow import task_definition_pb2
15
14
  from flyte._task import TaskTemplate
16
15
  from flyte._utils.helpers import _selector_policy
17
16
  from flyte.models import ActionID, NativeInterface
17
+ from flyte.remote._task import TaskDetails
18
18
 
19
19
  R = TypeVar("R")
20
20
 
@@ -192,7 +192,7 @@ class LocalController:
192
192
  assert info.start_time
193
193
  assert info.end_time
194
194
 
195
- async def submit_task_ref(
196
- self, _task: task_definition_pb2.TaskDetails, max_inline_io_bytes: int, *args, **kwargs
197
- ) -> Any:
198
- raise flyte.errors.ReferenceTaskError("Reference tasks cannot be executed locally, only remotely.")
195
+ async def submit_task_ref(self, _task: TaskDetails, max_inline_io_bytes: int, *args, **kwargs) -> Any:
196
+ raise flyte.errors.ReferenceTaskError(
197
+ f"Reference tasks cannot be executed locally, only remotely. Found remote task {_task.name}"
198
+ )
@@ -54,7 +54,5 @@ def create_remote_controller(
54
54
 
55
55
  controller = RemoteController(
56
56
  client_coro=client_coro,
57
- workers=10,
58
- max_system_retries=5,
59
57
  )
60
58
  return controller
@@ -12,7 +12,6 @@ from typing import Any, Awaitable, DefaultDict, Tuple, TypeVar
12
12
  import flyte
13
13
  import flyte.errors
14
14
  import flyte.storage as storage
15
- import flyte.types as types
16
15
  from flyte._code_bundle import build_pkl_bundle
17
16
  from flyte._context import internal_ctx
18
17
  from flyte._internal.controllers import TraceInfo
@@ -24,10 +23,11 @@ from flyte._internal.runtime.task_serde import translate_task_to_wire
24
23
  from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
25
24
  from flyte._logging import logger
26
25
  from flyte._protos.common import identifier_pb2
27
- from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
26
+ from flyte._protos.workflow import run_definition_pb2
28
27
  from flyte._task import TaskTemplate
29
28
  from flyte._utils.helpers import _selector_policy
30
29
  from flyte.models import MAX_INLINE_IO_BYTES, ActionID, NativeInterface, SerializationContext
30
+ from flyte.remote._task import TaskDetails
31
31
 
32
32
  R = TypeVar("R")
33
33
 
@@ -117,9 +117,8 @@ class RemoteController(Controller):
117
117
  def __init__(
118
118
  self,
119
119
  client_coro: Awaitable[ClientSet],
120
- workers: int,
121
- max_system_retries: int,
122
- default_parent_concurrency: int = 100,
120
+ workers: int = 20,
121
+ max_system_retries: int = 10,
123
122
  ):
124
123
  """ """
125
124
  super().__init__(
@@ -127,6 +126,7 @@ class RemoteController(Controller):
127
126
  workers=workers,
128
127
  max_system_retries=max_system_retries,
129
128
  )
129
+ default_parent_concurrency = int(os.getenv("_F_P_CNC", "100"))
130
130
  self._default_parent_concurrency = default_parent_concurrency
131
131
  self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
132
132
  lambda: asyncio.Semaphore(default_parent_concurrency)
@@ -482,19 +482,17 @@ class RemoteController(Controller):
482
482
  # If the action is cancelled, we need to cancel the action on the server as well
483
483
  raise
484
484
 
485
- async def _submit_task_ref(
486
- self, invoke_seq_num: int, _task: task_definition_pb2.TaskDetails, max_inline_io_bytes: int, *args, **kwargs
487
- ) -> Any:
485
+ async def _submit_task_ref(self, invoke_seq_num: int, _task: TaskDetails, *args, **kwargs) -> Any:
488
486
  ctx = internal_ctx()
489
487
  tctx = ctx.data.task_context
490
488
  if tctx is None:
491
489
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
492
490
  current_action_id = tctx.action
493
- task_name = _task.spec.task_template.id.name
491
+ task_name = _task.name
492
+
493
+ native_interface = _task.interface
494
+ pb_interface = _task.pb2.spec.task_template.interface
494
495
 
495
- native_interface = types.guess_interface(
496
- _task.spec.task_template.interface, default_inputs=_task.spec.default_inputs
497
- )
498
496
  inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
499
497
  inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
500
498
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
@@ -503,19 +501,19 @@ class RemoteController(Controller):
503
501
 
504
502
  serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
505
503
  inputs_uri = io.inputs_path(sub_action_output_path)
506
- await upload_inputs_with_retry(serialized_inputs, inputs_uri, max_inline_io_bytes)
504
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri, _task.max_inline_io_bytes)
507
505
  # cache key - task name, task signature, inputs, cache version
508
506
  cache_key = None
509
- md = _task.spec.task_template.metadata
507
+ md = _task.pb2.spec.task_template.metadata
510
508
  ignored_input_vars = []
511
509
  if len(md.cache_ignore_input_vars) > 0:
512
510
  ignored_input_vars = list(md.cache_ignore_input_vars)
513
- if _task.spec.task_template.metadata and _task.spec.task_template.metadata.discoverable:
514
- discovery_version = _task.spec.task_template.metadata.discovery_version
511
+ if md and md.discoverable:
512
+ discovery_version = md.discovery_version
515
513
  cache_key = convert.generate_cache_key_hash(
516
514
  task_name,
517
515
  inputs_hash,
518
- _task.spec.task_template.interface,
516
+ pb_interface,
519
517
  discovery_version,
520
518
  ignored_input_vars,
521
519
  inputs.proto_inputs,
@@ -537,7 +535,7 @@ class RemoteController(Controller):
537
535
  ),
538
536
  parent_action_name=current_action_id.name,
539
537
  group_data=tctx.group_data,
540
- task_spec=_task.spec,
538
+ task_spec=_task.pb2.spec,
541
539
  inputs_uri=inputs_uri,
542
540
  run_output_base=tctx.run_base_dir,
543
541
  cache_key=cache_key,
@@ -566,12 +564,10 @@ class RemoteController(Controller):
566
564
  "RuntimeError",
567
565
  f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
568
566
  )
569
- return await load_and_convert_outputs(native_interface, n.realized_outputs_uri, max_inline_io_bytes)
567
+ return await load_and_convert_outputs(native_interface, n.realized_outputs_uri, _task.max_inline_io_bytes)
570
568
  return None
571
569
 
572
- async def submit_task_ref(
573
- self, _task: task_definition_pb2.TaskDetails, max_inline_io_bytes: int, *args, **kwargs
574
- ) -> Any:
570
+ async def submit_task_ref(self, _task: TaskDetails, *args, **kwargs) -> Any:
575
571
  ctx = internal_ctx()
576
572
  tctx = ctx.data.task_context
577
573
  if tctx is None:
@@ -579,4 +575,4 @@ class RemoteController(Controller):
579
575
  current_action_id = tctx.action
580
576
  task_call_seq = self.generate_task_call_sequence(_task, current_action_id)
581
577
  async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
582
- return await self._submit_task_ref(task_call_seq, _task, max_inline_io_bytes, *args, **kwargs)
578
+ return await self._submit_task_ref(task_call_seq, _task, *args, **kwargs)
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import os
4
5
  import sys
5
6
  import threading
6
7
  from asyncio import Event
7
8
  from typing import Awaitable, Coroutine, Optional
8
9
 
9
10
  import grpc.aio
11
+ from aiolimiter import AsyncLimiter
10
12
  from google.protobuf.wrappers_pb2 import StringValue
11
13
 
12
14
  import flyte.errors
@@ -32,10 +34,10 @@ class Controller:
32
34
  def __init__(
33
35
  self,
34
36
  client_coro: Awaitable[ClientSet],
35
- workers: int = 2,
36
- max_system_retries: int = 5,
37
+ workers: int = 20,
38
+ max_system_retries: int = 10,
37
39
  resource_log_interval_sec: float = 10.0,
38
- min_backoff_on_err_sec: float = 0.1,
40
+ min_backoff_on_err_sec: float = 0.5,
39
41
  thread_wait_timeout_sec: float = 5.0,
40
42
  enqueue_timeout_sec: float = 5.0,
41
43
  ):
@@ -53,14 +55,17 @@ class Controller:
53
55
  self._running = False
54
56
  self._resource_log_task = None
55
57
  self._workers = workers
56
- self._max_retries = max_system_retries
58
+ self._max_retries = int(os.getenv("_F_MAX_RETRIES", max_system_retries))
57
59
  self._resource_log_interval = resource_log_interval_sec
58
60
  self._min_backoff_on_err = min_backoff_on_err_sec
61
+ self._max_backoff_on_err = float(os.getenv("_F_MAX_BFF_ON_ERR", "10.0"))
59
62
  self._thread_wait_timeout = thread_wait_timeout_sec
60
63
  self._client_coro = client_coro
61
64
  self._failure_event: Event | None = None
62
65
  self._enqueue_timeout = enqueue_timeout_sec
63
66
  self._informer_start_wait_timeout = thread_wait_timeout_sec
67
+ max_qps = int(os.getenv("_F_MAX_QPS", "100"))
68
+ self._rate_limiter = AsyncLimiter(max_qps, 1.0)
64
69
 
65
70
  # Thread management
66
71
  self._thread = None
@@ -194,15 +199,16 @@ class Controller:
194
199
  # We will wait for this to signal that the thread is ready
195
200
  # Signal the main thread that we're ready
196
201
  logger.debug("Background thread initialization complete")
197
- self._thread_ready.set()
198
202
  if sys.version_info >= (3, 11):
199
203
  async with asyncio.TaskGroup() as tg:
200
204
  for i in range(self._workers):
201
- tg.create_task(self._bg_run())
205
+ tg.create_task(self._bg_run(f"worker-{i}"))
206
+ self._thread_ready.set()
202
207
  else:
203
208
  tasks = []
204
209
  for i in range(self._workers):
205
- tasks.append(asyncio.create_task(self._bg_run()))
210
+ tasks.append(asyncio.create_task(self._bg_run(f"worker-{i}")))
211
+ self._thread_ready.set()
206
212
  await asyncio.gather(*tasks)
207
213
 
208
214
  def _bg_thread_target(self):
@@ -221,6 +227,7 @@ class Controller:
221
227
  except Exception as e:
222
228
  logger.error(f"Controller thread encountered an exception: {e}")
223
229
  self._set_exception(e)
230
+ self._failure_event.set()
224
231
  finally:
225
232
  if self._loop and self._loop.is_running():
226
233
  self._loop.close()
@@ -292,21 +299,22 @@ class Controller:
292
299
  started = action.is_started()
293
300
  action.mark_cancelled()
294
301
  if started:
295
- logger.info(f"Cancelling action: {action.name}")
296
- try:
297
- # TODO add support when the queue service supports aborting actions
298
- # await self._queue_service.AbortQueuedAction(
299
- # queue_service_pb2.AbortQueuedActionRequest(action_id=action.action_id),
300
- # wait_for_ready=True,
301
- # )
302
- logger.info(f"Successfully cancelled action: {action.name}")
303
- except grpc.aio.AioRpcError as e:
304
- if e.code() in [
305
- grpc.StatusCode.NOT_FOUND,
306
- grpc.StatusCode.FAILED_PRECONDITION,
307
- ]:
308
- logger.info(f"Action {action.name} not found, assumed completed or cancelled.")
309
- return
302
+ async with self._rate_limiter:
303
+ logger.info(f"Cancelling action: {action.name}")
304
+ try:
305
+ # TODO add support when the queue service supports aborting actions
306
+ # await self._queue_service.AbortQueuedAction(
307
+ # queue_service_pb2.AbortQueuedActionRequest(action_id=action.action_id),
308
+ # wait_for_ready=True,
309
+ # )
310
+ logger.info(f"Successfully cancelled action: {action.name}")
311
+ except grpc.aio.AioRpcError as e:
312
+ if e.code() in [
313
+ grpc.StatusCode.NOT_FOUND,
314
+ grpc.StatusCode.FAILED_PRECONDITION,
315
+ ]:
316
+ logger.info(f"Action {action.name} not found, assumed completed or cancelled.")
317
+ return
310
318
  else:
311
319
  # If the action is not started, we have to ensure it does not get launched
312
320
  logger.info(f"Action {action.name} is not started, no need to cancel.")
@@ -320,56 +328,69 @@ class Controller:
320
328
  Attempt to launch an action.
321
329
  """
322
330
  if not action.is_started():
323
- task: queue_service_pb2.TaskAction | None = None
324
- trace: queue_service_pb2.TraceAction | None = None
325
- if action.type == "task":
326
- if action.task is None:
327
- raise flyte.errors.RuntimeSystemError(
328
- "NoTaskSpec", "Task Spec not found, cannot launch Task Action."
331
+ async with self._rate_limiter:
332
+ task: queue_service_pb2.TaskAction | None = None
333
+ trace: queue_service_pb2.TraceAction | None = None
334
+ if action.type == "task":
335
+ if action.task is None:
336
+ raise flyte.errors.RuntimeSystemError(
337
+ "NoTaskSpec", "Task Spec not found, cannot launch Task Action."
338
+ )
339
+ cache_key = None
340
+ logger.info(f"Action {action.name} has cache version {action.cache_key}")
341
+ if action.cache_key:
342
+ cache_key = StringValue(value=action.cache_key)
343
+
344
+ task = queue_service_pb2.TaskAction(
345
+ id=task_definition_pb2.TaskIdentifier(
346
+ version=action.task.task_template.id.version,
347
+ org=action.task.task_template.id.org,
348
+ project=action.task.task_template.id.project,
349
+ domain=action.task.task_template.id.domain,
350
+ name=action.task.task_template.id.name,
351
+ ),
352
+ spec=action.task,
353
+ cache_key=cache_key,
329
354
  )
330
- cache_key = None
331
- logger.info(f"Action {action.name} has cache version {action.cache_key}")
332
- if action.cache_key:
333
- cache_key = StringValue(value=action.cache_key)
334
-
335
- task = queue_service_pb2.TaskAction(
336
- id=task_definition_pb2.TaskIdentifier(
337
- version=action.task.task_template.id.version,
338
- org=action.task.task_template.id.org,
339
- project=action.task.task_template.id.project,
340
- domain=action.task.task_template.id.domain,
341
- name=action.task.task_template.id.name,
342
- ),
343
- spec=action.task,
344
- cache_key=cache_key,
345
- )
346
- elif action.type == "trace":
347
- trace = action.trace
348
-
349
- logger.debug(f"Attempting to launch action: {action.name}")
350
- try:
351
- await self._queue_service.EnqueueAction(
352
- queue_service_pb2.EnqueueActionRequest(
353
- action_id=action.action_id,
354
- parent_action_name=action.parent_action_name,
355
- task=task,
356
- trace=trace,
357
- input_uri=action.inputs_uri,
358
- run_output_base=action.run_output_base,
359
- group=action.group.name if action.group else None,
360
- # Subject is not used in the current implementation
361
- ),
362
- wait_for_ready=True,
363
- timeout=self._enqueue_timeout,
364
- )
365
- logger.info(f"Successfully launched action: {action.name}")
366
- except grpc.aio.AioRpcError as e:
367
- if e.code() == grpc.StatusCode.ALREADY_EXISTS:
368
- logger.info(f"Action {action.name} already exists, continuing to monitor.")
369
- return
370
- logger.exception(f"Failed to launch action: {action.name} backing off...")
371
- logger.debug(f"Action details: {action}")
372
- raise e
355
+ elif action.type == "trace":
356
+ trace = action.trace
357
+
358
+ logger.debug(f"Attempting to launch action: {action.name}")
359
+ try:
360
+ await self._queue_service.EnqueueAction(
361
+ queue_service_pb2.EnqueueActionRequest(
362
+ action_id=action.action_id,
363
+ parent_action_name=action.parent_action_name,
364
+ task=task,
365
+ trace=trace,
366
+ input_uri=action.inputs_uri,
367
+ run_output_base=action.run_output_base,
368
+ group=action.group.name if action.group else None,
369
+ # Subject is not used in the current implementation
370
+ ),
371
+ wait_for_ready=True,
372
+ timeout=self._enqueue_timeout,
373
+ )
374
+ logger.info(f"Successfully launched action: {action.name}")
375
+ except grpc.aio.AioRpcError as e:
376
+ if e.code() == grpc.StatusCode.ALREADY_EXISTS:
377
+ logger.info(f"Action {action.name} already exists, continuing to monitor.")
378
+ return
379
+ if e.code() in [
380
+ grpc.StatusCode.FAILED_PRECONDITION,
381
+ grpc.StatusCode.INVALID_ARGUMENT,
382
+ grpc.StatusCode.NOT_FOUND,
383
+ ]:
384
+ raise flyte.errors.RuntimeSystemError(
385
+ e.code().name, f"Precondition failed: {e.details()}"
386
+ ) from e
387
+ # For all other errors, we will retry with backoff
388
+ logger.exception(
389
+ f"Failed to launch action: {action.name}, Code: {e.code()}, "
390
+ f"Details {e.details()} backing off..."
391
+ )
392
+ logger.debug(f"Action details: {action}")
393
+ raise flyte.errors.SlowDownError(f"Failed to launch action: {e.details()}") from e
373
394
 
374
395
  @log
375
396
  async def _bg_process(self, action: Action):
@@ -397,35 +418,42 @@ class Controller:
397
418
  await asyncio.sleep(self._resource_log_interval)
398
419
 
399
420
  @log
400
- async def _bg_run(self):
421
+ async def _bg_run(self, worker_id: str):
401
422
  """Run loop with resource status logging"""
423
+ logger.info(f"Worker {worker_id} started")
402
424
  while self._running:
403
425
  logger.debug(f"{threading.current_thread().name} Waiting for resource")
404
426
  action = await self._shared_queue.get()
405
427
  logger.debug(f"{threading.current_thread().name} Got resource {action.name}")
406
428
  try:
407
429
  await self._bg_process(action)
408
- except Exception as e:
409
- logger.error(f"Error in controller loop: {e}")
410
- # TODO we need a better way of handling backoffs currently the entire worker coroutine backs off
411
- await asyncio.sleep(self._min_backoff_on_err)
412
- action.increment_retries()
430
+ except flyte.errors.SlowDownError as e:
431
+ action.retries += 1
413
432
  if action.retries > self._max_retries:
414
- err = flyte.errors.RuntimeSystemError(
415
- code=type(e).__name__,
416
- message=f"Controller failed, system retries {action.retries}"
417
- f" crossed threshold {self._max_retries}",
418
- )
419
- err.__cause__ = e
420
- action.set_client_error(err)
421
- informer = await self._informers.get(
422
- run_name=action.run_name,
423
- parent_action_name=action.parent_action_name,
424
- )
425
- if informer:
426
- await informer.fire_completion_event(action.name)
427
- else:
428
- await self._shared_queue.put(action)
433
+ raise
434
+ backoff = min(self._min_backoff_on_err * (2 ** (action.retries - 1)), self._max_backoff_on_err)
435
+ logger.warning(
436
+ f"[{worker_id}] Backing off for {backoff} [retry {action.retries}/{self._max_retries}] "
437
+ f"on action {action.name} due to error: {e}"
438
+ )
439
+ await asyncio.sleep(backoff)
440
+ logger.warning(f"[{worker_id}] Retrying action {action.name} after backoff")
441
+ await self._shared_queue.put(action)
442
+ except Exception as e:
443
+ logger.error(f"[{worker_id}] Error in controller loop: {e}")
444
+ err = flyte.errors.RuntimeSystemError(
445
+ code=type(e).__name__,
446
+ message=f"Controller failed, system retries {action.retries} crossed threshold {self._max_retries}",
447
+ worker=worker_id,
448
+ )
449
+ err.__cause__ = e
450
+ action.set_client_error(err)
451
+ informer = await self._informers.get(
452
+ run_name=action.run_name,
453
+ parent_action_name=action.parent_action_name,
454
+ )
455
+ if informer:
456
+ await informer.fire_completion_event(action.name)
429
457
  finally:
430
458
  self._shared_queue.task_done()
431
459
 
@@ -132,8 +132,10 @@ class Informer:
132
132
  parent_action_name: str,
133
133
  shared_queue: Queue,
134
134
  client: Optional[StateService] = None,
135
- watch_backoff_interval_sec: float = 1.0,
135
+ min_watch_backoff: float = 1.0,
136
+ max_watch_backoff: float = 30.0,
136
137
  watch_conn_timeout_sec: float = 5.0,
138
+ max_watch_retries: int = 10,
137
139
  ):
138
140
  self.name = self.mkname(run_name=run_id.name, parent_action_name=parent_action_name)
139
141
  self.parent_action_name = parent_action_name
@@ -144,8 +146,10 @@ class Informer:
144
146
  self._running = False
145
147
  self._watch_task: asyncio.Task | None = None
146
148
  self._ready = asyncio.Event()
147
- self._watch_backoff_interval_sec = watch_backoff_interval_sec
149
+ self._min_watch_backoff = min_watch_backoff
150
+ self._max_watch_backoff = max_watch_backoff
148
151
  self._watch_conn_timeout_sec = watch_conn_timeout_sec
152
+ self._max_watch_retries = max_watch_retries
149
153
 
150
154
  @classmethod
151
155
  def mkname(cls, *, run_name: str, parent_action_name: str) -> str:
@@ -211,13 +215,16 @@ class Informer:
211
215
  """
212
216
  # sentinel = False
213
217
  retries = 0
214
- max_retries = 5
215
218
  last_exc = None
216
219
  while self._running:
217
- if retries >= max_retries:
218
- logger.error(f"Informer watch failure retries crossed threshold {retries}/{max_retries}, exiting!")
220
+ if retries >= self._max_watch_retries:
221
+ logger.error(
222
+ f"Informer watch failure retries crossed threshold {retries}/{self._max_watch_retries}, exiting!"
223
+ )
219
224
  raise last_exc
220
225
  try:
226
+ if retries >= 1:
227
+ logger.warning(f"Informer watch retrying, attempt {retries}/{self._max_watch_retries}")
221
228
  watcher = self._client.Watch(
222
229
  state_service_pb2.WatchRequest(
223
230
  parent_action_id=identifier_pb2.ActionIdentifier(
@@ -252,7 +259,9 @@ class Informer:
252
259
  logger.exception(f"Watch error: {self.name}", exc_info=e)
253
260
  last_exc = e
254
261
  retries += 1
255
- await asyncio.sleep(self._watch_backoff_interval_sec)
262
+ backoff = min(self._min_watch_backoff * (2**retries), self._max_watch_backoff)
263
+ logger.warning(f"Watch for {self.name} failed, retrying in {backoff} seconds...")
264
+ await asyncio.sleep(backoff)
256
265
 
257
266
  @log
258
267
  async def start(self, timeout: Optional[float] = None) -> asyncio.Task: