flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/debug.py +38 -0
  3. flyte/_bin/runtime.py +62 -8
  4. flyte/_cache/cache.py +4 -2
  5. flyte/_cache/local_cache.py +216 -0
  6. flyte/_code_bundle/_ignore.py +12 -4
  7. flyte/_code_bundle/_packaging.py +13 -9
  8. flyte/_code_bundle/_utils.py +18 -10
  9. flyte/_code_bundle/bundle.py +17 -9
  10. flyte/_constants.py +1 -0
  11. flyte/_context.py +4 -1
  12. flyte/_custom_context.py +73 -0
  13. flyte/_debug/constants.py +38 -0
  14. flyte/_debug/utils.py +17 -0
  15. flyte/_debug/vscode.py +307 -0
  16. flyte/_deploy.py +235 -61
  17. flyte/_environment.py +20 -6
  18. flyte/_excepthook.py +1 -1
  19. flyte/_hash.py +1 -16
  20. flyte/_image.py +178 -81
  21. flyte/_initialize.py +132 -51
  22. flyte/_interface.py +39 -2
  23. flyte/_internal/controllers/__init__.py +4 -5
  24. flyte/_internal/controllers/_local_controller.py +70 -29
  25. flyte/_internal/controllers/_trace.py +1 -1
  26. flyte/_internal/controllers/remote/__init__.py +0 -2
  27. flyte/_internal/controllers/remote/_action.py +14 -16
  28. flyte/_internal/controllers/remote/_client.py +1 -1
  29. flyte/_internal/controllers/remote/_controller.py +68 -70
  30. flyte/_internal/controllers/remote/_core.py +127 -99
  31. flyte/_internal/controllers/remote/_informer.py +19 -10
  32. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  33. flyte/_internal/imagebuild/docker_builder.py +181 -69
  34. flyte/_internal/imagebuild/image_builder.py +0 -5
  35. flyte/_internal/imagebuild/remote_builder.py +155 -64
  36. flyte/_internal/imagebuild/utils.py +51 -2
  37. flyte/_internal/resolvers/_task_module.py +5 -38
  38. flyte/_internal/resolvers/default.py +2 -2
  39. flyte/_internal/runtime/convert.py +110 -21
  40. flyte/_internal/runtime/entrypoints.py +27 -1
  41. flyte/_internal/runtime/io.py +21 -8
  42. flyte/_internal/runtime/resources_serde.py +20 -6
  43. flyte/_internal/runtime/reuse.py +1 -1
  44. flyte/_internal/runtime/rusty.py +20 -5
  45. flyte/_internal/runtime/task_serde.py +34 -19
  46. flyte/_internal/runtime/taskrunner.py +22 -4
  47. flyte/_internal/runtime/trigger_serde.py +160 -0
  48. flyte/_internal/runtime/types_serde.py +1 -1
  49. flyte/_keyring/__init__.py +0 -0
  50. flyte/_keyring/file.py +115 -0
  51. flyte/_logging.py +201 -39
  52. flyte/_map.py +111 -14
  53. flyte/_module.py +70 -0
  54. flyte/_pod.py +4 -3
  55. flyte/_resources.py +213 -31
  56. flyte/_run.py +110 -39
  57. flyte/_task.py +75 -16
  58. flyte/_task_environment.py +105 -29
  59. flyte/_task_plugins.py +4 -2
  60. flyte/_trace.py +5 -0
  61. flyte/_trigger.py +1000 -0
  62. flyte/_utils/__init__.py +2 -1
  63. flyte/_utils/asyn.py +3 -1
  64. flyte/_utils/coro_management.py +2 -1
  65. flyte/_utils/docker_credentials.py +173 -0
  66. flyte/_utils/module_loader.py +17 -2
  67. flyte/_version.py +3 -3
  68. flyte/cli/_abort.py +3 -3
  69. flyte/cli/_build.py +3 -6
  70. flyte/cli/_common.py +78 -7
  71. flyte/cli/_create.py +182 -4
  72. flyte/cli/_delete.py +23 -1
  73. flyte/cli/_deploy.py +63 -16
  74. flyte/cli/_get.py +79 -34
  75. flyte/cli/_params.py +26 -10
  76. flyte/cli/_plugins.py +209 -0
  77. flyte/cli/_run.py +151 -26
  78. flyte/cli/_serve.py +64 -0
  79. flyte/cli/_update.py +37 -0
  80. flyte/cli/_user.py +17 -0
  81. flyte/cli/main.py +30 -4
  82. flyte/config/_config.py +10 -6
  83. flyte/config/_internal.py +1 -0
  84. flyte/config/_reader.py +29 -8
  85. flyte/connectors/__init__.py +11 -0
  86. flyte/connectors/_connector.py +270 -0
  87. flyte/connectors/_server.py +197 -0
  88. flyte/connectors/utils.py +135 -0
  89. flyte/errors.py +22 -2
  90. flyte/extend.py +8 -1
  91. flyte/extras/_container.py +6 -1
  92. flyte/git/__init__.py +3 -0
  93. flyte/git/_config.py +21 -0
  94. flyte/io/__init__.py +2 -0
  95. flyte/io/_dataframe/__init__.py +2 -0
  96. flyte/io/_dataframe/basic_dfs.py +17 -8
  97. flyte/io/_dataframe/dataframe.py +98 -132
  98. flyte/io/_dir.py +575 -113
  99. flyte/io/_file.py +582 -139
  100. flyte/io/_hashing_io.py +342 -0
  101. flyte/models.py +74 -15
  102. flyte/remote/__init__.py +6 -1
  103. flyte/remote/_action.py +34 -26
  104. flyte/remote/_client/_protocols.py +39 -4
  105. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  106. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  107. flyte/remote/_client/auth/_channel.py +10 -6
  108. flyte/remote/_client/controlplane.py +17 -5
  109. flyte/remote/_console.py +3 -2
  110. flyte/remote/_data.py +6 -6
  111. flyte/remote/_logs.py +3 -3
  112. flyte/remote/_run.py +64 -8
  113. flyte/remote/_secret.py +26 -17
  114. flyte/remote/_task.py +75 -33
  115. flyte/remote/_trigger.py +306 -0
  116. flyte/remote/_user.py +33 -0
  117. flyte/report/_report.py +1 -1
  118. flyte/storage/__init__.py +6 -1
  119. flyte/storage/_config.py +5 -1
  120. flyte/storage/_parallel_reader.py +274 -0
  121. flyte/storage/_storage.py +200 -103
  122. flyte/types/__init__.py +16 -0
  123. flyte/types/_interface.py +2 -2
  124. flyte/types/_pickle.py +35 -8
  125. flyte/types/_string_literals.py +8 -9
  126. flyte/types/_type_engine.py +40 -70
  127. flyte/types/_utils.py +1 -1
  128. flyte-2.0.0b30.data/scripts/debug.py +38 -0
  129. {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
  130. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
  131. flyte-2.0.0b30.dist-info/RECORD +192 -0
  132. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
  133. flyte/_protos/common/authorization_pb2.py +0 -66
  134. flyte/_protos/common/authorization_pb2.pyi +0 -108
  135. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  136. flyte/_protos/common/identifier_pb2.py +0 -93
  137. flyte/_protos/common/identifier_pb2.pyi +0 -110
  138. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  139. flyte/_protos/common/identity_pb2.py +0 -48
  140. flyte/_protos/common/identity_pb2.pyi +0 -72
  141. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  142. flyte/_protos/common/list_pb2.py +0 -36
  143. flyte/_protos/common/list_pb2.pyi +0 -71
  144. flyte/_protos/common/list_pb2_grpc.py +0 -4
  145. flyte/_protos/common/policy_pb2.py +0 -37
  146. flyte/_protos/common/policy_pb2.pyi +0 -27
  147. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  148. flyte/_protos/common/role_pb2.py +0 -37
  149. flyte/_protos/common/role_pb2.pyi +0 -53
  150. flyte/_protos/common/role_pb2_grpc.py +0 -4
  151. flyte/_protos/common/runtime_version_pb2.py +0 -28
  152. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  153. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  154. flyte/_protos/imagebuilder/definition_pb2.py +0 -59
  155. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
  156. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  157. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  158. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  159. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  160. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  161. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  162. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  163. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  164. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  165. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  166. flyte/_protos/secret/definition_pb2.py +0 -49
  167. flyte/_protos/secret/definition_pb2.pyi +0 -93
  168. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  169. flyte/_protos/secret/payload_pb2.py +0 -62
  170. flyte/_protos/secret/payload_pb2.pyi +0 -94
  171. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  172. flyte/_protos/secret/secret_pb2.py +0 -38
  173. flyte/_protos/secret/secret_pb2.pyi +0 -6
  174. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  175. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  176. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  177. flyte/_protos/workflow/common_pb2.py +0 -27
  178. flyte/_protos/workflow/common_pb2.pyi +0 -14
  179. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  180. flyte/_protos/workflow/environment_pb2.py +0 -29
  181. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  182. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  183. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  184. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  185. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  186. flyte/_protos/workflow/queue_service_pb2.py +0 -109
  187. flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
  188. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  189. flyte/_protos/workflow/run_definition_pb2.py +0 -121
  190. flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
  191. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  192. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  193. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  194. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  195. flyte/_protos/workflow/run_service_pb2.py +0 -137
  196. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  197. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  198. flyte/_protos/workflow/state_service_pb2.py +0 -67
  199. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  200. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  201. flyte/_protos/workflow/task_definition_pb2.py +0 -79
  202. flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
  203. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  204. flyte/_protos/workflow/task_service_pb2.py +0 -60
  205. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  206. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  207. flyte-2.0.0b13.dist-info/RECORD +0 -239
  208. /flyte/{_protos → _debug}/__init__.py +0 -0
  209. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  210. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  211. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
@@ -9,10 +9,12 @@ from collections.abc import Callable
9
9
  from pathlib import Path
10
10
  from typing import Any, Awaitable, DefaultDict, Tuple, TypeVar
11
11
 
12
+ from flyteidl2.common import identifier_pb2
13
+ from flyteidl2.workflow import run_definition_pb2
14
+
12
15
  import flyte
13
16
  import flyte.errors
14
17
  import flyte.storage as storage
15
- import flyte.types as types
16
18
  from flyte._code_bundle import build_pkl_bundle
17
19
  from flyte._context import internal_ctx
18
20
  from flyte._internal.controllers import TraceInfo
@@ -23,11 +25,10 @@ from flyte._internal.runtime import convert, io
23
25
  from flyte._internal.runtime.task_serde import translate_task_to_wire
24
26
  from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
25
27
  from flyte._logging import logger
26
- from flyte._protos.common import identifier_pb2
27
- from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
28
28
  from flyte._task import TaskTemplate
29
29
  from flyte._utils.helpers import _selector_policy
30
30
  from flyte.models import MAX_INLINE_IO_BYTES, ActionID, NativeInterface, SerializationContext
31
+ from flyte.remote._task import TaskDetails
31
32
 
32
33
  R = TypeVar("R")
33
34
 
@@ -117,9 +118,8 @@ class RemoteController(Controller):
117
118
  def __init__(
118
119
  self,
119
120
  client_coro: Awaitable[ClientSet],
120
- workers: int,
121
- max_system_retries: int,
122
- default_parent_concurrency: int = 100,
121
+ workers: int = 20,
122
+ max_system_retries: int = 10,
123
123
  ):
124
124
  """ """
125
125
  super().__init__(
@@ -127,6 +127,7 @@ class RemoteController(Controller):
127
127
  workers=workers,
128
128
  max_system_retries=max_system_retries,
129
129
  )
130
+ default_parent_concurrency = int(os.getenv("_F_P_CNC", "1000"))
130
131
  self._default_parent_concurrency = default_parent_concurrency
131
132
  self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
132
133
  lambda: asyncio.Semaphore(default_parent_concurrency)
@@ -167,7 +168,7 @@ class RemoteController(Controller):
167
168
  # It is not allowed to change the code bundle (for regular code bundles) in the middle of a run.
168
169
  code_bundle = tctx.code_bundle
169
170
 
170
- if code_bundle and code_bundle.pkl:
171
+ if tctx.interactive_mode or (code_bundle and code_bundle.pkl):
171
172
  logger.debug(f"Building new pkl bundle for task {_task.name}")
172
173
  code_bundle = await build_pkl_bundle(
173
174
  _task,
@@ -238,6 +239,7 @@ class RemoteController(Controller):
238
239
  inputs_uri=inputs_uri,
239
240
  run_output_base=tctx.run_base_dir,
240
241
  cache_key=cache_key,
242
+ queue=_task.queue,
241
243
  )
242
244
 
243
245
  try:
@@ -375,11 +377,13 @@ class RemoteController(Controller):
375
377
 
376
378
  func_name = _func.__name__
377
379
  invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id)
