flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/debug.py +38 -0
  3. flyte/_bin/runtime.py +62 -8
  4. flyte/_cache/cache.py +4 -2
  5. flyte/_cache/local_cache.py +216 -0
  6. flyte/_code_bundle/_ignore.py +12 -4
  7. flyte/_code_bundle/_packaging.py +13 -9
  8. flyte/_code_bundle/_utils.py +18 -10
  9. flyte/_code_bundle/bundle.py +17 -9
  10. flyte/_constants.py +1 -0
  11. flyte/_context.py +4 -1
  12. flyte/_custom_context.py +73 -0
  13. flyte/_debug/constants.py +38 -0
  14. flyte/_debug/utils.py +17 -0
  15. flyte/_debug/vscode.py +307 -0
  16. flyte/_deploy.py +235 -61
  17. flyte/_environment.py +20 -6
  18. flyte/_excepthook.py +1 -1
  19. flyte/_hash.py +1 -16
  20. flyte/_image.py +178 -81
  21. flyte/_initialize.py +132 -51
  22. flyte/_interface.py +39 -2
  23. flyte/_internal/controllers/__init__.py +4 -5
  24. flyte/_internal/controllers/_local_controller.py +70 -29
  25. flyte/_internal/controllers/_trace.py +1 -1
  26. flyte/_internal/controllers/remote/__init__.py +0 -2
  27. flyte/_internal/controllers/remote/_action.py +14 -16
  28. flyte/_internal/controllers/remote/_client.py +1 -1
  29. flyte/_internal/controllers/remote/_controller.py +68 -70
  30. flyte/_internal/controllers/remote/_core.py +127 -99
  31. flyte/_internal/controllers/remote/_informer.py +19 -10
  32. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  33. flyte/_internal/imagebuild/docker_builder.py +181 -69
  34. flyte/_internal/imagebuild/image_builder.py +0 -5
  35. flyte/_internal/imagebuild/remote_builder.py +155 -64
  36. flyte/_internal/imagebuild/utils.py +51 -2
  37. flyte/_internal/resolvers/_task_module.py +5 -38
  38. flyte/_internal/resolvers/default.py +2 -2
  39. flyte/_internal/runtime/convert.py +110 -21
  40. flyte/_internal/runtime/entrypoints.py +27 -1
  41. flyte/_internal/runtime/io.py +21 -8
  42. flyte/_internal/runtime/resources_serde.py +20 -6
  43. flyte/_internal/runtime/reuse.py +1 -1
  44. flyte/_internal/runtime/rusty.py +20 -5
  45. flyte/_internal/runtime/task_serde.py +34 -19
  46. flyte/_internal/runtime/taskrunner.py +22 -4
  47. flyte/_internal/runtime/trigger_serde.py +160 -0
  48. flyte/_internal/runtime/types_serde.py +1 -1
  49. flyte/_keyring/__init__.py +0 -0
  50. flyte/_keyring/file.py +115 -0
  51. flyte/_logging.py +201 -39
  52. flyte/_map.py +111 -14
  53. flyte/_module.py +70 -0
  54. flyte/_pod.py +4 -3
  55. flyte/_resources.py +213 -31
  56. flyte/_run.py +110 -39
  57. flyte/_task.py +75 -16
  58. flyte/_task_environment.py +105 -29
  59. flyte/_task_plugins.py +4 -2
  60. flyte/_trace.py +5 -0
  61. flyte/_trigger.py +1000 -0
  62. flyte/_utils/__init__.py +2 -1
  63. flyte/_utils/asyn.py +3 -1
  64. flyte/_utils/coro_management.py +2 -1
  65. flyte/_utils/docker_credentials.py +173 -0
  66. flyte/_utils/module_loader.py +17 -2
  67. flyte/_version.py +3 -3
  68. flyte/cli/_abort.py +3 -3
  69. flyte/cli/_build.py +3 -6
  70. flyte/cli/_common.py +78 -7
  71. flyte/cli/_create.py +182 -4
  72. flyte/cli/_delete.py +23 -1
  73. flyte/cli/_deploy.py +63 -16
  74. flyte/cli/_get.py +79 -34
  75. flyte/cli/_params.py +26 -10
  76. flyte/cli/_plugins.py +209 -0
  77. flyte/cli/_run.py +151 -26
  78. flyte/cli/_serve.py +64 -0
  79. flyte/cli/_update.py +37 -0
  80. flyte/cli/_user.py +17 -0
  81. flyte/cli/main.py +30 -4
  82. flyte/config/_config.py +10 -6
  83. flyte/config/_internal.py +1 -0
  84. flyte/config/_reader.py +29 -8
  85. flyte/connectors/__init__.py +11 -0
  86. flyte/connectors/_connector.py +270 -0
  87. flyte/connectors/_server.py +197 -0
  88. flyte/connectors/utils.py +135 -0
  89. flyte/errors.py +22 -2
  90. flyte/extend.py +8 -1
  91. flyte/extras/_container.py +6 -1
  92. flyte/git/__init__.py +3 -0
  93. flyte/git/_config.py +21 -0
  94. flyte/io/__init__.py +2 -0
  95. flyte/io/_dataframe/__init__.py +2 -0
  96. flyte/io/_dataframe/basic_dfs.py +17 -8
  97. flyte/io/_dataframe/dataframe.py +98 -132
  98. flyte/io/_dir.py +575 -113
  99. flyte/io/_file.py +582 -139
  100. flyte/io/_hashing_io.py +342 -0
  101. flyte/models.py +74 -15
  102. flyte/remote/__init__.py +6 -1
  103. flyte/remote/_action.py +34 -26
  104. flyte/remote/_client/_protocols.py +39 -4
  105. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  106. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  107. flyte/remote/_client/auth/_channel.py +10 -6
  108. flyte/remote/_client/controlplane.py +17 -5
  109. flyte/remote/_console.py +3 -2
  110. flyte/remote/_data.py +6 -6
  111. flyte/remote/_logs.py +3 -3
  112. flyte/remote/_run.py +64 -8
  113. flyte/remote/_secret.py +26 -17
  114. flyte/remote/_task.py +75 -33
  115. flyte/remote/_trigger.py +306 -0
  116. flyte/remote/_user.py +33 -0
  117. flyte/report/_report.py +1 -1
  118. flyte/storage/__init__.py +6 -1
  119. flyte/storage/_config.py +5 -1
  120. flyte/storage/_parallel_reader.py +274 -0
  121. flyte/storage/_storage.py +200 -103
  122. flyte/types/__init__.py +16 -0
  123. flyte/types/_interface.py +2 -2
  124. flyte/types/_pickle.py +35 -8
  125. flyte/types/_string_literals.py +8 -9
  126. flyte/types/_type_engine.py +40 -70
  127. flyte/types/_utils.py +1 -1
  128. flyte-2.0.0b30.data/scripts/debug.py +38 -0
  129. {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
  130. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
  131. flyte-2.0.0b30.dist-info/RECORD +192 -0
  132. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
  133. flyte/_protos/common/authorization_pb2.py +0 -66
  134. flyte/_protos/common/authorization_pb2.pyi +0 -108
  135. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  136. flyte/_protos/common/identifier_pb2.py +0 -93
  137. flyte/_protos/common/identifier_pb2.pyi +0 -110
  138. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  139. flyte/_protos/common/identity_pb2.py +0 -48
  140. flyte/_protos/common/identity_pb2.pyi +0 -72
  141. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  142. flyte/_protos/common/list_pb2.py +0 -36
  143. flyte/_protos/common/list_pb2.pyi +0 -71
  144. flyte/_protos/common/list_pb2_grpc.py +0 -4
  145. flyte/_protos/common/policy_pb2.py +0 -37
  146. flyte/_protos/common/policy_pb2.pyi +0 -27
  147. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  148. flyte/_protos/common/role_pb2.py +0 -37
  149. flyte/_protos/common/role_pb2.pyi +0 -53
  150. flyte/_protos/common/role_pb2_grpc.py +0 -4
  151. flyte/_protos/common/runtime_version_pb2.py +0 -28
  152. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  153. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  154. flyte/_protos/imagebuilder/definition_pb2.py +0 -59
  155. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
  156. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  157. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  158. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  159. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  160. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  161. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  162. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  163. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  164. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  165. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  166. flyte/_protos/secret/definition_pb2.py +0 -49
  167. flyte/_protos/secret/definition_pb2.pyi +0 -93
  168. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  169. flyte/_protos/secret/payload_pb2.py +0 -62
  170. flyte/_protos/secret/payload_pb2.pyi +0 -94
  171. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  172. flyte/_protos/secret/secret_pb2.py +0 -38
  173. flyte/_protos/secret/secret_pb2.pyi +0 -6
  174. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  175. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  176. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  177. flyte/_protos/workflow/common_pb2.py +0 -27
  178. flyte/_protos/workflow/common_pb2.pyi +0 -14
  179. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  180. flyte/_protos/workflow/environment_pb2.py +0 -29
  181. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  182. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  183. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  184. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  185. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  186. flyte/_protos/workflow/queue_service_pb2.py +0 -109
  187. flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
  188. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  189. flyte/_protos/workflow/run_definition_pb2.py +0 -121
  190. flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
  191. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  192. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  193. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  194. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  195. flyte/_protos/workflow/run_service_pb2.py +0 -137
  196. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  197. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  198. flyte/_protos/workflow/state_service_pb2.py +0 -67
  199. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  200. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  201. flyte/_protos/workflow/task_definition_pb2.py +0 -79
  202. flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
  203. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  204. flyte/_protos/workflow/task_service_pb2.py +0 -60
  205. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  206. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  207. flyte-2.0.0b13.dist-info/RECORD +0 -239
  208. /flyte/{_protos → _debug}/__init__.py +0 -0
  209. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  210. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  211. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import _datetime
4
- import asyncio
5
4
  import collections
6
5
  import types
7
6
  import typing
8
7
  from abc import ABC, abstractmethod
9
- from dataclasses import dataclass, field, is_dataclass
8
+ from dataclasses import is_dataclass
10
9
  from typing import Any, ClassVar, Coroutine, Dict, Generic, List, Optional, Type, Union
11
10
 
12
- import msgpack
13
- from flyteidl.core import literals_pb2, types_pb2
11
+ from flyteidl2.core import literals_pb2, types_pb2
14
12
  from fsspec.utils import get_protocol
15
- from mashumaro.mixins.json import DataClassJSONMixin
16
13
  from mashumaro.types import SerializableType
17
- from pydantic import model_serializer, model_validator
14
+ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_serializer, model_validator
18
15
  from typing_extensions import Annotated, TypeAlias, get_args, get_origin
19
16
 
20
17
  import flyte.storage as storage
@@ -48,15 +45,23 @@ GENERIC_FORMAT: DataFrameFormat = ""
48
45
  GENERIC_PROTOCOL: str = "generic protocol"
49
46
 
50
47
 
51
- @dataclass
52
- class DataFrame(SerializableType, DataClassJSONMixin):
48
+ class DataFrame(BaseModel, SerializableType):
53
49
  """
54
50
  This is the user facing DataFrame class. Please don't confuse it with the literals.StructuredDataset
55
51
  class (that is just a model, a Python class representation of the protobuf).
56
52
  """
57
53
 
58
- uri: typing.Optional[str] = field(default=None)
59
- file_format: typing.Optional[str] = field(default=GENERIC_FORMAT)
54
+ uri: typing.Optional[str] = Field(default=None)
55
+ format: typing.Optional[str] = Field(default=GENERIC_FORMAT)
56
+
57
+ model_config = ConfigDict(arbitrary_types_allowed=True)
58
+
59
+ # Private attributes that are not part of the Pydantic model schema
60
+ _raw_df: typing.Optional[typing.Any] = PrivateAttr(default=None)
61
+ _metadata: typing.Optional[literals_pb2.StructuredDatasetMetadata] = PrivateAttr(default=None)
62
+ _literal_sd: Optional[literals_pb2.StructuredDataset] = PrivateAttr(default=None)
63
+ _dataframe_type: Optional[Type[Any]] = PrivateAttr(default=None)
64
+ _already_uploaded: bool = PrivateAttr(default=False)
60
65
 
61
66
  # loop manager is working better than synchronicity for some reason, was getting an error but may be an easy fix
62
67
  def _serialize(self) -> Dict[str, Optional[str]]:
@@ -65,16 +70,16 @@ class DataFrame(SerializableType, DataClassJSONMixin):
65
70
  engine = DataFrameTransformerEngine()
66
71
  lv = loop_manager.run_sync(engine.to_literal, self, type(self), lt)
67
72
  sd = DataFrame(uri=lv.scalar.structured_dataset.uri)
68
- sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
73
+ sd.format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
69
74
  return {
70
75
  "uri": sd.uri,
71
- "file_format": sd.file_format,
76
+ "format": sd.format,
72
77
  }
73
78
 
74
79
  @classmethod
75
- def _deserialize(cls, value) -> "DataFrame":
80
+ def _deserialize(cls, value) -> DataFrame:
76
81
  uri = value.get("uri", None)
77
- file_format = value.get("file_format", None)
82
+ format_val = value.get("format", None)
78
83
 
79
84
  if uri is None:
80
85
  raise ValueError("DataFrame's uri and file format should not be None")
@@ -86,7 +91,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
86
91
  scalar=literals_pb2.Scalar(
87
92
  structured_dataset=literals_pb2.StructuredDataset(
88
93
  metadata=literals_pb2.StructuredDatasetMetadata(
89
- structured_dataset_type=types_pb2.StructuredDatasetType(format=file_format)
94
+ structured_dataset_type=types_pb2.StructuredDatasetType(format=format_val)
90
95
  ),
91
96
  uri=uri,
92
97
  )
@@ -102,7 +107,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
102
107
  lv = loop_manager.run_sync(sde.to_literal, self, type(self), lt)
103
108
  return {
104
109
  "uri": lv.scalar.structured_dataset.uri,
105
- "file_format": lv.scalar.structured_dataset.metadata.structured_dataset_type.format,
110
+ "format": lv.scalar.structured_dataset.metadata.structured_dataset_type.format,
106
111
  }
107
112
 
108
113
  @model_validator(mode="after")
@@ -117,7 +122,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
117
122
  scalar=literals_pb2.Scalar(
118
123
  structured_dataset=literals_pb2.StructuredDataset(
119
124
  metadata=literals_pb2.StructuredDatasetMetadata(
120
- structured_dataset_type=types_pb2.StructuredDatasetType(format=self.file_format)
125
+ structured_dataset_type=types_pb2.StructuredDatasetType(format=self.format)
121
126
  ),
122
127
  uri=self.uri,
123
128
  )
@@ -134,30 +139,46 @@ class DataFrame(SerializableType, DataClassJSONMixin):
134
139
  def column_names(cls) -> typing.List[str]:
135
140
  return [k for k, v in cls.columns().items()]
136
141
 
137
- def __init__(
138
- self,
142
+ @classmethod
143
+ def from_df(
144
+ cls,
139
145
  val: typing.Optional[typing.Any] = None,
140
146
  uri: typing.Optional[str] = None,
141
- metadata: typing.Optional[literals_pb2.StructuredDatasetMetadata] = None,
147
+ ) -> DataFrame:
148
+ """
149
+ Wrapper to create a DataFrame from a dataframe.
150
+ The reason this is implemented as a wrapper instead of a full translation invoking
151
+ the type engine and the encoders is because there's too much information in the type
152
+ signature of the task that we don't want the user to have to replicate.
153
+ """
154
+ instance = cls(uri=uri)
155
+ instance._raw_df = val
156
+ return instance
157
+
158
+ @classmethod
159
+ def from_existing_remote(
160
+ cls,
161
+ remote_path: str,
162
+ format: typing.Optional[str] = None,
142
163
  **kwargs,
143
- ):
144
- self._val = val
145
- # Make these fields public, so that the dataclass transformer can set a value for it
146
- # https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298
147
- self.uri = uri
148
- # When dataclass_json runs from_json, we need to set it here, otherwise the format will be empty string
149
- self.file_format = kwargs["file_format"] if "file_format" in kwargs else GENERIC_FORMAT
150
- # This is a special attribute that indicates if the data was either downloaded or uploaded
151
- self._metadata = metadata
152
- # This is not for users to set, the transformer will set this.
153
- self._literal_sd: Optional[literals_pb2.StructuredDataset] = None
154
- # Not meant for users to set, will be set by an open() call
155
- self._dataframe_type: Optional[DF] = None # type: ignore
156
- self._already_uploaded = False
164
+ ) -> "DataFrame":
165
+ """
166
+ Create a DataFrame reference from an existing remote dataframe.
167
+
168
+ Args:
169
+ remote_path: The remote path to the existing dataframe
170
+ format: Format of the stored dataframe
171
+
172
+ Example:
173
+ ```python
174
+ df = DataFrame.from_existing_remote("s3://bucket/data.parquet", format="parquet")
175
+ ```
176
+ """
177
+ return cls(uri=remote_path, format=format or GENERIC_FORMAT, **kwargs)
157
178
 
158
179
  @property
159
180
  def val(self) -> Optional[DF]:
160
- return self._val
181
+ return self._raw_df
161
182
 
162
183
  @property
163
184
  def metadata(self) -> Optional[literals_pb2.StructuredDatasetMetadata]:
@@ -201,7 +222,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
201
222
 
202
223
  @task
203
224
  def return_df() -> DataFrame:
204
- df = DataFrame(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet", file_format="parquet")
225
+ df = DataFrame(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet", format="parquet")
205
226
  df = df.open(pd.DataFrame).all()
206
227
  return df
207
228
 
@@ -244,6 +265,9 @@ def flatten_dict(sub_dict: dict, parent_key: str = "") -> typing.Dict:
244
265
  fields = getattr(value, "__dataclass_fields__")
245
266
  d = {k: v.type for k, v in fields.items()}
246
267
  result.update(flatten_dict(sub_dict=d, parent_key=current_key))
268
+ elif hasattr(value, "model_fields"): # Pydantic model
269
+ d = {k: v.annotation for k, v in value.model_fields.items()}
270
+ result.update(flatten_dict(sub_dict=d, parent_key=current_key))
247
271
  else:
248
272
  result[current_key] = value
249
273
  return result
@@ -623,17 +647,21 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
623
647
  f"Already registered a handler for {(h.python_type, protocol, h.supported_format)}"
624
648
  )
625
649
  lowest_level[h.supported_format] = h
626
- logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {protocol}, fmt {h.supported_format}")
650
+ logger.debug(
651
+ f"Registered {h.__class__.__name__} as handler for {h.python_type.__class__.__name__},"
652
+ f" protocol {protocol}, fmt {h.supported_format}"
653
+ )
627
654
 
628
655
  if (default_format_for_type or default_for_type) and h.supported_format != GENERIC_FORMAT:
629
656
  if h.python_type in cls.DEFAULT_FORMATS and not override:
630
657
  if cls.DEFAULT_FORMATS[h.python_type] != h.supported_format:
631
658
  logger.info(
632
- f"Not using handler {h} with format {h.supported_format}"
633
- f" as default for {h.python_type}, {cls.DEFAULT_FORMATS[h.python_type]} already specified."
659
+ f"Not using handler {h.__class__.__name__} with format {h.supported_format}"
660
+ f" as default for {h.python_type.__class__.__name__},"
661
+ f" {cls.DEFAULT_FORMATS[h.python_type]} already specified."
634
662
  )
635
663
  else:
636
- logger.debug(f"Use {type(h).__name__} as default handler for {h.python_type}.")
664
+ logger.debug(f"Use {type(h).__name__} as default handler for {h.python_type.__class__.__name__}.")
637
665
  cls.DEFAULT_FORMATS[h.python_type] = h.supported_format
638
666
  if default_storage_for_type or default_for_type:
639
667
  if h.protocol in cls.DEFAULT_PROTOCOLS and not override:
@@ -661,7 +689,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
661
689
  expected: types_pb2.LiteralType,
662
690
  ) -> literals_pb2.Literal:
663
691
  # Make a copy in case we need to hand off to encoders, since we can't be sure of mutations.
664
- python_type, *attrs = extract_cols_and_format(python_type)
692
+ python_type, *_attrs = extract_cols_and_format(python_type)
665
693
  sdt = types_pb2.StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT))
666
694
 
667
695
  if issubclass(python_type, DataFrame) and not isinstance(python_val, DataFrame):
@@ -708,16 +736,16 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
708
736
  # return DataFrame(uri=uri)
709
737
  if python_val.val is None:
710
738
  uri = python_val.uri
711
- file_format = python_val.file_format
739
+ format_val = python_val.format
712
740
 
713
741
  # Check the user-specified uri
714
742
  if not uri:
715
743
  raise ValueError(f"If dataframe is not specified, then the uri should be specified. {python_val}")
716
744
  if not storage.is_remote(uri):
717
- uri = await storage.put(uri)
745
+ uri = await storage.put(uri, recursive=True)
718
746
 
719
- # Check the user-specified file_format
720
- # When users specify file_format for a DataFrame, the file_format should be retained
747
+ # Check the user-specified format
748
+ # When users specify format for a DataFrame, the format should be retained
721
749
  # conditionally. For details, please refer to https://github.com/flyteorg/flyte/issues/6096.
722
750
  # Following illustrates why we can't always copy the user-specified file_format over:
723
751
  #
@@ -725,14 +753,14 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
725
753
  # def modify_format(df: Annotated[DataFrame, {}, "task-format"]) -> DataFrame:
726
754
  # return df
727
755
  #
728
- # df = DataFrame(uri="s3://my-s3-bucket/df.parquet", file_format="user-format")
756
+ # df = DataFrame(uri="s3://my-s3-bucket/df.parquet", format="user-format")
729
757
  # df2 = modify_format(df=df)
730
758
  #
731
- # In this case, we expect the df2.file_format to be task-format (as shown in Annotated),
732
- # not user-format. If we directly copy the user-specified file_format over,
759
+ # In this case, we expect the df2.format to be task-format (as shown in Annotated),
760
+ # not user-format. If we directly copy the user-specified format over,
733
761
  # the type hint information will be missing.
734
- if sdt.format == GENERIC_FORMAT and file_format != GENERIC_FORMAT:
735
- sdt.format = file_format
762
+ if sdt.format == GENERIC_FORMAT and format_val != GENERIC_FORMAT:
763
+ sdt.format = format_val
736
764
 
737
765
  sd_model = literals_pb2.StructuredDataset(
738
766
  uri=uri,
@@ -760,8 +788,9 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
760
788
  structured_dataset_type=expected.structured_dataset_type if expected else None
761
789
  )
762
790
 
763
- sd = DataFrame(val=python_val, metadata=meta)
764
- return await self.encode(sd, python_type, protocol, fmt, sdt)
791
+ fdf = DataFrame.from_df(val=python_val)
792
+ fdf._metadata = meta
793
+ return await self.encode(fdf, python_type, protocol, fmt, sdt)
765
794
 
766
795
  def _protocol_from_type_or_prefix(self, df_type: Type, uri: Optional[str] = None) -> str:
767
796
  """
@@ -782,7 +811,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
782
811
 
783
812
  async def encode(
784
813
  self,
785
- sd: DataFrame,
814
+ df: DataFrame,
786
815
  df_type: Type,
787
816
  protocol: str,
788
817
  format: str,
@@ -791,7 +820,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
791
820
  handler: DataFrameEncoder
792
821
  handler = self.get_encoder(df_type, protocol, format)
793
822
 
794
- sd_model = await handler.encode(sd, structured_literal_type)
823
+ sd_model = await handler.encode(df, structured_literal_type)
795
824
  # This block is here in case the encoder did not set the type information in the metadata. Since this literal
796
825
  # is special in that it carries around the type itself, we want to make sure the type info therein is at
797
826
  # least as good as the type of the interface.
@@ -807,72 +836,13 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
807
836
  lit = literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=sd_model))
808
837
 
809
838
  # Because the handler.encode may have uploaded something, and because the sd may end up living inside a
810
- # dataclass, we need to modify any uploaded flyte:// urls here.
811
- modify_literal_uris(lit) # todo: verify that this can be removed.
812
- sd._literal_sd = sd_model
813
- sd._already_uploaded = True
839
+ # dataclass, we need to modify any uploaded flyte:// urls here. Needed here even though the Type engine
840
+ # already does this because the DataframeTransformerEngine may be called directly.
841
+ modify_literal_uris(lit)
842
+ df._literal_sd = sd_model
843
+ df._already_uploaded = True
814
844
  return lit
815
845
 
816
- # pr: han-ru: can this be removed if we make DataFrame a pydantic model?
817
- def dict_to_dataframe(
818
- self, dict_obj: typing.Dict[str, str], expected_python_type: Type[T] | DataFrame
819
- ) -> T | DataFrame:
820
- uri = dict_obj.get("uri", None)
821
- file_format = dict_obj.get("file_format", None)
822
-
823
- if uri is None:
824
- raise ValueError("DataFrame's uri and file format should not be None")
825
-
826
- # Instead of using python native DataFrame, we need to build a literals.StructuredDataset
827
- # The reason is that _literal_sd of python sd is accessed when task output LiteralMap is
828
- # converted back to flyteidl. Hence, _literal_sd must have to_flyte_idl method
829
- # See https://github.com/flyteorg/flytekit/blob/f938661ff8413219d1bea77f6914a58c302d5c6c/flytekit/bin/entrypoint.py#L326
830
- # For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5956.
831
- sdt = types_pb2.StructuredDatasetType(format=file_format)
832
- metad = literals_pb2.StructuredDatasetMetadata(structured_dataset_type=sdt)
833
- sd_literal = literals_pb2.StructuredDataset(uri=uri, metadata=metad)
834
-
835
- return asyncio.run(
836
- DataFrameTransformerEngine().to_python_value(
837
- literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=sd_literal)),
838
- expected_python_type,
839
- )
840
- )
841
-
842
- def from_binary_idl(
843
- self, binary_idl_object: literals_pb2.Binary, expected_python_type: Type[T] | DataFrame
844
- ) -> T | DataFrame:
845
- """
846
- If the input is from flytekit, the Life Cycle will be as follows:
847
-
848
- Life Cycle:
849
- binary IDL -> resolved binary -> bytes -> expected Python object
850
- (flytekit customized (propeller processing) (flytekit binary IDL) (flytekit customized
851
- serialization) deserialization)
852
-
853
- Example Code:
854
- @dataclass
855
- class DC:
856
- sd: StructuredDataset
857
-
858
- @workflow
859
- def wf(dc: DC):
860
- t_sd(dc.sd)
861
-
862
- Note:
863
- - The deserialization is the same as put a structured dataset in a dataclass,
864
- which will deserialize by the mashumaro's API.
865
-
866
- Related PR:
867
- - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro
868
- - Link: https://github.com/flyteorg/flytekit/pull/2554
869
- """
870
- if binary_idl_object.tag == MESSAGEPACK:
871
- python_val = msgpack.loads(binary_idl_object.value)
872
- return self.dict_to_dataframe(dict_obj=python_val, expected_python_type=expected_python_type)
873
- else:
874
- raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")
875
-
876
846
  async def to_python_value(
877
847
  self, lv: literals_pb2.Literal, expected_python_type: Type[T] | DataFrame
878
848
  ) -> T | DataFrame:
@@ -906,12 +876,11 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
906
876
  | | the running task's signature. | |
907
877
  +-----------------------------+-----------------------------------------+--------------------------------------+
908
878
  """
909
- # Handle dataclass attribute access
910
879
  if lv.HasField("scalar") and lv.scalar.HasField("binary"):
911
- return self.from_binary_idl(lv.scalar.binary, expected_python_type)
880
+ raise TypeTransformerFailedError("Attribute access unsupported.")
912
881
 
913
882
  # Detect annotations and extract out all the relevant information that the user might supply
914
- expected_python_type, column_dict, storage_fmt, pa_schema = extract_cols_and_format(expected_python_type)
883
+ expected_python_type, column_dict, _storage_fmt, _pa_schema = extract_cols_and_format(expected_python_type)
915
884
 
916
885
  # Start handling for DataFrame scalars, first look at the columns
917
886
  incoming_columns = lv.scalar.structured_dataset.metadata.structured_dataset_type.columns
@@ -939,16 +908,13 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
939
908
  # t1(input_a: DataFrame) # or
940
909
  # t1(input_a: Annotated[DataFrame, my_cols])
941
910
  if issubclass(expected_python_type, DataFrame):
942
- sd = expected_python_type(
943
- dataframe=None,
944
- # Note here that the type being passed in
945
- metadata=metad,
946
- )
947
- sd._literal_sd = lv.scalar.structured_dataset
948
- sd.file_format = metad.structured_dataset_type.format
949
- return sd
911
+ fdf = DataFrame(format=metad.structured_dataset_type.format, uri=lv.scalar.structured_dataset.uri)
912
+ fdf._already_uploaded = True
913
+ fdf._literal_sd = lv.scalar.structured_dataset
914
+ fdf._metadata = metad
915
+ return fdf
950
916
 
951
- # If the requested type was not a StructuredDataset, then it means it was a plain dataframe type, which means
917
+ # If the requested type was not a flyte.DataFrame, then it means it was a raw dataframe type, which means
952
918
  # we should do the opening/downloading and whatever else it might entail right now. No iteration option here.
953
919
  return await self.open_as(lv.scalar.structured_dataset, df_type=expected_python_type, updated_metadata=metad)
954
920
 
@@ -1024,7 +990,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
1024
990
  return converted_cols
1025
991
 
1026
992
  def _get_dataset_type(self, t: typing.Union[Type[DataFrame], typing.Any]) -> types_pb2.StructuredDatasetType:
1027
- original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore
993
+ _original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore
1028
994
 
1029
995
  # Get the column information
1030
996
  converted_cols: typing.List[types_pb2.StructuredDatasetType.DatasetColumn] = (
@@ -1051,7 +1017,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
1051
1017
  def guess_python_type(self, literal_type: types_pb2.LiteralType) -> Type[DataFrame]:
1052
1018
  # todo: technically we should return the dataframe type specified in the constructor, but to do that,
1053
1019
  # we'd have to store that, which we don't do today. See possibly #1363
1054
- if literal_type.HasField("dataframe_type"):
1020
+ if literal_type.HasField("structured_dataset_type"):
1055
1021
  return DataFrame
1056
1022
  raise ValueError(f"DataFrameTransformerEngine cannot reverse {literal_type}")
1057
1023