flyte 0.2.0b1__py3-none-any.whl → 2.0.0b46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. flyte/__init__.py +83 -30
  2. flyte/_bin/connect.py +61 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +87 -19
  5. flyte/_bin/serve.py +351 -0
  6. flyte/_build.py +3 -2
  7. flyte/_cache/cache.py +6 -5
  8. flyte/_cache/local_cache.py +216 -0
  9. flyte/_code_bundle/_ignore.py +31 -5
  10. flyte/_code_bundle/_packaging.py +42 -11
  11. flyte/_code_bundle/_utils.py +57 -34
  12. flyte/_code_bundle/bundle.py +130 -27
  13. flyte/_constants.py +1 -0
  14. flyte/_context.py +21 -5
  15. flyte/_custom_context.py +73 -0
  16. flyte/_debug/constants.py +37 -0
  17. flyte/_debug/utils.py +17 -0
  18. flyte/_debug/vscode.py +315 -0
  19. flyte/_deploy.py +396 -75
  20. flyte/_deployer.py +109 -0
  21. flyte/_environment.py +94 -11
  22. flyte/_excepthook.py +37 -0
  23. flyte/_group.py +2 -1
  24. flyte/_hash.py +1 -16
  25. flyte/_image.py +544 -231
  26. flyte/_initialize.py +456 -316
  27. flyte/_interface.py +40 -5
  28. flyte/_internal/controllers/__init__.py +22 -8
  29. flyte/_internal/controllers/_local_controller.py +159 -35
  30. flyte/_internal/controllers/_trace.py +18 -10
  31. flyte/_internal/controllers/remote/__init__.py +38 -9
  32. flyte/_internal/controllers/remote/_action.py +82 -12
  33. flyte/_internal/controllers/remote/_client.py +6 -2
  34. flyte/_internal/controllers/remote/_controller.py +290 -64
  35. flyte/_internal/controllers/remote/_core.py +155 -95
  36. flyte/_internal/controllers/remote/_informer.py +40 -20
  37. flyte/_internal/controllers/remote/_service_protocol.py +2 -2
  38. flyte/_internal/imagebuild/__init__.py +2 -10
  39. flyte/_internal/imagebuild/docker_builder.py +391 -84
  40. flyte/_internal/imagebuild/image_builder.py +111 -55
  41. flyte/_internal/imagebuild/remote_builder.py +409 -0
  42. flyte/_internal/imagebuild/utils.py +79 -0
  43. flyte/_internal/resolvers/_app_env_module.py +92 -0
  44. flyte/_internal/resolvers/_task_module.py +5 -38
  45. flyte/_internal/resolvers/app_env.py +26 -0
  46. flyte/_internal/resolvers/common.py +8 -1
  47. flyte/_internal/resolvers/default.py +2 -2
  48. flyte/_internal/runtime/convert.py +319 -36
  49. flyte/_internal/runtime/entrypoints.py +106 -18
  50. flyte/_internal/runtime/io.py +71 -23
  51. flyte/_internal/runtime/resources_serde.py +21 -7
  52. flyte/_internal/runtime/reuse.py +125 -0
  53. flyte/_internal/runtime/rusty.py +196 -0
  54. flyte/_internal/runtime/task_serde.py +239 -66
  55. flyte/_internal/runtime/taskrunner.py +48 -8
  56. flyte/_internal/runtime/trigger_serde.py +162 -0
  57. flyte/_internal/runtime/types_serde.py +7 -16
  58. flyte/_keyring/file.py +115 -0
  59. flyte/_link.py +30 -0
  60. flyte/_logging.py +241 -42
  61. flyte/_map.py +312 -0
  62. flyte/_metrics.py +59 -0
  63. flyte/_module.py +74 -0
  64. flyte/_pod.py +30 -0
  65. flyte/_resources.py +296 -33
  66. flyte/_retry.py +1 -7
  67. flyte/_reusable_environment.py +72 -7
  68. flyte/_run.py +462 -132
  69. flyte/_secret.py +47 -11
  70. flyte/_serve.py +333 -0
  71. flyte/_task.py +245 -56
  72. flyte/_task_environment.py +219 -97
  73. flyte/_task_plugins.py +47 -0
  74. flyte/_tools.py +8 -8
  75. flyte/_trace.py +15 -24
  76. flyte/_trigger.py +1027 -0
  77. flyte/_utils/__init__.py +12 -1
  78. flyte/_utils/asyn.py +3 -1
  79. flyte/_utils/async_cache.py +139 -0
  80. flyte/_utils/coro_management.py +5 -4
  81. flyte/_utils/description_parser.py +19 -0
  82. flyte/_utils/docker_credentials.py +173 -0
  83. flyte/_utils/helpers.py +45 -19
  84. flyte/_utils/module_loader.py +123 -0
  85. flyte/_utils/org_discovery.py +57 -0
  86. flyte/_utils/uv_script_parser.py +8 -1
  87. flyte/_version.py +16 -3
  88. flyte/app/__init__.py +27 -0
  89. flyte/app/_app_environment.py +362 -0
  90. flyte/app/_connector_environment.py +40 -0
  91. flyte/app/_deploy.py +130 -0
  92. flyte/app/_parameter.py +343 -0
  93. flyte/app/_runtime/__init__.py +3 -0
  94. flyte/app/_runtime/app_serde.py +383 -0
  95. flyte/app/_types.py +113 -0
  96. flyte/app/extras/__init__.py +9 -0
  97. flyte/app/extras/_auth_middleware.py +217 -0
  98. flyte/app/extras/_fastapi.py +93 -0
  99. flyte/app/extras/_model_loader/__init__.py +3 -0
  100. flyte/app/extras/_model_loader/config.py +7 -0
  101. flyte/app/extras/_model_loader/loader.py +288 -0
  102. flyte/cli/__init__.py +12 -0
  103. flyte/cli/_abort.py +28 -0
  104. flyte/cli/_build.py +114 -0
  105. flyte/cli/_common.py +493 -0
  106. flyte/cli/_create.py +371 -0
  107. flyte/cli/_delete.py +45 -0
  108. flyte/cli/_deploy.py +401 -0
  109. flyte/cli/_gen.py +316 -0
  110. flyte/cli/_get.py +446 -0
  111. flyte/cli/_option.py +33 -0
  112. flyte/{_cli → cli}/_params.py +57 -17
  113. flyte/cli/_plugins.py +209 -0
  114. flyte/cli/_prefetch.py +292 -0
  115. flyte/cli/_run.py +690 -0
  116. flyte/cli/_serve.py +338 -0
  117. flyte/cli/_update.py +86 -0
  118. flyte/cli/_user.py +20 -0
  119. flyte/cli/main.py +246 -0
  120. flyte/config/__init__.py +2 -167
  121. flyte/config/_config.py +215 -163
  122. flyte/config/_internal.py +10 -1
  123. flyte/config/_reader.py +225 -0
  124. flyte/connectors/__init__.py +11 -0
  125. flyte/connectors/_connector.py +330 -0
  126. flyte/connectors/_server.py +194 -0
  127. flyte/connectors/utils.py +159 -0
  128. flyte/errors.py +134 -2
  129. flyte/extend.py +24 -0
  130. flyte/extras/_container.py +69 -56
  131. flyte/git/__init__.py +3 -0
  132. flyte/git/_config.py +279 -0
  133. flyte/io/__init__.py +8 -1
  134. flyte/io/{structured_dataset → _dataframe}/__init__.py +32 -30
  135. flyte/io/{structured_dataset → _dataframe}/basic_dfs.py +75 -68
  136. flyte/io/{structured_dataset/structured_dataset.py → _dataframe/dataframe.py} +207 -242
  137. flyte/io/_dir.py +575 -113
  138. flyte/io/_file.py +587 -141
  139. flyte/io/_hashing_io.py +342 -0
  140. flyte/io/extend.py +7 -0
  141. flyte/models.py +635 -0
  142. flyte/prefetch/__init__.py +22 -0
  143. flyte/prefetch/_hf_model.py +563 -0
  144. flyte/remote/__init__.py +14 -3
  145. flyte/remote/_action.py +879 -0
  146. flyte/remote/_app.py +346 -0
  147. flyte/remote/_auth_metadata.py +42 -0
  148. flyte/remote/_client/_protocols.py +62 -4
  149. flyte/remote/_client/auth/_auth_utils.py +19 -0
  150. flyte/remote/_client/auth/_authenticators/base.py +8 -2
  151. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  152. flyte/remote/_client/auth/_authenticators/factory.py +4 -0
  153. flyte/remote/_client/auth/_authenticators/passthrough.py +79 -0
  154. flyte/remote/_client/auth/_authenticators/pkce.py +17 -18
  155. flyte/remote/_client/auth/_channel.py +47 -18
  156. flyte/remote/_client/auth/_client_config.py +5 -3
  157. flyte/remote/_client/auth/_keyring.py +15 -2
  158. flyte/remote/_client/auth/_token_client.py +3 -3
  159. flyte/remote/_client/controlplane.py +206 -18
  160. flyte/remote/_common.py +66 -0
  161. flyte/remote/_data.py +107 -22
  162. flyte/remote/_logs.py +116 -33
  163. flyte/remote/_project.py +21 -19
  164. flyte/remote/_run.py +164 -631
  165. flyte/remote/_secret.py +72 -29
  166. flyte/remote/_task.py +387 -46
  167. flyte/remote/_trigger.py +368 -0
  168. flyte/remote/_user.py +43 -0
  169. flyte/report/_report.py +10 -6
  170. flyte/storage/__init__.py +13 -1
  171. flyte/storage/_config.py +237 -0
  172. flyte/storage/_parallel_reader.py +289 -0
  173. flyte/storage/_storage.py +268 -59
  174. flyte/syncify/__init__.py +56 -0
  175. flyte/syncify/_api.py +414 -0
  176. flyte/types/__init__.py +39 -0
  177. flyte/types/_interface.py +22 -7
  178. flyte/{io/pickle/transformer.py → types/_pickle.py} +37 -9
  179. flyte/types/_string_literals.py +8 -9
  180. flyte/types/_type_engine.py +226 -126
  181. flyte/types/_utils.py +1 -1
  182. flyte-2.0.0b46.data/scripts/debug.py +38 -0
  183. flyte-2.0.0b46.data/scripts/runtime.py +194 -0
  184. flyte-2.0.0b46.dist-info/METADATA +352 -0
  185. flyte-2.0.0b46.dist-info/RECORD +221 -0
  186. flyte-2.0.0b46.dist-info/entry_points.txt +8 -0
  187. flyte-2.0.0b46.dist-info/licenses/LICENSE +201 -0
  188. flyte/_api_commons.py +0 -3
  189. flyte/_cli/_common.py +0 -299
  190. flyte/_cli/_create.py +0 -42
  191. flyte/_cli/_delete.py +0 -23
  192. flyte/_cli/_deploy.py +0 -140
  193. flyte/_cli/_get.py +0 -235
  194. flyte/_cli/_run.py +0 -174
  195. flyte/_cli/main.py +0 -98
  196. flyte/_datastructures.py +0 -342
  197. flyte/_internal/controllers/pbhash.py +0 -39
  198. flyte/_protos/common/authorization_pb2.py +0 -66
  199. flyte/_protos/common/authorization_pb2.pyi +0 -108
  200. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  201. flyte/_protos/common/identifier_pb2.py +0 -71
  202. flyte/_protos/common/identifier_pb2.pyi +0 -82
  203. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  204. flyte/_protos/common/identity_pb2.py +0 -48
  205. flyte/_protos/common/identity_pb2.pyi +0 -72
  206. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  207. flyte/_protos/common/list_pb2.py +0 -36
  208. flyte/_protos/common/list_pb2.pyi +0 -69
  209. flyte/_protos/common/list_pb2_grpc.py +0 -4
  210. flyte/_protos/common/policy_pb2.py +0 -37
  211. flyte/_protos/common/policy_pb2.pyi +0 -27
  212. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  213. flyte/_protos/common/role_pb2.py +0 -37
  214. flyte/_protos/common/role_pb2.pyi +0 -53
  215. flyte/_protos/common/role_pb2_grpc.py +0 -4
  216. flyte/_protos/common/runtime_version_pb2.py +0 -28
  217. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  218. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  219. flyte/_protos/logs/dataplane/payload_pb2.py +0 -96
  220. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -168
  221. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  222. flyte/_protos/secret/definition_pb2.py +0 -49
  223. flyte/_protos/secret/definition_pb2.pyi +0 -93
  224. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  225. flyte/_protos/secret/payload_pb2.py +0 -62
  226. flyte/_protos/secret/payload_pb2.pyi +0 -94
  227. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  228. flyte/_protos/secret/secret_pb2.py +0 -38
  229. flyte/_protos/secret/secret_pb2.pyi +0 -6
  230. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  231. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  232. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  233. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  234. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  235. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  236. flyte/_protos/workflow/queue_service_pb2.py +0 -106
  237. flyte/_protos/workflow/queue_service_pb2.pyi +0 -141
  238. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  239. flyte/_protos/workflow/run_definition_pb2.py +0 -128
  240. flyte/_protos/workflow/run_definition_pb2.pyi +0 -310
  241. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  242. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  243. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  244. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  245. flyte/_protos/workflow/run_service_pb2.py +0 -133
  246. flyte/_protos/workflow/run_service_pb2.pyi +0 -175
  247. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -412
  248. flyte/_protos/workflow/state_service_pb2.py +0 -58
  249. flyte/_protos/workflow/state_service_pb2.pyi +0 -71
  250. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  251. flyte/_protos/workflow/task_definition_pb2.py +0 -72
  252. flyte/_protos/workflow/task_definition_pb2.pyi +0 -65
  253. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  254. flyte/_protos/workflow/task_service_pb2.py +0 -44
  255. flyte/_protos/workflow/task_service_pb2.pyi +0 -31
  256. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -104
  257. flyte/io/_dataframe.py +0 -0
  258. flyte/io/pickle/__init__.py +0 -0
  259. flyte/remote/_console.py +0 -18
  260. flyte-0.2.0b1.dist-info/METADATA +0 -179
  261. flyte-0.2.0b1.dist-info/RECORD +0 -204
  262. flyte-0.2.0b1.dist-info/entry_points.txt +0 -3
  263. /flyte/{_cli → _debug}/__init__.py +0 -0
  264. /flyte/{_protos → _keyring}/__init__.py +0 -0
  265. {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/WHEEL +0 -0
  266. {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/top_level.txt +0 -0
@@ -1,46 +1,62 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import concurrent.futures
5
+ import os
6
+ import threading
4
7
  from collections import defaultdict
8
+ from collections.abc import Callable
5
9
  from pathlib import Path
6
10
  from typing import Any, Awaitable, DefaultDict, Tuple, TypeVar
7
11
 
12
+ from flyteidl2.common import identifier_pb2, phase_pb2
13
+
8
14
  import flyte
9
15
  import flyte.errors
10
16
  import flyte.storage as storage
11
- import flyte.types as types
12
17
  from flyte._code_bundle import build_pkl_bundle
13
18
  from flyte._context import internal_ctx
14
- from flyte._datastructures import ActionID, NativeInterface, SerializationContext
15
19
  from flyte._internal.controllers import TraceInfo
16
20
  from flyte._internal.controllers.remote._action import Action
17
21
  from flyte._internal.controllers.remote._core import Controller
18
22
  from flyte._internal.controllers.remote._service_protocol import ClientSet
19
23
  from flyte._internal.runtime import convert, io
20
24
  from flyte._internal.runtime.task_serde import translate_task_to_wire
25
+ from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
21
26
  from flyte._logging import logger
22
- from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
27
+ from flyte._metrics import Stopwatch
23
28
  from flyte._task import TaskTemplate
29
+ from flyte._utils.helpers import _selector_policy
30
+ from flyte.models import MAX_INLINE_IO_BYTES, ActionID, NativeInterface, SerializationContext
31
+ from flyte.remote._task import TaskDetails
24
32
 
25
33
  R = TypeVar("R")
26
34
 
35
+ MAX_TRACE_BYTES = MAX_INLINE_IO_BYTES
36
+
27
37
 
28
- async def upload_inputs_with_retry(inputs: convert.Inputs, inputs_uri: str) -> None:
38
+ async def upload_inputs_with_retry(serialized_inputs: bytes, inputs_uri: str, max_bytes: int) -> None:
29
39
  """
30
40
  Upload inputs to the specified URI with error handling.
31
41
 
32
42
  Args:
33
- inputs: The inputs to upload
43
+ serialized_inputs: The serialized inputs to upload
34
44
  inputs_uri: The destination URI
45
+ max_bytes: Maximum number of bytes to read from the input stream
35
46
 
36
47
  Raises:
37
48
  RuntimeSystemError: If the upload fails
38
49
  """
50
+ if len(serialized_inputs) > max_bytes:
51
+ raise flyte.errors.InlineIOMaxBytesBreached(
52
+ f"Inputs exceed max_bytes limit of {max_bytes / 1024 / 1024} MB,"
53
+ f" actual size: {len(serialized_inputs) / 1024 / 1024} MB"
54
+ )
39
55
  try:
40
56
  # TODO Add retry decorator to this
41
- await io.upload_inputs(inputs, inputs_uri)
57
+ await storage.put_stream(serialized_inputs, to_path=inputs_uri)
42
58
  except Exception as e:
43
- logger.exception("Failed to upload inputs", e)
59
+ logger.exception("Failed to upload inputs")
44
60
  raise flyte.errors.RuntimeSystemError(type(e).__name__, str(e)) from e
45
61
 
46
62
 
@@ -56,7 +72,7 @@ async def handle_action_failure(action: Action, task_name: str) -> Exception:
56
72
  Exception: The converted native exception or RuntimeSystemError
57
73
  """
58
74
  err = action.err or action.client_err
59
- if not err and action.phase == run_definition_pb2.PHASE_FAILED:
75
+ if not err and action.phase == phase_pb2.ACTION_PHASE_FAILED:
60
76
  logger.error(f"Server reported failure for action {action.name}, checking error file.")
61
77
  try:
62
78
  error_path = io.error_path(f"{action.run_output_base}/{action.action_id.name}/1")
@@ -73,22 +89,27 @@ async def handle_action_failure(action: Action, task_name: str) -> Exception:
73
89
  return exc
74
90
 
75
91
 
76
- async def load_and_convert_outputs(iface: NativeInterface, realized_outputs_uri: str) -> Any:
92
+ async def load_and_convert_outputs(iface: NativeInterface, realized_outputs_uri: str, max_bytes: int) -> Any:
77
93
  """
78
94
  Load outputs from the given URI and convert them to native format.
79
95
 
80
96
  Args:
81
97
  iface: The Native interface
82
98
  realized_outputs_uri: The URI where outputs are stored
99
+ max_bytes: Maximum number of bytes to read from the output file
83
100
 
84
101
  Returns:
85
102
  The converted native outputs
86
103
  """
87
104
  outputs_file_path = io.outputs_path(realized_outputs_uri)
88
- outputs = await io.load_outputs(outputs_file_path)
105
+ outputs = await io.load_outputs(outputs_file_path, max_bytes=max_bytes)
89
106
  return await convert.convert_outputs_to_native(iface, outputs)
90
107
 
91
108
 
109
+ def unique_action_name(action_id: ActionID) -> str:
110
+ return f"{action_id.name}_{action_id.run_name}"
111
+
112
+
92
113
  class RemoteController(Controller):
93
114
  """
94
115
  This a specialized controller that wraps the core controller and performs IO, serialization and deserialization
@@ -97,9 +118,8 @@ class RemoteController(Controller):
97
118
  def __init__(
98
119
  self,
99
120
  client_coro: Awaitable[ClientSet],
100
- workers: int,
101
- max_system_retries: int,
102
- default_parent_concurrency: int = 100,
121
+ workers: int = 20,
122
+ max_system_retries: int = 10,
103
123
  ):
104
124
  """ """
105
125
  super().__init__(
@@ -107,35 +127,56 @@ class RemoteController(Controller):
107
127
  workers=workers,
108
128
  max_system_retries=max_system_retries,
109
129
  )
130
+ default_parent_concurrency = int(os.getenv("_F_P_CNC", "1000"))
110
131
  self._default_parent_concurrency = default_parent_concurrency
111
132
  self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
112
133
  lambda: asyncio.Semaphore(default_parent_concurrency)
113
134
  )
135
+ self._parent_action_task_call_sequence: DefaultDict[str, DefaultDict[int, int]] = defaultdict(
136
+ lambda: defaultdict(int)
137
+ )
138
+ self._submit_loop: asyncio.AbstractEventLoop | None = None
139
+ self._submit_thread: threading.Thread | None = None
114
140
 
115
- async def _submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
141
+ def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> int:
142
+ """
143
+ Generate a task call sequence for the given task object and action ID.
144
+ This is used to track the number of times a task is called within an action.
145
+ """
146
+ uniq = unique_action_name(action_id)
147
+ current_action_sequencer = self._parent_action_task_call_sequence[uniq]
148
+ current_task_id = id(task_obj)
149
+ v = current_action_sequencer[current_task_id]
150
+ new_seq = v + 1
151
+ current_action_sequencer[current_task_id] = new_seq
152
+ name = ""
153
+ if hasattr(task_obj, "__name__"):
154
+ name = task_obj.__name__
155
+ elif hasattr(task_obj, "name"):
156
+ name = task_obj.name
157
+ logger.info(f"For action {uniq}, task {name} call sequence is {new_seq}")
158
+ return new_seq
159
+
160
+ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwargs) -> Any:
116
161
  ctx = internal_ctx()
