flyte 0.2.0b1__py3-none-any.whl → 0.2.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (57) hide show
  1. flyte/__init__.py +3 -4
  2. flyte/_bin/runtime.py +21 -7
  3. flyte/_cache/cache.py +1 -2
  4. flyte/_cli/_common.py +26 -4
  5. flyte/_cli/_create.py +48 -0
  6. flyte/_cli/_deploy.py +4 -2
  7. flyte/_cli/_get.py +18 -7
  8. flyte/_cli/_run.py +1 -0
  9. flyte/_cli/main.py +11 -5
  10. flyte/_code_bundle/bundle.py +42 -11
  11. flyte/_context.py +1 -1
  12. flyte/_deploy.py +3 -1
  13. flyte/_group.py +1 -1
  14. flyte/_initialize.py +28 -247
  15. flyte/_internal/controllers/__init__.py +6 -6
  16. flyte/_internal/controllers/_local_controller.py +14 -5
  17. flyte/_internal/controllers/_trace.py +1 -1
  18. flyte/_internal/controllers/remote/__init__.py +27 -7
  19. flyte/_internal/controllers/remote/_action.py +1 -1
  20. flyte/_internal/controllers/remote/_client.py +5 -1
  21. flyte/_internal/controllers/remote/_controller.py +68 -24
  22. flyte/_internal/controllers/remote/_core.py +1 -1
  23. flyte/_internal/runtime/convert.py +34 -8
  24. flyte/_internal/runtime/entrypoints.py +1 -1
  25. flyte/_internal/runtime/io.py +3 -3
  26. flyte/_internal/runtime/task_serde.py +31 -1
  27. flyte/_internal/runtime/taskrunner.py +1 -1
  28. flyte/_internal/runtime/types_serde.py +1 -1
  29. flyte/_run.py +47 -28
  30. flyte/_task.py +2 -2
  31. flyte/_task_environment.py +1 -1
  32. flyte/_trace.py +5 -6
  33. flyte/_utils/__init__.py +2 -0
  34. flyte/_utils/async_cache.py +139 -0
  35. flyte/_version.py +2 -2
  36. flyte/config/__init__.py +26 -4
  37. flyte/config/_config.py +13 -4
  38. flyte/extras/_container.py +3 -3
  39. flyte/{_datastructures.py → models.py} +3 -2
  40. flyte/remote/_client/auth/_auth_utils.py +14 -0
  41. flyte/remote/_client/auth/_channel.py +28 -3
  42. flyte/remote/_client/auth/_token_client.py +3 -3
  43. flyte/remote/_client/controlplane.py +13 -13
  44. flyte/remote/_logs.py +1 -1
  45. flyte/remote/_run.py +4 -8
  46. flyte/remote/_task.py +2 -2
  47. flyte/storage/__init__.py +5 -0
  48. flyte/storage/_config.py +233 -0
  49. flyte/storage/_storage.py +23 -3
  50. flyte/types/_interface.py +1 -1
  51. flyte/types/_type_engine.py +1 -1
  52. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/METADATA +2 -2
  53. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/RECORD +56 -54
  54. flyte/_internal/controllers/pbhash.py +0 -39
  55. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/WHEEL +0 -0
  56. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/entry_points.txt +0 -0
  57. {flyte-0.2.0b1.dist-info → flyte-0.2.0b3.dist-info}/top_level.txt +0 -0
@@ -2,8 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  import asyncio
4
4
  from collections import defaultdict
5
+ from collections.abc import Callable
5
6
  from pathlib import Path
6
- from typing import Any, Awaitable, DefaultDict, Tuple, TypeVar
7
+ from typing import Any, AsyncIterable, Awaitable, DefaultDict, Tuple, TypeVar
7
8
 
8
9
  import flyte
9
10
  import flyte.errors
@@ -11,7 +12,6 @@ import flyte.storage as storage
11
12
  import flyte.types as types
12
13
  from flyte._code_bundle import build_pkl_bundle
13
14
  from flyte._context import internal_ctx
14
- from flyte._datastructures import ActionID, NativeInterface, SerializationContext
15
15
  from flyte._internal.controllers import TraceInfo
16
16
  from flyte._internal.controllers.remote._action import Action