380
+
378
381
  inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
379
382
  serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
383
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
380
384
 
381
385
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
382
- tctx, func_name, serialized_inputs, invoke_seq_num
386
+ tctx, func_name, inputs_hash, invoke_seq_num
383
387
  )
384
388
 
385
389
  inputs_uri = io.inputs_path(sub_action_output_path)
@@ -413,8 +417,7 @@ class RemoteController(Controller):
413
417
  else:
414
418
  logger.warning(f"Action {prev_action.action_id.name} failed, but no error was found, re-running trace!")
415
419
  elif prev_action.realized_outputs_uri is not None:
416
- outputs_file_path = io.outputs_path(prev_action.realized_outputs_uri)
417
- o = await io.load_outputs(outputs_file_path, max_bytes=MAX_TRACE_BYTES)
420
+ o = await io.load_outputs(prev_action.realized_outputs_uri, max_bytes=MAX_TRACE_BYTES)
418
421
  outputs = await convert.convert_outputs_to_native(_interface, o)
419
422
  return (
420
423
  TraceInfo(func_name, sub_action_id, _interface, inputs_uri, output=outputs),
@@ -436,68 +439,64 @@ class RemoteController(Controller):
436
439
 
437
440
  current_action_id = tctx.action
438
441
  sub_run_output_path = storage.join(tctx.run_base_dir, info.action.name)
442
+ outputs_file_path: str = ""
439
443
 
440
444
  if info.interface.has_outputs():
441
- outputs_file_path: str = ""
442
- if info.output:
443
- outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
444
- outputs_file_path = io.outputs_path(sub_run_output_path)
445
- await io.upload_outputs(outputs, sub_run_output_path, max_bytes=MAX_TRACE_BYTES)
446
- elif info.error:
445
+ if info.error:
447
446
  err = convert.convert_from_native_to_error(info.error)
448
447
  await io.upload_error(err.err, sub_run_output_path)
449
448
  else:
450
- raise flyte.errors.RuntimeSystemError("BadTraceInfo", "Trace info does not have output or error")
451
-
452
- typed_interface = transform_native_to_typed_interface(info.interface)
453
-
454
- trace_action = Action.from_trace(
455
- parent_action_name=current_action_id.name,
456
- action_id=identifier_pb2.ActionIdentifier(
457
- name=info.action.name,
458
- run=identifier_pb2.RunIdentifier(
459
- name=current_action_id.run_name,
460
- project=current_action_id.project,
461
- domain=current_action_id.domain,
462
- org=current_action_id.org,
463
- ),
449
+ outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
450
+ outputs_file_path = io.outputs_path(sub_run_output_path)
451
+ await io.upload_outputs(outputs, sub_run_output_path, max_bytes=MAX_TRACE_BYTES)
452
+
453
+ typed_interface = transform_native_to_typed_interface(info.interface)
454
+
455
+ trace_action = Action.from_trace(
456
+ parent_action_name=current_action_id.name,
457
+ action_id=identifier_pb2.ActionIdentifier(
458
+ name=info.action.name,
459
+ run=identifier_pb2.RunIdentifier(
460
+ name=current_action_id.run_name,
461
+ project=current_action_id.project,
462
+ domain=current_action_id.domain,
463
+ org=current_action_id.org,
464
464
  ),
465
- inputs_uri=info.inputs_path,
466
- outputs_uri=outputs_file_path,
467
- friendly_name=info.name,
468
- group_data=tctx.group_data,
469
- run_output_base=tctx.run_base_dir,
470
- start_time=info.start_time,
471
- end_time=info.end_time,
472
- typed_interface=typed_interface if typed_interface else None,
473
- )
465
+ ),
466
+ inputs_uri=info.inputs_path,
467
+ outputs_uri=outputs_file_path,
468
+ friendly_name=info.name,
469
+ group_data=tctx.group_data,
470
+ run_output_base=tctx.run_base_dir,
471
+ start_time=info.start_time,
472
+ end_time=info.end_time,
473
+ typed_interface=typed_interface if typed_interface else None,
474
+ )
475
+
476
+ async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
477
+ try:
478
+ logger.info(
479
+ f"Submitting Trace action Run:[{trace_action.run_name},"
480
+ f" Parent:[{trace_action.parent_action_name}],"
481
+ f" Trace fn:[{info.name}], action:[{info.action.name}]"
482
+ )
483
+ await self.submit_action(trace_action)
484
+ logger.info(f"Trace Action for [{info.name}] action id: {info.action.name}, completed!")
485
+ except asyncio.CancelledError:
486
+ # If the action is cancelled, we need to cancel the action on the server as well
487
+ raise
474
488
 
475
- async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
476
- try:
477
- logger.info(
478
- f"Submitting Trace action Run:[{trace_action.run_name},"
479
- f" Parent:[{trace_action.parent_action_name}],"
480
- f" Trace fn:[{info.name}], action:[{info.action.name}]"
481
- )
482
- await self.submit_action(trace_action)
483
- logger.info(f"Trace Action for [{info.name}] action id: {info.action.name}, completed!")
484
- except asyncio.CancelledError:
485
- # If the action is cancelled, we need to cancel the action on the server as well
486
- raise
487
-
488
- async def _submit_task_ref(
489
- self, invoke_seq_num: int, _task: task_definition_pb2.TaskDetails, max_inline_io_bytes: int, *args, **kwargs
490
- ) -> Any:
489
+ async def _submit_task_ref(self, invoke_seq_num: int, _task: TaskDetails, *args, **kwargs) -> Any:
491
490
  ctx = internal_ctx()
492
491
  tctx = ctx.data.task_context
493
492
  if tctx is None:
494
493
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
495
494
  current_action_id = tctx.action
496
- task_name = _task.spec.task_template.id.name
495
+ task_name = _task.name
496
+
497
+ native_interface = _task.interface
498
+ pb_interface = _task.pb2.spec.task_template.interface
497
499
 
498
- native_interface = types.guess_interface(
499
- _task.spec.task_template.interface, default_inputs=_task.spec.default_inputs
500
- )
501
500
  inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
502
501
  inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
503
502
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
@@ -506,19 +505,19 @@ class RemoteController(Controller):
506
505
 
507
506
  serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
508
507
  inputs_uri = io.inputs_path(sub_action_output_path)
509
- await upload_inputs_with_retry(serialized_inputs, inputs_uri, max_inline_io_bytes)
508
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri, _task.max_inline_io_bytes)
510
509
  # cache key - task name, task signature, inputs, cache version
511
510
  cache_key = None
512
- md = _task.spec.task_template.metadata
511
+ md = _task.pb2.spec.task_template.metadata
513
512
  ignored_input_vars = []
514
513
  if len(md.cache_ignore_input_vars) > 0:
515
514
  ignored_input_vars = list(md.cache_ignore_input_vars)
516
- if _task.spec.task_template.metadata and _task.spec.task_template.metadata.discoverable:
517
- discovery_version = _task.spec.task_template.metadata.discovery_version
515
+ if md and md.discoverable:
516
+ discovery_version = md.discovery_version
518
517
  cache_key = convert.generate_cache_key_hash(
519
518
  task_name,
520
519
  inputs_hash,
521
- _task.spec.task_template.interface,
520
+ pb_interface,
522
521
  discovery_version,
523
522
  ignored_input_vars,
524
523
  inputs.proto_inputs,
@@ -540,10 +539,11 @@ class RemoteController(Controller):
540
539
  ),
541
540
  parent_action_name=current_action_id.name,
542
541
  group_data=tctx.group_data,
543
- task_spec=_task.spec,
542
+ task_spec=_task.pb2.spec,
544
543
  inputs_uri=inputs_uri,
545
544
  run_output_base=tctx.run_base_dir,
546
545
  cache_key=cache_key,
546
+ queue=None,
547
547
  )