117
162
  tctx = ctx.data.task_context
118
163
  if tctx is None:
119
164
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
120
165
  current_action_id = tctx.action
121
166
 
122
- inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
123
- sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _task.name, inputs)
124
-
125
167
  # In the case of a regular code bundle, we will just pass it down as it is to the downstream tasks
126
168
  # It is not allowed to change the code bundle (for regular code bundles) in the middle of a run.
127
169
  code_bundle = tctx.code_bundle
128
170
 
129
- if code_bundle and code_bundle.pkl:
130
- logger.debug(f"Building new pkl bundle for task {sub_action_id.name}")
171
+ if tctx.interactive_mode or (code_bundle and code_bundle.pkl):
172
+ logger.debug(f"Building new pkl bundle for task {_task.name}")
131
173
  code_bundle = await build_pkl_bundle(
132
174
  _task,
133
175
  upload_to_controlplane=False,
134
- upload_from_dataplane_path=io.pkl_path(sub_action_output_path),
176
+ upload_from_dataplane_base_path=tctx.run_base_dir,
135
177
  )
136
178
 
137
- inputs_uri = io.inputs_path(sub_action_output_path)
138
- await upload_inputs_with_retry(inputs, inputs_uri)
179
+ inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
139
180
 
140
181
  root_dir = Path(code_bundle.destination).absolute() if code_bundle else Path.cwd()
