flyte 2.0.0b21__py3-none-any.whl → 2.0.0b23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (92) hide show
  1. flyte/__init__.py +5 -0
  2. flyte/_bin/runtime.py +36 -6
  3. flyte/_cache/cache.py +4 -2
  4. flyte/_cache/local_cache.py +215 -0
  5. flyte/_code_bundle/bundle.py +1 -0
  6. flyte/_debug/constants.py +0 -1
  7. flyte/_debug/vscode.py +6 -1
  8. flyte/_deploy.py +204 -55
  9. flyte/_environment.py +5 -0
  10. flyte/_excepthook.py +1 -1
  11. flyte/_image.py +101 -68
  12. flyte/_initialize.py +30 -1
  13. flyte/_interface.py +3 -1
  14. flyte/_internal/controllers/_local_controller.py +64 -24
  15. flyte/_internal/controllers/remote/_action.py +4 -1
  16. flyte/_internal/controllers/remote/_controller.py +5 -2
  17. flyte/_internal/controllers/remote/_core.py +6 -3
  18. flyte/_internal/controllers/remote/_informer.py +1 -1
  19. flyte/_internal/imagebuild/docker_builder.py +95 -28
  20. flyte/_internal/imagebuild/image_builder.py +0 -5
  21. flyte/_internal/imagebuild/remote_builder.py +6 -1
  22. flyte/_internal/runtime/io.py +13 -1
  23. flyte/_internal/runtime/rusty.py +17 -2
  24. flyte/_internal/runtime/task_serde.py +15 -11
  25. flyte/_internal/runtime/taskrunner.py +1 -1
  26. flyte/_internal/runtime/trigger_serde.py +153 -0
  27. flyte/_keyring/file.py +2 -2
  28. flyte/_logging.py +1 -1
  29. flyte/_protos/common/identifier_pb2.py +19 -1
  30. flyte/_protos/common/identifier_pb2.pyi +22 -0
  31. flyte/_protos/workflow/common_pb2.py +14 -3
  32. flyte/_protos/workflow/common_pb2.pyi +49 -0
  33. flyte/_protos/workflow/queue_service_pb2.py +41 -35
  34. flyte/_protos/workflow/queue_service_pb2.pyi +26 -12
  35. flyte/_protos/workflow/queue_service_pb2_grpc.py +34 -0
  36. flyte/_protos/workflow/run_definition_pb2.py +38 -38
  37. flyte/_protos/workflow/run_definition_pb2.pyi +4 -2
  38. flyte/_protos/workflow/run_service_pb2.py +60 -50
  39. flyte/_protos/workflow/run_service_pb2.pyi +24 -6
  40. flyte/_protos/workflow/run_service_pb2_grpc.py +34 -0
  41. flyte/_protos/workflow/task_definition_pb2.py +15 -11
  42. flyte/_protos/workflow/task_definition_pb2.pyi +19 -2
  43. flyte/_protos/workflow/task_service_pb2.py +18 -17
  44. flyte/_protos/workflow/task_service_pb2.pyi +5 -2
  45. flyte/_protos/workflow/trigger_definition_pb2.py +66 -0
  46. flyte/_protos/workflow/trigger_definition_pb2.pyi +117 -0
  47. flyte/_protos/workflow/trigger_definition_pb2_grpc.py +4 -0
  48. flyte/_protos/workflow/trigger_service_pb2.py +96 -0
  49. flyte/_protos/workflow/trigger_service_pb2.pyi +110 -0
  50. flyte/_protos/workflow/trigger_service_pb2_grpc.py +281 -0
  51. flyte/_run.py +42 -15
  52. flyte/_task.py +35 -4
  53. flyte/_task_environment.py +61 -16
  54. flyte/_trigger.py +382 -0
  55. flyte/_version.py +3 -3
  56. flyte/cli/_abort.py +3 -3
  57. flyte/cli/_build.py +1 -3
  58. flyte/cli/_common.py +17 -4
  59. flyte/cli/_create.py +74 -0
  60. flyte/cli/_delete.py +23 -1
  61. flyte/cli/_deploy.py +16 -10
  62. flyte/cli/_get.py +75 -34
  63. flyte/cli/_params.py +4 -2
  64. flyte/cli/_run.py +25 -6
  65. flyte/cli/_update.py +36 -0
  66. flyte/cli/_user.py +17 -0
  67. flyte/cli/main.py +9 -1
  68. flyte/errors.py +9 -0
  69. flyte/io/_dir.py +513 -115
  70. flyte/io/_file.py +495 -135
  71. flyte/models.py +32 -0
  72. flyte/remote/__init__.py +6 -1
  73. flyte/remote/_action.py +9 -8
  74. flyte/remote/_client/_protocols.py +36 -2
  75. flyte/remote/_client/controlplane.py +19 -3
  76. flyte/remote/_run.py +42 -2
  77. flyte/remote/_task.py +14 -1
  78. flyte/remote/_trigger.py +308 -0
  79. flyte/remote/_user.py +33 -0
  80. flyte/storage/__init__.py +6 -1
  81. flyte/storage/_storage.py +119 -101
  82. flyte/types/_pickle.py +34 -7
  83. flyte/types/_type_engine.py +6 -0
  84. {flyte-2.0.0b21.data → flyte-2.0.0b23.data}/scripts/runtime.py +36 -6
  85. {flyte-2.0.0b21.dist-info → flyte-2.0.0b23.dist-info}/METADATA +3 -1
  86. {flyte-2.0.0b21.dist-info → flyte-2.0.0b23.dist-info}/RECORD +91 -79
  87. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  88. {flyte-2.0.0b21.data → flyte-2.0.0b23.data}/scripts/debug.py +0 -0
  89. {flyte-2.0.0b21.dist-info → flyte-2.0.0b23.dist-info}/WHEEL +0 -0
  90. {flyte-2.0.0b21.dist-info → flyte-2.0.0b23.dist-info}/entry_points.txt +0 -0
  91. {flyte-2.0.0b21.dist-info → flyte-2.0.0b23.dist-info}/licenses/LICENSE +0 -0
  92. {flyte-2.0.0b21.dist-info → flyte-2.0.0b23.dist-info}/top_level.txt +0 -0