17
17
  from flyte._internal.controllers.remote._core import Controller
@@ -21,16 +21,17 @@ from flyte._internal.runtime.task_serde import translate_task_to_wire
21
21
  from flyte._logging import logger
22
22
  from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
23
23
  from flyte._task import TaskTemplate
24
+ from flyte.models import ActionID, NativeInterface, SerializationContext
24
25
 
25
26
  R = TypeVar("R")
26
27
 
27
28
 
28
- async def upload_inputs_with_retry(inputs: convert.Inputs, inputs_uri: str) -> None:
29
+ async def upload_inputs_with_retry(serialized_inputs: AsyncIterable[bytes] | bytes, inputs_uri: str) -> None:
29
30
  """
30
31
  Upload inputs to the specified URI with error handling.
31
32
 
32
33
  Args:
33
- inputs: The inputs to upload
34
+ serialized_inputs: The serialized inputs to upload
34
35
  inputs_uri: The destination URI
35
36
 
36
37
  Raises:
@@ -38,9 +39,9 @@ async def upload_inputs_with_retry(inputs: convert.Inputs, inputs_uri: str) -> N
38
39
  """
39
40
  try:
40
41
  # TODO Add retry decorator to this
41
- await io.upload_inputs(inputs, inputs_uri)
42
+ await storage.put_stream(serialized_inputs, to_path=inputs_uri)
42
43
  except Exception as e:
43
- logger.exception("Failed to upload inputs", e)
44
+ logger.exception("Failed to upload inputs")
44
45
  raise flyte.errors.RuntimeSystemError(type(e).__name__, str(e)) from e
45
46
 
46
47
 
@@ -89,6 +90,10 @@ async def load_and_convert_outputs(iface: NativeInterface, realized_outputs_uri:
89
90
  return await convert.convert_outputs_to_native(iface, outputs)
90
91
 
91
92
 
93
+ def unique_action_name(action_id: ActionID) -> str:
94
+ return f"{action_id.name}_{action_id.run_name}"
95
+
96
+
92
97
  class RemoteController(Controller):
93
98
  """
94
99
  This a specialized controller that wraps the core controller and performs IO, serialization and deserialization
@@ -111,31 +116,42 @@ class RemoteController(Controller):
111
116
  self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
112
117
  lambda: asyncio.Semaphore(default_parent_concurrency)
113
118
  )
119
+ self._parent_action_task_call_sequence: DefaultDict[str, DefaultDict[int, int]] = defaultdict(
120
+ lambda: defaultdict(int)
121
+ )
114
122
 
115
- async def _submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
123
+ def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> int:
124
+ """
125
+ Generate a task call sequence for the given task object and action ID.
126
+ This is used to track the number of times a task is called within an action.
127
+ """
128
+ current_action_sequencer = self._parent_action_task_call_sequence[unique_action_name(action_id)]
129
+ current_task_id = id(task_obj)
130
+ v = current_action_sequencer[current_task_id]
131
+ new_seq = v + 1
132
+ current_action_sequencer[current_task_id] = new_seq
133
+ return new_seq
134
+
135
+ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwargs) -> Any:
116
136
  ctx = internal_ctx()
117
137
  tctx = ctx.data.task_context
118
138
  if tctx is None:
119
139
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
120
140
  current_action_id = tctx.action
121
141
 
122
- inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
123
- sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _task.name, inputs)
124
-
125
142
  # In the case of a regular code bundle, we will just pass it down as it is to the downstream tasks
126
143
  # It is not allowed to change the code bundle (for regular code bundles) in the middle of a run.
127
144
  code_bundle = tctx.code_bundle
128
145
 
129
146
  if code_bundle and code_bundle.pkl:
130
- logger.debug(f"Building new pkl bundle for task {sub_action_id.name}")
147
+ logger.debug(f"Building new pkl bundle for task {_task.name}")
131
148
  code_bundle = await build_pkl_bundle(
132
149
  _task,
133
150
  upload_to_controlplane=False,
134
- upload_from_dataplane_path=io.pkl_path(sub_action_output_path),
151
+ upload_from_dataplane_base_path=tctx.run_base_dir,
135
152
  )
136
153
 
137
- inputs_uri = io.inputs_path(sub_action_output_path)
138
- await upload_inputs_with_retry(inputs, inputs_uri)
154
+ inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
139
155
 
140
156
  root_dir = Path(code_bundle.destination).absolute() if code_bundle else Path.cwd()
141
157
  # Don't set output path in sec context because node executor will set it
@@ -146,13 +162,23 @@ class RemoteController(Controller):
146
162
  code_bundle=code_bundle,
147
163
  version=tctx.version,
148
164
  # supplied version.
149
- input_path=inputs_uri,
165
+ # input_path=inputs_uri,
150
166
  image_cache=tctx.compiled_image_cache,
151
167
  root_dir=root_dir,
152
168
  )
153
169
 
154
170
  task_spec = translate_task_to_wire(_task, new_serialization_context)
155
171
 
172
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
173
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
174
+ tctx, task_spec, serialized_inputs, _task_call_seq
175
+ )
176
+
177
+ inputs_uri = io.inputs_path(sub_action_output_path)
178
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri)
179
+ # Clear to free memory
180
+ serialized_inputs = None # type: ignore
181
+
156
182
  action = Action.from_task(
157
183
  sub_action_id=run_definition_pb2.ActionIdentifier(
158
184
  name=sub_action_id.name,
@@ -205,8 +231,9 @@ class RemoteController(Controller):
205
231
  if tctx is None:
206
232
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
207
233
  current_action_id = tctx.action
208
- async with self._parent_action_semaphore[current_action_id.name]:
209
- return await self._submit(_task, *args, **kwargs)
234
+ task_call_seq = self.generate_task_call_sequence(_task, current_action_id)
235
+ async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
236
+ return await self._submit(task_call_seq, _task, *args, **kwargs)
210
237
 
211
238
  async def finalize_parent_action(self, action_id: ActionID):
212
239
  """
@@ -220,16 +247,17 @@ class RemoteController(Controller):
220
247
  org=action_id.org,
221
248
  )
