flyte 2.0.0b23__py3-none-any.whl → 2.0.0b25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (162) hide show
  1. flyte/__init__.py +11 -2
  2. flyte/_cache/local_cache.py +4 -3
  3. flyte/_code_bundle/_utils.py +3 -3
  4. flyte/_code_bundle/bundle.py +12 -5
  5. flyte/_context.py +4 -1
  6. flyte/_custom_context.py +73 -0
  7. flyte/_deploy.py +31 -7
  8. flyte/_image.py +48 -16
  9. flyte/_initialize.py +69 -26
  10. flyte/_internal/controllers/_local_controller.py +1 -0
  11. flyte/_internal/controllers/_trace.py +1 -1
  12. flyte/_internal/controllers/remote/_action.py +9 -10
  13. flyte/_internal/controllers/remote/_client.py +1 -1
  14. flyte/_internal/controllers/remote/_controller.py +4 -2
  15. flyte/_internal/controllers/remote/_core.py +10 -13
  16. flyte/_internal/controllers/remote/_informer.py +3 -3
  17. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  18. flyte/_internal/imagebuild/docker_builder.py +45 -59
  19. flyte/_internal/imagebuild/remote_builder.py +51 -11
  20. flyte/_internal/imagebuild/utils.py +51 -3
  21. flyte/_internal/runtime/convert.py +39 -18
  22. flyte/_internal/runtime/io.py +8 -7
  23. flyte/_internal/runtime/resources_serde.py +20 -6
  24. flyte/_internal/runtime/reuse.py +1 -1
  25. flyte/_internal/runtime/task_serde.py +7 -10
  26. flyte/_internal/runtime/taskrunner.py +10 -1
  27. flyte/_internal/runtime/trigger_serde.py +13 -13
  28. flyte/_internal/runtime/types_serde.py +1 -1
  29. flyte/_keyring/file.py +2 -2
  30. flyte/_map.py +65 -13
  31. flyte/_pod.py +2 -2
  32. flyte/_resources.py +175 -31
  33. flyte/_run.py +37 -21
  34. flyte/_task.py +27 -6
  35. flyte/_task_environment.py +37 -10
  36. flyte/_utils/module_loader.py +2 -2
  37. flyte/_version.py +3 -3
  38. flyte/cli/_common.py +47 -5
  39. flyte/cli/_create.py +4 -0
  40. flyte/cli/_deploy.py +8 -0
  41. flyte/cli/_get.py +4 -0
  42. flyte/cli/_params.py +4 -4
  43. flyte/cli/_run.py +50 -7
  44. flyte/cli/_update.py +4 -3
  45. flyte/config/_config.py +2 -0
  46. flyte/config/_internal.py +1 -0
  47. flyte/config/_reader.py +3 -3
  48. flyte/errors.py +1 -1
  49. flyte/extend.py +4 -0
  50. flyte/extras/_container.py +6 -1
  51. flyte/git/_config.py +11 -9
  52. flyte/io/_dataframe/basic_dfs.py +1 -1
  53. flyte/io/_dataframe/dataframe.py +12 -8
  54. flyte/io/_dir.py +48 -15
  55. flyte/io/_file.py +48 -11
  56. flyte/models.py +12 -8
  57. flyte/remote/_action.py +18 -16
  58. flyte/remote/_client/_protocols.py +4 -3
  59. flyte/remote/_client/auth/_channel.py +1 -1
  60. flyte/remote/_client/controlplane.py +4 -8
  61. flyte/remote/_data.py +4 -3
  62. flyte/remote/_logs.py +3 -3
  63. flyte/remote/_run.py +5 -5
  64. flyte/remote/_secret.py +20 -13
  65. flyte/remote/_task.py +7 -8
  66. flyte/remote/_trigger.py +25 -27
  67. flyte/storage/_parallel_reader.py +274 -0
  68. flyte/storage/_storage.py +66 -2
  69. flyte/types/_interface.py +2 -2
  70. flyte/types/_pickle.py +1 -1
  71. flyte/types/_string_literals.py +8 -9
  72. flyte/types/_type_engine.py +25 -17
  73. flyte/types/_utils.py +1 -1
  74. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/METADATA +2 -1
  75. flyte-2.0.0b25.dist-info/RECORD +184 -0
  76. flyte/_protos/__init__.py +0 -0
  77. flyte/_protos/common/authorization_pb2.py +0 -66
  78. flyte/_protos/common/authorization_pb2.pyi +0 -108
  79. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  80. flyte/_protos/common/identifier_pb2.py +0 -117
  81. flyte/_protos/common/identifier_pb2.pyi +0 -142
  82. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  83. flyte/_protos/common/identity_pb2.py +0 -48
  84. flyte/_protos/common/identity_pb2.pyi +0 -72
  85. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  86. flyte/_protos/common/list_pb2.py +0 -36
  87. flyte/_protos/common/list_pb2.pyi +0 -71
  88. flyte/_protos/common/list_pb2_grpc.py +0 -4
  89. flyte/_protos/common/policy_pb2.py +0 -37
  90. flyte/_protos/common/policy_pb2.pyi +0 -27
  91. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  92. flyte/_protos/common/role_pb2.py +0 -37
  93. flyte/_protos/common/role_pb2.pyi +0 -53
  94. flyte/_protos/common/role_pb2_grpc.py +0 -4
  95. flyte/_protos/common/runtime_version_pb2.py +0 -28
  96. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  97. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  98. flyte/_protos/imagebuilder/definition_pb2.py +0 -60
  99. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -153
  100. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  101. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  102. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  103. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  104. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  105. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  106. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  107. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  108. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  109. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  110. flyte/_protos/secret/definition_pb2.py +0 -49
  111. flyte/_protos/secret/definition_pb2.pyi +0 -93
  112. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  113. flyte/_protos/secret/payload_pb2.py +0 -62
  114. flyte/_protos/secret/payload_pb2.pyi +0 -94
  115. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  116. flyte/_protos/secret/secret_pb2.py +0 -38
  117. flyte/_protos/secret/secret_pb2.pyi +0 -6
  118. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  119. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  120. flyte/_protos/workflow/common_pb2.py +0 -38
  121. flyte/_protos/workflow/common_pb2.pyi +0 -63
  122. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  123. flyte/_protos/workflow/environment_pb2.py +0 -29
  124. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  125. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  126. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  127. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  128. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  129. flyte/_protos/workflow/queue_service_pb2.py +0 -117
  130. flyte/_protos/workflow/queue_service_pb2.pyi +0 -182
  131. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -206
  132. flyte/_protos/workflow/run_definition_pb2.py +0 -123
  133. flyte/_protos/workflow/run_definition_pb2.pyi +0 -354
  134. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  135. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  136. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  137. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  138. flyte/_protos/workflow/run_service_pb2.py +0 -147
  139. flyte/_protos/workflow/run_service_pb2.pyi +0 -203
  140. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -480
  141. flyte/_protos/workflow/state_service_pb2.py +0 -67
  142. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  143. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  144. flyte/_protos/workflow/task_definition_pb2.py +0 -86
  145. flyte/_protos/workflow/task_definition_pb2.pyi +0 -105
  146. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  147. flyte/_protos/workflow/task_service_pb2.py +0 -61
  148. flyte/_protos/workflow/task_service_pb2.pyi +0 -62
  149. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  150. flyte/_protos/workflow/trigger_definition_pb2.py +0 -66
  151. flyte/_protos/workflow/trigger_definition_pb2.pyi +0 -117
  152. flyte/_protos/workflow/trigger_definition_pb2_grpc.py +0 -4
  153. flyte/_protos/workflow/trigger_service_pb2.py +0 -96
  154. flyte/_protos/workflow/trigger_service_pb2.pyi +0 -110
  155. flyte/_protos/workflow/trigger_service_pb2_grpc.py +0 -281
  156. flyte-2.0.0b23.dist-info/RECORD +0 -262
  157. {flyte-2.0.0b23.data → flyte-2.0.0b25.data}/scripts/debug.py +0 -0
  158. {flyte-2.0.0b23.data → flyte-2.0.0b25.data}/scripts/runtime.py +0 -0
  159. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/WHEEL +0 -0
  160. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/entry_points.txt +0 -0
  161. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/licenses/LICENSE +0 -0
  162. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ import typing