548
548
 
549
549
  try:
@@ -569,12 +569,10 @@ class RemoteController(Controller):
569
569
  "RuntimeError",
570
570
  f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
571
571
  )
572
- return await load_and_convert_outputs(native_interface, n.realized_outputs_uri, max_inline_io_bytes)
572
+ return await load_and_convert_outputs(native_interface, n.realized_outputs_uri, _task.max_inline_io_bytes)
573
573
  return None
574
574
 
575
- async def submit_task_ref(
576
- self, _task: task_definition_pb2.TaskDetails, max_inline_io_bytes: int, *args, **kwargs
577
- ) -> Any:
575
+ async def submit_task_ref(self, _task: TaskDetails, *args, **kwargs) -> Any:
578
576
  ctx = internal_ctx()
579
577
  tctx = ctx.data.task_context
580
578
  if tctx is None:
@@ -582,4 +580,4 @@ class RemoteController(Controller):
582
580
  current_action_id = tctx.action
583
581
  task_call_seq = self.generate_task_call_sequence(_task, current_action_id)
584
582
  async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
585
- return await self._submit_task_ref(task_call_seq, _task, max_inline_io_bytes, *args, **kwargs)
583
+ return await self._submit_task_ref(task_call_seq, _task, *args, **kwargs)
@@ -1,21 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import os
4
5
  import sys