222
249
  await super()._finalize_parent_action(run_id=run_id, parent_action_name=action_id.name)
223
- self._parent_action_semaphore.pop(action_id.name, None)
250
+ self._parent_action_semaphore.pop(unique_action_name(action_id), None)
251
+ self._parent_action_task_call_sequence.pop(unique_action_name(action_id), None)
224
252
 
225
253
  async def get_action_outputs(
226
- self, _interface: NativeInterface, _func_name: str, *args, **kwargs
254
+ self, _interface: NativeInterface, _func: Callable, *args, **kwargs
227
255
  ) -> Tuple[TraceInfo, bool]:
228
256
  """
229
257
  This method returns the outputs of the action, if it is available.
230
258
  If not available it raises a NotFoundError.
231
259
  :param _interface: NativeInterface
232
- :param _func_name: Function name
260
+ :param _func: Function name
233
261
  :param args: Arguments
234
262
  :param kwargs: Keyword arguments
235
263
  :return:
@@ -240,11 +268,19 @@ class RemoteController(Controller):
240
268
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
241
269
  current_action_id = tctx.action
242
270
 
271
+ func_name = _func.__name__
272
+ invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id)
243
273
  inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
244
- sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, _func_name, inputs)
274
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
275
+
276
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
277
+ tctx, func_name, serialized_inputs, invoke_seq_num
278
+ )
245
279
 
246
280
  inputs_uri = io.inputs_path(sub_action_output_path)
247
- await upload_inputs_with_retry(inputs, inputs_uri)
281
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri)
282
+ # Clear to free memory
283
+ serialized_inputs = None # type: ignore
248
284
 
249
285
  prev_action = await self.get_action(
250
286
  run_definition_pb2.ActionIdentifier(
@@ -310,12 +346,20 @@ class RemoteController(Controller):
310
346
  current_action_id = tctx.action
311
347
  task_name = _task.spec.task_template.id.name
312
348
 
349
+ invoke_seq_num = self.generate_task_call_sequence(_task, current_action_id)
350
+
313
351
  native_interface = types.guess_interface(_task.spec.task_template.interface)
314
352
  inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
315
- sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(tctx, task_name, inputs)
353
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
354
+
355
+ sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
356
+ tctx, task_name, serialized_inputs, invoke_seq_num
357
+ )
316
358
 
317
359
  inputs_uri = io.inputs_path(sub_action_output_path)
318
- await upload_inputs_with_retry(inputs, inputs_uri)
360
+ await upload_inputs_with_retry(serialized_inputs, inputs_uri)
361
+ # Clear to free memory
362
+ serialized_inputs = None # type: ignore
319
363
 
320
364
  action = Action.from_task(
321
365
  sub_action_id=run_definition_pb2.ActionIdentifier(
@@ -32,7 +32,7 @@ class Controller:
32
32
  max_system_retries: int = 5,
33
33
  resource_log_interval_sec: float = 10.0,
34
34
  min_backoff_on_err_sec: float = 0.1,
35
- thread_wait_timeout_sec: float = 10.0,
35
+ thread_wait_timeout_sec: float = 0.5,
36
36
  enqueue_timeout_sec: float = 5.0,
37
37
  ):
38
38
  """
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import base64
4
+ import hashlib
3
5
  from dataclasses import dataclass