3
3
  from pathlib import Path
4
4
  from typing import TypeVar
5
5
 
6
- from flyteidl.core import literals_pb2, types_pb2
6
+ from flyteidl2.core import literals_pb2, types_pb2
7
7
  from fsspec.core import split_protocol, strip_protocol
8
8
 
9
9
  import flyte.storage as storage
@@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
8
8
  from dataclasses import is_dataclass
9
9
  from typing import Any, ClassVar, Coroutine, Dict, Generic, List, Optional, Type, Union
10
10
 
11
- from flyteidl.core import literals_pb2, types_pb2
11
+ from flyteidl2.core import literals_pb2, types_pb2
12
12
  from fsspec.utils import get_protocol
13
13
  from mashumaro.types import SerializableType
14
14
  from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_serializer, model_validator
@@ -647,17 +647,21 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
647
647
  f"Already registered a handler for {(h.python_type, protocol, h.supported_format)}"
648
648
  )
649
649
  lowest_level[h.supported_format] = h
650
- logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {protocol}, fmt {h.supported_format}")
650
+ logger.debug(
651
+ f"Registered {h.__class__.__name__} as handler for {h.python_type.__class__.__name__},"
652
+ f" protocol {protocol}, fmt {h.supported_format}"
653
+ )
651
654
 
652
655
  if (default_format_for_type or default_for_type) and h.supported_format != GENERIC_FORMAT:
653
656
  if h.python_type in cls.DEFAULT_FORMATS and not override:
654
657
  if cls.DEFAULT_FORMATS[h.python_type] != h.supported_format:
655
658
  logger.info(
656
- f"Not using handler {h} with format {h.supported_format}"
657
- f" as default for {h.python_type}, {cls.DEFAULT_FORMATS[h.python_type]} already specified."
659
+ f"Not using handler {h.__class__.__name__} with format {h.supported_format}"
660
+ f" as default for {h.python_type.__class__.__name__},"
661
+ f" {cls.DEFAULT_FORMATS[h.python_type]} already specified."
658
662
  )