5
6
  import threading
6
7
  from asyncio import Event
7
8
  from typing import Awaitable, Coroutine, Optional
8
9
 
9
10
  import grpc.aio
11
+ from aiolimiter import AsyncLimiter
12
+ from flyteidl2.common import identifier_pb2
13
+ from flyteidl2.task import task_definition_pb2
14
+ from flyteidl2.workflow import queue_service_pb2, run_definition_pb2
10
15
  from google.protobuf.wrappers_pb2 import StringValue
11
16
 
12
17
  import flyte.errors
13
18
  from flyte._logging import log, logger
14
- from flyte._protos.common import identifier_pb2
15
- from flyte._protos.workflow import (
16
- queue_service_pb2,
17
- task_definition_pb2,
18
- )
19
19
 
20
20
  from ._action import Action
21
21
  from ._informer import InformerCache
@@ -32,10 +32,10 @@ class Controller:
32
32
  def __init__(
33
33
  self,
34
34
  client_coro: Awaitable[ClientSet],
35
- workers: int = 2,
36
- max_system_retries: int = 5,
35
+ workers: int = 20,
36
+ max_system_retries: int = 10,
37
37
  resource_log_interval_sec: float = 10.0,
38
- min_backoff_on_err_sec: float = 0.1,
38
+ min_backoff_on_err_sec: float = 0.5,
39
39
  thread_wait_timeout_sec: float = 5.0,
40
40
  enqueue_timeout_sec: float = 5.0,
41
41
  ):
@@ -53,14 +53,17 @@ class Controller:
53
53
  self._running = False
54
54
  self._resource_log_task = None
55
55
  self._workers = workers
56
- self._max_retries = max_system_retries
56
+ self._max_retries = int(os.getenv("_F_MAX_RETRIES", max_system_retries))
57
57
  self._resource_log_interval = resource_log_interval_sec
58
58
  self._min_backoff_on_err = min_backoff_on_err_sec
59
+ self._max_backoff_on_err = float(os.getenv("_F_MAX_BFF_ON_ERR", "10.0"))
59
60
  self._thread_wait_timeout = thread_wait_timeout_sec
60
61
  self._client_coro = client_coro
61
62
  self._failure_event: Event | None = None
62
63
  self._enqueue_timeout = enqueue_timeout_sec
63
64
  self._informer_start_wait_timeout = thread_wait_timeout_sec
65
+ max_qps = int(os.getenv("_F_MAX_QPS", "100"))
66
+ self._rate_limiter = AsyncLimiter(max_qps, 1.0)
64
67
 
65
68
  # Thread management
66
69
  self._thread = None
@@ -113,13 +116,14 @@ class Controller:
113
116
  raise RuntimeError("Failure event not initialized")
114
117
  self._failure_event.set()
115
118
  except asyncio.CancelledError:
116
- pass
119
+ raise
117
120
 
118
121
  async def _bg_watch_for_errors(self):
119
122
  if self._failure_event is None:
120
123
  raise RuntimeError("Failure event not initialized")
121
124
  await self._failure_event.wait()
122
125
  logger.warning(f"Failure event received: {self._failure_event}, cleaning up informers and exiting.")
126
+ self._running = False
123
127
 
124
128
  async def watch_for_errors(self):
125
129
  """Watch for errors in the background thread"""
@@ -158,8 +162,8 @@ class Controller:
158
162
  self._thread.start()
159
163
 
160
164
  # Wait for the thread to be ready
161
- logger.info("Waiting for controller thread to be ready...")
162
165
  if not self._thread_ready.wait(timeout=self._thread_wait_timeout):
166
+ logger.warning("Controller thread did not finish within timeout")
163
167
  raise TimeoutError("Controller thread failed to start in time")
164
168
 
165
169
  if self._get_exception():
@@ -194,15 +198,16 @@ class Controller:
194
198
  # We will wait for this to signal that the thread is ready
195
199
  # Signal the main thread that we're ready
196
200
  logger.debug("Background thread initialization complete")
197
- self._thread_ready.set()
198
201
  if sys.version_info >= (3, 11):
199
202
  async with asyncio.TaskGroup() as tg:
200
203
  for i in range(self._workers):
201
- tg.create_task(self._bg_run())
204
+ tg.create_task(self._bg_run(f"worker-{i}"))
205
+ self._thread_ready.set()
202
206
  else:
203
207
  tasks = []
204
208
  for i in range(self._workers):
205
- tasks.append(asyncio.create_task(self._bg_run()))
209
+ tasks.append(asyncio.create_task(self._bg_run(f"worker-{i}")))
210
+ self._thread_ready.set()
206
211
  await asyncio.gather(*tasks)
207
212
 
208
213
  def _bg_thread_target(self):
@@ -221,6 +226,7 @@ class Controller:
221
226
  except Exception as e:
222
227
  logger.error(f"Controller thread encountered an exception: {e}")
223
228
  self._set_exception(e)
229
+ self._failure_event.set()
224
230
  finally:
225
231
  if self._loop and self._loop.is_running():
226
232
  self._loop.close()
@@ -292,21 +298,21 @@ class Controller:
292
298
  started = action.is_started()
293
299
  action.mark_cancelled()
294
300
  if started:
295
- logger.info(f"Cancelling action: {action.name}")
296
- try:
297
- # TODO add support when the queue service supports aborting actions
298
- # await self._queue_service.AbortQueuedAction(
299
- # queue_service_pb2.AbortQueuedActionRequest(action_id=action.action_id),
300
- # wait_for_ready=True,
301
- # )
302
- logger.info(f"Successfully cancelled action: {action.name}")
303
- except grpc.aio.AioRpcError as e:
304
- if e.code() in [
305
- grpc.StatusCode.NOT_FOUND,
306
- grpc.StatusCode.FAILED_PRECONDITION,
307
- ]:
308
- logger.info(f"Action {action.name} not found, assumed completed or cancelled.")
309
- return
301
+ async with self._rate_limiter:
302
+ logger.info(f"Cancelling action: {action.name}")
303
+ try:
304
+ await self._queue_service.AbortQueuedAction(
305
+ queue_service_pb2.AbortQueuedActionRequest(action_id=action.action_id),
306
+ wait_for_ready=True,
307
+ )
308
+ logger.info(f"Successfully cancelled action: {action.name}")
309
+ except grpc.aio.AioRpcError as e:
310
+ if e.code() in [
311
+ grpc.StatusCode.NOT_FOUND,
312
+ grpc.StatusCode.FAILED_PRECONDITION,
313
+ ]:
314
+ logger.info(f"Action {action.name} not found, assumed completed or cancelled.")
315
+ return
310
316
  else:
311
317
  # If the action is not started, we have to ensure it does not get launched
312
318
  logger.info(f"Action {action.name} is not started, no need to cancel.")
@@ -320,56 +326,70 @@ class Controller:
320
326
  Attempt to launch an action.
321
327
  """
322
328
  if not action.is_started():
323
- task: queue_service_pb2.TaskAction | None = None
324
- trace: queue_service_pb2.TraceAction | None = None
325
- if action.type == "task":
326
- if action.task is None:
327
- raise flyte.errors.RuntimeSystemError(
328
- "NoTaskSpec", "Task Spec not found, cannot launch Task Action."
329
+ async with self._rate_limiter:
330
+ task: run_definition_pb2.TaskAction | None = None
331
+ trace: run_definition_pb2.TraceAction | None = None
332
+ if action.type == "task":
333
+ if action.task is None:
334
+ raise flyte.errors.RuntimeSystemError(
335
+ "NoTaskSpec", "Task Spec not found, cannot launch Task Action."
336
+ )
337
+ cache_key = None
338
+ logger.info(f"Action {action.name} has cache version {action.cache_key}")
339
+ if action.cache_key:
340
+ cache_key = StringValue(value=action.cache_key)
341
+
342
+ task = run_definition_pb2.TaskAction(
343
+ id=task_definition_pb2.TaskIdentifier(
344
+ version=action.task.task_template.id.version,
345
+ org=action.task.task_template.id.org,
346
+ project=action.task.task_template.id.project,
347
+ domain=action.task.task_template.id.domain,
348
+ name=action.task.task_template.id.name,
349
+ ),
350
+ spec=action.task,
351
+ cache_key=cache_key,
352
+ cluster=action.queue,
329
353
  )
330
- cache_key = None
331
- logger.info(f"Action {action.name} has cache version {action.cache_key}")
332
- if action.cache_key:
333
- cache_key = StringValue(value=action.cache_key)
334
-
335
- task = queue_service_pb2.TaskAction(
336
- id=task_definition_pb2.TaskIdentifier(
337
- version=action.task.task_template.id.version,
338
- org=action.task.task_template.id.org,
339
- project=action.task.task_template.id.project,
340
- domain=action.task.task_template.id.domain,
341
- name=action.task.task_template.id.name,
342
- ),
343
- spec=action.task,
344
- cache_key=cache_key,
345
- )
346
- elif action.type == "trace":
347
- trace = action.trace
348
-
349
- logger.debug(f"Attempting to launch action: {action.name}")
350
- try:
351
- await self._queue_service.EnqueueAction(
352
- queue_service_pb2.EnqueueActionRequest(
353
- action_id=action.action_id,
354
- parent_action_name=action.parent_action_name,
355
- task=task,
356
- trace=trace,
357
- input_uri=action.inputs_uri,
358
- run_output_base=action.run_output_base,
359
- group=action.group.name if action.group else None,
360
- # Subject is not used in the current implementation
361
- ),
362
- wait_for_ready=True,
363
- timeout=self._enqueue_timeout,
364
- )
365
- logger.info(f"Successfully launched action: {action.name}")
366
- except grpc.aio.AioRpcError as e:
367
- if e.code() == grpc.StatusCode.ALREADY_EXISTS:
368
- logger.info(f"Action {action.name} already exists, continuing to monitor.")
369
- return
370
- logger.exception(f"Failed to launch action: {action.name} backing off...")
371
- logger.debug(f"Action details: {action}")
372
- raise e
354
+ elif action.type == "trace":
355
+ trace = action.trace
356
+
357
+ logger.debug(f"Attempting to launch action: {action.name}")
358
+ try:
359
+ await self._queue_service.EnqueueAction(
360
+ queue_service_pb2.EnqueueActionRequest(
361
+ action_id=action.action_id,
362
+ parent_action_name=action.parent_action_name,
363
+ task=task,
364
+ trace=trace,
365
+ input_uri=action.inputs_uri,
366
+ run_output_base=action.run_output_base,
367
+ group=action.group.name if action.group else None,
368
+ # Subject is not used in the current implementation
369
+ ),
370
+ wait_for_ready=True,
371
+ timeout=self._enqueue_timeout,
372
+ )
373
+ logger.info(f"Successfully launched action: {action.name}")
374
+ except grpc.aio.AioRpcError as e:
375
+ if e.code() == grpc.StatusCode.ALREADY_EXISTS:
376
+ logger.info(f"Action {action.name} already exists, continuing to monitor.")
377
+ return
378
+ if e.code() in [
379
+ grpc.StatusCode.FAILED_PRECONDITION,
380
+ grpc.StatusCode.INVALID_ARGUMENT,
381
+ grpc.StatusCode.NOT_FOUND,
382
+ ]:
383
+ raise flyte.errors.RuntimeSystemError(
384
+ e.code().name, f"Precondition failed: {e.details()}"
385
+ ) from e
386
+ # For all other errors, we will retry with backoff
387
+ logger.exception(
388
+ f"Failed to launch action: {action.name}, Code: {e.code()}, "
389
+ f"Details {e.details()} backing off..."
390
+ )
391
+ logger.debug(f"Action details: {action}")
392
+ raise flyte.errors.SlowDownError(f"Failed to launch action: {e.details()}") from e
373
393
 
374
394
  @log
375
395
  async def _bg_process(self, action: Action):
@@ -397,35 +417,43 @@ class Controller:
397
417
  await asyncio.sleep(self._resource_log_interval)
398
418
 
399
419
  @log
400
- async def _bg_run(self):
420
+ async def _bg_run(self, worker_id: str):
401
421
  """Run loop with resource status logging"""
422
+ logger.info(f"Worker {worker_id} started")
402
423
  while self._running:
403
424
  logger.debug(f"{threading.current_thread().name} Waiting for resource")
404
425
  action = await self._shared_queue.get()
405
426
  logger.debug(f"{threading.current_thread().name} Got resource {action.name}")
406
427
  try:
407
428
  await self._bg_process(action)
408
- except Exception as e:
409
- logger.error(f"Error in controller loop: {e}")
410
- # TODO we need a better way of handling backoffs currently the entire worker coroutine backs off
411
- await asyncio.sleep(self._min_backoff_on_err)
412
- action.increment_retries()
429
+ except flyte.errors.SlowDownError as e:
430
+ action.retries += 1
413
431
  if action.retries > self._max_retries:
414
- err = flyte.errors.RuntimeSystemError(
415
- code=type(e).__name__,
416
- message=f"Controller failed, system retries {action.retries}"
417
- f" crossed threshold {self._max_retries}",
418
- )
419
- err.__cause__ = e
420
- action.set_client_error(err)
421
- informer = await self._informers.get(
422
- run_name=action.run_name,
423
- parent_action_name=action.parent_action_name,
424
- )
425
- if informer:
426
- await informer.fire_completion_event(action.name)
427
- else:
428
- await self._shared_queue.put(action)
432
+ raise
433
+ backoff = min(self._min_backoff_on_err * (2 ** (action.retries - 1)), self._max_backoff_on_err)
434
+ logger.warning(
435
+ f"[{worker_id}] Backing off for {backoff} [retry {action.retries}/{self._max_retries}] "
436
+ f"on action {action.name} due to error: {e}"
437
+ )
438
+ await asyncio.sleep(backoff)
439
+ logger.warning(f"[{worker_id}] Retrying action {action.name} after backoff")
440
+ await self._shared_queue.put(action)
441
+ except Exception as e:
442
+ logger.error(f"[{worker_id}] Error in controller loop for {action.name}: {e}")
443
+ err = flyte.errors.RuntimeSystemError(
444
+ code=type(e).__name__,
445
+ message=f"Controller failed, system retries {action.retries} / {self._max_retries} "
446
+ f"crossed threshold, for action {action.name}: {e}",
447
+ worker=worker_id,
448
+ )
449
+ err.__cause__ = e
450
+ action.set_client_error(err)
451
+ informer = await self._informers.get(
452
+ run_name=action.run_name,
453
+ parent_action_name=action.parent_action_name,
454
+ )
455
+ if informer:
456
+ await informer.fire_completion_event(action.name)
429
457
  finally:
430
458
  self._shared_queue.task_done()
431
459