141
182
  # Don't set output path in sec context because node executor will set it
@@ -146,17 +187,46 @@ class RemoteController(Controller):
146
187
  code_bundle=code_bundle,
147
188
  version=tctx.version,
148
189
  # supplied version.
149
- input_path=inputs_uri,
190
+ # input_path=inputs_uri,
150
191
  image_cache=tctx.compiled_image_cache,
151
192
  root_dir=root_dir,
152
193
  )
153
194
 
154
195
  task_spec = translate_task_to_wire(_task, new_serialization_context)
196
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
197
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
198
+ tctx, task_spec, inputs_hash, _task_call_seq
199
+ )
200
+ logger.info(f"Sub action {sub_action_id} output path {sub_action_output_path}")
201
+
202
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
203
+ inputs_uri = io.inputs_path(sub_action_output_path)
204
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri, max_bytes=_task.max_inline_io_bytes)
205
+
206
+ md = task_spec.task_template.metadata
207
+ ignored_input_vars = []
208
+ if len(md.cache_ignore_input_vars) > 0:
209
+ ignored_input_vars = list(md.cache_ignore_input_vars)
210
+ cache_key = None
211
+ if task_spec.task_template.metadata and task_spec.task_template.metadata.discoverable:
212
+ discovery_version = task_spec.task_template.metadata.discovery_version
213
+ cache_key = convert.generate_cache_key_hash(
214
+ _task.name,
215
+ inputs_hash,
216
+ task_spec.task_template.interface,
217
+ discovery_version,
218
+ ignored_input_vars,
219
+ inputs.proto_inputs,
220
+ )
221
+
222
+ # Clear to free memory
223
+ serialized_inputs = None # type: ignore
224
+ inputs_hash = None # type: ignore
155
225
 
