flyte 2.0.0b18__py3-none-any.whl → 2.0.0b20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/_bin/runtime.py +2 -1
- flyte/_initialize.py +4 -4
- flyte/_internal/controllers/__init__.py +4 -5
- flyte/_internal/controllers/_local_controller.py +5 -5
- flyte/_internal/controllers/remote/__init__.py +0 -2
- flyte/_internal/controllers/remote/_controller.py +19 -23
- flyte/_internal/controllers/remote/_core.py +120 -92
- flyte/_internal/controllers/remote/_informer.py +15 -6
- flyte/_map.py +90 -12
- flyte/_task.py +3 -0
- flyte/_version.py +3 -3
- flyte/cli/_create.py +4 -1
- flyte/cli/_deploy.py +4 -5
- flyte/cli/_params.py +18 -4
- flyte/cli/_run.py +2 -2
- flyte/config/_config.py +2 -2
- flyte/config/_reader.py +14 -8
- flyte/errors.py +12 -1
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +17 -0
- flyte/io/_dataframe/basic_dfs.py +16 -7
- flyte/io/_dataframe/dataframe.py +84 -123
- flyte/remote/_task.py +52 -22
- flyte/report/_report.py +1 -1
- flyte/types/_type_engine.py +1 -30
- {flyte-2.0.0b18.data → flyte-2.0.0b20.data}/scripts/runtime.py +2 -1
- {flyte-2.0.0b18.dist-info → flyte-2.0.0b20.dist-info}/METADATA +2 -1
- {flyte-2.0.0b18.dist-info → flyte-2.0.0b20.dist-info}/RECORD +33 -31
- {flyte-2.0.0b18.data → flyte-2.0.0b20.data}/scripts/debug.py +0 -0
- {flyte-2.0.0b18.dist-info → flyte-2.0.0b20.dist-info}/WHEEL +0 -0
- {flyte-2.0.0b18.dist-info → flyte-2.0.0b20.dist-info}/entry_points.txt +0 -0
- {flyte-2.0.0b18.dist-info → flyte-2.0.0b20.dist-info}/licenses/LICENSE +0 -0
- {flyte-2.0.0b18.dist-info → flyte-2.0.0b20.dist-info}/top_level.txt +0 -0
flyte/io/_dataframe/dataframe.py
CHANGED
|
@@ -1,20 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import _datetime
|
|
4
|
-
import asyncio
|
|
5
4
|
import collections
|
|
6
5
|
import types
|
|
7
6
|
import typing
|
|
8
7
|
from abc import ABC, abstractmethod
|
|
9
|
-
from dataclasses import
|
|
8
|
+
from dataclasses import is_dataclass
|
|
10
9
|
from typing import Any, ClassVar, Coroutine, Dict, Generic, List, Optional, Type, Union
|
|
11
10
|
|
|
12
|
-
import msgpack
|
|
13
11
|
from flyteidl.core import literals_pb2, types_pb2
|
|
14
12
|
from fsspec.utils import get_protocol
|
|
15
|
-
from mashumaro.mixins.json import DataClassJSONMixin
|
|
16
13
|
from mashumaro.types import SerializableType
|
|
17
|
-
from pydantic import model_serializer, model_validator
|
|
14
|
+
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_serializer, model_validator
|
|
18
15
|
from typing_extensions import Annotated, TypeAlias, get_args, get_origin
|
|
19
16
|
|
|
20
17
|
import flyte.storage as storage
|
|
@@ -48,15 +45,23 @@ GENERIC_FORMAT: DataFrameFormat = ""
|
|
|
48
45
|
GENERIC_PROTOCOL: str = "generic protocol"
|
|
49
46
|
|
|
50
47
|
|
|
51
|
-
|
|
52
|
-
class DataFrame(SerializableType, DataClassJSONMixin):
|
|
48
|
+
class DataFrame(BaseModel, SerializableType):
|
|
53
49
|
"""
|
|
54
50
|
This is the user facing DataFrame class. Please don't confuse it with the literals.StructuredDataset
|
|
55
51
|
class (that is just a model, a Python class representation of the protobuf).
|
|
56
52
|
"""
|
|
57
53
|
|
|
58
|
-
uri: typing.Optional[str] =
|
|
59
|
-
|
|
54
|
+
uri: typing.Optional[str] = Field(default=None)
|
|
55
|
+
format: typing.Optional[str] = Field(default=GENERIC_FORMAT)
|
|
56
|
+
|
|
57
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
58
|
+
|
|
59
|
+
# Private attributes that are not part of the Pydantic model schema
|
|
60
|
+
_raw_df: typing.Optional[typing.Any] = PrivateAttr(default=None)
|
|
61
|
+
_metadata: typing.Optional[literals_pb2.StructuredDatasetMetadata] = PrivateAttr(default=None)
|
|
62
|
+
_literal_sd: Optional[literals_pb2.StructuredDataset] = PrivateAttr(default=None)
|
|
63
|
+
_dataframe_type: Optional[Type[Any]] = PrivateAttr(default=None)
|
|
64
|
+
_already_uploaded: bool = PrivateAttr(default=False)
|
|
60
65
|
|
|
61
66
|
# loop manager is working better than synchronicity for some reason, was getting an error but may be an easy fix
|
|
62
67
|
def _serialize(self) -> Dict[str, Optional[str]]:
|
|
@@ -65,16 +70,16 @@ class DataFrame(SerializableType, DataClassJSONMixin):
|
|
|
65
70
|
engine = DataFrameTransformerEngine()
|
|
66
71
|
lv = loop_manager.run_sync(engine.to_literal, self, type(self), lt)
|
|
67
72
|
sd = DataFrame(uri=lv.scalar.structured_dataset.uri)
|
|
68
|
-
sd.
|
|
73
|
+
sd.format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
|
|
69
74
|
return {
|
|
70
75
|
"uri": sd.uri,
|
|
71
|
-
"
|
|
76
|
+
"format": sd.format,
|
|
72
77
|
}
|
|
73
78
|
|
|
74
79
|
@classmethod
|
|
75
|
-
def _deserialize(cls, value) ->
|
|
80
|
+
def _deserialize(cls, value) -> DataFrame:
|
|
76
81
|
uri = value.get("uri", None)
|
|
77
|
-
|
|
82
|
+
format_val = value.get("format", None)
|
|
78
83
|
|
|
79
84
|
if uri is None:
|
|
80
85
|
raise ValueError("DataFrame's uri and file format should not be None")
|
|
@@ -86,7 +91,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
|
|
|
86
91
|
scalar=literals_pb2.Scalar(
|
|
87
92
|
structured_dataset=literals_pb2.StructuredDataset(
|
|
88
93
|
metadata=literals_pb2.StructuredDatasetMetadata(
|
|
89
|
-
structured_dataset_type=types_pb2.StructuredDatasetType(format=
|
|
94
|
+
structured_dataset_type=types_pb2.StructuredDatasetType(format=format_val)
|
|
90
95
|
),
|
|
91
96
|
uri=uri,
|
|
92
97
|
)
|
|
@@ -102,7 +107,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
|
|
|
102
107
|
lv = loop_manager.run_sync(sde.to_literal, self, type(self), lt)
|
|
103
108
|
return {
|
|
104
109
|
"uri": lv.scalar.structured_dataset.uri,
|
|
105
|
-
"
|
|
110
|
+
"format": lv.scalar.structured_dataset.metadata.structured_dataset_type.format,
|
|
106
111
|
}
|
|
107
112
|
|
|
108
113
|
@model_validator(mode="after")
|
|
@@ -117,7 +122,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
|
|
|
117
122
|
scalar=literals_pb2.Scalar(
|
|
118
123
|
structured_dataset=literals_pb2.StructuredDataset(
|
|
119
124
|
metadata=literals_pb2.StructuredDatasetMetadata(
|
|
120
|
-
structured_dataset_type=types_pb2.StructuredDatasetType(format=self.
|
|
125
|
+
structured_dataset_type=types_pb2.StructuredDatasetType(format=self.format)
|
|
121
126
|
),
|
|
122
127
|
uri=self.uri,
|
|
123
128
|
)
|
|
@@ -134,30 +139,46 @@ class DataFrame(SerializableType, DataClassJSONMixin):
|
|
|
134
139
|
def column_names(cls) -> typing.List[str]:
|
|
135
140
|
return [k for k, v in cls.columns().items()]
|
|
136
141
|
|
|
137
|
-
|
|
138
|
-
|
|
142
|
+
@classmethod
|
|
143
|
+
def from_df(
|
|
144
|
+
cls,
|
|
139
145
|
val: typing.Optional[typing.Any] = None,
|
|
140
146
|
uri: typing.Optional[str] = None,
|
|
141
|
-
|
|
147
|
+
) -> DataFrame:
|
|
148
|
+
"""
|
|
149
|
+
Wrapper to create a DataFrame from a dataframe.
|
|
150
|
+
The reason this is implemented as a wrapper instead of a full translation invoking
|
|
151
|
+
the type engine and the encoders is because there's too much information in the type
|
|
152
|
+
signature of the task that we don't want the user to have to replicate.
|
|
153
|
+
"""
|
|
154
|
+
instance = cls(uri=uri)
|
|
155
|
+
instance._raw_df = val
|
|
156
|
+
return instance
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def from_existing_remote(
|
|
160
|
+
cls,
|
|
161
|
+
remote_path: str,
|
|
162
|
+
format: typing.Optional[str] = None,
|
|
142
163
|
**kwargs,
|
|
143
|
-
):
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
164
|
+
) -> "DataFrame":
|
|
165
|
+
"""
|
|
166
|
+
Create a DataFrame reference from an existing remote dataframe.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
remote_path: The remote path to the existing dataframe
|
|
170
|
+
format: Format of the stored dataframe
|
|
171
|
+
|
|
172
|
+
Example:
|
|
173
|
+
```python
|
|
174
|
+
df = DataFrame.from_existing_remote("s3://bucket/data.parquet", format="parquet")
|
|
175
|
+
```
|
|
176
|
+
"""
|
|
177
|
+
return cls(uri=remote_path, format=format or GENERIC_FORMAT, **kwargs)
|
|
157
178
|
|
|
158
179
|
@property
|
|
159
180
|
def val(self) -> Optional[DF]:
|
|
160
|
-
return self.
|
|
181
|
+
return self._raw_df
|
|
161
182
|
|
|
162
183
|
@property
|
|
163
184
|
def metadata(self) -> Optional[literals_pb2.StructuredDatasetMetadata]:
|
|
@@ -201,7 +222,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
|
|
|
201
222
|
|
|
202
223
|
@task
|
|
203
224
|
def return_df() -> DataFrame:
|
|
204
|
-
df = DataFrame(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet",
|
|
225
|
+
df = DataFrame(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet", format="parquet")
|
|
205
226
|
df = df.open(pd.DataFrame).all()
|
|
206
227
|
return df
|
|
207
228
|
|
|
@@ -244,6 +265,9 @@ def flatten_dict(sub_dict: dict, parent_key: str = "") -> typing.Dict:
|
|
|
244
265
|
fields = getattr(value, "__dataclass_fields__")
|
|
245
266
|
d = {k: v.type for k, v in fields.items()}
|
|
246
267
|
result.update(flatten_dict(sub_dict=d, parent_key=current_key))
|
|
268
|
+
elif hasattr(value, "model_fields"): # Pydantic model
|
|
269
|
+
d = {k: v.annotation for k, v in value.model_fields.items()}
|
|
270
|
+
result.update(flatten_dict(sub_dict=d, parent_key=current_key))
|
|
247
271
|
else:
|
|
248
272
|
result[current_key] = value
|
|
249
273
|
return result
|
|
@@ -708,16 +732,16 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
708
732
|
# return DataFrame(uri=uri)
|
|
709
733
|
if python_val.val is None:
|
|
710
734
|
uri = python_val.uri
|
|
711
|
-
|
|
735
|
+
format_val = python_val.format
|
|
712
736
|
|
|
713
737
|
# Check the user-specified uri
|
|
714
738
|
if not uri:
|
|
715
739
|
raise ValueError(f"If dataframe is not specified, then the uri should be specified. {python_val}")
|
|
716
740
|
if not storage.is_remote(uri):
|
|
717
|
-
uri = await storage.put(uri)
|
|
741
|
+
uri = await storage.put(uri, recursive=True)
|
|
718
742
|
|
|
719
|
-
# Check the user-specified
|
|
720
|
-
# When users specify
|
|
743
|
+
# Check the user-specified format
|
|
744
|
+
# When users specify format for a DataFrame, the format should be retained
|
|
721
745
|
# conditionally. For details, please refer to https://github.com/flyteorg/flyte/issues/6096.
|
|
722
746
|
# Following illustrates why we can't always copy the user-specified file_format over:
|
|
723
747
|
#
|
|
@@ -725,14 +749,14 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
725
749
|
# def modify_format(df: Annotated[DataFrame, {}, "task-format"]) -> DataFrame:
|
|
726
750
|
# return df
|
|
727
751
|
#
|
|
728
|
-
# df = DataFrame(uri="s3://my-s3-bucket/df.parquet",
|
|
752
|
+
# df = DataFrame(uri="s3://my-s3-bucket/df.parquet", format="user-format")
|
|
729
753
|
# df2 = modify_format(df=df)
|
|
730
754
|
#
|
|
731
|
-
# In this case, we expect the df2.
|
|
732
|
-
# not user-format. If we directly copy the user-specified
|
|
755
|
+
# In this case, we expect the df2.format to be task-format (as shown in Annotated),
|
|
756
|
+
# not user-format. If we directly copy the user-specified format over,
|
|
733
757
|
# the type hint information will be missing.
|
|
734
|
-
if sdt.format == GENERIC_FORMAT and
|
|
735
|
-
sdt.format =
|
|
758
|
+
if sdt.format == GENERIC_FORMAT and format_val != GENERIC_FORMAT:
|
|
759
|
+
sdt.format = format_val
|
|
736
760
|
|
|
737
761
|
sd_model = literals_pb2.StructuredDataset(
|
|
738
762
|
uri=uri,
|
|
@@ -760,8 +784,9 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
760
784
|
structured_dataset_type=expected.structured_dataset_type if expected else None
|
|
761
785
|
)
|
|
762
786
|
|
|
763
|
-
|
|
764
|
-
|
|
787
|
+
fdf = DataFrame.from_df(val=python_val)
|
|
788
|
+
fdf._metadata = meta
|
|
789
|
+
return await self.encode(fdf, python_type, protocol, fmt, sdt)
|
|
765
790
|
|
|
766
791
|
def _protocol_from_type_or_prefix(self, df_type: Type, uri: Optional[str] = None) -> str:
|
|
767
792
|
"""
|
|
@@ -782,7 +807,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
782
807
|
|
|
783
808
|
async def encode(
|
|
784
809
|
self,
|
|
785
|
-
|
|
810
|
+
df: DataFrame,
|
|
786
811
|
df_type: Type,
|
|
787
812
|
protocol: str,
|
|
788
813
|
format: str,
|
|
@@ -791,7 +816,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
791
816
|
handler: DataFrameEncoder
|
|
792
817
|
handler = self.get_encoder(df_type, protocol, format)
|
|
793
818
|
|
|
794
|
-
sd_model = await handler.encode(
|
|
819
|
+
sd_model = await handler.encode(df, structured_literal_type)
|
|
795
820
|
# This block is here in case the encoder did not set the type information in the metadata. Since this literal
|
|
796
821
|
# is special in that it carries around the type itself, we want to make sure the type info therein is at
|
|
797
822
|
# least as good as the type of the interface.
|
|
@@ -807,72 +832,13 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
807
832
|
lit = literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=sd_model))
|
|
808
833
|
|
|
809
834
|
# Because the handler.encode may have uploaded something, and because the sd may end up living inside a
|
|
810
|
-
# dataclass, we need to modify any uploaded flyte:// urls here.
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
835
|
+
# dataclass, we need to modify any uploaded flyte:// urls here. Needed here even though the Type engine
|
|
836
|
+
# already does this because the DataframeTransformerEngine may be called directly.
|
|
837
|
+
modify_literal_uris(lit)
|
|
838
|
+
df._literal_sd = sd_model
|
|
839
|
+
df._already_uploaded = True
|
|
814
840
|
return lit
|
|
815
841
|
|
|
816
|
-
# pr: han-ru: can this be removed if we make DataFrame a pydantic model?
|
|
817
|
-
def dict_to_dataframe(
|
|
818
|
-
self, dict_obj: typing.Dict[str, str], expected_python_type: Type[T] | DataFrame
|
|
819
|
-
) -> T | DataFrame:
|
|
820
|
-
uri = dict_obj.get("uri", None)
|
|
821
|
-
file_format = dict_obj.get("file_format", None)
|
|
822
|
-
|
|
823
|
-
if uri is None:
|
|
824
|
-
raise ValueError("DataFrame's uri and file format should not be None")
|
|
825
|
-
|
|
826
|
-
# Instead of using python native DataFrame, we need to build a literals.StructuredDataset
|
|
827
|
-
# The reason is that _literal_sd of python sd is accessed when task output LiteralMap is
|
|
828
|
-
# converted back to flyteidl. Hence, _literal_sd must have to_flyte_idl method
|
|
829
|
-
# See https://github.com/flyteorg/flytekit/blob/f938661ff8413219d1bea77f6914a58c302d5c6c/flytekit/bin/entrypoint.py#L326
|
|
830
|
-
# For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5956.
|
|
831
|
-
sdt = types_pb2.StructuredDatasetType(format=file_format)
|
|
832
|
-
metad = literals_pb2.StructuredDatasetMetadata(structured_dataset_type=sdt)
|
|
833
|
-
sd_literal = literals_pb2.StructuredDataset(uri=uri, metadata=metad)
|
|
834
|
-
|
|
835
|
-
return asyncio.run(
|
|
836
|
-
DataFrameTransformerEngine().to_python_value(
|
|
837
|
-
literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=sd_literal)),
|
|
838
|
-
expected_python_type,
|
|
839
|
-
)
|
|
840
|
-
)
|
|
841
|
-
|
|
842
|
-
def from_binary_idl(
|
|
843
|
-
self, binary_idl_object: literals_pb2.Binary, expected_python_type: Type[T] | DataFrame
|
|
844
|
-
) -> T | DataFrame:
|
|
845
|
-
"""
|
|
846
|
-
If the input is from flytekit, the Life Cycle will be as follows:
|
|
847
|
-
|
|
848
|
-
Life Cycle:
|
|
849
|
-
binary IDL -> resolved binary -> bytes -> expected Python object
|
|
850
|
-
(flytekit customized (propeller processing) (flytekit binary IDL) (flytekit customized
|
|
851
|
-
serialization) deserialization)
|
|
852
|
-
|
|
853
|
-
Example Code:
|
|
854
|
-
@dataclass
|
|
855
|
-
class DC:
|
|
856
|
-
sd: StructuredDataset
|
|
857
|
-
|
|
858
|
-
@workflow
|
|
859
|
-
def wf(dc: DC):
|
|
860
|
-
t_sd(dc.sd)
|
|
861
|
-
|
|
862
|
-
Note:
|
|
863
|
-
- The deserialization is the same as put a structured dataset in a dataclass,
|
|
864
|
-
which will deserialize by the mashumaro's API.
|
|
865
|
-
|
|
866
|
-
Related PR:
|
|
867
|
-
- Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro
|
|
868
|
-
- Link: https://github.com/flyteorg/flytekit/pull/2554
|
|
869
|
-
"""
|
|
870
|
-
if binary_idl_object.tag == MESSAGEPACK:
|
|
871
|
-
python_val = msgpack.loads(binary_idl_object.value)
|
|
872
|
-
return self.dict_to_dataframe(dict_obj=python_val, expected_python_type=expected_python_type)
|
|
873
|
-
else:
|
|
874
|
-
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")
|
|
875
|
-
|
|
876
842
|
async def to_python_value(
|
|
877
843
|
self, lv: literals_pb2.Literal, expected_python_type: Type[T] | DataFrame
|
|
878
844
|
) -> T | DataFrame:
|
|
@@ -906,9 +872,8 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
906
872
|
| | the running task's signature. | |
|
|
907
873
|
+-----------------------------+-----------------------------------------+--------------------------------------+
|
|
908
874
|
"""
|
|
909
|
-
# Handle dataclass attribute access
|
|
910
875
|
if lv.HasField("scalar") and lv.scalar.HasField("binary"):
|
|
911
|
-
|
|
876
|
+
raise TypeTransformerFailedError("Attribute access unsupported.")
|
|
912
877
|
|
|
913
878
|
# Detect annotations and extract out all the relevant information that the user might supply
|
|
914
879
|
expected_python_type, column_dict, storage_fmt, pa_schema = extract_cols_and_format(expected_python_type)
|
|
@@ -939,16 +904,12 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
|
|
|
939
904
|
# t1(input_a: DataFrame) # or
|
|
940
905
|
# t1(input_a: Annotated[DataFrame, my_cols])
|
|
941
906
|
if issubclass(expected_python_type, DataFrame):
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
)
|
|
947
|
-
sd._literal_sd = lv.scalar.structured_dataset
|
|
948
|
-
sd.file_format = metad.structured_dataset_type.format
|
|
949
|
-
return sd
|
|
907
|
+
fdf = DataFrame(format=metad.structured_dataset_type.format)
|
|
908
|
+
fdf._literal_sd = lv.scalar.structured_dataset
|
|
909
|
+
fdf._metadata = metad
|
|
910
|
+
return fdf
|
|
950
911
|
|
|
951
|
-
# If the requested type was not a
|
|
912
|
+
# If the requested type was not a flyte.DataFrame, then it means it was a raw dataframe type, which means
|
|
952
913
|
# we should do the opening/downloading and whatever else it might entail right now. No iteration option here.
|
|
953
914
|
return await self.open_as(lv.scalar.structured_dataset, df_type=expected_python_type, updated_metadata=metad)
|
|
954
915
|
|
flyte/remote/_task.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import functools
|
|
4
5
|
from dataclasses import dataclass
|
|
5
|
-
from threading import Lock
|
|
6
6
|
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Iterator, Literal, Optional, Tuple, Union, cast
|
|
7
7
|
|
|
8
8
|
import rich.repr
|
|
@@ -49,7 +49,7 @@ class LazyEntity:
|
|
|
49
49
|
self._task: Optional[TaskDetails] = None
|
|
50
50
|
self._getter = getter
|
|
51
51
|
self._name = name
|
|
52
|
-
self._mutex = Lock()
|
|
52
|
+
self._mutex = asyncio.Lock()
|
|
53
53
|
|
|
54
54
|
@property
|
|
55
55
|
def name(self) -> str:
|
|
@@ -60,11 +60,11 @@ class LazyEntity:
|
|
|
60
60
|
"""
|
|
61
61
|
Forwards all other attributes to task, causing the task to be fetched!
|
|
62
62
|
"""
|
|
63
|
-
with self._mutex:
|
|
63
|
+
async with self._mutex:
|
|
64
64
|
if self._task is None:
|
|
65
65
|
self._task = await self._getter()
|
|
66
|
-
|
|
67
|
-
|
|
66
|
+
if self._task is None:
|
|
67
|
+
raise RuntimeError(f"Error downloading the task {self._name}, (check original exception...)")
|
|
68
68
|
return self._task
|
|
69
69
|
|
|
70
70
|
@syncify
|
|
@@ -73,8 +73,10 @@ class LazyEntity:
|
|
|
73
73
|
**kwargs: Any,
|
|
74
74
|
) -> LazyEntity:
|
|
75
75
|
task_details = cast(TaskDetails, await self.fetch.aio())
|
|
76
|
-
task_details.override(**kwargs)
|
|
77
|
-
|
|
76
|
+
new_task_details = task_details.override(**kwargs)
|
|
77
|
+
new_entity = LazyEntity(self._name, self._getter)
|
|
78
|
+
new_entity._task = new_task_details
|
|
79
|
+
return new_entity
|
|
78
80
|
|
|
79
81
|
async def __call__(self, *args, **kwargs):
|
|
80
82
|
"""
|
|
@@ -93,7 +95,7 @@ class LazyEntity:
|
|
|
93
95
|
AutoVersioning = Literal["latest", "current"]
|
|
94
96
|
|
|
95
97
|
|
|
96
|
-
@dataclass
|
|
98
|
+
@dataclass(frozen=True)
|
|
97
99
|
class TaskDetails(ToJSONMixin):
|
|
98
100
|
pb2: task_definition_pb2.TaskDetails
|
|
99
101
|
max_inline_io_bytes: int = 10 * 1024 * 1024 # 10 MB
|
|
@@ -261,12 +263,6 @@ class TaskDetails(ToJSONMixin):
|
|
|
261
263
|
f"Reference task {self.name} does not support positional arguments"
|
|
262
264
|
f"currently. Please use keyword arguments."
|
|
263
265
|
)
|
|
264
|
-
if len(self.required_args) > 0:
|
|
265
|
-
if len(args) + len(kwargs) < len(self.required_args):
|
|
266
|
-
raise ValueError(
|
|
267
|
-
f"Task {self.name} requires at least {self.required_args} arguments, "
|
|
268
|
-
f"but only received args:{args} kwargs{kwargs}."
|
|
269
|
-
)
|
|
270
266
|
|
|
271
267
|
ctx = internal_ctx()
|
|
272
268
|
if ctx.is_task_context():
|
|
@@ -276,9 +272,17 @@ class TaskDetails(ToJSONMixin):
|
|
|
276
272
|
from flyte._internal.controllers import get_controller
|
|
277
273
|
|
|
278
274
|
controller = get_controller()
|
|
275
|
+
if len(self.required_args) > 0:
|
|
276
|
+
if len(args) + len(kwargs) < len(self.required_args):
|
|
277
|
+
raise ValueError(
|
|
278
|
+
f"Task {self.name} requires at least {self.required_args} arguments, "
|
|
279
|
+
f"but only received args:{args} kwargs{kwargs}."
|
|
280
|
+
)
|
|
279
281
|
if controller:
|
|
280
|
-
return await controller.submit_task_ref(self
|
|
281
|
-
raise flyte.errors
|
|
282
|
+
return await controller.submit_task_ref(self, *args, **kwargs)
|
|
283
|
+
raise flyte.errors.ReferenceTaskError(
|
|
284
|
+
f"Reference tasks [{self.name}] cannot be executed locally, only remotely."
|
|
285
|
+
)
|
|
282
286
|
|
|
283
287
|
def override(
|
|
284
288
|
self,
|
|
@@ -289,6 +293,8 @@ class TaskDetails(ToJSONMixin):
|
|
|
289
293
|
timeout: Optional[flyte.TimeoutType] = None,
|
|
290
294
|
env_vars: Optional[Dict[str, str]] = None,
|
|
291
295
|
secrets: Optional[flyte.SecretRequest] = None,
|
|
296
|
+
max_inline_io_bytes: Optional[int] = None,
|
|
297
|
+
cache: Optional[flyte.Cache] = None,
|
|
292
298
|
**kwargs: Any,
|
|
293
299
|
) -> TaskDetails:
|
|
294
300
|
if len(kwargs) > 0:
|
|
@@ -296,23 +302,47 @@ class TaskDetails(ToJSONMixin):
|
|
|
296
302
|
f"ReferenceTasks [{self.name}] do not support overriding with kwargs: {kwargs}, "
|
|
297
303
|
f"Check the parameters for override method."
|
|
298
304
|
)
|
|
299
|
-
|
|
305
|
+
pb2 = task_definition_pb2.TaskDetails()
|
|
306
|
+
pb2.CopyFrom(self.pb2)
|
|
307
|
+
|
|
300
308
|
if short_name:
|
|
301
|
-
|
|
309
|
+
pb2.metadata.short_name = short_name
|
|
310
|
+
|
|
311
|
+
template = pb2.spec.task_template
|
|
302
312
|
if secrets:
|
|
303
313
|
template.security_context.CopyFrom(get_security_context(secrets))
|
|
314
|
+
|
|
304
315
|
if template.HasField("container"):
|
|
305
316
|
if env_vars:
|
|
306
317
|
template.container.env.clear()
|
|
307
318
|
template.container.env.extend([literals_pb2.KeyValuePair(key=k, value=v) for k, v in env_vars.items()])
|
|
308
319
|
if resources:
|
|
309
320
|
template.container.resources.CopyFrom(get_proto_resources(resources))
|
|
321
|
+
|
|
322
|
+
md = template.metadata
|
|
310
323
|
if retries:
|
|
311
|
-
|
|
312
|
-
if timeout:
|
|
313
|
-
template.metadata.timeout.CopyFrom(get_proto_timeout(timeout))
|
|
324
|
+
md.retries.CopyFrom(get_proto_retry_strategy(retries))
|
|
314
325
|
|
|
315
|
-
|
|
326
|
+
if timeout:
|
|
327
|
+
md.timeout.CopyFrom(get_proto_timeout(timeout))
|
|
328
|
+
|
|
329
|
+
if cache:
|
|
330
|
+
if cache.behavior == "disable":
|
|
331
|
+
md.discoverable = False
|
|
332
|
+
md.discovery_version = ""
|
|
333
|
+
elif cache.behavior == "override":
|
|
334
|
+
md.discoverable = True
|
|
335
|
+
if not cache.version_override:
|
|
336
|
+
raise ValueError("cache.version_override must be set when cache.behavior is 'override'")
|
|
337
|
+
md.discovery_version = cache.version_override
|
|
338
|
+
else:
|
|
339
|
+
if cache.behavior == "auto":
|
|
340
|
+
raise ValueError("cache.behavior must be 'disable' or 'override' for reference tasks")
|
|
341
|
+
raise ValueError(f"Invalid cache behavior: {cache.behavior}.")
|
|
342
|
+
md.cache_serializable = cache.serialize
|
|
343
|
+
md.cache_ignore_input_vars[:] = list(cache.ignored_inputs or ())
|
|
344
|
+
|
|
345
|
+
return TaskDetails(pb2, max_inline_io_bytes=max_inline_io_bytes or self.max_inline_io_bytes)
|
|
316
346
|
|
|
317
347
|
def __rich_repr__(self) -> rich.repr.Result:
|
|
318
348
|
"""
|
flyte/report/_report.py
CHANGED
|
@@ -4,7 +4,6 @@ import string
|
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
5
|
from typing import TYPE_CHECKING, Dict, List, Union
|
|
6
6
|
|
|
7
|
-
from flyte._internal.runtime import io
|
|
8
7
|
from flyte._logging import logger
|
|
9
8
|
from flyte._tools import ipython_check
|
|
10
9
|
from flyte.syncify import syncify
|
|
@@ -133,6 +132,7 @@ async def flush():
|
|
|
133
132
|
"""
|
|
134
133
|
import flyte.storage as storage
|
|
135
134
|
from flyte._context import internal_ctx
|
|
135
|
+
from flyte._internal.runtime import io
|
|
136
136
|
|
|
137
137
|
if not internal_ctx().is_task_context():
|
|
138
138
|
return
|
flyte/types/_type_engine.py
CHANGED
|
@@ -455,35 +455,6 @@ class DataclassTransformer(TypeTransformer[object]):
|
|
|
455
455
|
2. Deserialization: The dataclass transformer converts the MessagePack Bytes back to a dataclass.
|
|
456
456
|
(1) Convert MessagePack Bytes to a dataclass using mashumaro.
|
|
457
457
|
(2) Handle dataclass attributes to ensure they are of the correct types.
|
|
458
|
-
|
|
459
|
-
TODO: Update the example using mashumaro instead of the older library
|
|
460
|
-
|
|
461
|
-
Example
|
|
462
|
-
|
|
463
|
-
.. code-block:: python
|
|
464
|
-
|
|
465
|
-
@dataclass
|
|
466
|
-
class Test:
|
|
467
|
-
a: int
|
|
468
|
-
b: str
|
|
469
|
-
|
|
470
|
-
t = Test(a=10,b="e")
|
|
471
|
-
JSONSchema().dump(t.schema())
|
|
472
|
-
|
|
473
|
-
Output will look like
|
|
474
|
-
|
|
475
|
-
.. code-block:: json
|
|
476
|
-
|
|
477
|
-
{'$schema': 'http://json-schema.org/draft-07/schema#',
|
|
478
|
-
'definitions': {'TestSchema': {'properties': {'a': {'title': 'a',
|
|
479
|
-
'type': 'number',
|
|
480
|
-
'format': 'integer'},
|
|
481
|
-
'b': {'title': 'b', 'type': 'string'}},
|
|
482
|
-
'type': 'object',
|
|
483
|
-
'additionalProperties': False}},
|
|
484
|
-
'$ref': '#/definitions/TestSchema'}
|
|
485
|
-
|
|
486
|
-
|
|
487
458
|
"""
|
|
488
459
|
|
|
489
460
|
def __init__(self) -> None:
|
|
@@ -615,7 +586,7 @@ class DataclassTransformer(TypeTransformer[object]):
|
|
|
615
586
|
}
|
|
616
587
|
)
|
|
617
588
|
|
|
618
|
-
# The type engine used to publish
|
|
589
|
+
# The type engine used to publish the type `structure` for attribute access. As of v2, this is no longer needed.
|
|
619
590
|
return types_pb2.LiteralType(
|
|
620
591
|
simple=types_pb2.SimpleType.STRUCT,
|
|
621
592
|
metadata=schema,
|
|
@@ -101,7 +101,6 @@ def main(
|
|
|
101
101
|
from flyte._logging import logger
|
|
102
102
|
from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
|
|
103
103
|
|
|
104
|
-
logger.warning(f"Flyte runtime started for action {name} with run name {run_name}")
|
|
105
104
|
logger.info("Registering faulthandler for SIGUSR1")
|
|
106
105
|
faulthandler.register(signal.SIGUSR1)
|
|
107
106
|
|
|
@@ -117,6 +116,8 @@ def main(
|
|
|
117
116
|
if name.startswith("{{"):
|
|
118
117
|
name = os.getenv("ACTION_NAME", "")
|
|
119
118
|
|
|
119
|
+
logger.warning(f"Flyte runtime started for action {name} with run name {run_name}")
|
|
120
|
+
|
|
120
121
|
if debug and name == "a0":
|
|
121
122
|
from flyte._debug.vscode import _start_vscode_server
|
|
122
123
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: flyte
|
|
3
|
-
Version: 2.0.
|
|
3
|
+
Version: 2.0.0b20
|
|
4
4
|
Summary: Add your description here
|
|
5
5
|
Author-email: Ketan Umare <kumare3@users.noreply.github.com>
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -24,6 +24,7 @@ Requires-Dist: toml>=0.10.2
|
|
|
24
24
|
Requires-Dist: async-lru>=2.0.5
|
|
25
25
|
Requires-Dist: mashumaro
|
|
26
26
|
Requires-Dist: dataclasses_json
|
|
27
|
+
Requires-Dist: aiolimiter>=1.2.1
|
|
27
28
|
Dynamic: license-file
|
|
28
29
|
|
|
29
30
|
# Flyte 2 SDK 🚀
|