4
6
  from typing import Any, Dict, Tuple, Union
5
7
 
@@ -7,9 +9,8 @@ from flyteidl.core import execution_pb2, literals_pb2
7
9
 
8
10
  import flyte.errors
9
11
  import flyte.storage as storage
10
- from flyte._datastructures import ActionID, NativeInterface, TaskContext
11
- from flyte._internal.controllers import pbhash
12
- from flyte._protos.workflow import run_definition_pb2
12
+ from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
13
+ from flyte.models import ActionID, NativeInterface, TaskContext
13
14
  from flyte.types import TypeEngine
14
15
 
15
16
 
@@ -185,21 +186,46 @@ def convert_from_native_to_error(err: BaseException) -> Error:
185
186
  )
186
187
 
187
188
 
188
- def generate_sub_action_id_and_output_path(tctx: TaskContext, task_name: str, inputs: Inputs) -> Tuple[ActionID, str]:
189
+ def hash_data(data: Union[str, bytes]) -> str:
190
+ """
191
+ Generate a hash for the given data. If the data is a string, it will be encoded to bytes before hashing.
192
+ :param data: The data to hash, can be a string or bytes.
193
+ :return: A hexadecimal string representation of the hash.
194
+ """
195
+ if isinstance(data, str):
196
+ data = data.encode("utf-8")
197
+ digest = hashlib.sha256(data).digest()
198
+ return base64.b64encode(digest).decode("utf-8")
199
+
200
+
201
+ def generate_sub_action_id_and_output_path(
202
+ tctx: TaskContext,
203
+ task_spec_or_name: task_definition_pb2.TaskSpec | str,
204
+ serialized_inputs: str | bytes,
205
+ invoke_seq: int,
206
+ ) -> Tuple[ActionID, str]:
189
207
  """
190
208
  Generate a sub-action ID and output path based on the current task context, task name, and inputs.
209
+
210
+ action name = current action name + task name + input hash + group name (if available)
191
211
  :param tctx:
192
- :param task_name:
193
- :param inputs:
212
+ :param task_spec_or_name: task specification or task name. Task name is only used in case of trace actions.
213
+ :param serialized_inputs:
214
+ :param invoke_seq: The sequence number of the invocation, used to differentiate between multiple invocations.
194
215
  :return:
195
216
  """
196
217
  current_action_id = tctx.action
197
218
  current_output_path = tctx.run_base_dir
198
- inputs_hash = pbhash.compute_hash_string(inputs.proto_inputs)
219
+ inputs_hash = hash_data(serialized_inputs)
220
+ if isinstance(task_spec_or_name, task_definition_pb2.TaskSpec):
221
+ task_hash = hash_data(task_spec_or_name.SerializeToString(deterministic=True))
222
+ else:
223
+ task_hash = task_spec_or_name
199
224
  sub_action_id = current_action_id.new_sub_action_from(
200
- task_name=task_name,
225
+ task_hash=task_hash,
201
226
  input_hash=inputs_hash,
202
227
  group=tctx.group_data.name if tctx.group_data else None,
228
+ task_call_seq=invoke_seq,
203
229
  )
204
230
  sub_run_output_path = storage.join(current_output_path, sub_action_id.name)
205
231
  return sub_action_id, sub_run_output_path
@@ -3,11 +3,11 @@ from typing import List, Optional, Tuple
3
3
  import flyte.errors
4
4
  from flyte._code_bundle import download_bundle
5
5
  from flyte._context import contextual_run
6
- from flyte._datastructures import ActionID, Checkpoints, CodeBundle, RawDataPath
7
6
  from flyte._internal import Controller
8
7
  from flyte._internal.imagebuild.image_builder import ImageCache
9
8
  from flyte._logging import log, logger
10
9
  from flyte._task import TaskTemplate
10
+ from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
11
11
 
12
12
  from .convert import Error, Inputs, Outputs
13
13
  from .task_serde import load_task
@@ -21,11 +21,11 @@ _OUTPUTS_FILE_NAME = "outputs.pb"
21
21
  _CHECKPOINT_FILE_NAME = "_flytecheckpoints"
22
22
  _ERROR_FILE_NAME = "error.pb"
23
23
  _REPORT_FILE_NAME = "report.html"
24
- _PKL_FILE_NAME = "code_bundle.pkl.gz"
24
+ _PKL_EXT = ".pkl.gz"
25
25
 
26
26
 
27
- def pkl_path(base_path: str) -> str:
28
- return storage.join(base_path, _PKL_FILE_NAME)
27
+ def pkl_path(base_path: str, pkl_name: str) -> str:
28
+ return storage.join(base_path, f"{pkl_name}{_PKL_EXT}")
29
29
 
30
30
 
31
31
  def inputs_path(base_path: str) -> str:
@@ -12,11 +12,11 @@ from google.protobuf import duration_pb2, wrappers_pb2
12
12
 
13
13
  import flyte.errors
14
14
  from flyte._cache.cache import VersionParameters, cache_from_request
15
- from flyte._datastructures import SerializationContext
16
15
  from flyte._logging import logger
17
16
  from flyte._protos.workflow import task_definition_pb2
18
17
  from flyte._secret import SecretRequest, secrets_from_request
19
18
  from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
19
+ from flyte.models import CodeBundle, SerializationContext
20
20
 
21
21
  from ..._retry import RetryStrategy
22
22
  from ..._timeout import TimeoutType, timeout_from_request
@@ -208,3 +208,33 @@ def _get_urun_container(
208
208
  data_config=task_template.data_loading_config(serialize_context),
209
209
  config=task_template.config(serialize_context),
210
210
  )
211
+
212
+
213
+ def extract_code_bundle(task_spec: task_definition_pb2.TaskSpec) -> Optional[CodeBundle]:
214
+ """
215
+ Extract the code bundle from the task spec.
216
+ :param task_spec: The task spec to extract the code bundle from.
217
+ :return: The extracted code bundle or None if not present.
218
+ """
219
+ container = task_spec.task_template.container
220
+ if container and container.args:
221
+ pkl_path = None
222
+ tgz_path = None
223
+ dest_path: str = "."
224
+ version = ""
225
+ for i, v in enumerate(container.args):
226
+ if v == "--pkl":
227
+ # Extract the code bundle path from the argument
228
+ pkl_path = container.args[i + 1] if i + 1 < len(container.args) else None
229
+ elif v == "--tgz":
230
+ # Extract the code bundle path from the argument
231
+ tgz_path = container.args[i + 1] if i + 1 < len(container.args) else None
232
+ elif v == "--dest":
233
+ # Extract the destination path from the argument
234
+ dest_path = container.args[i + 1] if i + 1 < len(container.args) else "."
235
+ elif v == "--version":
236
+ # Extract the version from the argument
237
+ version = container.args[i + 1] if i + 1 < len(container.args) else ""
238
+ if pkl_path or tgz_path:
239
+ return CodeBundle(destination=dest_path, tgz=tgz_path, pkl=pkl_path, computed_version=version)
240
+ return None
@@ -8,11 +8,11 @@ from typing import Any, Dict, List, Optional, Tuple
8
8
 
9
9
  import flyte.report
10
10
  from flyte._context import internal_ctx
11
- from flyte._datastructures import ActionID, Checkpoints, CodeBundle, RawDataPath, TaskContext
12
11
  from flyte._internal.imagebuild.image_builder import ImageCache