156
226
  action = Action.from_task(
157
- sub_action_id=run_definition_pb2.ActionIdentifier(
227
+ sub_action_id=identifier_pb2.ActionIdentifier(
158
228
  name=sub_action_id.name,
159
- run=run_definition_pb2.RunIdentifier(
229
+ run=identifier_pb2.RunIdentifier(
160
230
  name=current_action_id.run_name,
161
231
  project=current_action_id.project,
162
232
  domain=current_action_id.domain,
@@ -168,6 +238,8 @@ class RemoteController(Controller):
168
238
  task_spec=task_spec,
169
239
  inputs_uri=inputs_uri,
170
240
  run_output_base=tctx.run_base_dir,
241
+ cache_key=cache_key,
242
+ queue=_task.queue,
171
243
  )
172
244
 
173
245
  try:
@@ -183,7 +255,22 @@ class RemoteController(Controller):
183
255
  await self.cancel_action(action)
184
256
  raise
185
257
 
186
- if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED:
258
+ # If the action is aborted, we should abort the controller as well
259
+ if n.phase == phase_pb2.ACTION_PHASE_ABORTED:
260
+ logger.warning(f"Action {n.action_id.name} was aborted, aborting current Action{current_action_id.name}")
261
+ raise flyte.errors.RunAbortedError(
262
+ f"Action {n.action_id.name} was aborted, aborting current Action {current_action_id.name}"
263
+ )
264
+
265
+ if n.phase == phase_pb2.ACTION_PHASE_TIMED_OUT:
266
+ logger.warning(
267
+ f"Action {n.action_id.name} timed out, raising timeout exception Action {current_action_id.name}"
268
+ )
269
+ raise flyte.errors.TaskTimeoutError(
270
+ f"Action {n.action_id.name} timed out, raising exception in current Action {current_action_id.name}"
271
+ )
272
+
273
+ if n.has_error() or n.phase == phase_pb2.ACTION_PHASE_FAILED:
187
274
  exc = await handle_action_failure(action, _task.name)
188
275
  raise exc
189
276
 
@@ -193,7 +280,9 @@ class RemoteController(Controller):
193
280
  "RuntimeError",
194
281
  f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
195
282
  )
196
- return await load_and_convert_outputs(_task.native_interface, n.realized_outputs_uri)
283
+ return await load_and_convert_outputs(
284
+ _task.native_interface, n.realized_outputs_uri, max_bytes=_task.max_inline_io_bytes
285
+ )
197
286
  return None
198
287
 
199
288
  async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
@@ -205,31 +294,81 @@ class RemoteController(Controller):
205
294
  if tctx is None:
206
295
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
207
296
  current_action_id = tctx.action
208
- async with self._parent_action_semaphore[current_action_id.name]:
209
- return await self._submit(_task, *args, **kwargs)
297
+ task_call_seq = self.generate_task_call_sequence(_task, current_action_id)
298
+ async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
299
+ sw = Stopwatch(f"controller-submit-{unique_action_name(current_action_id)}")
300
+ sw.start()
301
+ result = await self._submit(task_call_seq, _task, *args, **kwargs)
302
+ sw.stop()
303
+ return result
304
+
305
+ def _sync_thread_loop_runner(self) -> None:
306
+ """This method runs the event loop and should be invoked in a separate thread."""
307
+
308
+ loop = self._submit_loop
309
+ assert loop is not None
310
+ try:
311
+ loop.run_forever()
312
+ finally:
313
+ loop.close()
314
+
315
+ def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
316
+ """
317
+ This function creates a cached thread and loop for the purpose of calling the submit method synchronously,
318
+ returning a concurrent Future that can be awaited. There's no need for a lock because this function itself is
319
+ single threaded and non-async. This pattern here is basically the trivial/degenerate case of the thread pool
320
+ in the LocalController.
321
+ Please see additional comments in protocol.
322
+
323
+ :param _task:
324
+ :param args:
325
+ :param kwargs:
326
+ :return:
327
+ """
328
+ if self._submit_thread is None:
329
+ # Please see LocalController for the general implementation of this pattern.
330
+ def exc_handler(loop, context):
331
+ logger.error(f"Remote controller submit sync loop caught exception in {loop}: {context}")
332
+
333
+ with _selector_policy():
334
+ self._submit_loop = asyncio.new_event_loop()
335
+ self._submit_loop.set_exception_handler(exc_handler)
336
+
337
+ self._submit_thread = threading.Thread(
338
+ name=f"remote-controller-{os.getpid()}-submitter",
339
+ daemon=True,
340
+ target=self._sync_thread_loop_runner,
341
+ )
342
+ self._submit_thread.start()
343
+
344
+ coro = self.submit(_task, *args, **kwargs)
345
+ assert self._submit_loop is not None, "Submit loop should always have been initialized by now"
346
+ fut = asyncio.run_coroutine_threadsafe(coro, self._submit_loop)
347
+ return fut
210
348
 
211
349
  async def finalize_parent_action(self, action_id: ActionID):
212
350
  """
213
351
  This method is invoked when the parent action is finished. It will finalize the run and upload the outputs
214
352
  to the control plane.
215
353
  """
216
- run_id = run_definition_pb2.RunIdentifier(
354
+ run_id = identifier_pb2.RunIdentifier(
217
355
  name=action_id.run_name,
218
356
  project=action_id.project,
219
357
  domain=action_id.domain,
220
358
  org=action_id.org,
221
359
  )
222
360
  await super()._finalize_parent_action(run_id=run_id, parent_action_name=action_id.name)
223
- self._parent_action_semaphore.pop(action_id.name, None)
361
+ self._parent_action_semaphore.pop(unique_action_name(action_id), None)
362
+ self._parent_action_task_call_sequence.pop(unique_action_name(action_id), None)
224
363
 
225
364
  async def get_action_outputs(
226
- self, _interface: NativeInterface, _func_name: str, *args, **kwargs
365
+ self, _interface: NativeInterface, _func: Callable, *args, **kwargs
227
366
  ) -> Tuple[TraceInfo, bool]:
228
367
  """
229
368
  This method returns the outputs of the action, if it is available.
230
369
  If not available it raises a NotFoundError.
231
370
  :param _interface: NativeInterface
232
- :param _func_name: Function name
371
+ :param _func: Function name
233
372
  :param args: Arguments
234
373
  :param kwargs: Keyword arguments
235
374
  :return:
@@ -240,16 +379,26 @@ class RemoteController(Controller):
240
379
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
241
380
  current_action_id = tctx.action
242
381
 
382
+ func_name = _func.__name__
383
+ invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id)
384
+
243
385
  inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
244
- sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _func_name, inputs)
386
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
387
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
388
+
389
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
390
+ tctx, func_name, inputs_hash, invoke_seq_num
391
+ )
245
392
 
246
393
  inputs_uri = io.inputs_path(sub_action_output_path)
247
- await upload_inputs_with_retry(inputs, inputs_uri)
394
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri, max_bytes=MAX_TRACE_BYTES)
395
+ # Clear to free memory
396
+ serialized_inputs = None # type: ignore
248
397
 
249
398
  prev_action = await self.get_action(
250
- run_definition_pb2.ActionIdentifier(
399
+ identifier_pb2.ActionIdentifier(
251
400
  name=sub_action_id.name,
252
- run=run_definition_pb2.RunIdentifier(
401
+ run=identifier_pb2.RunIdentifier(
253
402
  name=current_action_id.run_name,
254
403
  project=current_action_id.project,
255
404
  domain=current_action_id.domain,
@@ -260,21 +409,26 @@ class RemoteController(Controller):
260
409
  )
261
410
 
262
411
  if prev_action is None:
263
- return TraceInfo(sub_action_id, _interface, inputs_uri), False
412
+ return TraceInfo(func_name, sub_action_id, _interface, inputs_uri), False
264
413
 
265
- if prev_action.phase == run_definition_pb2.PHASE_FAILED:
414
+ if prev_action.phase == phase_pb2.ACTION_PHASE_FAILED:
266
415
  if prev_action.has_error():
267
416
  exc = convert.convert_error_to_native(prev_action.err)
268
- return TraceInfo(sub_action_id, _interface, inputs_uri, error=exc), True
417
+ return (
418
+ TraceInfo(func_name, sub_action_id, _interface, inputs_uri, error=exc),
419
+ True,
420
+ )
269
421
  else:
270
422
  logger.warning(f"Action {prev_action.action_id.name} failed, but no error was found, re-running trace!")
271
423
  elif prev_action.realized_outputs_uri is not None:
272
- outputs_file_path = io.outputs_path(prev_action.realized_outputs_uri)
273
- o = await io.load_outputs(outputs_file_path)
424
+ o = await io.load_outputs(prev_action.realized_outputs_uri, max_bytes=MAX_TRACE_BYTES)
274
425
  outputs = await convert.convert_outputs_to_native(_interface, o)
275
- return TraceInfo(sub_action_id, _interface, inputs_uri, output=outputs), True
426
+ return (
427
+ TraceInfo(func_name, sub_action_id, _interface, inputs_uri, output=outputs),
428
+ True,
429
+ )
276
430
 
277
- return TraceInfo(sub_action_id, _interface, inputs_uri), False
431
+ return TraceInfo(func_name, sub_action_id, _interface, inputs_uri), False
278
432
 
279
433
  async def record_trace(self, info: TraceInfo):
280
434
  """
@@ -287,40 +441,100 @@ class RemoteController(Controller):
287
441
  if tctx is None:
288
442
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
289
443
 
290
- current_output_path = tctx.output_path
291
- sub_run_output_path = storage.join(current_output_path, info.action.name)
444
+ current_action_id = tctx.action
445
+ sub_run_output_path = storage.join(tctx.run_base_dir, info.action.name)
446
+ outputs_file_path: str = ""
292
447
 
293
448
  if info.interface.has_outputs():
294
- if info.output:
295
- outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
296
- outputs_file_path = io.outputs_path(sub_run_output_path)
297
- await io.upload_outputs(outputs, outputs_file_path)
298
- elif info.error:
449
+ if info.error:
299
450
  err = convert.convert_from_native_to_error(info.error)
300
- error_path = io.error_path(sub_run_output_path)
301
- await io.upload_error(err.err, error_path)
451
+ await io.upload_error(err.err, sub_run_output_path)
302
452
  else:
303
- raise flyte.errors.RuntimeSystemError("BadTraceInfo", "Trace info does not have output or error")
453
+ outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
454
+ outputs_file_path = io.outputs_path(sub_run_output_path)
455
+ await io.upload_outputs(outputs, sub_run_output_path, max_bytes=MAX_TRACE_BYTES)
456
+
457
+ typed_interface = transform_native_to_typed_interface(info.interface)
458
+
459
+ trace_action = Action.from_trace(
460
+ parent_action_name=current_action_id.name,
461
+ action_id=identifier_pb2.ActionIdentifier(
462
+ name=info.action.name,
463
+ run=identifier_pb2.RunIdentifier(
464
+ name=current_action_id.run_name,
465
+ project=current_action_id.project,
466
+ domain=current_action_id.domain,
467
+ org=current_action_id.org,
468
+ ),
469
+ ),
470
+ inputs_uri=info.inputs_path,
471
+ outputs_uri=outputs_file_path,
472
+ friendly_name=info.name,
473
+ group_data=tctx.group_data,
474
+ run_output_base=tctx.run_base_dir,
475
+ start_time=info.start_time,
476
+ end_time=info.end_time,
477
+ typed_interface=typed_interface if typed_interface else None,
478
+ )
304
479
 
305
- async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
480
+ async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
481
+ try:
482
+ logger.info(
483
+ f"Submitting Trace action Run:[{trace_action.run_name},"
484
+ f" Parent:[{trace_action.parent_action_name}],"
485
+ f" Trace fn:[{info.name}], action:[{info.action.name}]"
486
+ )
487
+ await self.submit_action(trace_action)
488
+ logger.info(f"Trace Action for [{info.name}] action id: {info.action.name}, completed!")
489
+ except asyncio.CancelledError:
490
+ # If the action is cancelled, we need to cancel the action on the server as well
491
+ raise
492
+
493
+ async def _submit_task_ref(self, invoke_seq_num: int, _task: TaskDetails, *args, **kwargs) -> Any:
306
494
  ctx = internal_ctx()
307
495
  tctx = ctx.data.task_context
308
496
  if tctx is None:
309
497
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
310
498
  current_action_id = tctx.action
311
- task_name = _task.spec.task_template.id.name
499
+ task_name = _task.name
500
+
501
+ native_interface = _task.interface
502
+ pb_interface = _task.pb2.spec.task_template.interface
312
503
 
313
- native_interface = types.guess_interface(_task.spec.task_template.interface)
314
504
  inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
315
- sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, task_name, inputs)
505
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
506
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
507
+ tctx, task_name, inputs_hash, invoke_seq_num
508
+ )
316
509
 
510
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
317
511
  inputs_uri = io.inputs_path(sub_action_output_path)
318
- await upload_inputs_with_retry(inputs, inputs_uri)
512
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri, _task.max_inline_io_bytes)
513
+ # cache key - task name, task signature, inputs, cache version
514
+ cache_key = None
515
+ md = _task.pb2.spec.task_template.metadata
516
+ ignored_input_vars = []
517
+ if len(md.cache_ignore_input_vars) > 0:
518
+ ignored_input_vars = list(md.cache_ignore_input_vars)
519
+ if md and md.discoverable:
520
+ discovery_version = md.discovery_version
521
+ cache_key = convert.generate_cache_key_hash(
522
+ task_name,
523
+ inputs_hash,
524
+ pb_interface,
525
+ discovery_version,
526
+ ignored_input_vars,
527
+ inputs.proto_inputs,
528
+ )
529
+
530
+ # Clear to free memory
531
+ serialized_inputs = None # type: ignore
532
+ inputs_hash = None # type: ignore
319
533
 
320
534
  action = Action.from_task(
321
- sub_action_id=run_definition_pb2.ActionIdentifier(
535
+ sub_action_id=identifier_pb2.ActionIdentifier(
322
536
  name=sub_action_id.name,
323
- run=run_definition_pb2.RunIdentifier(
537
+ run=identifier_pb2.RunIdentifier(
324
538
  name=current_action_id.run_name,
325
539
  project=current_action_id.project,
326
540
  domain=current_action_id.domain,
@@ -329,9 +543,11 @@ class RemoteController(Controller):
329
543
  ),
330
544
  parent_action_name=current_action_id.name,
331
545
  group_data=tctx.group_data,
332
- task_spec=_task.spec,
546
+ task_spec=_task.pb2.spec,
333
547
  inputs_uri=inputs_uri,
334
548
  run_output_base=tctx.run_base_dir,
549
+ cache_key=cache_key,
550
+ queue=None,
335
551
  )
336
552
 
337
553
  try:
@@ -347,7 +563,7 @@ class RemoteController(Controller):
347
563
  await self.cancel_action(action)
348
564
  raise
349
565
 
350
- if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED:
566
+ if n.has_error() or n.phase == phase_pb2.ACTION_PHASE_FAILED:
351
567
  exc = await handle_action_failure(action, task_name)
352
568
  raise exc
353
569
 
@@ -357,5 +573,15 @@ class RemoteController(Controller):
357
573
  "RuntimeError",
358
574
  f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.",
359
575
  )
360
- return await load_and_convert_outputs(native_interface, n.realized_outputs_uri)
576
+ return await load_and_convert_outputs(native_interface, n.realized_outputs_uri, _task.max_inline_io_bytes)
361
577
  return None
578
+
579
+ async def submit_task_ref(self, _task: TaskDetails, *args, **kwargs) -> Any:
580
+ ctx = internal_ctx()
581
+ tctx = ctx.data.task_context
582
+ if tctx is None:
583
+ raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
584
+ current_action_id = tctx.action
585
+ task_call_seq = self.generate_task_call_sequence(_task, current_action_id)
586
+ async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
587
+ return await self._submit_task_ref(task_call_seq, _task, *args, **kwargs)