659
663
  else:
660
- logger.debug(f"Use {type(h).__name__} as default handler for {h.python_type}.")
664
+ logger.debug(f"Use {type(h).__name__} as default handler for {h.python_type.__class__.__name__}.")
661
665
  cls.DEFAULT_FORMATS[h.python_type] = h.supported_format
662
666
  if default_storage_for_type or default_for_type:
663
667
  if h.protocol in cls.DEFAULT_PROTOCOLS and not override:
@@ -685,7 +689,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
685
689
  expected: types_pb2.LiteralType,
686
690
  ) -> literals_pb2.Literal:
687
691
  # Make a copy in case we need to hand off to encoders, since we can't be sure of mutations.
688
- python_type, *attrs = extract_cols_and_format(python_type)
692
+ python_type, *_attrs = extract_cols_and_format(python_type)
689
693
  sdt = types_pb2.StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT))
690
694
 
691
695
  if issubclass(python_type, DataFrame) and not isinstance(python_val, DataFrame):
@@ -876,7 +880,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
876
880
  raise TypeTransformerFailedError("Attribute access unsupported.")
877
881
 
878
882
  # Detect annotations and extract out all the relevant information that the user might supply
879
- expected_python_type, column_dict, storage_fmt, pa_schema = extract_cols_and_format(expected_python_type)
883
+ expected_python_type, column_dict, _storage_fmt, _pa_schema = extract_cols_and_format(expected_python_type)
880
884
 
881
885
  # Start handling for DataFrame scalars, first look at the columns
882
886
  incoming_columns = lv.scalar.structured_dataset.metadata.structured_dataset_type.columns
@@ -986,7 +990,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
986
990
  return converted_cols
987
991
 
988
992
  def _get_dataset_type(self, t: typing.Union[Type[DataFrame], typing.Any]) -> types_pb2.StructuredDatasetType:
989
- original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore
993
+ _original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore
990
994
 
991
995
  # Get the column information