13
12
  from flyte._logging import log, logger
14
13
  from flyte._task import TaskTemplate
15
14
  from flyte.errors import CustomError, RuntimeSystemError, RuntimeUnknownError, RuntimeUserError
15
+ from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath, TaskContext
16
16
 
17
17
  from .. import Controller
18
18
  from .convert import (
@@ -2,7 +2,7 @@ from typing import Dict, Optional, TypeVar
2
2
 
3
3
  from flyteidl.core import interface_pb2
4
4
 
5
- from flyte._datastructures import NativeInterface
5
+ from flyte.models import NativeInterface
6
6
  from flyte.types._type_engine import TypeEngine
7
7
 
8
8
  T = TypeVar("T")
flyte/_run.py CHANGED
@@ -4,17 +4,12 @@ import pathlib
4
4
  import uuid
5
5
  from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union, cast
6
6
 
7
- import flyte
8
- import flyte.report
9
- from flyte import S3
7
+ from flyte.errors import InitializationError
10
8
 
11
9
  from ._api_commons import syncer
12
10
  from ._context import contextual_run, internal_ctx
13
- from ._datastructures import ActionID, Checkpoints, RawDataPath, SerializationContext, TaskContext
14
11
  from ._environment import Environment
15
12
  from ._initialize import (
16
- ABFS,
17
- GCS,
18
13
  _get_init_config,
19
14
  get_client,
20
15
  get_common_config,
@@ -22,14 +17,10 @@ from ._initialize import (
22
17
  requires_initialization,
23
18
  requires_storage,
24
19
  )
25
- from ._internal import create_controller
26
- from ._internal.runtime.io import _CHECKPOINT_FILE_NAME
27
- from ._internal.runtime.taskrunner import run_task
28
20
  from ._logging import logger
29
- from ._protos.common import identifier_pb2
30
21
  from ._task import P, R, TaskTemplate
31
22
  from ._tools import ipython_check
32
- from .errors import InitializationError
23
+ from .models import ActionID, Checkpoints, CodeBundle, RawDataPath, SerializationContext, TaskContext
33
24
 
34
25
  if TYPE_CHECKING:
35
26
  from flyte.remote import Run
@@ -39,6 +30,22 @@ if TYPE_CHECKING:
39
30
  Mode = Literal["local", "remote", "hybrid"]
40
31
 
41
32
 
33
+ async def _get_code_bundle_for_run(name: str) -> CodeBundle | None:
34
+ """
35
+ Get the code bundle for the run with the given name.
36
+ This is used to get the code bundle for the run when running in hybrid mode.
37
+ """
38
+ from flyte._internal.runtime.task_serde import extract_code_bundle
39
+ from flyte.remote import Run
40
+
41
+ run = await Run.get.aio(Run, name=name)
42
+ if run:
43
+ run_details = await run.details()
44
+ spec = run_details.action_details.pb2.resolved_task_spec
45
+ return extract_code_bundle(spec)
46
+ return None
47
+
48
+
42
49
  @syncer.wrap
43
50
  class _Runner:
44
51
  def __init__(
@@ -81,6 +88,7 @@ class _Runner:
81
88
  from ._deploy import build_images, plan_deploy
82
89
  from ._internal.runtime.convert import convert_from_native_to_inputs
83
90
  from ._internal.runtime.task_serde import translate_task_to_wire
91
+ from ._protos.common import identifier_pb2
84
92
  from ._protos.workflow import run_definition_pb2, run_service_pb2
85
93
 
86
94
  cfg = get_common_config()
@@ -187,9 +195,14 @@ class _Runner:
187
195
  run in the cluster remotely. This is currently only used for testing,
188
196
  over the longer term we will productize this.
189
197
  """
198
+ import flyte.report
190
199
  from flyte._code_bundle import build_code_bundle, build_pkl_bundle
191
- from flyte._datastructures import RawDataPath
192
200
  from flyte._deploy import build_images, plan_deploy
201
+ from flyte.models import RawDataPath
202
+ from flyte.storage import ABFS, GCS, S3
203
+
204
+ from ._internal import create_controller
205
+ from ._internal.runtime.taskrunner import run_task
193
206
 
194
207
  cfg = get_common_config()
195
208
 
@@ -199,17 +212,23 @@ class _Runner:
199
212
  deploy_plan = plan_deploy(cast(Environment, obj.parent_env()))
200
213
  image_cache = await build_images(deploy_plan)
201
214
 
202
- if self._interactive_mode:
203
- code_bundle = await build_pkl_bundle(
204
- obj, upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to
205
- )
206
- else:
207
- if self._copy_files != "none":
208
- code_bundle = await build_code_bundle(
209
- from_dir=cfg.root_dir, dryrun=self._dry_run, copy_bundle_to=self._copy_bundle_to
215
+ code_bundle = None
216
+ if self._name is not None:
217
+ # Check if remote run service has this run name already and if exists, then extract the code bundle from it.
218
+ code_bundle = await _get_code_bundle_for_run(name=self._name)
219
+
220
+ if not code_bundle:
221
+ if self._interactive_mode:
222
+ code_bundle = await build_pkl_bundle(
223
+ obj, upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to
210
224
  )
211
225
  else:
212
- code_bundle = None
226
+ if self._copy_files != "none":
227
+ code_bundle = await build_code_bundle(
228
+ from_dir=cfg.root_dir, dryrun=self._dry_run, copy_bundle_to=self._copy_bundle_to
229
+ )
230
+ else:
231
+ code_bundle = None
213
232
 
214
233
  version = self._version or (
215
234
  code_bundle.computed_version if code_bundle and code_bundle.computed_version else None
@@ -217,14 +236,14 @@ class _Runner:
217
236
  if not version:
218
237
  raise ValueError("Version is required when running a task")
219
238
 
220
- project = cfg.project or "testproject"
221
- domain = cfg.domain or "development"
222
- org = cfg.org or "testorg"
239
+ project = cfg.project
240
+ domain = cfg.domain
241
+ org = cfg.org
223
242
  action_name = "a0"
224
243
  run_name = self._name
225
244
  random_id = str(uuid.uuid4())[:6]
226
245
 
227
- controller = create_controller(ct="remote", endpoint="localhost:8090", insecure=True)
246
+ controller = create_controller("remote", endpoint="localhost:8090", insecure=True)
228
247
  action = ActionID(name=action_name, run_name=run_name, project=project, domain=domain, org=org)
229
248
 
230
249
  inputs = obj.native_interface.convert_to_kwargs(*args, **kwargs)
@@ -242,7 +261,7 @@ class _Runner:
242
261
  output_path = self._run_base_dir
243
262
  raw_data_path = f"{output_path}/rd/{random_id}"
244
263
  raw_data_path_obj = RawDataPath(path=raw_data_path)
245
- checkpoint_path = f"{raw_data_path}/{_CHECKPOINT_FILE_NAME}"
264
+ checkpoint_path = f"{raw_data_path}/checkpoint"
246
265
  prev_checkpoint = f"{raw_data_path}/prev_checkpoint"
247
266
  checkpoints = Checkpoints(checkpoint_path, prev_checkpoint)
248
267
 
@@ -253,7 +272,7 @@ class _Runner:
253
272
  checkpoints=checkpoints,
254
273
  code_bundle=code_bundle,
255
274
  output_path=output_path,
256
- version=version,
275
+ version=version if version else "na",
257
276
  raw_data_path=raw_data_path_obj,
258
277
  compiled_image_cache=image_cache,
259
278
  run_base_dir=self._run_base_dir,
@@ -276,7 +295,7 @@ class _Runner:
276
295
  )
277
296
  from flyte._internal.runtime.entrypoints import direct_dispatch
278
297
 
279
- controller = create_controller(ct="local")
298
+ controller = create_controller("local")
280
299
  inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
281
300
  if self._name is None:
282
301
  action = ActionID.create_random()
flyte/_task.py CHANGED
@@ -25,7 +25,6 @@ from flyte.errors import RuntimeSystemError, RuntimeUserError
25
25
 
26
26
  from ._cache import Cache, CacheRequest
27
27
  from ._context import internal_ctx
28
- from ._datastructures import NativeInterface, SerializationContext
29
28
  from ._doc import Documentation
30
29
  from ._image import Image
31
30
  from ._resources import Resources
@@ -33,6 +32,7 @@ from ._retry import RetryStrategy
33
32
  from ._reusable_environment import ReusePolicy
34
33
  from ._secret import SecretRequest
35
34
  from ._timeout import TimeoutType
35
+ from .models import NativeInterface, SerializationContext
36
36
 
37
37
  if TYPE_CHECKING:
38
38
  from kubernetes.client import V1PodTemplate
@@ -224,7 +224,7 @@ class TaskTemplate(Generic[P, R]):
224
224
  # We will also check if we are not initialized, It is not expected to be not initialized
225
225
  from ._internal.controllers import get_controller
226
226
 
227
- controller = await get_controller()
227
+ controller = get_controller()
228
228
  if controller:
229
229
  return await controller.submit(self, *args, **kwargs)
230
230
  return await self.execute(*args, **kwargs)
@@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Literal, Opti
10
10
  import rich.repr
11
11
 
12
12
  from ._cache import CacheRequest
13
- from ._datastructures import NativeInterface
14
13
  from ._doc import Documentation
15
14
  from ._environment import Environment
16
15
  from ._image import Image
@@ -19,6 +18,7 @@ from ._retry import RetryStrategy
19
18
  from ._reusable_environment import ReusePolicy
20
19
  from ._secret import SecretRequest
21
20
  from ._task import AsyncFunctionTaskTemplate, TaskTemplate
21
+ from .models import NativeInterface
22
22
 
23
23
  if TYPE_CHECKING:
24
24
  from kubernetes.client import V1PodTemplate
flyte/_trace.py CHANGED
@@ -4,7 +4,7 @@ import time
4
4
  from datetime import timedelta
5
5
  from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Callable, TypeGuard, TypeVar, Union, cast
6
6
 
7
- from flyte._datastructures import NativeInterface
7
+ from flyte.models import NativeInterface
8
8
 
9
9
  T = TypeVar("T")
10
10
 
@@ -14,7 +14,6 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
14
14
  A decorator that traces function execution with timing information.
15
15
  Works with regular functions, async functions, and async generators/iterators.
16
16
  """
17
- func_name = func.__name__
18
17
 
19
18
  @functools.wraps(func)
20
19
  def wrapper_sync(*args: Any, **kwargs: Any) -> Any:
@@ -31,9 +30,9 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
31
30
  # We will also check if we are not initialized, It is not expected to be not initialized
32
31
  from ._internal.controllers import get_controller
33
32
 
34
- controller = await get_controller()
33
+ controller = get_controller()
35
34
  iface = NativeInterface.from_callable(func)
36
- info, ok = await controller.get_action_outputs(iface, func_name, *args, **kwargs)
35
+ info, ok = await controller.get_action_outputs(iface, func, *args, **kwargs)
37
36
  if ok:
38
37
  if info.output:
39
38
  return info.output
@@ -74,9 +73,9 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
74
73
  # We will also check if we are not initialized, It is not expected to be not initialized
75
74
  from ._internal.controllers import get_controller
76
75
 
77
- controller = await get_controller()
76
+ controller = get_controller()
78
77
  iface = NativeInterface.from_callable(func)
79
- info, ok = await controller.get_action_outputs(iface, func_name, *args, **kwargs)
78
+ info, ok = await controller.get_action_outputs(iface, func, *args, **kwargs)
80
79
  if ok:
81
80
  if info.output:
82
81
  for item in info.output:
flyte/_utils/__init__.py CHANGED
@@ -4,6 +4,7 @@ Internal utility functions.
4
4
  Except for logging, modules in this package should not depend on any other part of the repo.
5
5
  """
6
6
 
7
+ from .async_cache import AsyncLRUCache
7
8
  from .coro_management import run_coros
8
9
  from .file_handling import filehash_update, update_hasher_for_source
9
10
  from .helpers import get_cwd_editable_install
@@ -11,6 +12,7 @@ from .lazy_module import lazy_module
11
12
  from .uv_script_parser import parse_uv_script_file
12
13
 
13
14
  __all__ = [
15
+ "AsyncLRUCache",
14
16
  "filehash_update",
15
17
  "get_cwd_editable_install",
16
18
  "lazy_module",