flyte/_image.py CHANGED
@@ -1,11 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- import base64
4
3
  import hashlib
5
4
  import sys
6
5
  import typing
7
6
  from abc import abstractmethod
8
- from dataclasses import asdict, dataclass, field, fields
7
+ from dataclasses import dataclass, field
9
8
  from functools import cached_property
10
9
  from pathlib import Path
11
10
  from typing import TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, Tuple, TypeVar, Union
@@ -13,6 +12,8 @@ from typing import TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, Tuple
13
12
  import rich.repr
14
13
  from packaging.version import Version
15
14
 
15
+ from flyte._utils import update_hasher_for_source
16
+
16
17
  if TYPE_CHECKING:
17
18
  from flyte import Secret, SecretRequest
18
19
 
@@ -56,7 +57,6 @@ class Layer:
56
57
 
57
58
  :param hasher: The hash object to update with the layer's data.
58
59
  """
59
- print("hash hash")
60
60
 
61
61
  def validate(self):
62
62
  """
@@ -64,27 +64,6 @@ class Layer:
64
64
  :return:
65
65
  """
66
66
 
67
- def identifier(self) -> str:
68
- """
69
- This method computes a unique identifier for the layer based on its properties.
70
- It is used to identify the layer in the image cache.
71
-
72
- It is also used to compute a unique identifier for the image itself, which is a combination of all the layers.
73
- This identifier is used to look up previously built images in the image cache. So having a consistent
74
- identifier is important for the image cache to work correctly.
75
-
76
- :return: A unique identifier for the layer.
77
- """
78
- ignore_fields: list[str] = []
79
- for f in fields(self):
80
- if f.metadata.get("identifier", True) is False:
81
- ignore_fields.append(f.name)
82
- d = asdict(self)
83
- for v in ignore_fields:
84
- d.pop(v)
85
-
86
- return str(d)
87
-
88
67
 
89
68
  @rich.repr.auto
90
69
  @dataclass(kw_only=True, frozen=True, repr=True)
@@ -152,7 +131,7 @@ class PipPackages(PipOption, Layer):
152
131
  @rich.repr.auto
153
132
  @dataclass(kw_only=True, frozen=True, repr=True)
154
133
  class PythonWheels(PipOption, Layer):
155
- wheel_dir: Path = field(metadata={"identifier": False})
134
+ wheel_dir: Path
156
135
  wheel_dir_name: str = field(init=False)
157
136
  package_name: str
158
137
 
@@ -202,13 +181,56 @@ class UVProject(PipOption, Layer):
202
181
  from ._utils import filehash_update
203
182
 
204
183
  super().update_hash(hasher)
205
- filehash_update(self.uvlock, hasher)
184
+ if self.extra_args and "--no-install-project" in self.extra_args:
185
+ filehash_update(self.uvlock, hasher)
186
+ filehash_update(self.pyproject, hasher)
187
+ else:
188
+ update_hasher_for_source(self.pyproject.parent, hasher)
189
+
190
+
191
+ @rich.repr.auto
192
+ @dataclass(frozen=True, repr=True)
193
+ class PoetryProject(Layer):
194
+ """
195
+ Poetry does not use pip options, so the PoetryProject class do not inherits PipOption class
196
+ """
197
+
198
+ pyproject: Path
199
+ poetry_lock: Path
200
+ extra_args: Optional[str] = None
201
+ secret_mounts: Optional[Tuple[str | Secret, ...]] = None
202
+
203
+ def validate(self):
204
+ if not self.pyproject.exists():
205
+ raise FileNotFoundError(f"pyproject.toml file {self.pyproject} does not exist")
206
+ if not self.pyproject.is_file():
207
+ raise ValueError(f"Pyproject file {self.pyproject} is not a file")
208
+ if not self.poetry_lock.exists():
209
+ raise ValueError(f"poetry.lock file {self.poetry_lock} does not exist")
210
+ super().validate()
211
+
212
+ def update_hash(self, hasher: hashlib._Hash):
213
+ from ._utils import filehash_update
214
+
215
+ hash_input = ""
216
+ if self.extra_args:
217
+ hash_input += self.extra_args
218
+ if self.secret_mounts:
219
+ for secret_mount in self.secret_mounts:
220
+ hash_input += str(secret_mount)
221
+ hasher.update(hash_input.encode("utf-8"))
222
+
223
+ if self.extra_args and "--no-root" in self.extra_args:
224
+ filehash_update(self.poetry_lock, hasher)
225
+ filehash_update(self.pyproject, hasher)
226
+ else:
227
+ update_hasher_for_source(self.pyproject.parent, hasher)
206
228
 
207
229
 
208
230
  @rich.repr.auto
209
231
  @dataclass(frozen=True, repr=True)
210
232
  class UVScript(PipOption, Layer):
211
- script: Path = field(metadata={"identifier": False})
233
+ script: Path
212
234
  script_name: str = field(init=False)
213
235
 
214
236
  def __post_init__(self):
@@ -284,15 +306,14 @@ class DockerIgnore(Layer):
284
306
  @rich.repr.auto
285
307
  @dataclass(frozen=True, repr=True)
286
308
  class CopyConfig(Layer):
287
- path_type: CopyConfigType = field(metadata={"identifier": True})
288
- src: Path = field(metadata={"identifier": False})
309
+ path_type: CopyConfigType
310
+ src: Path
289
311
  dst: str
290
- src_name: str = field(init=False)
312
+ src_name: str
291
313
 
292
314
  def __post_init__(self):
293
315
  if self.path_type not in (0, 1):
294
316
  raise ValueError(f"Invalid path_type {self.path_type}, must be 0 (file) or 1 (directory)")
295
- object.__setattr__(self, "src_name", self.src.name)
296
317
 
297
318
  def validate(self):
298
319
  if not self.src.exists():
@@ -381,9 +402,6 @@ class Image:
381
402
  platform: Tuple[Architecture, ...] = field(default=("linux/amd64",))
382
403
  python_version: Tuple[int, int] = field(default_factory=_detect_python_version)
383
404
 
384
- # For .auto() images. Don't compute an actual identifier.
385
- _identifier_override: Optional[str] = field(default=None, init=False)
386
-
387
405
  # Layers to be added to the image. In init, because frozen, but users shouldn't access, so underscore.
388
406
  _layers: Tuple[Layer, ...] = field(default_factory=tuple)
389
407
 
@@ -416,31 +434,6 @@ class Image:
416
434
  cls.__init__(obj, **kwargs) # run dataclass generated __init__
417
435
  return obj
418
436
 
419
- @cached_property
420
- def identifier(self) -> str:
421
- """
422
- This identifier is a hash of the layers and properties of the image. It is used to look up previously built
423
- images. Why is this useful? For example, if a user has Image.from_uv_base().with_source_file("a/local/file"),
424
- it's not necessarily the case that that file exists within the image (further commands may have removed/changed
425
- it), and certainly not the case that the path to the file, inside the image (which is used as part of the layer
426
- hash computation), is the same. That is, inside the image when a task runs, as we come across the same Image
427
- declaration, we need a way of identifying the image and its uri, without hashing all the layers again. This
428
- is what this identifier is for. See the ImageCache object for additional information.
429
-
430
- :return: A unique identifier of the Image
431
- """
432
- if self._identifier_override:
433
- return self._identifier_override
434
-
435
- # Only get the non-None values in the Image to ensure the hash is consistent
436
- # across different SDK versions.
437
- # Layers can specify a _compute_identifier optionally, but the default will just stringify
438
- image_dict = asdict(self, dict_factory=lambda x: {k: v for (k, v) in x if v is not None and k != "_layers"})
439
- layers_str_repr = "".join([layer.identifier() for layer in self._layers])
440
- image_dict["layers"] = layers_str_repr
441
- spec_bytes = image_dict.__str__().encode("utf-8")
442
- return base64.urlsafe_b64encode(hashlib.md5(spec_bytes).digest()).decode("ascii").rstrip("=")
443
-
444
437
  def validate(self):
445
438
  for layer in self._layers:
446
439
  layer.validate()
@@ -503,9 +496,6 @@ class Image:
503
496
  image = image.with_pip_packages(f"flyte=={flyte_version}")
504
497
  if not dev_mode:
505
498
  object.__setattr__(image, "_tag", preset_tag)
506
- # Set this to auto for all auto images because the meaning of "auto" can change (based on logic inside
507
- # _get_default_image_for, acts differently in a running task container) so let's make sure it stays auto.
508
- object.__setattr__(image, "_identifier_override", "auto")
509
499
 
510
500
  return image
511
501
 
@@ -546,9 +536,6 @@ class Image:
546
536
  if registry or name:
547
537
  return base_image.clone(registry=registry, name=name)
548
538
 
549
- # # Set this to auto for all auto images because the meaning of "auto" can change (based on logic inside
550
- # # _get_default_image_for, acts differently in a running task container) so let's make sure it stays auto.
551
- # object.__setattr__(base_image, "_identifier_override", "auto")
552
539
  return base_image
553
540
 
554
541
  @classmethod
@@ -845,16 +832,23 @@ class Image:
845
832
  new_image = self.clone(addl_layer=Env.from_dict(env_vars))
846
833
  return new_image
847
834
 
848
- def with_source_folder(self, src: Path, dst: str = ".") -> Image:
835
+ def with_source_folder(self, src: Path, dst: str = ".", copy_contents_only: bool = False) -> Image:
849
836
  """
850
837
  Use this method to create a new image with the specified local directory layered on top of the current image.
851
838
  If dest is not specified, it will be copied to the working directory of the image
852
839
 
853
840
  :param src: root folder of the source code from the build context to be copied
854
841
  :param dst: destination folder in the image
842
+ :param copy_contents_only: If True, will copy the contents of the source folder to the destination folder,
843
+ instead of the folder itself. Default is False.
855
844
  :return: Image
856
845
  """
857
- new_image = self.clone(addl_layer=CopyConfig(path_type=1, src=src, dst=dst))
846
+ src_name = src.name
847
+ if copy_contents_only:
848
+ src_name = "."
849
+ else:
850
+ dst = str("./" + src_name)
851
+ new_image = self.clone(addl_layer=CopyConfig(path_type=1, src=src, dst=dst, src_name=src_name))
858
852
  return new_image
859
853
 
860
854
  def with_source_file(self, src: Path, dst: str = ".") -> Image:
@@ -866,7 +860,7 @@ class Image:
866
860
  :param dst: destination folder in the image
867
861
  :return: Image
868
862
  """
869
- new_image = self.clone(addl_layer=CopyConfig(path_type=0, src=src, dst=dst))
863
+ new_image = self.clone(addl_layer=CopyConfig(path_type=0, src=src, dst=dst, src_name=src.name))
870
864
  return new_image
