flyte 2.0.0b17__py3-none-any.whl → 2.0.0b19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (52) hide show
  1. flyte/_bin/runtime.py +3 -0
  2. flyte/_debug/vscode.py +4 -2
  3. flyte/_deploy.py +3 -1
  4. flyte/_environment.py +15 -6
  5. flyte/_hash.py +1 -16
  6. flyte/_image.py +6 -1
  7. flyte/_initialize.py +15 -16
  8. flyte/_internal/controllers/__init__.py +4 -5
  9. flyte/_internal/controllers/_local_controller.py +5 -5
  10. flyte/_internal/controllers/remote/_controller.py +21 -28
  11. flyte/_internal/controllers/remote/_core.py +1 -1
  12. flyte/_internal/imagebuild/docker_builder.py +31 -23
  13. flyte/_internal/imagebuild/remote_builder.py +37 -10
  14. flyte/_internal/imagebuild/utils.py +2 -1
  15. flyte/_internal/runtime/convert.py +69 -2
  16. flyte/_internal/runtime/taskrunner.py +4 -1
  17. flyte/_logging.py +110 -26
  18. flyte/_map.py +90 -12
  19. flyte/_pod.py +2 -1
  20. flyte/_run.py +6 -1
  21. flyte/_task.py +3 -0
  22. flyte/_task_environment.py +5 -1
  23. flyte/_trace.py +5 -0
  24. flyte/_version.py +3 -3
  25. flyte/cli/_create.py +4 -1
  26. flyte/cli/_deploy.py +4 -5
  27. flyte/cli/_params.py +18 -4
  28. flyte/cli/_run.py +2 -2
  29. flyte/config/_config.py +2 -2
  30. flyte/config/_reader.py +14 -8
  31. flyte/errors.py +3 -1
  32. flyte/git/__init__.py +3 -0
  33. flyte/git/_config.py +17 -0
  34. flyte/io/_dataframe/basic_dfs.py +16 -7
  35. flyte/io/_dataframe/dataframe.py +84 -123
  36. flyte/io/_dir.py +35 -4
  37. flyte/io/_file.py +61 -15
  38. flyte/io/_hashing_io.py +342 -0
  39. flyte/models.py +12 -4
  40. flyte/remote/_action.py +4 -2
  41. flyte/remote/_task.py +52 -22
  42. flyte/report/_report.py +1 -1
  43. flyte/storage/_storage.py +16 -1
  44. flyte/types/_type_engine.py +1 -51
  45. {flyte-2.0.0b17.data → flyte-2.0.0b19.data}/scripts/runtime.py +3 -0
  46. {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/METADATA +1 -1
  47. {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/RECORD +52 -49
  48. {flyte-2.0.0b17.data → flyte-2.0.0b19.data}/scripts/debug.py +0 -0
  49. {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/WHEEL +0 -0
  50. {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/entry_points.txt +0 -0
  51. {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/licenses/LICENSE +0 -0
  52. {flyte-2.0.0b17.dist-info → flyte-2.0.0b19.dist-info}/top_level.txt +0 -0
flyte/config/_config.py CHANGED
@@ -192,7 +192,7 @@ class Config(object):
192
192
  )
193
193
 
194
194
  @classmethod
195
- def auto(cls, config_file: typing.Union[str, ConfigFile, None] = None) -> "Config":
195
+ def auto(cls, config_file: typing.Union[str, pathlib.Path, ConfigFile, None] = None) -> "Config":
196
196
  """
197
197
  Automatically constructs the Config Object. The order of precedence is as follows
198
198
  1. first try to find any env vars that match the config vars specified in the FLYTE_CONFIG format.
@@ -225,7 +225,7 @@ def set_if_exists(d: dict, k: str, val: typing.Any) -> dict:
225
225
  return d
226
226
 
227
227
 
228
- def auto(config_file: typing.Union[str, ConfigFile, None] = None) -> Config:
228
+ def auto(config_file: typing.Union[str, pathlib.Path, ConfigFile, None] = None) -> Config:
229
229
  """
230
230
  Automatically constructs the Config Object. The order of precedence is as follows
231
231
  1. If specified, read the config from the provided file path.
flyte/config/_reader.py CHANGED
@@ -108,7 +108,7 @@ class ConfigFile(object):
108
108
  return pathlib.Path(self._location)
109
109
 
110
110
  @staticmethod
111
- def _read_yaml_config(location: str) -> typing.Optional[typing.Dict[str, typing.Any]]:
111
+ def _read_yaml_config(location: str | pathlib.Path) -> typing.Optional[typing.Dict[str, typing.Any]]:
112
112
  with open(location, "r") as fh:
113
113
  try:
114
114
  yaml_contents = yaml.safe_load(fh)
@@ -139,16 +139,22 @@ def resolve_config_path() -> pathlib.Path | None:
139
139
  """
140
140
  Config is read from the following locations in order of precedence:
141
141
  1. ./config.yaml if it exists
142
- 2. `UCTL_CONFIG` environment variable
143
- 3. `FLYTECTL_CONFIG` environment variable
144
- 4. ~/.union/config.yaml if it exists
145
- 5. ~/.flyte/config.yaml if it exists
142
+ 2. ./.flyte/config.yaml if it exists
143
+ 3. `UCTL_CONFIG` environment variable
144
+ 4. `FLYTECTL_CONFIG` environment variable
145
+ 5. ~/.union/config.yaml if it exists
146
+ 6. ~/.flyte/config.yaml if it exists
146
147
  """
147
148
  current_location_config = Path("config.yaml")
148
149
  if current_location_config.exists():
149
150
  return current_location_config
150
151
  logger.debug("No ./config.yaml found")
151
152
 
153
+ dot_flyte_config = Path(".flyte", "config.yaml")
154
+ if dot_flyte_config.exists():
155
+ return dot_flyte_config
156
+ logger.debug("No ./.flyte/config.yaml found")
157
+
152
158
  uctl_path_from_env = getenv(UCTL_CONFIG_ENV_VAR, None)
153
159
  if uctl_path_from_env:
154
160
  return pathlib.Path(uctl_path_from_env)
@@ -173,13 +179,13 @@ def resolve_config_path() -> pathlib.Path | None:
173
179
 
174
180
 
175
181
  @lru_cache
176
- def get_config_file(c: typing.Union[str, ConfigFile, None]) -> ConfigFile | None:
182
+ def get_config_file(c: typing.Union[str, pathlib.Path, ConfigFile, None]) -> ConfigFile | None:
177
183
  """
178
184
  Checks if the given argument is a file or a configFile and returns a loaded configFile else returns None
179
185
  """
180
- if isinstance(c, str):
186
+ if isinstance(c, (str, pathlib.Path)):
181
187
  logger.debug(f"Using specified config file at {c}")
182
- return ConfigFile(c)
188
+ return ConfigFile(str(c))
183
189
  elif isinstance(c, ConfigFile):
184
190
  return c
185
191
  config_path = resolve_config_path()
flyte/errors.py CHANGED
@@ -132,7 +132,9 @@ class CustomError(RuntimeUserError):
132
132
  Create a CustomError from an exception. The exception's class name is used as the error code and the exception
133
133
  message is used as the error message.
134
134
  """
135
- return cls(e.__class__.__name__, str(e))
135
+ new_exc = cls(e.__class__.__name__, str(e))
136
+ new_exc.__cause__ = e
137
+ return new_exc
136
138
 
137
139
 
138
140
  class NotInTaskContextError(RuntimeUserError):
flyte/git/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from ._config import config_from_root
2
+
3
+ __all__ = ["config_from_root"]
flyte/git/_config.py ADDED
@@ -0,0 +1,17 @@
1
+ import pathlib
2
+ import subprocess
3
+
4
+ import flyte.config
5
+
6
+
7
+ def config_from_root(path: pathlib.Path | str = ".flyte/config.yaml") -> flyte.config.Config:
8
+ """Get the config file from the git root directory.
9
+
10
+ By default, the config file is expected to be in `.flyte/config.yaml` in the git root directory.
11
+ """
12
+
13
+ result = subprocess.run(["git", "rev-parse", "--show-toplevel"], check=False, capture_output=True, text=True)
14
+ if result.returncode != 0:
15
+ raise RuntimeError(f"Failed to get git root directory: {result.stderr}")
16
+ root = pathlib.Path(result.stdout.strip())
17
+ return flyte.config.auto(root / path)
@@ -58,16 +58,16 @@ class PandasToCSVEncodingHandler(DataFrameEncoder):
58
58
 
59
59
  if not storage.is_remote(uri):
60
60
  Path(uri).mkdir(parents=True, exist_ok=True)
61
- path = os.path.join(uri, ".csv")
61
+ csv_file = storage.join(uri, "data.csv")
62
62
  df = typing.cast(pd.DataFrame, dataframe.val)
63
63
  df.to_csv(
64
- path,
64
+ csv_file,
65
65
  index=False,
66
- storage_options=get_pandas_storage_options(uri=path),
66
+ storage_options=get_pandas_storage_options(uri=csv_file),
67
67
  )
68
68
  structured_dataset_type.format = CSV
69
69
  return literals_pb2.StructuredDataset(
70
- uri=uri, metadata=literals_pb2.StructuredDatasetMetadata(structured_dataset_type)
70
+ uri=uri, metadata=literals_pb2.StructuredDatasetMetadata(structured_dataset_type=structured_dataset_type)
71
71
  )
72
72
 
73
73
 
@@ -83,16 +83,25 @@ class CSVToPandasDecodingHandler(DataFrameDecoder):
83
83
  uri = proto_value.uri
84
84
  columns = None
85
85
  kwargs = get_pandas_storage_options(uri=uri)
86
- path = os.path.join(uri, ".csv")
86
+ csv_file = storage.join(uri, "data.csv")
87
87
  if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
88
88
  columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
89
89
  try:
90
- return pd.read_csv(path, usecols=columns, storage_options=kwargs)
90
+ import io
91
+
92
+ # The pattern used here is a bit wonky because of obstore issues with csv, getting early eof error.
93
+ buf = io.BytesIO()
94
+ async for chunk in storage.get_stream(csv_file):
95
+ buf.write(chunk)
96
+ buf.seek(0)
97
+ df = pd.read_csv(buf)
98
+ return df
99
+
91
100
  except Exception as exc:
92
101
  if exc.__class__.__name__ == "NoCredentialsError":
93
102
  logger.debug("S3 source detected, attempting anonymous S3 access")
94
103
  kwargs = get_pandas_storage_options(uri=uri, anonymous=True)
95
- return pd.read_csv(path, usecols=columns, storage_options=kwargs)
104
+ return pd.read_csv(csv_file, usecols=columns, storage_options=kwargs)
96
105
  else:
97
106
  raise
98
107
 
@@ -1,20 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import _datetime
4
- import asyncio
5
4
  import collections
6
5
  import types
7
6
  import typing
8
7
  from abc import ABC, abstractmethod
9
- from dataclasses import dataclass, field, is_dataclass
8
+ from dataclasses import is_dataclass
10
9
  from typing import Any, ClassVar, Coroutine, Dict, Generic, List, Optional, Type, Union
11
10
 
12
- import msgpack
13
11
  from flyteidl.core import literals_pb2, types_pb2
14
12
  from fsspec.utils import get_protocol
15
- from mashumaro.mixins.json import DataClassJSONMixin
16
13
  from mashumaro.types import SerializableType
17
- from pydantic import model_serializer, model_validator
14
+ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_serializer, model_validator
18
15
  from typing_extensions import Annotated, TypeAlias, get_args, get_origin
19
16
 
20
17
  import flyte.storage as storage
@@ -48,15 +45,23 @@ GENERIC_FORMAT: DataFrameFormat = ""
48
45
  GENERIC_PROTOCOL: str = "generic protocol"
49
46
 
50
47
 
51
- @dataclass
52
- class DataFrame(SerializableType, DataClassJSONMixin):
48
+ class DataFrame(BaseModel, SerializableType):
53
49
  """
54
50
  This is the user facing DataFrame class. Please don't confuse it with the literals.StructuredDataset
55
51
  class (that is just a model, a Python class representation of the protobuf).
56
52
  """
57
53
 
58
- uri: typing.Optional[str] = field(default=None)
59
- file_format: typing.Optional[str] = field(default=GENERIC_FORMAT)
54
+ uri: typing.Optional[str] = Field(default=None)
55
+ format: typing.Optional[str] = Field(default=GENERIC_FORMAT)
56
+
57
+ model_config = ConfigDict(arbitrary_types_allowed=True)
58
+
59
+ # Private attributes that are not part of the Pydantic model schema
60
+ _raw_df: typing.Optional[typing.Any] = PrivateAttr(default=None)
61
+ _metadata: typing.Optional[literals_pb2.StructuredDatasetMetadata] = PrivateAttr(default=None)
62
+ _literal_sd: Optional[literals_pb2.StructuredDataset] = PrivateAttr(default=None)
63
+ _dataframe_type: Optional[Type[Any]] = PrivateAttr(default=None)
64
+ _already_uploaded: bool = PrivateAttr(default=False)
60
65
 
61
66
  # loop manager is working better than synchronicity for some reason, was getting an error but may be an easy fix
62
67
  def _serialize(self) -> Dict[str, Optional[str]]:
@@ -65,16 +70,16 @@ class DataFrame(SerializableType, DataClassJSONMixin):
65
70
  engine = DataFrameTransformerEngine()
66
71
  lv = loop_manager.run_sync(engine.to_literal, self, type(self), lt)
67
72
  sd = DataFrame(uri=lv.scalar.structured_dataset.uri)
68
- sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
73
+ sd.format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
69
74
  return {
70
75
  "uri": sd.uri,
71
- "file_format": sd.file_format,
76
+ "format": sd.format,
72
77
  }
73
78
 
74
79
  @classmethod
75
- def _deserialize(cls, value) -> "DataFrame":
80
+ def _deserialize(cls, value) -> DataFrame:
76
81
  uri = value.get("uri", None)
77
- file_format = value.get("file_format", None)
82
+ format_val = value.get("format", None)
78
83
 
79
84
  if uri is None:
80
85
  raise ValueError("DataFrame's uri and file format should not be None")
@@ -86,7 +91,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
86
91
  scalar=literals_pb2.Scalar(
87
92
  structured_dataset=literals_pb2.StructuredDataset(
88
93
  metadata=literals_pb2.StructuredDatasetMetadata(
89
- structured_dataset_type=types_pb2.StructuredDatasetType(format=file_format)
94
+ structured_dataset_type=types_pb2.StructuredDatasetType(format=format_val)
90
95
  ),
91
96
  uri=uri,
92
97
  )
@@ -102,7 +107,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
102
107
  lv = loop_manager.run_sync(sde.to_literal, self, type(self), lt)
103
108
  return {
104
109
  "uri": lv.scalar.structured_dataset.uri,
105
- "file_format": lv.scalar.structured_dataset.metadata.structured_dataset_type.format,
110
+ "format": lv.scalar.structured_dataset.metadata.structured_dataset_type.format,
106
111
  }
107
112
 
108
113
  @model_validator(mode="after")
@@ -117,7 +122,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
117
122
  scalar=literals_pb2.Scalar(
118
123
  structured_dataset=literals_pb2.StructuredDataset(
119
124
  metadata=literals_pb2.StructuredDatasetMetadata(
120
- structured_dataset_type=types_pb2.StructuredDatasetType(format=self.file_format)
125
+ structured_dataset_type=types_pb2.StructuredDatasetType(format=self.format)
121
126
  ),
122
127
  uri=self.uri,
123
128
  )
@@ -134,30 +139,46 @@ class DataFrame(SerializableType, DataClassJSONMixin):
134
139
  def column_names(cls) -> typing.List[str]:
135
140
  return [k for k, v in cls.columns().items()]
136
141
 
137
- def __init__(
138
- self,
142
+ @classmethod
143
+ def from_df(
144
+ cls,
139
145
  val: typing.Optional[typing.Any] = None,
140
146
  uri: typing.Optional[str] = None,
141
- metadata: typing.Optional[literals_pb2.StructuredDatasetMetadata] = None,
147
+ ) -> DataFrame:
148
+ """
149
+ Wrapper to create a DataFrame from a dataframe.
150
+ The reason this is implemented as a wrapper instead of a full translation invoking
151
+ the type engine and the encoders is because there's too much information in the type
152
+ signature of the task that we don't want the user to have to replicate.
153
+ """
154
+ instance = cls(uri=uri)
155
+ instance._raw_df = val
156
+ return instance
157
+
158
+ @classmethod
159
+ def from_existing_remote(
160
+ cls,
161
+ remote_path: str,
162
+ format: typing.Optional[str] = None,
142
163
  **kwargs,
143
- ):
144
- self._val = val
145
- # Make these fields public, so that the dataclass transformer can set a value for it
146
- # https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298
147
- self.uri = uri
148
- # When dataclass_json runs from_json, we need to set it here, otherwise the format will be empty string
149
- self.file_format = kwargs["file_format"] if "file_format" in kwargs else GENERIC_FORMAT
150
- # This is a special attribute that indicates if the data was either downloaded or uploaded
151
- self._metadata = metadata
152
- # This is not for users to set, the transformer will set this.
153
- self._literal_sd: Optional[literals_pb2.StructuredDataset] = None
154
- # Not meant for users to set, will be set by an open() call
155
- self._dataframe_type: Optional[DF] = None # type: ignore
156
- self._already_uploaded = False
164
+ ) -> "DataFrame":
165
+ """
166
+ Create a DataFrame reference from an existing remote dataframe.
167
+
168
+ Args:
169
+ remote_path: The remote path to the existing dataframe
170
+ format: Format of the stored dataframe
171
+
172
+ Example:
173
+ ```python
174
+ df = DataFrame.from_existing_remote("s3://bucket/data.parquet", format="parquet")
175
+ ```
176
+ """
177
+ return cls(uri=remote_path, format=format or GENERIC_FORMAT, **kwargs)
157
178
 
158
179
  @property
159
180
  def val(self) -> Optional[DF]:
160
- return self._val
181
+ return self._raw_df
161
182
 
162
183
  @property
163
184
  def metadata(self) -> Optional[literals_pb2.StructuredDatasetMetadata]:
@@ -201,7 +222,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
201
222
 
202
223
  @task
203
224
  def return_df() -> DataFrame:
204
- df = DataFrame(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet", file_format="parquet")
225
+ df = DataFrame(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet", format="parquet")
205
226
  df = df.open(pd.DataFrame).all()
206
227
  return df
207
228
 
@@ -244,6 +265,9 @@ def flatten_dict(sub_dict: dict, parent_key: str = "") -> typing.Dict:
244
265
  fields = getattr(value, "__dataclass_fields__")
245
266
  d = {k: v.type for k, v in fields.items()}
246
267
  result.update(flatten_dict(sub_dict=d, parent_key=current_key))
268
+ elif hasattr(value, "model_fields"): # Pydantic model
269
+ d = {k: v.annotation for k, v in value.model_fields.items()}
270
+ result.update(flatten_dict(sub_dict=d, parent_key=current_key))
247
271
  else:
248
272
  result[current_key] = value
249
273
  return result
@@ -708,16 +732,16 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
708
732
  # return DataFrame(uri=uri)
709
733
  if python_val.val is None:
710
734
  uri = python_val.uri
711
- file_format = python_val.file_format
735
+ format_val = python_val.format
712
736
 
713
737
  # Check the user-specified uri
714
738
  if not uri:
715
739
  raise ValueError(f"If dataframe is not specified, then the uri should be specified. {python_val}")
716
740
  if not storage.is_remote(uri):
717
- uri = await storage.put(uri)
741
+ uri = await storage.put(uri, recursive=True)
718
742
 
719
- # Check the user-specified file_format
720
- # When users specify file_format for a DataFrame, the file_format should be retained
743
+ # Check the user-specified format
744
+ # When users specify format for a DataFrame, the format should be retained
721
745
  # conditionally. For details, please refer to https://github.com/flyteorg/flyte/issues/6096.
722
746
  # Following illustrates why we can't always copy the user-specified file_format over:
723
747
  #
@@ -725,14 +749,14 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
725
749
  # def modify_format(df: Annotated[DataFrame, {}, "task-format"]) -> DataFrame:
726
750
  # return df
727
751
  #
728
- # df = DataFrame(uri="s3://my-s3-bucket/df.parquet", file_format="user-format")
752
+ # df = DataFrame(uri="s3://my-s3-bucket/df.parquet", format="user-format")
729
753
  # df2 = modify_format(df=df)
730
754
  #
731
- # In this case, we expect the df2.file_format to be task-format (as shown in Annotated),
732
- # not user-format. If we directly copy the user-specified file_format over,
755
+ # In this case, we expect the df2.format to be task-format (as shown in Annotated),
756
+ # not user-format. If we directly copy the user-specified format over,
733
757
  # the type hint information will be missing.
734
- if sdt.format == GENERIC_FORMAT and file_format != GENERIC_FORMAT:
735
- sdt.format = file_format
758
+ if sdt.format == GENERIC_FORMAT and format_val != GENERIC_FORMAT:
759
+ sdt.format = format_val
736
760
 
737
761
  sd_model = literals_pb2.StructuredDataset(
738
762
  uri=uri,
@@ -760,8 +784,9 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
760
784
  structured_dataset_type=expected.structured_dataset_type if expected else None
761
785
  )
762
786
 
763
- sd = DataFrame(val=python_val, metadata=meta)
764
- return await self.encode(sd, python_type, protocol, fmt, sdt)
787
+ fdf = DataFrame.from_df(val=python_val)
788
+ fdf._metadata = meta
789
+ return await self.encode(fdf, python_type, protocol, fmt, sdt)
765
790
 
766
791
  def _protocol_from_type_or_prefix(self, df_type: Type, uri: Optional[str] = None) -> str:
767
792
  """
@@ -782,7 +807,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
782
807
 
783
808
  async def encode(
784
809
  self,
785
- sd: DataFrame,
810
+ df: DataFrame,
786
811
  df_type: Type,
787
812
  protocol: str,
788
813
  format: str,
@@ -791,7 +816,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
791
816
  handler: DataFrameEncoder
792
817
  handler = self.get_encoder(df_type, protocol, format)
793
818
 
794
- sd_model = await handler.encode(sd, structured_literal_type)
819
+ sd_model = await handler.encode(df, structured_literal_type)
795
820
  # This block is here in case the encoder did not set the type information in the metadata. Since this literal
796
821
  # is special in that it carries around the type itself, we want to make sure the type info therein is at
797
822
  # least as good as the type of the interface.
@@ -807,72 +832,13 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
807
832
  lit = literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=sd_model))
808
833
 
809
834
  # Because the handler.encode may have uploaded something, and because the sd may end up living inside a
810
- # dataclass, we need to modify any uploaded flyte:// urls here.
811
- modify_literal_uris(lit) # todo: verify that this can be removed.
812
- sd._literal_sd = sd_model
813
- sd._already_uploaded = True
835
+ # dataclass, we need to modify any uploaded flyte:// urls here. Needed here even though the Type engine
836
+ # already does this because the DataframeTransformerEngine may be called directly.
837
+ modify_literal_uris(lit)
838
+ df._literal_sd = sd_model
839
+ df._already_uploaded = True
814
840
  return lit
815
841
 
816
- # pr: han-ru: can this be removed if we make DataFrame a pydantic model?
817
- def dict_to_dataframe(
818
- self, dict_obj: typing.Dict[str, str], expected_python_type: Type[T] | DataFrame
819
- ) -> T | DataFrame:
820
- uri = dict_obj.get("uri", None)
821
- file_format = dict_obj.get("file_format", None)
822
-
823
- if uri is None:
824
- raise ValueError("DataFrame's uri and file format should not be None")
825
-
826
- # Instead of using python native DataFrame, we need to build a literals.StructuredDataset
827
- # The reason is that _literal_sd of python sd is accessed when task output LiteralMap is
828
- # converted back to flyteidl. Hence, _literal_sd must have to_flyte_idl method
829
- # See https://github.com/flyteorg/flytekit/blob/f938661ff8413219d1bea77f6914a58c302d5c6c/flytekit/bin/entrypoint.py#L326
830
- # For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5956.
831
- sdt = types_pb2.StructuredDatasetType(format=file_format)
832
- metad = literals_pb2.StructuredDatasetMetadata(structured_dataset_type=sdt)
833
- sd_literal = literals_pb2.StructuredDataset(uri=uri, metadata=metad)
834
-
835
- return asyncio.run(
836
- DataFrameTransformerEngine().to_python_value(
837
- literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=sd_literal)),
838
- expected_python_type,
839
- )
840
- )
841
-
842
- def from_binary_idl(
843
- self, binary_idl_object: literals_pb2.Binary, expected_python_type: Type[T] | DataFrame
844
- ) -> T | DataFrame:
845
- """
846
- If the input is from flytekit, the Life Cycle will be as follows:
847
-
848
- Life Cycle:
849
- binary IDL -> resolved binary -> bytes -> expected Python object
850
- (flytekit customized (propeller processing) (flytekit binary IDL) (flytekit customized
851
- serialization) deserialization)
852
-
853
- Example Code:
854
- @dataclass
855
- class DC:
856
- sd: StructuredDataset
857
-
858
- @workflow
859
- def wf(dc: DC):
860
- t_sd(dc.sd)
861
-
862
- Note:
863
- - The deserialization is the same as put a structured dataset in a dataclass,
864
- which will deserialize by the mashumaro's API.
865
-
866
- Related PR:
867
- - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro
868
- - Link: https://github.com/flyteorg/flytekit/pull/2554
869
- """
870
- if binary_idl_object.tag == MESSAGEPACK:
871
- python_val = msgpack.loads(binary_idl_object.value)
872
- return self.dict_to_dataframe(dict_obj=python_val, expected_python_type=expected_python_type)
873
- else:
874
- raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")
875
-
876
842
  async def to_python_value(
877
843
  self, lv: literals_pb2.Literal, expected_python_type: Type[T] | DataFrame
878
844
  ) -> T | DataFrame:
@@ -906,9 +872,8 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
906
872
  | | the running task's signature. | |
907
873
  +-----------------------------+-----------------------------------------+--------------------------------------+
908
874
  """
909
- # Handle dataclass attribute access
910
875
  if lv.HasField("scalar") and lv.scalar.HasField("binary"):
911
- return self.from_binary_idl(lv.scalar.binary, expected_python_type)
876
+ raise TypeTransformerFailedError("Attribute access unsupported.")
912
877
 
913
878
  # Detect annotations and extract out all the relevant information that the user might supply
914
879
  expected_python_type, column_dict, storage_fmt, pa_schema = extract_cols_and_format(expected_python_type)
@@ -939,16 +904,12 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
939
904
  # t1(input_a: DataFrame) # or
940
905
  # t1(input_a: Annotated[DataFrame, my_cols])
941
906
  if issubclass(expected_python_type, DataFrame):
942
- sd = expected_python_type(
943
- dataframe=None,
944
- # Note here that the type being passed in
945
- metadata=metad,
946
- )
947
- sd._literal_sd = lv.scalar.structured_dataset
948
- sd.file_format = metad.structured_dataset_type.format
949
- return sd
907
+ fdf = DataFrame(format=metad.structured_dataset_type.format)
908
+ fdf._literal_sd = lv.scalar.structured_dataset
909
+ fdf._metadata = metad
910
+ return fdf
950
911
 
951
- # If the requested type was not a StructuredDataset, then it means it was a plain dataframe type, which means
912
+ # If the requested type was not a flyte.DataFrame, then it means it was a raw dataframe type, which means
952
913
  # we should do the opening/downloading and whatever else it might entail right now. No iteration option here.
953
914
  return await self.open_as(lv.scalar.structured_dataset, df_type=expected_python_type, updated_metadata=metad)
954
915
 
flyte/io/_dir.py CHANGED
@@ -48,6 +48,7 @@ class Dir(BaseModel, Generic[T], SerializableType):
48
48
  path: str
49
49
  name: Optional[str] = None
50
50
  format: str = ""
51
+ hash: Optional[str] = None
51
52
 
52
53
  class Config:
53
54
  arbitrary_types_allowed = True
@@ -248,13 +249,20 @@ class Dir(BaseModel, Generic[T], SerializableType):
248
249
  raise NotImplementedError("Sync download is not implemented for remote paths")
249
250
 
250
251
  @classmethod
251
- async def from_local(cls, local_path: Union[str, Path], remote_path: Optional[str] = None) -> Dir[T]:
252
+ async def from_local(
253
+ cls,
254
+ local_path: Union[str, Path],
255
+ remote_path: Optional[str] = None,
256
+ dir_cache_key: Optional[str] = None,
257
+ ) -> Dir[T]:
252
258
  """
253
259
  Asynchronously create a new Dir by uploading a local directory to the configured remote store.
254
260
 
255
261
  Args:
256
262
  local_path: Path to the local directory
257
263
  remote_path: Optional path to store the directory remotely. If None, a path will be generated.
264
+ dir_cache_key: If you have a precomputed hash value you want to use when computing cache keys for
265
+ discoverable tasks that this File is an input to.
258
266
 
259
267
  Returns:
260
268
  A new Dir instance pointing to the uploaded directory
@@ -262,13 +270,34 @@ class Dir(BaseModel, Generic[T], SerializableType):
262
270
  Example:
263
271
  ```python
264
272
  remote_dir = await Dir[DataFrame].from_local('/tmp/data_dir/', 's3://bucket/data/')
273
+ # With a known hash value you want to use for cache key calculation
274
+ remote_dir = await Dir[DataFrame].from_local('/tmp/data_dir/', 's3://bucket/data/', dir_cache_key='abc123')
265
275
  ```
266
276
  """
267
277
  local_path_str = str(local_path)
268
278
  dirname = os.path.basename(os.path.normpath(local_path_str))
269
279
 
270
280
  output_path = await storage.put(from_path=local_path_str, to_path=remote_path, recursive=True)
271
- return cls(path=output_path, name=dirname)
281
+ return cls(path=output_path, name=dirname, hash=dir_cache_key)
282
+
283
+ @classmethod
284
+ def from_existing_remote(cls, remote_path: str, dir_cache_key: Optional[str] = None) -> Dir[T]:
285
+ """
286
+ Create a Dir reference from an existing remote directory.
287
+
288
+ Args:
289
+ remote_path: The remote path to the existing directory
290
+ dir_cache_key: Optional hash value to use for cache key computation. If not specified,
291
+ the cache key will be computed based on this object's attributes.
292
+
293
+ Example:
294
+ ```python
295
+ remote_dir = Dir.from_existing_remote("s3://bucket/data/")
296
+ # With a known hash
297
+ remote_dir = Dir.from_existing_remote("s3://bucket/data/", dir_cache_key="abc123")
298
+ ```
299
+ """
300
+ return cls(path=remote_path, hash=dir_cache_key)
272
301
 
273
302
  @classmethod
274
303
  def from_local_sync(cls, local_path: Union[str, Path], remote_path: Optional[str] = None) -> Dir[T]:
@@ -414,7 +443,8 @@ class DirTransformer(TypeTransformer[Dir]):
414
443
  ),
415
444
  uri=python_val.path,
416
445
  )
417
- )
446
+ ),
447
+ hash=python_val.hash if python_val.hash else None,
418
448
  )
419
449
 
420
450
  async def to_python_value(
@@ -432,7 +462,8 @@ class DirTransformer(TypeTransformer[Dir]):
432
462
 
433
463
  uri = lv.scalar.blob.uri
434
464
  filename = Path(uri).name
435
- f: Dir = Dir(path=uri, name=filename, format=lv.scalar.blob.metadata.type.format)
465
+ hash_value = lv.hash if lv.hash else None
466
+ f: Dir = Dir(path=uri, name=filename, format=lv.scalar.blob.metadata.type.format, hash=hash_value)
436
467
  return f
437
468
 
438
469
  def guess_python_type(self, literal_type: types_pb2.LiteralType) -> Type[Dir]: