flyte 2.0.0b22__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/runtime.py +43 -5
  3. flyte/_cache/cache.py +4 -2
  4. flyte/_cache/local_cache.py +216 -0
  5. flyte/_code_bundle/_ignore.py +1 -1
  6. flyte/_code_bundle/_packaging.py +4 -4
  7. flyte/_code_bundle/_utils.py +14 -8
  8. flyte/_code_bundle/bundle.py +13 -5
  9. flyte/_constants.py +1 -0
  10. flyte/_context.py +4 -1
  11. flyte/_custom_context.py +73 -0
  12. flyte/_debug/constants.py +0 -1
  13. flyte/_debug/vscode.py +6 -1
  14. flyte/_deploy.py +223 -59
  15. flyte/_environment.py +5 -0
  16. flyte/_excepthook.py +1 -1
  17. flyte/_image.py +144 -82
  18. flyte/_initialize.py +95 -12
  19. flyte/_interface.py +2 -0
  20. flyte/_internal/controllers/_local_controller.py +65 -24
  21. flyte/_internal/controllers/_trace.py +1 -1
  22. flyte/_internal/controllers/remote/_action.py +13 -11
  23. flyte/_internal/controllers/remote/_client.py +1 -1
  24. flyte/_internal/controllers/remote/_controller.py +9 -4
  25. flyte/_internal/controllers/remote/_core.py +16 -16
  26. flyte/_internal/controllers/remote/_informer.py +4 -4
  27. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  28. flyte/_internal/imagebuild/docker_builder.py +139 -84
  29. flyte/_internal/imagebuild/image_builder.py +7 -13
  30. flyte/_internal/imagebuild/remote_builder.py +65 -13
  31. flyte/_internal/imagebuild/utils.py +51 -3
  32. flyte/_internal/resolvers/_task_module.py +5 -38
  33. flyte/_internal/resolvers/default.py +2 -2
  34. flyte/_internal/runtime/convert.py +42 -20
  35. flyte/_internal/runtime/entrypoints.py +24 -1
  36. flyte/_internal/runtime/io.py +21 -8
  37. flyte/_internal/runtime/resources_serde.py +20 -6
  38. flyte/_internal/runtime/reuse.py +1 -1
  39. flyte/_internal/runtime/rusty.py +20 -5
  40. flyte/_internal/runtime/task_serde.py +33 -27
  41. flyte/_internal/runtime/taskrunner.py +10 -1
  42. flyte/_internal/runtime/trigger_serde.py +160 -0
  43. flyte/_internal/runtime/types_serde.py +1 -1
  44. flyte/_keyring/file.py +39 -9
  45. flyte/_logging.py +79 -12
  46. flyte/_map.py +31 -12
  47. flyte/_module.py +70 -0
  48. flyte/_pod.py +2 -2
  49. flyte/_resources.py +213 -31
  50. flyte/_run.py +107 -41
  51. flyte/_task.py +66 -10
  52. flyte/_task_environment.py +96 -24
  53. flyte/_task_plugins.py +4 -2
  54. flyte/_trigger.py +1000 -0
  55. flyte/_utils/__init__.py +2 -1
  56. flyte/_utils/asyn.py +3 -1
  57. flyte/_utils/docker_credentials.py +173 -0
  58. flyte/_utils/module_loader.py +17 -2
  59. flyte/_version.py +3 -3
  60. flyte/cli/_abort.py +3 -3
  61. flyte/cli/_build.py +1 -3
  62. flyte/cli/_common.py +78 -7
  63. flyte/cli/_create.py +178 -3
  64. flyte/cli/_delete.py +23 -1
  65. flyte/cli/_deploy.py +49 -11
  66. flyte/cli/_get.py +79 -34
  67. flyte/cli/_params.py +8 -6
  68. flyte/cli/_plugins.py +209 -0
  69. flyte/cli/_run.py +127 -11
  70. flyte/cli/_serve.py +64 -0
  71. flyte/cli/_update.py +37 -0
  72. flyte/cli/_user.py +17 -0
  73. flyte/cli/main.py +30 -4
  74. flyte/config/_config.py +2 -0
  75. flyte/config/_internal.py +1 -0
  76. flyte/config/_reader.py +3 -3
  77. flyte/connectors/__init__.py +11 -0
  78. flyte/connectors/_connector.py +270 -0
  79. flyte/connectors/_server.py +197 -0
  80. flyte/connectors/utils.py +135 -0
  81. flyte/errors.py +10 -1
  82. flyte/extend.py +8 -1
  83. flyte/extras/_container.py +6 -1
  84. flyte/git/_config.py +11 -9
  85. flyte/io/__init__.py +2 -0
  86. flyte/io/_dataframe/__init__.py +2 -0
  87. flyte/io/_dataframe/basic_dfs.py +1 -1
  88. flyte/io/_dataframe/dataframe.py +12 -8
  89. flyte/io/_dir.py +551 -120
  90. flyte/io/_file.py +538 -141
  91. flyte/models.py +57 -12
  92. flyte/remote/__init__.py +6 -1
  93. flyte/remote/_action.py +18 -16
  94. flyte/remote/_client/_protocols.py +39 -4
  95. flyte/remote/_client/auth/_channel.py +10 -6
  96. flyte/remote/_client/controlplane.py +17 -5
  97. flyte/remote/_console.py +3 -2
  98. flyte/remote/_data.py +4 -3
  99. flyte/remote/_logs.py +3 -3
  100. flyte/remote/_run.py +47 -7
  101. flyte/remote/_secret.py +26 -17
  102. flyte/remote/_task.py +21 -9
  103. flyte/remote/_trigger.py +306 -0
  104. flyte/remote/_user.py +33 -0
  105. flyte/storage/__init__.py +6 -1
  106. flyte/storage/_parallel_reader.py +274 -0
  107. flyte/storage/_storage.py +185 -103
  108. flyte/types/__init__.py +16 -0
  109. flyte/types/_interface.py +2 -2
  110. flyte/types/_pickle.py +17 -4
  111. flyte/types/_string_literals.py +8 -9
  112. flyte/types/_type_engine.py +26 -19
  113. flyte/types/_utils.py +1 -1
  114. {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/runtime.py +43 -5
  115. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/METADATA +8 -1
  116. flyte-2.0.0b30.dist-info/RECORD +192 -0
  117. flyte/_protos/__init__.py +0 -0
  118. flyte/_protos/common/authorization_pb2.py +0 -66
  119. flyte/_protos/common/authorization_pb2.pyi +0 -108
  120. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  121. flyte/_protos/common/identifier_pb2.py +0 -99
  122. flyte/_protos/common/identifier_pb2.pyi +0 -120
  123. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  124. flyte/_protos/common/identity_pb2.py +0 -48
  125. flyte/_protos/common/identity_pb2.pyi +0 -72
  126. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  127. flyte/_protos/common/list_pb2.py +0 -36
  128. flyte/_protos/common/list_pb2.pyi +0 -71
  129. flyte/_protos/common/list_pb2_grpc.py +0 -4
  130. flyte/_protos/common/policy_pb2.py +0 -37
  131. flyte/_protos/common/policy_pb2.pyi +0 -27
  132. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  133. flyte/_protos/common/role_pb2.py +0 -37
  134. flyte/_protos/common/role_pb2.pyi +0 -53
  135. flyte/_protos/common/role_pb2_grpc.py +0 -4
  136. flyte/_protos/common/runtime_version_pb2.py +0 -28
  137. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  138. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  139. flyte/_protos/imagebuilder/definition_pb2.py +0 -60
  140. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -153
  141. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  142. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  143. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  144. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  145. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  146. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  147. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  148. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  149. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  150. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  151. flyte/_protos/secret/definition_pb2.py +0 -49
  152. flyte/_protos/secret/definition_pb2.pyi +0 -93
  153. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  154. flyte/_protos/secret/payload_pb2.py +0 -62
  155. flyte/_protos/secret/payload_pb2.pyi +0 -94
  156. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  157. flyte/_protos/secret/secret_pb2.py +0 -38
  158. flyte/_protos/secret/secret_pb2.pyi +0 -6
  159. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  160. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  161. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  162. flyte/_protos/workflow/common_pb2.py +0 -27
  163. flyte/_protos/workflow/common_pb2.pyi +0 -14
  164. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  165. flyte/_protos/workflow/environment_pb2.py +0 -29
  166. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  167. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  168. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  169. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  170. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  171. flyte/_protos/workflow/queue_service_pb2.py +0 -111
  172. flyte/_protos/workflow/queue_service_pb2.pyi +0 -168
  173. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  174. flyte/_protos/workflow/run_definition_pb2.py +0 -123
  175. flyte/_protos/workflow/run_definition_pb2.pyi +0 -352
  176. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  177. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  178. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  179. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  180. flyte/_protos/workflow/run_service_pb2.py +0 -137
  181. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  182. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  183. flyte/_protos/workflow/state_service_pb2.py +0 -67
  184. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  185. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  186. flyte/_protos/workflow/task_definition_pb2.py +0 -82
  187. flyte/_protos/workflow/task_definition_pb2.pyi +0 -88
  188. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  189. flyte/_protos/workflow/task_service_pb2.py +0 -60
  190. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  191. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  192. flyte-2.0.0b22.dist-info/RECORD +0 -250
  193. {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/debug.py +0 -0
  194. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  195. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +0 -0
  196. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  197. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
flyte/_interface.py CHANGED
@@ -51,6 +51,8 @@ def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Di
51
51
  Note that Options 1 and 2 are identical, just syntactic sugar. In the NamedTuple case, we'll use the names in the
52
52
  definition. In all other cases, we'll automatically generate output names, indexed starting at 0.
53
53
  """
54
+ if isinstance(return_annotation, str):
55
+ raise TypeError("String return annotations are not supported.")
54
56
 
55
57
  # Handle Option 6
56
58
  # We can think about whether we should add a default output name with type None in the future.
@@ -2,16 +2,20 @@ import asyncio
2
2
  import atexit
3
3
  import concurrent.futures
4
4
  import os
5
+ import pathlib
5
6
  import threading
6
7
  from typing import Any, Callable, Tuple, TypeVar
7
8
 
8
9
  import flyte.errors
10
+ from flyte._cache.cache import VersionParameters, cache_from_request
11
+ from flyte._cache.local_cache import LocalTaskCache
9
12
  from flyte._context import internal_ctx
10
13
  from flyte._internal.controllers import TraceInfo
11
14
  from flyte._internal.runtime import convert
12
15
  from flyte._internal.runtime.entrypoints import direct_dispatch
16
+ from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
13
17
  from flyte._logging import log, logger
14
- from flyte._task import TaskTemplate
18
+ from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
15
19
  from flyte._utils.helpers import _selector_policy
16
20
  from flyte.models import ActionID, NativeInterface
17
21
  from flyte.remote._task import TaskDetails
@@ -81,31 +85,67 @@ class LocalController:
81
85
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
82
86
 
83
87
  inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
84
- serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
88
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
89
+ task_interface = transform_native_to_typed_interface(_task.interface)
85
90
 
86
91
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
87
- tctx, _task.name, serialized_inputs, 0
92
+ tctx, _task.name, inputs_hash, 0
88
93
  )
89
94
  sub_action_raw_data_path = tctx.raw_data_path
90
-
91
- out, err = await direct_dispatch(
92
- _task,
93
- controller=self,
94
- action=sub_action_id,
95
- raw_data_path=sub_action_raw_data_path,
96
- inputs=inputs,
97
- version=tctx.version,
98
- checkpoints=tctx.checkpoints,
99
- code_bundle=tctx.code_bundle,
100
- output_path=sub_action_output_path,
101
- run_base_dir=tctx.run_base_dir,
95
+ # Make sure the output path exists
96
+ pathlib.Path(sub_action_output_path).mkdir(parents=True, exist_ok=True)
97
+ pathlib.Path(sub_action_raw_data_path.path).mkdir(parents=True, exist_ok=True)
98
+
99
+ task_cache = cache_from_request(_task.cache)
100
+ cache_enabled = task_cache.is_enabled()
101
+ if isinstance(_task, AsyncFunctionTaskTemplate):
102
+ version_parameters = VersionParameters(func=_task.func, image=_task.image)
103
+ else:
104
+ version_parameters = VersionParameters(func=None, image=_task.image)
105
+ cache_version = task_cache.get_version(version_parameters)
106
+ cache_key = convert.generate_cache_key_hash(
107
+ _task.name,
108
+ inputs_hash,
109
+ task_interface,
110
+ cache_version,
111
+ list(task_cache.get_ignored_inputs()),
112
+ inputs.proto_inputs,
102
113
  )
103
- if err:
104
- exc = convert.convert_error_to_native(err)
105
- if exc:
106
- raise exc
107
- else:
108
- raise flyte.errors.RuntimeSystemError("BadError", "Unknown error")
114
+
115
+ out = None
116
+ # We only get output from cache if the cache behavior is set to auto
117
+ if task_cache.behavior == "auto":
118
+ out = await LocalTaskCache.get(cache_key)
119
+ if out is not None:
120
+ logger.info(
121
+ f"Cache hit for task '{_task.name}' (version: {cache_version}), getting result from cache..."
122
+ )
123
+
124
+ if out is None:
125
+ out, err = await direct_dispatch(
126
+ _task,
127
+ controller=self,
128
+ action=sub_action_id,
129
+ raw_data_path=sub_action_raw_data_path,
130
+ inputs=inputs,
131
+ version=cache_version,
132
+ checkpoints=tctx.checkpoints,
133
+ code_bundle=tctx.code_bundle,
134
+ output_path=sub_action_output_path,
135
+ run_base_dir=tctx.run_base_dir,
136
+ )
137
+
138
+ if err:
139
+ exc = convert.convert_error_to_native(err)
140
+ if exc:
141
+ raise exc
142
+ else:
143
+ raise flyte.errors.RuntimeSystemError("BadError", "Unknown error")
144
+
145
+ # store into cache
146
+ if cache_enabled and out is not None:
147
+ await LocalTaskCache.set(cache_key, out)
148
+
109
149
  if _task.native_interface.outputs:
110
150
  if out is None:
111
151
  raise flyte.errors.RuntimeSystemError("BadOutput", "Task output not captured.")
@@ -129,7 +169,7 @@ class LocalController:
129
169
  pass
130
170
 
131
171
  async def stop(self):
132
- pass
172
+ await LocalTaskCache.close()
133
173
 
134
174
  async def watch_for_errors(self):
135
175
  pass
@@ -146,16 +186,17 @@ class LocalController:
146
186
  tctx = ctx.data.task_context
147
187
  if not tctx:
148
188
  raise flyte.errors.NotInTaskContextError("BadContext", "Task context not initialized")
189
+
149
190
  converted_inputs = convert.Inputs.empty()
150
191
  if _interface.inputs:
151
192
  converted_inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
152
193
  assert converted_inputs
153
194
 
154
- serialized_inputs = converted_inputs.proto_inputs.SerializeToString(deterministic=True)
195
+ inputs_hash = convert.generate_inputs_hash_from_proto(converted_inputs.proto_inputs)
155
196
  action_id, action_output_path = convert.generate_sub_action_id_and_output_path(
156
197
  tctx,
157
198
  _func.__name__,
158
- serialized_inputs,
199
+ inputs_hash,
159
200
  0,
160
201
  )
161
202
  assert action_output_path
@@ -1,7 +1,7 @@
1
1
  from dataclasses import dataclass, field
2
2
  from typing import Any, Optional
3
3
 
4
- from flyteidl.core import interface_pb2
4
+ from flyteidl2.core import interface_pb2
5
5
 
6
6
  from flyte.models import ActionID, NativeInterface
7
7
 
@@ -1,18 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Literal
4
+ from typing import Literal, Optional
5
5
 
6
- from flyteidl.core import execution_pb2, interface_pb2
7
- from google.protobuf import timestamp_pb2
8
-
9
- from flyte._protos.common import identifier_pb2
10
- from flyte._protos.workflow import (
11
- queue_service_pb2,
6
+ from flyteidl2.common import identifier_pb2
7
+ from flyteidl2.core import execution_pb2, interface_pb2
8
+ from flyteidl2.task import common_pb2, task_definition_pb2
9
+ from flyteidl2.workflow import (
12
10
  run_definition_pb2,
13
11
  state_service_pb2,
14
- task_definition_pb2,
15
12
  )
13
+ from google.protobuf import timestamp_pb2
14
+
16
15
  from flyte.models import GroupData
17
16
 
18
17
  ActionType = Literal["task", "trace"]
@@ -31,7 +30,7 @@ class Action:
31
30
  friendly_name: str | None = None
32
31
  group: GroupData | None = None
33
32
  task: task_definition_pb2.TaskSpec | None = None
34
- trace: queue_service_pb2.TraceAction | None = None
33
+ trace: run_definition_pb2.TraceAction | None = None
35
34
  inputs_uri: str | None = None
36
35
  run_output_base: str | None = None
37
36
  realized_outputs_uri: str | None = None
@@ -39,6 +38,7 @@ class Action:
39
38
  phase: run_definition_pb2.Phase | None = None
40
39
  started: bool = False
41
40
  retries: int = 0
41
+ queue: Optional[str] = None # The queue to which this action was submitted.
42
42
  client_err: Exception | None = None # This error is set when something goes wrong in the controller.
43
43
  cache_key: str | None = None # None means no caching, otherwise it is the version of the cache.
44
44
 
@@ -122,6 +122,7 @@ class Action:
122
122
  inputs_uri: str,
123
123
  run_output_base: str,
124
124
  cache_key: str | None = None,
125
+ queue: Optional[str] = None,
125
126
  ) -> Action:
126
127
  return cls(
127
128
  action_id=sub_action_id,
@@ -132,6 +133,7 @@ class Action:
132
133
  inputs_uri=inputs_uri,
133
134
  run_output_base=run_output_base,
134
135
  cache_key=cache_key,
136
+ queue=queue,
135
137
  )
136
138
 
137
139
  @classmethod
@@ -195,12 +197,12 @@ class Action:
195
197
  realized_outputs_uri=outputs_uri,
196
198
  phase=run_definition_pb2.Phase.PHASE_SUCCEEDED,
197
199
  run_output_base=run_output_base,
198
- trace=queue_service_pb2.TraceAction(
200
+ trace=run_definition_pb2.TraceAction(
199
201
  name=friendly_name,
200
202
  phase=run_definition_pb2.Phase.PHASE_SUCCEEDED,
201
203
  start_time=st,
202
204
  end_time=et,
203
- outputs=run_definition_pb2.OutputReferences(
205
+ outputs=common_pb2.OutputReferences(
204
206
  output_uri=outputs_uri,
205
207
  report_uri=report_uri,
206
208
  ),
@@ -1,8 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import grpc.aio
4
+ from flyteidl2.workflow import queue_service_pb2_grpc, state_service_pb2_grpc
4
5
 
5
- from flyte._protos.workflow import queue_service_pb2_grpc, state_service_pb2_grpc
6
6
  from flyte.remote import create_channel
7
7
 
8
8
  from ._service_protocol import QueueService, StateService
@@ -9,6 +9,9 @@ from collections.abc import Callable
9
9
  from pathlib import Path
10
10
  from typing import Any, Awaitable, DefaultDict, Tuple, TypeVar
11
11
 
12
+ from flyteidl2.common import identifier_pb2
13
+ from flyteidl2.workflow import run_definition_pb2
14
+
12
15
  import flyte
13
16
  import flyte.errors
14
17
  import flyte.storage as storage
@@ -22,8 +25,6 @@ from flyte._internal.runtime import convert, io
22
25
  from flyte._internal.runtime.task_serde import translate_task_to_wire
23
26
  from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
24
27
  from flyte._logging import logger
25
- from flyte._protos.common import identifier_pb2
26
- from flyte._protos.workflow import run_definition_pb2
27
28
  from flyte._task import TaskTemplate
28
29
  from flyte._utils.helpers import _selector_policy
29
30
  from flyte.models import MAX_INLINE_IO_BYTES, ActionID, NativeInterface, SerializationContext
@@ -126,7 +127,7 @@ class RemoteController(Controller):
126
127
  workers=workers,
127
128
  max_system_retries=max_system_retries,
128
129
  )
129
- default_parent_concurrency = int(os.getenv("_F_P_CNC", "100"))
130
+ default_parent_concurrency = int(os.getenv("_F_P_CNC", "1000"))
130
131
  self._default_parent_concurrency = default_parent_concurrency
131
132
  self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
132
133
  lambda: asyncio.Semaphore(default_parent_concurrency)
@@ -238,6 +239,7 @@ class RemoteController(Controller):
238
239
  inputs_uri=inputs_uri,
239
240
  run_output_base=tctx.run_base_dir,
240
241
  cache_key=cache_key,
242
+ queue=_task.queue,
241
243
  )
242
244
 
243
245
  try:
@@ -375,11 +377,13 @@ class RemoteController(Controller):
375
377
 
376
378
  func_name = _func.__name__
377
379
  invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id)
380
+
378
381
  inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
379
382
  serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
383
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
380
384
 
381
385
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
382
- tctx, func_name, serialized_inputs, invoke_seq_num
386
+ tctx, func_name, inputs_hash, invoke_seq_num
383
387
  )
384
388
 
385
389
  inputs_uri = io.inputs_path(sub_action_output_path)
@@ -539,6 +543,7 @@ class RemoteController(Controller):
539
543
  inputs_uri=inputs_uri,
540
544
  run_output_base=tctx.run_base_dir,
541
545
  cache_key=cache_key,
546
+ queue=None,
542
547
  )
543
548
 
544
549
  try:
@@ -9,15 +9,13 @@ from typing import Awaitable, Coroutine, Optional
9
9
 
10
10
  import grpc.aio
11
11
  from aiolimiter import AsyncLimiter
12
+ from flyteidl2.common import identifier_pb2
13
+ from flyteidl2.task import task_definition_pb2
14
+ from flyteidl2.workflow import queue_service_pb2, run_definition_pb2
12
15
  from google.protobuf.wrappers_pb2 import StringValue
13
16
 
14
17
  import flyte.errors
15
18
  from flyte._logging import log, logger
16
- from flyte._protos.common import identifier_pb2
17
- from flyte._protos.workflow import (
18
- queue_service_pb2,
19
- task_definition_pb2,
20
- )
21
19
 
22
20
  from ._action import Action
23
21
  from ._informer import InformerCache
@@ -118,13 +116,14 @@ class Controller:
118
116
  raise RuntimeError("Failure event not initialized")
119
117
  self._failure_event.set()
120
118
  except asyncio.CancelledError:
121
- pass
119
+ raise
122
120
 
123
121
  async def _bg_watch_for_errors(self):
124
122
  if self._failure_event is None:
125
123
  raise RuntimeError("Failure event not initialized")
126
124
  await self._failure_event.wait()
127
125
  logger.warning(f"Failure event received: {self._failure_event}, cleaning up informers and exiting.")
126
+ self._running = False
128
127
 
129
128
  async def watch_for_errors(self):
130
129
  """Watch for errors in the background thread"""
@@ -302,11 +301,10 @@ class Controller:
302
301
  async with self._rate_limiter:
303
302
  logger.info(f"Cancelling action: {action.name}")
304
303
  try:
305
- # TODO add support when the queue service supports aborting actions
306
- # await self._queue_service.AbortQueuedAction(
307
- # queue_service_pb2.AbortQueuedActionRequest(action_id=action.action_id),
308
- # wait_for_ready=True,
309
- # )
304
+ await self._queue_service.AbortQueuedAction(
305
+ queue_service_pb2.AbortQueuedActionRequest(action_id=action.action_id),
306
+ wait_for_ready=True,
307
+ )
310
308
  logger.info(f"Successfully cancelled action: {action.name}")
311
309
  except grpc.aio.AioRpcError as e:
312
310
  if e.code() in [
@@ -329,8 +327,8 @@ class Controller:
329
327
  """
330
328
  if not action.is_started():
331
329
  async with self._rate_limiter:
332
- task: queue_service_pb2.TaskAction | None = None
333
- trace: queue_service_pb2.TraceAction | None = None
330
+ task: run_definition_pb2.TaskAction | None = None
331
+ trace: run_definition_pb2.TraceAction | None = None
334
332
  if action.type == "task":
335
333
  if action.task is None:
336
334
  raise flyte.errors.RuntimeSystemError(
@@ -341,7 +339,7 @@ class Controller:
341
339
  if action.cache_key:
342
340
  cache_key = StringValue(value=action.cache_key)
343
341
 
344
- task = queue_service_pb2.TaskAction(
342
+ task = run_definition_pb2.TaskAction(
345
343
  id=task_definition_pb2.TaskIdentifier(
346
344
  version=action.task.task_template.id.version,
347
345
  org=action.task.task_template.id.org,
@@ -351,6 +349,7 @@ class Controller:
351
349
  ),
352
350
  spec=action.task,
353
351
  cache_key=cache_key,
352
+ cluster=action.queue,
354
353
  )
355
354
  elif action.type == "trace":
356
355
  trace = action.trace
@@ -440,10 +439,11 @@ class Controller:
440
439
  logger.warning(f"[{worker_id}] Retrying action {action.name} after backoff")
441
440
  await self._shared_queue.put(action)
442
441
  except Exception as e:
443
- logger.error(f"[{worker_id}] Error in controller loop: {e}")
442
+ logger.error(f"[{worker_id}] Error in controller loop for {action.name}: {e}")
444
443
  err = flyte.errors.RuntimeSystemError(
445
444
  code=type(e).__name__,
446
- message=f"Controller failed, system retries {action.retries} crossed threshold {self._max_retries}",
445
+ message=f"Controller failed, system retries {action.retries} / {self._max_retries} "
446
+ f"crossed threshold, for action {action.name}: {e}",
447
447
  worker=worker_id,
448
448
  )
449
449
  err.__cause__ = e
@@ -5,10 +5,10 @@ from asyncio import Queue
5
5
  from typing import AsyncIterator, Callable, Dict, Optional, Tuple, cast
6
6
 
7
7
  import grpc.aio
8
+ from flyteidl2.common import identifier_pb2
9
+ from flyteidl2.workflow import run_definition_pb2, state_service_pb2
8
10
 
9
11
  from flyte._logging import log, logger
10
- from flyte._protos.common import identifier_pb2
11
- from flyte._protos.workflow import run_definition_pb2, state_service_pb2
12
12
 
13
13
  from ._action import Action
14
14
  from ._service_protocol import StateService
@@ -270,7 +270,7 @@ class Informer:
270
270
  logger.warning("Informer already running")
271
271
  return cast(asyncio.Task, self._watch_task)
272
272
  self._running = True
273
- self._watch_task = asyncio.create_task(self.watch())
273
+ self._watch_task = asyncio.create_task(self.watch(), name=f"InformerWatch-{self.parent_action_name}")
274
274
  await self.wait_for_cache_sync(timeout=timeout)
275
275
  return self._watch_task
276
276
 
@@ -373,7 +373,7 @@ class InformerCache:
373
373
  """Stop all informers and remove them from the cache"""
374
374
  async with self._lock:
375
375
  while self._cache:
376
- name, informer = self._cache.popitem()
376
+ _name, informer = self._cache.popitem()
377
377
  try:
378
378
  await informer.stop()
379
379
  except asyncio.CancelledError:
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from typing import AsyncIterator, Protocol
4
4
 
5
- from flyte._protos.workflow import queue_service_pb2, state_service_pb2
5
+ from flyteidl2.workflow import queue_service_pb2, state_service_pb2
6
6
 
7
7
 
8
8
  class StateService(Protocol):
@@ -28,12 +28,12 @@ class QueueService(Protocol):
28
28
  ) -> queue_service_pb2.EnqueueActionResponse:
29
29
  """Enqueue a task"""
30
30
 
31
- # async def AbortQueuedAction(
32
- # self,
33
- # req: queue_service_pb2.AbortQueuedActionRequest,
34
- # **kwargs,
35
- # ) -> queue_service_pb2.AbortQueuedActionResponse:
36
- # """Dequeue a task"""
31
+ async def AbortQueuedAction(
32
+ self,
33
+ req: queue_service_pb2.AbortQueuedActionRequest,
34
+ **kwargs,
35
+ ) -> queue_service_pb2.AbortQueuedActionResponse:
36
+ """Cancel an enqueued task"""
37
37
 
38
38
 
39
39
  class ClientSet(Protocol):