871
865
 
872
866
  def with_dockerignore(self, path: Path) -> Image:
@@ -919,6 +913,45 @@ class Image:
919
913
  )
920
914
  return new_image
921
915
 
916
+ def with_poetry_project(
917
+ self,
918
+ pyproject_file: str | Path,
919
+ poetry_lock: Path | None = None,
920
+ extra_args: Optional[str] = None,
921
+ secret_mounts: Optional[SecretRequest] = None,
922
+ ):
923
+ """
924
+ Use this method to create a new image with the specified pyproject.toml layered on top of the current image.
925
+ Must have a corresponding pyproject.toml file in the same directory.
926
+ Cannot be used in conjunction with conda.
927
+
928
+ By default, this method copies the entire project into the image,
929
+ including files such as pyproject.toml, poetry.lock, and the src/ directory.
930
+
931
+ If you prefer not to install the current project, you can pass through `extra_args`
932
+ `--no-root`. In this case, the image builder will only copy pyproject.toml and poetry.lock
933
+ into the image.
934
+
935
+ :param pyproject_file: Path to the pyproject.toml file. A poetry.lock file must exist in the same directory
936
+ unless `poetry_lock` is explicitly provided.
937
+ :param poetry_lock: Path to the poetry.lock file. If not specified, the default is the file named
938
+ 'poetry.lock' in the same directory as `pyproject_file` (pyproject.parent / "poetry.lock").
939
+ :param extra_args: Extra arguments to pass through to the package installer/resolver, default is None.
940
+ :param secret_mounts: Secrets to make available during dependency resolution/build (e.g., private indexes).
941
+ :return: Image
942
+ """
943
+ if isinstance(pyproject_file, str):
944
+ pyproject_file = Path(pyproject_file)
945
+ new_image = self.clone(
946
+ addl_layer=PoetryProject(
947
+ pyproject=pyproject_file,
948
+ poetry_lock=poetry_lock or (pyproject_file.parent / "poetry.lock"),
949
+ extra_args=extra_args,
950
+ secret_mounts=_ensure_tuple(secret_mounts) if secret_mounts else None,
951
+ )
952
+ )
953
+ return new_image
954
+
922
955
  def with_apt_packages(self, *packages: str, secret_mounts: Optional[SecretRequest] = None) -> Image:
923
956
  """
924
957
  Use this method to create a new image with the specified apt packages layered on top of the current image
flyte/_initialize.py CHANGED
@@ -33,6 +33,7 @@ class CommonInit:
33
33
  project: str | None = None
34
34
  domain: str | None = None
35
35
  batch_size: int = 1000
36
+ source_config_path: Optional[Path] = None # Only used for documentation
36
37
 
37
38
 
38
39
  @dataclass(init=True, kw_only=True, repr=True, eq=True, frozen=True)
@@ -140,6 +141,7 @@ async def init(
140
141
  storage: Storage | None = None,
141
142
  batch_size: int = 1000,
142
143
  image_builder: ImageBuildEngine.ImageBuilderType = "local",
144
+ source_config_path: Optional[Path] = None,
143
145
  ) -> None:
144
146
  """
145
147
  Initialize the Flyte system with the given configuration. This method should be called before any other Flyte
@@ -175,11 +177,33 @@ async def init(
175
177
  :param batch_size: Optional batch size for operations that use listings, defaults to 1000, so limit larger than
176
178
  batch_size will be split into multiple requests.
177
179
  :param image_builder: Optional image builder configuration, if not provided, the default image builder will be used.
178
-
180
+ :param source_config_path: Optional path to the source configuration file (This is only used for documentation)
179
181
  :return: None
180
182
  """
181
183
  from flyte._utils import get_cwd_editable_install, org_from_endpoint, sanitize_endpoint
182
184
 
185
+ if endpoint or api_key:
186
+ if project is None:
187
+ raise ValueError(
188
+ "Project must be provided to initialize the client. "
189
+ "Please set 'project' in the 'task' section of your config file, "
190
+ "or pass it directly to flyte.init(project='your-project-name')."
191
+ )
192
+
193
+ if domain is None:
194
+ raise ValueError(
195
+ "Domain must be provided to initialize the client. "
196
+ "Please set 'domain' in the 'task' section of your config file, "
197
+ "or pass it directly to flyte.init(domain='your-domain-name')."
198
+ )
199
+
200
+ if org is None and org_from_endpoint(endpoint) is None:
201
+ raise ValueError(
202
+ "Organization must be provided to initialize the client. "
203
+ "Please set 'org' in the 'task' section of your config file, "
204
+ "or pass it directly to flyte.init(org='your-org-name')."
205
+ )
206
+
183
207
  _initialize_logger(log_level=log_level)
184
208
 
185
209
  global _init_config # noqa: PLW0603
@@ -224,6 +248,7 @@ async def init(
224
248
  org=org or org_from_endpoint(endpoint),
225
249
  batch_size=batch_size,
226
250
  image_builder=image_builder,
251
+ source_config_path=source_config_path,
227
252
  )
228
253
 
229
254
 
@@ -245,6 +270,7 @@ async def init_from_config(
245
270
  if not available, the current working directory.
246
271
  :param log_level: Optional logging level for the framework logger,
247
272
  default is set using the default initialization policies
273
+ :param storage: Optional blob store (S3, GCS, Azure) configuration if needed to access (i.e. using Minio)
248
274
  :return: None
249
275
  """
250
276
  from rich.highlighter import ReprHighlighter
@@ -252,6 +278,7 @@ async def init_from_config(
252
278
  import flyte.config as config
253
279
 
254
280
  cfg: config.Config
281
+ cfg_path: Optional[Path] = None
255
282
  if path_or_config is None:
256
283
  # If no path is provided, use the default config file
257
284
  cfg = config.auto()
@@ -273,6 +300,7 @@ async def init_from_config(
273
300
  _initialize_logger(log_level=log_level)
274
301
 
275
302
  logger.info(f"Flyte config initialized as {cfg}", extra={"highlighter": ReprHighlighter()})
303
+
276
304
  await init.aio(
277
305
  org=cfg.task.org,
278
306
  project=cfg.task.project,
@@ -290,6 +318,7 @@ async def init_from_config(
290
318
  log_level=log_level,
291
319
  image_builder=cfg.image.builder,
292
320
  storage=storage,
321
+ source_config_path=cfg_path,
293
322
  )
294
323
 
295
324
 
flyte/_interface.py CHANGED
@@ -7,6 +7,8 @@ from typing import Dict, Generator, Literal, Tuple, Type, TypeVar, Union, cast,
7
7
 
8
8
  from flyte._logging import logger
9
9
 
10
+ LITERAL_ENUM = "LiteralEnum"
11
+
10
12
 
11
13
  def default_output_name(index: int = 0) -> str:
12
14
  return f"o{index}"
@@ -110,6 +112,6 @@ def literal_to_enum(literal_type: Type) -> Type[Enum | typing.Any]:
110
112
  enum_dict = {str(v).upper(): v for v in values}
111
113
 
112
114
  # Dynamically create an Enum
113
- literal_enum = Enum("LiteralEnum", enum_dict) # type: ignore
115
+ literal_enum = Enum(LITERAL_ENUM, enum_dict) # type: ignore
114
116
 
115
117
  return literal_enum # type: ignore
@@ -2,16 +2,20 @@ import asyncio
2
2
  import atexit
3
3
  import concurrent.futures
4
4
  import os
5
+ import pathlib
5
6
  import threading
6
7
  from typing import Any, Callable, Tuple, TypeVar
7
8
 
8
9
  import flyte.errors
10
+ from flyte._cache.cache import VersionParameters, cache_from_request
11
+ from flyte._cache.local_cache import LocalTaskCache
9
12
  from flyte._context import internal_ctx
10
13
  from flyte._internal.controllers import TraceInfo
11
14
  from flyte._internal.runtime import convert
12
15
  from flyte._internal.runtime.entrypoints import direct_dispatch
16
+ from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
13
17
  from flyte._logging import log, logger
14
- from flyte._task import TaskTemplate
18
+ from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
15
19
  from flyte._utils.helpers import _selector_policy
16
20
  from flyte.models import ActionID, NativeInterface
17
21
  from flyte.remote._task import TaskDetails
@@ -81,31 +85,67 @@ class LocalController:
81
85
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
82
86
 
83
87
  inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
84
- serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
88
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
89
+ task_interface = transform_native_to_typed_interface(_task.interface)
85
90
 
86
91
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
87
- tctx, _task.name, serialized_inputs, 0
92
+ tctx, _task.name, inputs_hash, 0
88
93
  )
89
94
  sub_action_raw_data_path = tctx.raw_data_path
90
-
91
- out, err = await direct_dispatch(
92
- _task,
93
- controller=self,
94
- action=sub_action_id,
95
- raw_data_path=sub_action_raw_data_path,
96
- inputs=inputs,
97
- version=tctx.version,
98
- checkpoints=tctx.checkpoints,
99
- code_bundle=tctx.code_bundle,
100
- output_path=sub_action_output_path,
101
- run_base_dir=tctx.run_base_dir,
95
+ # Make sure the output path exists
96
+ pathlib.Path(sub_action_output_path).mkdir(parents=True, exist_ok=True)
97
+ pathlib.Path(sub_action_raw_data_path.path).mkdir(parents=True, exist_ok=True)
98
+
99
+ task_cache = cache_from_request(_task.cache)
100
+ cache_enabled = task_cache.is_enabled()
101
+ if isinstance(_task, AsyncFunctionTaskTemplate):
102
+ version_parameters = VersionParameters(func=_task.func, image=_task.image)
103
+ else:
104
+ version_parameters = VersionParameters(func=None, image=_task.image)
105
+ cache_version = task_cache.get_version(version_parameters)
106
+ cache_key = convert.generate_cache_key_hash(
107
+ _task.name,
108
+ inputs_hash,
109
+ task_interface,
110
+ cache_version,
111
+ list(task_cache.get_ignored_inputs()),
112
+ inputs.proto_inputs,
102
113
  )
103
- if err:
104
- exc = convert.convert_error_to_native(err)
105
- if exc:
106
- raise exc
107
- else:
108
- raise flyte.errors.RuntimeSystemError("BadError", "Unknown error")
114
+
115
+ out = None
116
+ # We only get output from cache if the cache behavior is set to auto
117
+ if task_cache.behavior == "auto":
118
+ out = await LocalTaskCache.get(cache_key)
119
+ if out is not None:
120
+ logger.info(
121
+ f"Cache hit for task '{_task.name}' (version: {cache_version}), getting result from cache..."
122
+ )
123
+
124
+ if out is None:
125
+ out, err = await direct_dispatch(
126
+ _task,
127
+ controller=self,
128
+ action=sub_action_id,
129
+ raw_data_path=sub_action_raw_data_path,
130
+ inputs=inputs,
131
+ version=cache_version,
132
+ checkpoints=tctx.checkpoints,
133
+ code_bundle=tctx.code_bundle,
134
+ output_path=sub_action_output_path,
135
+ run_base_dir=tctx.run_base_dir,
136
+ )
137
+
138
+ if err:
139
+ exc = convert.convert_error_to_native(err)
140
+ if exc:
141
+ raise exc
142
+ else:
143
+ raise flyte.errors.RuntimeSystemError("BadError", "Unknown error")
144
+
145
+ # store into cache
146
+ if cache_enabled and out is not None:
147
+ await LocalTaskCache.set(cache_key, out)
148
+
109
149
  if _task.native_interface.outputs:
110
150
  if out is None:
111
151
  raise flyte.errors.RuntimeSystemError("BadOutput", "Task output not captured.")
@@ -129,7 +169,7 @@ class LocalController:
129
169
  pass
130
170
 
131
171
  async def stop(self):
132
- pass
172
+ await LocalTaskCache.close()
133
173
 
134
174
  async def watch_for_errors(self):
135
175
  pass
@@ -151,11 +191,11 @@ class LocalController:
151
191
  converted_inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
152
192
  assert converted_inputs
153
193
 
154
- serialized_inputs = converted_inputs.proto_inputs.SerializeToString(deterministic=True)
194
+ inputs_hash = convert.generate_inputs_hash_from_proto(converted_inputs.proto_inputs)
155
195
  action_id, action_output_path = convert.generate_sub_action_id_and_output_path(
156
196
  tctx,
157
197
  _func.__name__,
158
- serialized_inputs,
198
+ inputs_hash,
159
199
  0,
160
200
  )
161
201
  assert action_output_path
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Literal
4
+ from typing import Literal, Optional
5
5
 
6
6
  from flyteidl.core import execution_pb2, interface_pb2
7
7
  from google.protobuf import timestamp_pb2
@@ -39,6 +39,7 @@ class Action:
39
39
  phase: run_definition_pb2.Phase | None = None
40
40
  started: bool = False
41
41
  retries: int = 0
42
+ queue: Optional[str] = None # The queue to which this action was submitted.
42
43
  client_err: Exception | None = None # This error is set when something goes wrong in the controller.
43
44
  cache_key: str | None = None # None means no caching, otherwise it is the version of the cache.
44
45
 
@@ -122,6 +123,7 @@ class Action:
122
123
  inputs_uri: str,
123
124
  run_output_base: str,
124
125
  cache_key: str | None = None,
126
+ queue: Optional[str] = None,
125
127
  ) -> Action:
126
128
  return cls(
127
129
  action_id=sub_action_id,
@@ -132,6 +134,7 @@ class Action:
132
134
  inputs_uri=inputs_uri,
133
135
  run_output_base=run_output_base,
134
136
  cache_key=cache_key,
137
+ queue=queue,
135
138
  )
136
139
 
137
140
  @classmethod
@@ -126,7 +126,7 @@ class RemoteController(Controller):
126
126
  workers=workers,
127
127
  max_system_retries=max_system_retries,
128
128
  )
129
- default_parent_concurrency = int(os.getenv("_F_P_CNC", "100"))
129
+ default_parent_concurrency = int(os.getenv("_F_P_CNC", "1000"))
130
130
  self._default_parent_concurrency = default_parent_concurrency
131
131
  self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
132
132
  lambda: asyncio.Semaphore(default_parent_concurrency)
@@ -238,6 +238,7 @@ class RemoteController(Controller):
238
238
  inputs_uri=inputs_uri,
239
239
  run_output_base=tctx.run_base_dir,
240
240
  cache_key=cache_key,
241
+ queue=_task.queue,
241
242
  )
242
243
 
243
244
  try:
@@ -377,9 +378,10 @@ class RemoteController(Controller):
377
378
  invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id)
