flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyte/__init__.py +18 -2
- flyte/_bin/debug.py +38 -0
- flyte/_bin/runtime.py +62 -8
- flyte/_cache/cache.py +4 -2
- flyte/_cache/local_cache.py +216 -0
- flyte/_code_bundle/_ignore.py +12 -4
- flyte/_code_bundle/_packaging.py +13 -9
- flyte/_code_bundle/_utils.py +18 -10
- flyte/_code_bundle/bundle.py +17 -9
- flyte/_constants.py +1 -0
- flyte/_context.py +4 -1
- flyte/_custom_context.py +73 -0
- flyte/_debug/constants.py +38 -0
- flyte/_debug/utils.py +17 -0
- flyte/_debug/vscode.py +307 -0
- flyte/_deploy.py +235 -61
- flyte/_environment.py +20 -6
- flyte/_excepthook.py +1 -1
- flyte/_hash.py +1 -16
- flyte/_image.py +178 -81
- flyte/_initialize.py +132 -51
- flyte/_interface.py +39 -2
- flyte/_internal/controllers/__init__.py +4 -5
- flyte/_internal/controllers/_local_controller.py +70 -29
- flyte/_internal/controllers/_trace.py +1 -1
- flyte/_internal/controllers/remote/__init__.py +0 -2
- flyte/_internal/controllers/remote/_action.py +14 -16
- flyte/_internal/controllers/remote/_client.py +1 -1
- flyte/_internal/controllers/remote/_controller.py +68 -70
- flyte/_internal/controllers/remote/_core.py +127 -99
- flyte/_internal/controllers/remote/_informer.py +19 -10
- flyte/_internal/controllers/remote/_service_protocol.py +7 -7
- flyte/_internal/imagebuild/docker_builder.py +181 -69
- flyte/_internal/imagebuild/image_builder.py +0 -5
- flyte/_internal/imagebuild/remote_builder.py +155 -64
- flyte/_internal/imagebuild/utils.py +51 -2
- flyte/_internal/resolvers/_task_module.py +5 -38
- flyte/_internal/resolvers/default.py +2 -2
- flyte/_internal/runtime/convert.py +110 -21
- flyte/_internal/runtime/entrypoints.py +27 -1
- flyte/_internal/runtime/io.py +21 -8
- flyte/_internal/runtime/resources_serde.py +20 -6
- flyte/_internal/runtime/reuse.py +1 -1
- flyte/_internal/runtime/rusty.py +20 -5
- flyte/_internal/runtime/task_serde.py +34 -19
- flyte/_internal/runtime/taskrunner.py +22 -4
- flyte/_internal/runtime/trigger_serde.py +160 -0
- flyte/_internal/runtime/types_serde.py +1 -1
- flyte/_keyring/__init__.py +0 -0
- flyte/_keyring/file.py +115 -0
- flyte/_logging.py +201 -39
- flyte/_map.py +111 -14
- flyte/_module.py +70 -0
- flyte/_pod.py +4 -3
- flyte/_resources.py +213 -31
- flyte/_run.py +110 -39
- flyte/_task.py +75 -16
- flyte/_task_environment.py +105 -29
- flyte/_task_plugins.py +4 -2
- flyte/_trace.py +5 -0
- flyte/_trigger.py +1000 -0
- flyte/_utils/__init__.py +2 -1
- flyte/_utils/asyn.py +3 -1
- flyte/_utils/coro_management.py +2 -1
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/module_loader.py +17 -2
- flyte/_version.py +3 -3
- flyte/cli/_abort.py +3 -3
- flyte/cli/_build.py +3 -6
- flyte/cli/_common.py +78 -7
- flyte/cli/_create.py +182 -4
- flyte/cli/_delete.py +23 -1
- flyte/cli/_deploy.py +63 -16
- flyte/cli/_get.py +79 -34
- flyte/cli/_params.py +26 -10
- flyte/cli/_plugins.py +209 -0
- flyte/cli/_run.py +151 -26
- flyte/cli/_serve.py +64 -0
- flyte/cli/_update.py +37 -0
- flyte/cli/_user.py +17 -0
- flyte/cli/main.py +30 -4
- flyte/config/_config.py +10 -6
- flyte/config/_internal.py +1 -0
- flyte/config/_reader.py +29 -8
- flyte/connectors/__init__.py +11 -0
- flyte/connectors/_connector.py +270 -0
- flyte/connectors/_server.py +197 -0
- flyte/connectors/utils.py +135 -0
- flyte/errors.py +22 -2
- flyte/extend.py +8 -1
- flyte/extras/_container.py +6 -1
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +21 -0
- flyte/io/__init__.py +2 -0
- flyte/io/_dataframe/__init__.py +2 -0
- flyte/io/_dataframe/basic_dfs.py +17 -8
- flyte/io/_dataframe/dataframe.py +98 -132
- flyte/io/_dir.py +575 -113
- flyte/io/_file.py +582 -139
- flyte/io/_hashing_io.py +342 -0
- flyte/models.py +74 -15
- flyte/remote/__init__.py +6 -1
- flyte/remote/_action.py +34 -26
- flyte/remote/_client/_protocols.py +39 -4
- flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
- flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
- flyte/remote/_client/auth/_channel.py +10 -6
- flyte/remote/_client/controlplane.py +17 -5
- flyte/remote/_console.py +3 -2
- flyte/remote/_data.py +6 -6
- flyte/remote/_logs.py +3 -3
- flyte/remote/_run.py +64 -8
- flyte/remote/_secret.py +26 -17
- flyte/remote/_task.py +75 -33
- flyte/remote/_trigger.py +306 -0
- flyte/remote/_user.py +33 -0
- flyte/report/_report.py +1 -1
- flyte/storage/__init__.py +6 -1
- flyte/storage/_config.py +5 -1
- flyte/storage/_parallel_reader.py +274 -0
- flyte/storage/_storage.py +200 -103
- flyte/types/__init__.py +16 -0
- flyte/types/_interface.py +2 -2
- flyte/types/_pickle.py +35 -8
- flyte/types/_string_literals.py +8 -9
- flyte/types/_type_engine.py +40 -70
- flyte/types/_utils.py +1 -1
- flyte-2.0.0b30.data/scripts/debug.py +38 -0
- {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
- flyte-2.0.0b30.dist-info/RECORD +192 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
- flyte/_protos/common/authorization_pb2.py +0 -66
- flyte/_protos/common/authorization_pb2.pyi +0 -108
- flyte/_protos/common/authorization_pb2_grpc.py +0 -4
- flyte/_protos/common/identifier_pb2.py +0 -93
- flyte/_protos/common/identifier_pb2.pyi +0 -110
- flyte/_protos/common/identifier_pb2_grpc.py +0 -4
- flyte/_protos/common/identity_pb2.py +0 -48
- flyte/_protos/common/identity_pb2.pyi +0 -72
- flyte/_protos/common/identity_pb2_grpc.py +0 -4
- flyte/_protos/common/list_pb2.py +0 -36
- flyte/_protos/common/list_pb2.pyi +0 -71
- flyte/_protos/common/list_pb2_grpc.py +0 -4
- flyte/_protos/common/policy_pb2.py +0 -37
- flyte/_protos/common/policy_pb2.pyi +0 -27
- flyte/_protos/common/policy_pb2_grpc.py +0 -4
- flyte/_protos/common/role_pb2.py +0 -37
- flyte/_protos/common/role_pb2.pyi +0 -53
- flyte/_protos/common/role_pb2_grpc.py +0 -4
- flyte/_protos/common/runtime_version_pb2.py +0 -28
- flyte/_protos/common/runtime_version_pb2.pyi +0 -24
- flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/definition_pb2.py +0 -59
- flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
- flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/payload_pb2.py +0 -32
- flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
- flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/service_pb2.py +0 -29
- flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
- flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
- flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
- flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/definition_pb2.py +0 -49
- flyte/_protos/secret/definition_pb2.pyi +0 -93
- flyte/_protos/secret/definition_pb2_grpc.py +0 -4
- flyte/_protos/secret/payload_pb2.py +0 -62
- flyte/_protos/secret/payload_pb2.pyi +0 -94
- flyte/_protos/secret/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/secret_pb2.py +0 -38
- flyte/_protos/secret/secret_pb2.pyi +0 -6
- flyte/_protos/secret/secret_pb2_grpc.py +0 -198
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
- flyte/_protos/validate/validate/validate_pb2.py +0 -76
- flyte/_protos/workflow/common_pb2.py +0 -27
- flyte/_protos/workflow/common_pb2.pyi +0 -14
- flyte/_protos/workflow/common_pb2_grpc.py +0 -4
- flyte/_protos/workflow/environment_pb2.py +0 -29
- flyte/_protos/workflow/environment_pb2.pyi +0 -12
- flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
- flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
- flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- flyte/_protos/workflow/queue_service_pb2.py +0 -109
- flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
- flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
- flyte/_protos/workflow/run_definition_pb2.py +0 -121
- flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
- flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
- flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- flyte/_protos/workflow/run_service_pb2.py +0 -137
- flyte/_protos/workflow/run_service_pb2.pyi +0 -185
- flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
- flyte/_protos/workflow/state_service_pb2.py +0 -67
- flyte/_protos/workflow/state_service_pb2.pyi +0 -76
- flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
- flyte/_protos/workflow/task_definition_pb2.py +0 -79
- flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
- flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/task_service_pb2.py +0 -60
- flyte/_protos/workflow/task_service_pb2.pyi +0 -59
- flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
- flyte-2.0.0b13.dist-info/RECORD +0 -239
- /flyte/{_protos → _debug}/__init__.py +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
flyte/io/_dataframe/dataframe.py
CHANGED
|
@@ -1,20 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import _datetime
|
|
4
|
-
import asyncio
|
|
5
4
|
import collections
|
|
6
5
|
import types
|
|
7
6
|
import typing
|
|
8
7
|
from abc import ABC, abstractmethod
|
|
9
|
-
from dataclasses import
|
|
8
|
+
from dataclasses import is_dataclass
|
|
10
9
|
from typing import Any, ClassVar, Coroutine, Dict, Generic, List, Optional, Type, Union
|
|
11
10
|
|
|
12
|
-
import
|
|
13
|
-
from flyteidl.core import literals_pb2, types_pb2
|
|
11
|
+
from flyteidl2.core import literals_pb2, types_pb2
|
|
14
12
|
from fsspec.utils import get_protocol
|
|
15
|
-
from mashumaro.mixins.json import DataClassJSONMixin
|
|
16
13
|
from mashumaro.types import SerializableType
|
|
17
|
-
from pydantic import model_serializer, model_validator
|
|
14
|
+
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_serializer, model_validator
|
|
18
15
|
from typing_extensions import Annotated, TypeAlias, get_args, get_origin
|
|
19
16
|
|
|
20
17
|
import flyte.storage as storage
|
|
@@ -48,15 +45,23 @@ GENERIC_FORMAT: DataFrameFormat = ""
|
|
|
48
45
|
GENERIC_PROTOCOL: str = "generic protocol"
|
|
49
46
|
|
|
50
47
|
|
|
51
|
-
|
|
52
|
-
class DataFrame(SerializableType, DataClassJSONMixin):
|
|
48
|
+
class DataFrame(BaseModel, SerializableType):
|
|
53
49
|
"""
|
|
54
50
|
This is the user facing DataFrame class. Please don't confuse it with the literals.StructuredDataset
|
|
55
51
|
class (that is just a model, a Python class representation of the protobuf).
|
|
56
52
|
"""
|
|
57
53
|
|
|
58
|
-
uri: typing.Optional[str] =
|
|
59
|
-
|
|
54
|
+
uri: typing.Optional[str] = Field(default=None)
|
|
55
|
+
format: typing.Optional[str] = Field(default=GENERIC_FORMAT)
|
|
56
|
+
|
|
57
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
58
|
+
|
|
59
|
+
# Private attributes that are not part of the Pydantic model schema
|
|
60
|
+
_raw_df: typing.Optional[typing.Any] = PrivateAttr(default=None)
|
|
61
|
+
_metadata: typing.Optional[literals_pb2.StructuredDatasetMetadata] = PrivateAttr(default=None)
|
|
62
|
+
_literal_sd: Optional[literals_pb2.StructuredDataset] = PrivateAttr(default=None)
|
|
63
|
+
_dataframe_type: Optional[Type[Any]] = PrivateAttr(default=None)
|
|
64
|
+
_already_uploaded: bool = PrivateAttr(default=False)
|
|
60
65
|
|
|
61
66
|
# loop manager is working better than synchronicity for some reason, was getting an error but may be an easy fix
|
|
62
67
|
def _serialize(self) -> Dict[str, Optional[str]]:
|
|
@@ -65,16 +70,16 @@ class DataFrame(SerializableType, DataClassJSONMixin):
|
|
|
65
70
|
engine = DataFrameTransformerEngine()
|
|
66
71
|
lv = loop_manager.run_sync(engine.to_literal, self, type(self), lt)
|
|
67
72
|
sd = DataFrame(uri=lv.scalar.structured_dataset.uri)
|
|
68
|
-
sd.
|
|
73
|
+
sd.format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
|
|
69
74
|
return {
|
|
70
75
|
"uri": sd.uri,
|
|
71
|
-
"
|
|
76
|
+
"format": sd.format,
|
|
72
77
|
}
|
|
73
78
|
|
|
74
79
|
@classmethod
|
|
75
|
-
def _deserialize(cls, value) ->
|
|
80
|
+
def _deserialize(cls, value) -> DataFrame:
|
|
76
81
|
uri = value.get("uri", None)
|
|
77
|
-
|
|
82
|
+
format_val = value.get("format", None)
|
|
78
83
|
|
|
79
84
|
if uri is None:
|
|
80
85
|
raise ValueError("DataFrame's uri and file format should not be None")
|
|
@@ -86,7 +91,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
|
|
|
86
91
|
scalar=literals_pb2.Scalar(
|
|
87
92
|
structured_dataset=literals_pb2.StructuredDataset(
|
|
88
93
|
metadata=literals_pb2.StructuredDatasetMetadata(
|
|
89
|
-
structured_dataset_type=types_pb2.StructuredDatasetType(format=
|
|
94
|
+
structured_dataset_type=types_pb2.StructuredDatasetType(format=format_val)
|
|
90
95
|
),
|
|
91
96
|
uri=uri,
|
|
92
97
|
)
|
|
@@ -102,7 +107,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
|
|
|
102
107
|
lv = loop_manager.run_sync(sde.to_literal, self, type(self), lt)
|
|
103
108
|
return {
|
|
104
109
|
"uri": lv.scalar.structured_dataset.uri,
|
|
105
|
-
"
|
|
110
|
+
"format": lv.scalar.structured_dataset.metadata.structured_dataset_type.format,
|
|
106
111
|
}
|
|
107
112
|
|
|
108
113
|
@model_validator(mode="after")
|
|
@@ -117,7 +122,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
|
|
|
117
122
|
scalar=literals_pb2.Scalar(
|
|
118
123
|
structured_dataset=literals_pb2.StructuredDataset(
|
|
119
124
|
metadata=literals_pb2.StructuredDatasetMetadata(
|
|
120
|
-
structured_dataset_type=types_pb2.StructuredDatasetType(format=self.
|
|
125
|
+
structured_dataset_type=types_pb2.StructuredDatasetType(format=self.format)
|
|
121
126
|
),
|
|
122
127
|
uri=self.uri,
|
|
123
128
|
)
|
|
@@ -134,30 +139,46 @@ class DataFrame(SerializableType, DataClassJSONMixin):
|
|
|
134
139
|
def column_names(cls) -> typing.List[str]:
|
|
135
140
|
return [k for k, v in cls.columns().items()]
|
|
136
141
|
|
|
137
|
-
|
|
138
|
-
|
|
142
|
+
@classmethod
|
|
143
|
+
def from_df(
|
|
144
|
+
cls,
|
|
139
145
|
val: typing.Optional[typing.Any] = None,
|
|
140
146
|
uri: typing.Optional[str] = None,
|
|
141
|
-
|
|
147
|
+
) -> DataFrame:
|
|
148
|
+
"""
|
|
149
|
+
Wrapper to create a DataFrame from a dataframe.
|
|
150
|
+
The reason this is implemented as a wrapper instead of a full translation invoking
|
|
151
|
+
the type engine and the encoders is because there's too much information in the type
|
|
152
|
+
signature of the task that we don't want the user to have to replicate.
|
|
153
|
+
"""
|
|
154
|
+
instance = cls(uri=uri)
|
|
155
|
+
instance._raw_df = val
|
|
156
|
+
return instance
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def from_existing_remote(
|
|
160
|
+
cls,
|
|
161
|
+
remote_path: str,
|
|
162
|
+
format: typing.Optional[str] = None,
|
|
142
163
|
**kwargs,
|
|
143
|
-
):
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
164
|
+
) -> "DataFrame":
|
|
165
|
+
"""
|
|
166
|
+
Create a DataFrame reference from an existing remote dataframe.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
remote_path: The remote path to the existing dataframe
|
|
170
|
+
format: Format of the stored dataframe
|
|
171
|
+
|
|
172
|
+
Example:
|
|
173
|
+
```python
|
|
174
|
+
df = DataFrame.from_existing_remote("s3://bucket/data.parquet", format="parquet")
|
|
175
|
+
```
|
|
176
|
+
"""
|
|
177
|
+
return cls(uri=remote_path, format=format or GENERIC_FORMAT, **kwargs)
|
|
157
178
|
|
|
158
179
|
@property
|
|
159
180
|
def val(self) -> Optional[DF]:
|
|
160
|
-
return self.
|
|
181
|
+
return self._raw_df
|
|
161
182
|
|
|
162
183
|
@property
|
|
163
184
|
def metadata(self) -> Optional[literals_pb2.StructuredDatasetMetadata]:
|
|
@@ -201,7 +222,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
|
|
|
201
222
|
|
|
202
223
|
@task
|
|
203
224
|
def return_df() -> DataFrame:
|
|
204
|
-
df = DataFrame(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet",
|
|
225
|
+
df = DataFrame(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet", format="parquet")
|
|
205
226
|
df = df.open(pd.DataFrame).all()
|
|
206
227
|
return df
|
|
207
228
|
|
|
@@ -244,6 +265,9 @@ def flatten_dict(sub_dict: dict, parent_key: str = "") -> typing.Dict:
|
|
|
244
265
|
fields = getattr(value, "__dataclass_fields__")
|
|
245
266
|
d = {k: v.type for k, v in fields.items()}
|
|
246
267
|
result.update(flatten_dict(sub_dict=d, parent_key=current_key))
|
|
268
|
+
elif hasattr(value, "model_fields"): # Pydantic model
|
|
269
|
+
d = {k: v.annotation for k, v in value.model_fields.items()}
|
|
270
|
+
result.update(flatten_dict(sub_dict=d, parent_key=current_key))
|
|
247
271
|
else:
|
|
248
272
|
result[current_key] = value
|
|
249
273
|
return result
|
|
@@ -623,17 +647,21 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
623
647
|
f"Already registered a handler for {(h.python_type, protocol, h.supported_format)}"
|
|
624
648
|
)
|
|
625
649
|
lowest_level[h.supported_format] = h
|
|
626
|
-
logger.debug(
|
|
650
|
+
logger.debug(
|
|
651
|
+
f"Registered {h.__class__.__name__} as handler for {h.python_type.__class__.__name__},"
|
|
652
|
+
f" protocol {protocol}, fmt {h.supported_format}"
|
|
653
|
+
)
|
|
627
654
|
|
|
628
655
|
if (default_format_for_type or default_for_type) and h.supported_format != GENERIC_FORMAT:
|
|
629
656
|
if h.python_type in cls.DEFAULT_FORMATS and not override:
|
|
630
657
|
if cls.DEFAULT_FORMATS[h.python_type] != h.supported_format:
|
|
631
658
|
logger.info(
|
|
632
|
-
f"Not using handler {h} with format {h.supported_format}"
|
|
633
|
-
f" as default for {h.python_type
|
|
659
|
+
f"Not using handler {h.__class__.__name__} with format {h.supported_format}"
|
|
660
|
+
f" as default for {h.python_type.__class__.__name__},"
|
|
661
|
+
f" {cls.DEFAULT_FORMATS[h.python_type]} already specified."
|
|
634
662
|
)
|
|
635
663
|
else:
|
|
636
|
-
logger.debug(f"Use {type(h).__name__} as default handler for {h.python_type}.")
|
|
664
|
+
logger.debug(f"Use {type(h).__name__} as default handler for {h.python_type.__class__.__name__}.")
|
|
637
665
|
cls.DEFAULT_FORMATS[h.python_type] = h.supported_format
|
|
638
666
|
if default_storage_for_type or default_for_type:
|
|
639
667
|
if h.protocol in cls.DEFAULT_PROTOCOLS and not override:
|
|
@@ -661,7 +689,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
661
689
|
expected: types_pb2.LiteralType,
|
|
662
690
|
) -> literals_pb2.Literal:
|
|
663
691
|
# Make a copy in case we need to hand off to encoders, since we can't be sure of mutations.
|
|
664
|
-
python_type, *
|
|
692
|
+
python_type, *_attrs = extract_cols_and_format(python_type)
|
|
665
693
|
sdt = types_pb2.StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT))
|
|
666
694
|
|
|
667
695
|
if issubclass(python_type, DataFrame) and not isinstance(python_val, DataFrame):
|
|
@@ -708,16 +736,16 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
708
736
|
# return DataFrame(uri=uri)
|
|
709
737
|
if python_val.val is None:
|
|
710
738
|
uri = python_val.uri
|
|
711
|
-
|
|
739
|
+
format_val = python_val.format
|
|
712
740
|
|
|
713
741
|
# Check the user-specified uri
|
|
714
742
|
if not uri:
|
|
715
743
|
raise ValueError(f"If dataframe is not specified, then the uri should be specified. {python_val}")
|
|
716
744
|
if not storage.is_remote(uri):
|
|
717
|
-
uri = await storage.put(uri)
|
|
745
|
+
uri = await storage.put(uri, recursive=True)
|
|
718
746
|
|
|
719
|
-
# Check the user-specified
|
|
720
|
-
# When users specify
|
|
747
|
+
# Check the user-specified format
|
|
748
|
+
# When users specify format for a DataFrame, the format should be retained
|
|
721
749
|
# conditionally. For details, please refer to https://github.com/flyteorg/flyte/issues/6096.
|
|
722
750
|
# Following illustrates why we can't always copy the user-specified file_format over:
|
|
723
751
|
#
|
|
@@ -725,14 +753,14 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
725
753
|
# def modify_format(df: Annotated[DataFrame, {}, "task-format"]) -> DataFrame:
|
|
726
754
|
# return df
|
|
727
755
|
#
|
|
728
|
-
# df = DataFrame(uri="s3://my-s3-bucket/df.parquet",
|
|
756
|
+
# df = DataFrame(uri="s3://my-s3-bucket/df.parquet", format="user-format")
|
|
729
757
|
# df2 = modify_format(df=df)
|
|
730
758
|
#
|
|
731
|
-
# In this case, we expect the df2.
|
|
732
|
-
# not user-format. If we directly copy the user-specified
|
|
759
|
+
# In this case, we expect the df2.format to be task-format (as shown in Annotated),
|
|
760
|
+
# not user-format. If we directly copy the user-specified format over,
|
|
733
761
|
# the type hint information will be missing.
|
|
734
|
-
if sdt.format == GENERIC_FORMAT and
|
|
735
|
-
sdt.format =
|
|
762
|
+
if sdt.format == GENERIC_FORMAT and format_val != GENERIC_FORMAT:
|
|
763
|
+
sdt.format = format_val
|
|
736
764
|
|
|
737
765
|
sd_model = literals_pb2.StructuredDataset(
|
|
738
766
|
uri=uri,
|
|
@@ -760,8 +788,9 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
760
788
|
structured_dataset_type=expected.structured_dataset_type if expected else None
|
|
761
789
|
)
|
|
762
790
|
|
|
763
|
-
|
|
764
|
-
|
|
791
|
+
fdf = DataFrame.from_df(val=python_val)
|
|
792
|
+
fdf._metadata = meta
|
|
793
|
+
return await self.encode(fdf, python_type, protocol, fmt, sdt)
|
|
765
794
|
|
|
766
795
|
def _protocol_from_type_or_prefix(self, df_type: Type, uri: Optional[str] = None) -> str:
|
|
767
796
|
"""
|
|
@@ -782,7 +811,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
782
811
|
|
|
783
812
|
async def encode(
|
|
784
813
|
self,
|
|
785
|
-
|
|
814
|
+
df: DataFrame,
|
|
786
815
|
df_type: Type,
|
|
787
816
|
protocol: str,
|
|
788
817
|
format: str,
|
|
@@ -791,7 +820,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
791
820
|
handler: DataFrameEncoder
|
|
792
821
|
handler = self.get_encoder(df_type, protocol, format)
|
|
793
822
|
|
|
794
|
-
sd_model = await handler.encode(
|
|
823
|
+
sd_model = await handler.encode(df, structured_literal_type)
|
|
795
824
|
# This block is here in case the encoder did not set the type information in the metadata. Since this literal
|
|
796
825
|
# is special in that it carries around the type itself, we want to make sure the type info therein is at
|
|
797
826
|
# least as good as the type of the interface.
|
|
@@ -807,72 +836,13 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
807
836
|
lit = literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=sd_model))
|
|
808
837
|
|
|
809
838
|
# Because the handler.encode may have uploaded something, and because the sd may end up living inside a
|
|
810
|
-
# dataclass, we need to modify any uploaded flyte:// urls here.
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
839
|
+
# dataclass, we need to modify any uploaded flyte:// urls here. Needed here even though the Type engine
|
|
840
|
+
# already does this because the DataframeTransformerEngine may be called directly.
|
|
841
|
+
modify_literal_uris(lit)
|
|
842
|
+
df._literal_sd = sd_model
|
|
843
|
+
df._already_uploaded = True
|
|
814
844
|
return lit
|
|
815
845
|
|
|
816
|
-
# pr: han-ru: can this be removed if we make DataFrame a pydantic model?
|
|
817
|
-
def dict_to_dataframe(
|
|
818
|
-
self, dict_obj: typing.Dict[str, str], expected_python_type: Type[T] | DataFrame
|
|
819
|
-
) -> T | DataFrame:
|
|
820
|
-
uri = dict_obj.get("uri", None)
|
|
821
|
-
file_format = dict_obj.get("file_format", None)
|
|
822
|
-
|
|
823
|
-
if uri is None:
|
|
824
|
-
raise ValueError("DataFrame's uri and file format should not be None")
|
|
825
|
-
|
|
826
|
-
# Instead of using python native DataFrame, we need to build a literals.StructuredDataset
|
|
827
|
-
# The reason is that _literal_sd of python sd is accessed when task output LiteralMap is
|
|
828
|
-
# converted back to flyteidl. Hence, _literal_sd must have to_flyte_idl method
|
|
829
|
-
# See https://github.com/flyteorg/flytekit/blob/f938661ff8413219d1bea77f6914a58c302d5c6c/flytekit/bin/entrypoint.py#L326
|
|
830
|
-
# For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5956.
|
|
831
|
-
sdt = types_pb2.StructuredDatasetType(format=file_format)
|
|
832
|
-
metad = literals_pb2.StructuredDatasetMetadata(structured_dataset_type=sdt)
|
|
833
|
-
sd_literal = literals_pb2.StructuredDataset(uri=uri, metadata=metad)
|
|
834
|
-
|
|
835
|
-
return asyncio.run(
|
|
836
|
-
DataFrameTransformerEngine().to_python_value(
|
|
837
|
-
literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=sd_literal)),
|
|
838
|
-
expected_python_type,
|
|
839
|
-
)
|
|
840
|
-
)
|
|
841
|
-
|
|
842
|
-
def from_binary_idl(
|
|
843
|
-
self, binary_idl_object: literals_pb2.Binary, expected_python_type: Type[T] | DataFrame
|
|
844
|
-
) -> T | DataFrame:
|
|
845
|
-
"""
|
|
846
|
-
If the input is from flytekit, the Life Cycle will be as follows:
|
|
847
|
-
|
|
848
|
-
Life Cycle:
|
|
849
|
-
binary IDL -> resolved binary -> bytes -> expected Python object
|
|
850
|
-
(flytekit customized (propeller processing) (flytekit binary IDL) (flytekit customized
|
|
851
|
-
serialization) deserialization)
|
|
852
|
-
|
|
853
|
-
Example Code:
|
|
854
|
-
@dataclass
|
|
855
|
-
class DC:
|
|
856
|
-
sd: StructuredDataset
|
|
857
|
-
|
|
858
|
-
@workflow
|
|
859
|
-
def wf(dc: DC):
|
|
860
|
-
t_sd(dc.sd)
|
|
861
|
-
|
|
862
|
-
Note:
|
|
863
|
-
- The deserialization is the same as put a structured dataset in a dataclass,
|
|
864
|
-
which will deserialize by the mashumaro's API.
|
|
865
|
-
|
|
866
|
-
Related PR:
|
|
867
|
-
- Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro
|
|
868
|
-
- Link: https://github.com/flyteorg/flytekit/pull/2554
|
|
869
|
-
"""
|
|
870
|
-
if binary_idl_object.tag == MESSAGEPACK:
|
|
871
|
-
python_val = msgpack.loads(binary_idl_object.value)
|
|
872
|
-
return self.dict_to_dataframe(dict_obj=python_val, expected_python_type=expected_python_type)
|
|
873
|
-
else:
|
|
874
|
-
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")
|
|
875
|
-
|
|
876
846
|
async def to_python_value(
|
|
877
847
|
self, lv: literals_pb2.Literal, expected_python_type: Type[T] | DataFrame
|
|
878
848
|
) -> T | DataFrame:
|
|
@@ -906,12 +876,11 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
906
876
|
| | the running task's signature. | |
|
|
907
877
|
+-----------------------------+-----------------------------------------+--------------------------------------+
|
|
908
878
|
"""
|
|
909
|
-
# Handle dataclass attribute access
|
|
910
879
|
if lv.HasField("scalar") and lv.scalar.HasField("binary"):
|
|
911
|
-
|
|
880
|
+
raise TypeTransformerFailedError("Attribute access unsupported.")
|
|
912
881
|
|
|
913
882
|
# Detect annotations and extract out all the relevant information that the user might supply
|
|
914
|
-
expected_python_type, column_dict,
|
|
883
|
+
expected_python_type, column_dict, _storage_fmt, _pa_schema = extract_cols_and_format(expected_python_type)
|
|
915
884
|
|
|
916
885
|
# Start handling for DataFrame scalars, first look at the columns
|
|
917
886
|
incoming_columns = lv.scalar.structured_dataset.metadata.structured_dataset_type.columns
|
|
@@ -939,16 +908,13 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
939
908
|
# t1(input_a: DataFrame) # or
|
|
940
909
|
# t1(input_a: Annotated[DataFrame, my_cols])
|
|
941
910
|
if issubclass(expected_python_type, DataFrame):
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
sd._literal_sd = lv.scalar.structured_dataset
|
|
948
|
-
sd.file_format = metad.structured_dataset_type.format
|
|
949
|
-
return sd
|
|
911
|
+
fdf = DataFrame(format=metad.structured_dataset_type.format, uri=lv.scalar.structured_dataset.uri)
|
|
912
|
+
fdf._already_uploaded = True
|
|
913
|
+
fdf._literal_sd = lv.scalar.structured_dataset
|
|
914
|
+
fdf._metadata = metad
|
|
915
|
+
return fdf
|
|
950
916
|
|
|
951
|
-
# If the requested type was not a
|
|
917
|
+
# If the requested type was not a flyte.DataFrame, then it means it was a raw dataframe type, which means
|
|
952
918
|
# we should do the opening/downloading and whatever else it might entail right now. No iteration option here.
|
|
953
919
|
return await self.open_as(lv.scalar.structured_dataset, df_type=expected_python_type, updated_metadata=metad)
|
|
954
920
|
|
|
@@ -1024,7 +990,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
1024
990
|
return converted_cols
|
|
1025
991
|
|
|
1026
992
|
def _get_dataset_type(self, t: typing.Union[Type[DataFrame], typing.Any]) -> types_pb2.StructuredDatasetType:
|
|
1027
|
-
|
|
993
|
+
_original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore
|
|
1028
994
|
|
|
1029
995
|
# Get the column information
|
|
1030
996
|
converted_cols: typing.List[types_pb2.StructuredDatasetType.DatasetColumn] = (
|
|
@@ -1051,7 +1017,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
1051
1017
|
def guess_python_type(self, literal_type: types_pb2.LiteralType) -> Type[DataFrame]:
|
|
1052
1018
|
# todo: technically we should return the dataframe type specified in the constructor, but to do that,
|
|
1053
1019
|
# we'd have to store that, which we don't do today. See possibly #1363
|
|
1054
|
-
if literal_type.HasField("
|
|
1020
|
+
if literal_type.HasField("structured_dataset_type"):
|
|
1055
1021
|
return DataFrame
|
|
1056
1022
|
raise ValueError(f"DataFrameTransformerEngine cannot reverse {literal_type}")
|
|
1057
1023
|
|