992
996
  converted_cols: typing.List[types_pb2.StructuredDatasetType.DatasetColumn] = (
flyte/io/_dir.py CHANGED
@@ -4,12 +4,14 @@ import os
4
4
  from pathlib import Path
5
5
  from typing import AsyncIterator, Dict, Generic, Iterator, List, Optional, Type, TypeVar, Union
6
6
 
7
- from flyteidl.core import literals_pb2, types_pb2
7
+ from flyteidl2.core import literals_pb2, types_pb2
8
8
  from fsspec.asyn import AsyncFileSystem
9
+ from fsspec.utils import get_protocol
9
10
  from mashumaro.types import SerializableType
10
11
  from pydantic import BaseModel, model_validator
11
12
 
12
13
  import flyte.storage as storage
14
+ from flyte._context import internal_ctx
13
15
  from flyte.io._file import File
14
16
  from flyte.types import TypeEngine, TypeTransformer, TypeTransformerFailedError
15
17
 
@@ -462,7 +464,15 @@ class Dir(BaseModel, Generic[T], SerializableType):
462
464
  Returns:
463
465
  The absolute path to the downloaded directory
464
466
  """
465
- local_dest = str(local_path) if local_path else str(storage.get_random_local_path())
467
+ # If no local_path specified, create a unique path + append source directory name
468
+ if local_path is None:
469
+ unique_path = storage.get_random_local_path()
470
+ source_dirname = Path(self.path).name # will need to be updated for windows
471
+ local_dest = str(Path(unique_path) / source_dirname)
472
+ else:
473
+ # If local_path is specified, use it directly (contents go into it)
474
+ local_dest = str(local_path)
475
+
466
476
  if not storage.is_remote(self.path):
467
477
  if not local_path or local_path == self.path:
468
478
  # Skip copying
@@ -477,6 +487,7 @@ class Dir(BaseModel, Generic[T], SerializableType):
477
487
  await loop.run_in_executor(None, lambda: shutil.copytree(self.path, local_dest, dirs_exist_ok=True))
478
488
 
479
489
  await copy_tree()
490
+ return local_dest
480
491
  return await storage.get(self.path, local_dest, recursive=True)
481
492
 
482
493
  def download_sync(self, local_path: Optional[Union[str, Path]] = None) -> str:
@@ -510,7 +521,15 @@ class Dir(BaseModel, Generic[T], SerializableType):
510
521
  Returns:
511
522
  The absolute path to the downloaded directory
512
523
  """
513
- local_dest = str(local_path) if local_path else str(storage.get_random_local_path())
524
+ # If no local_path specified, create a unique path + append source directory name
525
+ if local_path is None:
526
+ unique_path = storage.get_random_local_path()
527
+ source_dirname = Path(self.path).name
528
+ local_dest = str(Path(unique_path) / source_dirname)
529
+ else:
530
+ # If local_path is specified, use it directly (contents go into it)
531
+ local_dest = str(local_path)
532
+
514
533
  if not storage.is_remote(self.path):
515
534
  if not local_path or local_path == self.path:
516
535
  # Skip copying
@@ -520,6 +539,7 @@ class Dir(BaseModel, Generic[T], SerializableType):
520
539
  import shutil
521
540
 
522
541
  shutil.copytree(self.path, local_dest, dirs_exist_ok=True)
542
+ return local_dest
523
543
 
524
544
  fs = storage.get_underlying_filesystem(path=self.path)
525
545
  fs.get(self.path, local_dest, recursive=True)
@@ -529,7 +549,7 @@ class Dir(BaseModel, Generic[T], SerializableType):
529
549
  async def from_local(
530
550
  cls,
531
551
  local_path: Union[str, Path],
532
- remote_path: Optional[str] = None,
552
+ remote_destination: Optional[str] = None,
533
553
  dir_cache_key: Optional[str] = None,
534
554
  ) -> Dir[T]:
535
555
  """
@@ -571,7 +591,8 @@ class Dir(BaseModel, Generic[T], SerializableType):
571
591
  ```
572
592
  Args:
573
593
  local_path: Path to the local directory
574
- remote_path: Optional remote path to store the directory. If None, a path will be automatically generated.
594
+ remote_destination: Optional remote path to store the directory. If None, a path will be automatically
595
+ generated.
575
596
  dir_cache_key: Optional precomputed hash value to use for cache key computation when this Dir is used
576
597
  as an input to discoverable tasks. If not specified, the cache key will be based on
577
598
  directory attributes.
@@ -581,15 +602,23 @@ class Dir(BaseModel, Generic[T], SerializableType):
581
602
  """
582
603
  local_path_str = str(local_path)
583
604
  dirname = os.path.basename(os.path.normpath(local_path_str))
605
+ resolved_remote_path = remote_destination or internal_ctx().raw_data.get_random_remote_path(dirname)
606
+ protocol = get_protocol(resolved_remote_path)
607
+
608
+ # Shortcut for local, don't copy and just return
609
+ if "file" in protocol and remote_destination is None:
610
+ output_path = str(Path(local_path).absolute())
611
+ return cls(path=output_path, name=dirname, hash=dir_cache_key)
584
612
 
585
- output_path = await storage.put(from_path=local_path_str, to_path=remote_path, recursive=True)
613
+ # todo: in the future, mirror File and set the file to_path here
614
+ output_path = await storage.put(from_path=local_path_str, to_path=remote_destination, recursive=True)
586
615
  return cls(path=output_path, name=dirname, hash=dir_cache_key)
587
616
 
588
617
  @classmethod
589
618
  def from_local_sync(
590
619
  cls,
591
620
  local_path: Union[str, Path],
592
- remote_path: Optional[str] = None,
621
+ remote_destination: Optional[str] = None,
593
622
  dir_cache_key: Optional[str] = None,
594
623
  ) -> Dir[T]:
595
624
  """
@@ -632,7 +661,8 @@ class Dir(BaseModel, Generic[T], SerializableType):
632
661
 
633
662
  Args:
634
663
  local_path: Path to the local directory
635
- remote_path: Optional remote path to store the directory. If None, a path will be automatically generated.
664
+ remote_destination: Optional remote path to store the directory. If None, a path will be automatically
665
+ generated.
636
666
  dir_cache_key: Optional precomputed hash value to use for cache key computation when this Dir is used
637
667
  as an input to discoverable tasks. If not specified, the cache key will be based on
638
668
  directory attributes.
@@ -643,14 +673,17 @@ class Dir(BaseModel, Generic[T], SerializableType):
643
673
  local_path_str = str(local_path)
644
674
  dirname = os.path.basename(os.path.normpath(local_path_str))
645
675
 
646
- if not remote_path:
647
- from flyte._context import internal_ctx
676
+ resolved_remote_path = remote_destination or internal_ctx().raw_data.get_random_remote_path(dirname)
677
+ protocol = get_protocol(resolved_remote_path)
678
+
679
+ # Shortcut for local, don't copy and just return
680
+ if "file" in protocol and remote_destination is None:
681
+ output_path = str(Path(local_path).absolute())
682
+ return cls(path=output_path, name=dirname, hash=dir_cache_key)
648
683
 
649
- ctx = internal_ctx()
650
- remote_path = ctx.raw_data.get_random_remote_path(dirname)
651
- fs = storage.get_underlying_filesystem(path=remote_path)
652
- fs.put(local_path_str, remote_path, recursive=True)
653
- return cls(path=remote_path, name=dirname, hash=dir_cache_key)
684
+ fs = storage.get_underlying_filesystem(path=resolved_remote_path)
685
+ fs.put(local_path_str, resolved_remote_path, recursive=True)
686
+ return cls(path=resolved_remote_path, name=dirname, hash=dir_cache_key)
654
687
 
655
688
  @classmethod
656
689
  def from_existing_remote(cls, remote_path: str, dir_cache_key: Optional[str] = None) -> Dir[T]:
flyte/io/_file.py CHANGED
@@ -20,7 +20,7 @@ from typing import (
20
20
  )
21
21
 
22
22
  import aiofiles
23
- from flyteidl.core import literals_pb2, types_pb2
23
+ from flyteidl2.core import literals_pb2, types_pb2
24
24
  from fsspec.utils import get_protocol
25
25
  from mashumaro.types import SerializableType
26
26
  from pydantic import BaseModel, Field, model_validator
@@ -33,6 +33,9 @@ from flyte._initialize import requires_initialization
33
33
  from flyte.io._hashing_io import AsyncHashingReader, HashingWriter, HashMethod, PrecomputedValue
34
34
  from flyte.types import TypeEngine, TypeTransformer, TypeTransformerFailedError
35
35
 
36
+ if typing.TYPE_CHECKING:
37
+ from obstore import AsyncReadableFile, AsyncWritableFile
38
+
36
39
  if typing.TYPE_CHECKING:
37
40
  from obstore import AsyncReadableFile, AsyncWritableFile
38
41
 
@@ -524,23 +527,40 @@ class File(BaseModel, Generic[T], SerializableType):
524
527
  The absolute path to the downloaded file
525
528
  """
526
529
  if local_path is None:
527
- local_path = storage.get_random_local_path(file_path_or_file_name=local_path)
530
+ local_path = storage.get_random_local_path(file_path_or_file_name=self.path)
528
531
  else:
532
+ # Preserve trailing separator if present (Path.absolute() strips it)
533
+ local_path_str = str(local_path)
534
+ has_trailing_sep = local_path_str.endswith(os.sep)
529
535
  local_path = str(Path(local_path).absolute())
536
+ if has_trailing_sep:
537
+ local_path = local_path + os.sep
530
538
 
531
539
  fs = storage.get_underlying_filesystem(path=self.path)
532
540
 
533
541
  # If it's already a local file, just copy it
534
542
  if "file" in fs.protocol:
543
+ # Apply directory logic for local-to-local copies
544
+ local_path_for_copy = local_path
545
+ if isinstance(local_path, str):
546
+ local_path_obj = Path(local_path)
547
+ # Check if it's a directory or ends with separator
548
+ if local_path.endswith(os.sep) or (local_path_obj.exists() and local_path_obj.is_dir()):
549
+ remote_filename = Path(self.path).name
550
+ local_path_for_copy = str(local_path_obj / remote_filename)
551
+
552
+ # Ensure parent directory exists
553
+ Path(local_path_for_copy).parent.mkdir(parents=True, exist_ok=True)
554
+
535
555
  # Use aiofiles for async copy
536
556
  async with aiofiles.open(self.path, "rb") as src:
537
- async with aiofiles.open(local_path, "wb") as dst:
557
+ async with aiofiles.open(local_path_for_copy, "wb") as dst:
538
558
  await dst.write(await src.read())
539
- return str(local_path)
559
+ return str(local_path_for_copy)
540
560
 
541
561
  # Otherwise download from remote using async functionality
542
- await storage.get(self.path, str(local_path))
543
- return str(local_path)
562
+ result_path = await storage.get(self.path, str(local_path))
563
+ return result_path
544
564
 
545
565
  def download_sync(self, local_path: Optional[Union[str, Path]] = None) -> str:
546
566
  """
@@ -576,19 +596,36 @@ class File(BaseModel, Generic[T], SerializableType):
576
596
  The absolute path to the downloaded file
577
597
  """
578
598
  if local_path is None:
579
- local_path = storage.get_random_local_path(file_path_or_file_name=local_path)
599
+ local_path = storage.get_random_local_path(file_path_or_file_name=self.path)
580
600
  else:
601
+ # Preserve trailing separator if present (Path.absolute() strips it)
602
+ local_path_str = str(local_path)
603
+ has_trailing_sep = local_path_str.endswith(os.sep)
581
604
  local_path = str(Path(local_path).absolute())
605
+ if has_trailing_sep:
606
+ local_path = local_path + os.sep
582
607
 
583
608
  fs = storage.get_underlying_filesystem(path=self.path)
584
609
 
585
610
  # If it's already a local file, just copy it
586
611
  if "file" in fs.protocol:
612
+ # Apply directory logic for local-to-local copies
613
+ local_path_for_copy = local_path
614
+ if isinstance(local_path, str):
615
+ local_path_obj = Path(local_path)
616
+ # Check if it's a directory or ends with separator
617
+ if local_path.endswith(os.sep) or (local_path_obj.exists() and local_path_obj.is_dir()):
618
+ remote_filename = Path(self.path).name
619
+ local_path_for_copy = str(local_path_obj / remote_filename)
620
+
621
+ # Ensure parent directory exists
622
+ Path(local_path_for_copy).parent.mkdir(parents=True, exist_ok=True)
623
+
587
624
  # Use standard file operations for sync copy
588
625
  import shutil
589
626
 
590
- shutil.copy2(self.path, local_path)
591
- return str(local_path)
627
+ shutil.copy2(self.path, local_path_for_copy)
628
+ return str(local_path_for_copy)
592
629
 
593
630
  # Otherwise download from remote using sync functionality
594
631
  # Use the sync version of storage operations
@@ -756,9 +793,9 @@ class File(BaseModel, Generic[T], SerializableType):
756
793
  if not os.path.exists(local_path):
757
794
  raise ValueError(f"File not found: {local_path}")
758
795
 
759
- remote_path = remote_destination or internal_ctx().raw_data.get_random_remote_path()
760
- protocol = get_protocol(remote_path)
761
796
  filename = Path(local_path).name
797
+ remote_path = remote_destination or internal_ctx().raw_data.get_random_remote_path(filename)
798
+ protocol = get_protocol(remote_path)
762
799
 
763
800
  # If remote_destination was not set by the user, and the configured raw data path is also local,
764
801
  # then let's optimize by not uploading.
flyte/models.py CHANGED
@@ -14,7 +14,7 @@ from flyte._interface import extract_return_annotation, literal_to_enum
14
14
  from flyte._logging import logger
15
15
 
16
16
  if TYPE_CHECKING:
17
- from flyteidl.core import literals_pb2
17
+ from flyteidl2.core import literals_pb2
18
18
 
19
19
  from flyte._internal.imagebuild.image_builder import ImageCache
20
20
  from flyte.report import Report
@@ -144,7 +144,7 @@ class RawDataPath:
144
144
 
145
145
  def get_random_remote_path(self, file_name: Optional[str] = None) -> str:
146
146
  """
147
- Returns a random path for uploading a file/directory to.
147
+ Returns a random path for uploading a file/directory to. This file/folder will not be created, it's just a path.
148
148
 
149
149
  :param file_name: If given, will be joined after a randomly generated portion.
150
150
  :return:
@@ -160,13 +160,14 @@ class RawDataPath:
160
160
 
161
161
  protocol = get_protocol(file_prefix)
162
162
  if "file" in protocol:
163
- local_path = pathlib.Path(file_prefix) / random_string
163
+ parent_folder = pathlib.Path(file_prefix)
164
+ parent_folder.mkdir(exist_ok=True, parents=True)
164
165
  if file_name:
165
- # Only if file name is given do we create the parent, because it may be needed as a folder otherwise
166
- local_path = local_path / file_name
167
- if not local_path.exists():
168
- local_path.parent.mkdir(exist_ok=True, parents=True)
169
- local_path.touch()
166
+ random_folder = parent_folder / random_string
167
+ random_folder.mkdir()
168
+ local_path = random_folder / file_name
169
+ else:
170
+ local_path = parent_folder / random_string
170
171
  return str(local_path.absolute())
171
172
 
172
173
  fs = fsspec.filesystem(protocol)
@@ -194,6 +195,8 @@ class TaskContext:
194
195
  :param action: The action ID of the current execution. This is always set, within a run.
195
196
  :param version: The version of the executed task. This is set when the task is executed by an action and will be
196
197
  set on all sub-actions.
198
+ :param custom_context: Context metadata for the action. If an action receives context, it'll automatically pass it
199
+ to any actions it spawns. Context will not be used for cache key computation.
197
200
  """
198
201
 
199
202
  action: ActionID
@@ -210,6 +213,7 @@ class TaskContext:
210
213
  data: Dict[str, Any] = field(default_factory=dict)
211
214
  mode: Literal["local", "remote", "hybrid"] = "remote"
212
215
  interactive_mode: bool = False
216
+ custom_context: Dict[str, str] = field(default_factory=dict)
213
217
 
214
218
  def replace(self, **kwargs) -> TaskContext:
215
219
  if "data" in kwargs:
flyte/remote/_action.py CHANGED
@@ -20,14 +20,15 @@ from typing import (
20
20
  import grpc
21
21
  import rich.pretty
22
22
  import rich.repr
23
+ from flyteidl2.common import identifier_pb2, list_pb2
24
+ from flyteidl2.task import common_pb2
25
+ from flyteidl2.workflow import run_definition_pb2, run_service_pb2
26
+ from flyteidl2.workflow.run_service_pb2 import WatchActionDetailsResponse
23
27
  from rich.console import Console
24
28
  from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
25
29
 
26
30
  from flyte import types
27
- from flyte._initialize import ensure_client, get_client, get_common_config
28
- from flyte._protos.common import identifier_pb2, list_pb2
29
- from flyte._protos.workflow import run_definition_pb2, run_service_pb2
30
- from flyte._protos.workflow.run_service_pb2 import WatchActionDetailsResponse
31
+ from flyte._initialize import ensure_client, get_client, get_init_config
31
32
  from flyte.remote._common import ToJSONMixin
32
33
  from flyte.remote._logs import Logs
33
34
  from flyte.syncify import syncify
@@ -67,14 +68,14 @@ def _action_rich_repr(action: run_definition_pb2.Action) -> rich.repr.Result:
67
68
  """
68
69
  Rich representation of the action.
69
70
  """
70
- yield "run", action.id.run.name
71
+ yield "name", action.id.run.name
71
72
  if action.metadata.HasField("task"):
72
- yield "task", action.metadata.task.id.name
73
+ yield "task name", action.metadata.task.id.name
73
74
  yield "type", action.metadata.task.task_type
74
75
  elif action.metadata.HasField("trace"):
75
76
  yield "trace", action.metadata.trace.name
76
77
  yield "type", "trace"
77
- yield "name", action.id.name
78
+ yield "action name", action.id.name
78
79
  yield from _action_time_phase(action)
79
80
  yield "group", action.metadata.group
80
81
  yield "parent", action.metadata.parent
@@ -98,9 +99,10 @@ def _action_details_rich_repr(
98
99
  """
99
100
  yield "name", action.id.run.name
100
101
  yield from _action_time_phase(action)
101
- yield "task", action.resolved_task_spec.task_template.id.name
102
- yield "task_type", action.resolved_task_spec.task_template.type
103
- yield "task_version", action.resolved_task_spec.task_template.id.version
102
+ if action.HasField("task"):
103
+ yield "task", action.task.task_template.id.name
104
+ yield "task_type", action.task.task_template.type
105
+ yield "task_version", action.task.task_template.id.version
104
106
  yield "attempts", action.attempts
105
107
  yield "error", (f"{action.error_info.kind}: {action.error_info.message}" if action.HasField("error_info") else "NA")
106
108
  yield "phase", run_definition_pb2.Phase.Name(action.status.phase)
@@ -152,7 +154,7 @@ class Action(ToJSONMixin):
152
154
  key=sort_by[0],
153
155
  direction=(list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING),
154
156
  )
155
- cfg = get_common_config()
157
+ cfg = get_init_config()
156
158
  while True:
157
159
  req = list_pb2.ListRequest(
158
160
  limit=100,
@@ -193,7 +195,7 @@ class Action(ToJSONMixin):
193
195
  :param name: The name of the action.
194
196
  """
195
197
  ensure_client()
196
- cfg = get_common_config()
198
+ cfg = get_init_config()
197
199
  details: ActionDetails = await ActionDetails.get_details.aio(
198
200
  identifier_pb2.ActionIdentifier(
199
201
  run=identifier_pb2.RunIdentifier(
@@ -458,7 +460,7 @@ class ActionDetails(ToJSONMixin):
458
460
  ensure_client()
459
461
  if not uri:
460
462
  assert name is not None and run_name is not None, "Either uri or name and run_name must be provided"
461
- cfg = get_common_config()
463
+ cfg = get_init_config()
462
464
  return await cls.get_details.aio(
463
465
  identifier_pb2.ActionIdentifier(
464
466
  run=identifier_pb2.RunIdentifier(
@@ -706,7 +708,7 @@ class ActionInputs(UserDict, ToJSONMixin):
706
708
  remote Union API.
707
709
  """
708
710
 
709
- pb2: run_definition_pb2.Inputs
711
+ pb2: common_pb2.Inputs
710
712
  data: Dict[str, Any]
711
713
 
712
714
  def __repr__(self):
@@ -723,14 +725,14 @@ class ActionOutputs(tuple, ToJSONMixin):
723
725
  remote Union API.
724
726
  """
725
727
 
726
- def __new__(cls, pb2: run_definition_pb2.Outputs, data: Tuple[Any, ...]):
728
+ def __new__(cls, pb2: common_pb2.Outputs, data: Tuple[Any, ...]):
727
729
  # Create the tuple part
728
730
  obj = super().__new__(cls, data)
729
731
  # Store extra data (you can't do this here directly since it's immutable)
730
732
  obj.pb2 = pb2
731
733
  return obj
732
734
 
733
- def __init__(self, pb2: run_definition_pb2.Outputs, data: Tuple[Any, ...]):
735
+ def __init__(self, pb2: common_pb2.Outputs, data: Tuple[Any, ...]):
734
736
  # Normally you'd set instance attributes here,
735
737
  # but we've already set `pb2` in `__new__`
736
738
  self.pb2 = pb2
@@ -2,12 +2,13 @@ from typing import AsyncIterator, Protocol
2
2
 
3
3
  from flyteidl.admin import project_attributes_pb2, project_pb2, version_pb2
4
4
  from flyteidl.service import dataproxy_pb2, identity_pb2
5
+ from flyteidl2.secret import payload_pb2
6
+ from flyteidl2.task import task_service_pb2
7
+ from flyteidl2.trigger import trigger_service_pb2
8
+ from flyteidl2.workflow import run_logs_service_pb2, run_service_pb2
5
9
  from grpc.aio import UnaryStreamCall
6
10
  from grpc.aio._typing import RequestType
7
11
 
8
- from flyte._protos.secret import payload_pb2
9
- from flyte._protos.workflow import run_logs_service_pb2, run_service_pb2, task_service_pb2, trigger_service_pb2
10
-
11
12
 
12
13
  class MetadataServiceProtocol(Protocol):
13
14
  async def GetVersion(self, request: version_pb2.GetVersionRequest) -> version_pb2.GetVersionResponse: ...
@@ -112,7 +112,7 @@ async def create_channel(
112
112
  if api_key:
113
113
  from flyte.remote._client.auth._auth_utils import decode_api_key
114
114
 
115
- endpoint, client_id, client_secret, org = decode_api_key(api_key)
115
+ endpoint, client_id, client_secret, _org = decode_api_key(api_key)
116
116
  kwargs["auth_type"] = "ClientSecret"
117
117
  kwargs["client_id"] = client_id
118
118
  kwargs["client_secret"] = client_secret
@@ -16,14 +16,10 @@ if "GRPC_VERBOSITY" not in os.environ:
16
16
 
17
17
  import grpc
18
18
  from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc, identity_pb2_grpc
19
-
20
- from flyte._protos.secret import secret_pb2_grpc
21
- from flyte._protos.workflow import (
22
- run_logs_service_pb2_grpc,
23
- run_service_pb2_grpc,
24
- task_service_pb2_grpc,
25
- trigger_service_pb2_grpc,
26
- )
19
+ from flyteidl2.secret import secret_pb2_grpc
20
+ from flyteidl2.task import task_service_pb2_grpc
21
+ from flyteidl2.trigger import trigger_service_pb2_grpc
22
+ from flyteidl2.workflow import run_logs_service_pb2_grpc, run_service_pb2_grpc
27
23
 
28
24
  from ._protocols import (
29
25
  DataProxyService,
flyte/remote/_data.py CHANGED
@@ -15,7 +15,7 @@ import httpx
15
15
  from flyteidl.service import dataproxy_pb2
16
16
  from google.protobuf import duration_pb2
17
17
 
18
- from flyte._initialize import CommonInit, ensure_client, get_client, get_common_config
18
+ from flyte._initialize import CommonInit, ensure_client, get_client, get_init_config, require_project_and_domain
19
19
  from flyte.errors import InitializationError, RuntimeSystemError
20
20
  from flyte.syncify import syncify
21
21
 
@@ -54,6 +54,7 @@ def hash_file(file_path: typing.Union[os.PathLike, str]) -> Tuple[bytes, str, in
54
54
  return h.digest(), h.hexdigest(), size
55
55
 
56
56
 
57
+ @require_project_and_domain
57
58
  async def _upload_single_file(
58
59
  cfg: CommonInit, fp: Path, verify: bool = True, basedir: str | None = None
59
60
  ) -> Tuple[str, str]:
@@ -124,7 +125,7 @@ async def upload_file(fp: Path, verify: bool = True) -> Tuple[str, str]:
124
125
  """
125
126
  # This is a placeholder implementation. Replace with actual upload logic.
126
127
  ensure_client()
127
- cfg = get_common_config()
128
+ cfg = get_init_config()
128
129
  if not fp.is_file():
129
130
  raise ValueError(f"{fp} is not a single file, upload arg must be a single file.")
130
131
  return await _upload_single_file(cfg, fp, verify=verify)
@@ -140,7 +141,7 @@ async def upload_dir(dir_path: Path, verify: bool = True) -> str:
140
141
  """
141
142
  # This is a placeholder implementation. Replace with actual upload logic.
142
143
  ensure_client()
143
- cfg = get_common_config()
144
+ cfg = get_init_config()
144
145
  if not dir_path.is_dir():
145
146
  raise ValueError(f"{dir_path} is not a directory, upload arg must be a directory.")
146
147
 
flyte/remote/_logs.py CHANGED
@@ -4,6 +4,9 @@ from dataclasses import dataclass
4
4
  from typing import AsyncGenerator, AsyncIterator
5
5
 
6
6
  import grpc
7
+ from flyteidl2.common import identifier_pb2
8
+ from flyteidl2.logs.dataplane import payload_pb2
9
+ from flyteidl2.workflow import run_logs_service_pb2
7
10
  from rich.console import Console
8
11
  from rich.live import Live
9
12
  from rich.panel import Panel
@@ -11,9 +14,6 @@ from rich.text import Text
11
14
 
12
15
  from flyte._initialize import ensure_client, get_client
13
16
  from flyte._logging import logger
14
- from flyte._protos.common import identifier_pb2
15
- from flyte._protos.logs.dataplane import payload_pb2
16
- from flyte._protos.workflow import run_logs_service_pb2
17
17
  from flyte._tools import ipython_check, ipywidgets_check
18
18
  from flyte.errors import LogsNotYetAvailableError
19
19
  from flyte.syncify import syncify
flyte/remote/_run.py CHANGED
@@ -5,11 +5,11 @@ from typing import AsyncGenerator, AsyncIterator, Literal, Tuple
5
5
 
6
6
  import grpc
7
7
  import rich.repr
8
+ from flyteidl2.common import identifier_pb2, list_pb2
9
+ from flyteidl2.workflow import run_definition_pb2, run_service_pb2
8
10
 
9
- from flyte._initialize import ensure_client, get_client, get_common_config
11
+ from flyte._initialize import ensure_client, get_client, get_init_config
10
12
  from flyte._logging import logger
11
- from flyte._protos.common import identifier_pb2, list_pb2
12
- from flyte._protos.workflow import run_definition_pb2, run_service_pb2
13
13
  from flyte.syncify import syncify
14
14
 
15
15
  from . import Action, ActionDetails, ActionInputs, ActionOutputs
@@ -97,7 +97,7 @@ class Run(ToJSONMixin):
97
97
  ),
98
98
  )
99
99
 
100
- cfg = get_common_config()
100
+ cfg = get_init_config()
101
101
  i = 0
102
102
  while True:
103
103
  req = list_pb2.ListRequest(
@@ -316,7 +316,7 @@ class RunDetails(ToJSONMixin):
316
316
  :param name: The name of the run.
317
317
  """
318
318
  ensure_client()
319
- cfg = get_common_config()
319
+ cfg = get_init_config()
320
320
  return await RunDetails.get_details.aio(
321
321
  run_id=identifier_pb2.RunIdentifier(
322
322
  org=cfg.org,