378
379
  inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
379
380
  serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
381
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
380
382
 
381
383
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
382
- tctx, func_name, serialized_inputs, invoke_seq_num
384
+ tctx, func_name, inputs_hash, invoke_seq_num
383
385
  )
384
386
 
385
387
  inputs_uri = io.inputs_path(sub_action_output_path)
@@ -539,6 +541,7 @@ class RemoteController(Controller):
539
541
  inputs_uri=inputs_uri,
540
542
  run_output_base=tctx.run_base_dir,
541
543
  cache_key=cache_key,
544
+ queue=None,
542
545
  )
543
546
 
544
547
  try:
@@ -118,13 +118,14 @@ class Controller:
118
118
  raise RuntimeError("Failure event not initialized")
119
119
  self._failure_event.set()
120
120
  except asyncio.CancelledError:
121
- pass
121
+ raise
122
122
 
123
123
  async def _bg_watch_for_errors(self):
124
124
  if self._failure_event is None:
125
125
  raise RuntimeError("Failure event not initialized")
126
126
  await self._failure_event.wait()
127
127
  logger.warning(f"Failure event received: {self._failure_event}, cleaning up informers and exiting.")
128
+ self._running = False
128
129
 
129
130
  async def watch_for_errors(self):
130
131
  """Watch for errors in the background thread"""
@@ -351,6 +352,7 @@ class Controller:
351
352
  ),
352
353
  spec=action.task,
353
354
  cache_key=cache_key,
355
+ cluster=action.queue,
354
356
  )
355
357
  elif action.type == "trace":
356
358
  trace = action.trace
@@ -440,10 +442,11 @@ class Controller:
440
442
  logger.warning(f"[{worker_id}] Retrying action {action.name} after backoff")
441
443
  await self._shared_queue.put(action)
442
444
  except Exception as e:
443
- logger.error(f"[{worker_id}] Error in controller loop: {e}")
445
+ logger.error(f"[{worker_id}] Error in controller loop for {action.name}: {e}")
444
446
  err = flyte.errors.RuntimeSystemError(
445
447
  code=type(e).__name__,
446
- message=f"Controller failed, system retries {action.retries} crossed threshold {self._max_retries}",
448
+ message=f"Controller failed, system retries {action.retries} / {self._max_retries} "
449
+ f"crossed threshold, for action {action.name}: {e}",
447
450
  worker=worker_id,
448
451
  )
449
452
  err.__cause__ = e
@@ -270,7 +270,7 @@ class Informer:
270
270
  logger.warning("Informer already running")
271
271
  return cast(asyncio.Task, self._watch_task)
272
272
  self._running = True
273
- self._watch_task = asyncio.create_task(self.watch())
273
+ self._watch_task = asyncio.create_task(self.watch(), name=f"InformerWatch-{self.parent_action_name}")
274
274
  await self.wait_for_cache_sync(timeout=timeout)
275
275
  return self._watch_task
276
276