flyte 2.0.0b18__py3-none-any.whl → 2.0.0b20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

@@ -1,20 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import _datetime
4
- import asyncio
5
4
  import collections
6
5
  import types
7
6
  import typing
8
7
  from abc import ABC, abstractmethod
9
- from dataclasses import dataclass, field, is_dataclass
8
+ from dataclasses import is_dataclass
10
9
  from typing import Any, ClassVar, Coroutine, Dict, Generic, List, Optional, Type, Union
11
10
 
12
- import msgpack
13
11
  from flyteidl.core import literals_pb2, types_pb2
14
12
  from fsspec.utils import get_protocol
15
- from mashumaro.mixins.json import DataClassJSONMixin
16
13
  from mashumaro.types import SerializableType
17
- from pydantic import model_serializer, model_validator
14
+ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_serializer, model_validator
18
15
  from typing_extensions import Annotated, TypeAlias, get_args, get_origin
19
16
 
20
17
  import flyte.storage as storage
@@ -48,15 +45,23 @@ GENERIC_FORMAT: DataFrameFormat = ""
48
45
  GENERIC_PROTOCOL: str = "generic protocol"
49
46
 
50
47
 
51
- @dataclass
52
- class DataFrame(SerializableType, DataClassJSONMixin):
48
+ class DataFrame(BaseModel, SerializableType):
53
49
  """
54
50
  This is the user facing DataFrame class. Please don't confuse it with the literals.StructuredDataset
55
51
  class (that is just a model, a Python class representation of the protobuf).
56
52
  """
57
53
 
58
- uri: typing.Optional[str] = field(default=None)
59
- file_format: typing.Optional[str] = field(default=GENERIC_FORMAT)
54
+ uri: typing.Optional[str] = Field(default=None)
55
+ format: typing.Optional[str] = Field(default=GENERIC_FORMAT)
56
+
57
+ model_config = ConfigDict(arbitrary_types_allowed=True)
58
+
59
+ # Private attributes that are not part of the Pydantic model schema
60
+ _raw_df: typing.Optional[typing.Any] = PrivateAttr(default=None)
61
+ _metadata: typing.Optional[literals_pb2.StructuredDatasetMetadata] = PrivateAttr(default=None)
62
+ _literal_sd: Optional[literals_pb2.StructuredDataset] = PrivateAttr(default=None)
63
+ _dataframe_type: Optional[Type[Any]] = PrivateAttr(default=None)
64
+ _already_uploaded: bool = PrivateAttr(default=False)
60
65
 
61
66
  # loop manager is working better than synchronicity for some reason, was getting an error but may be an easy fix
62
67
  def _serialize(self) -> Dict[str, Optional[str]]:
@@ -65,16 +70,16 @@ class DataFrame(SerializableType, DataClassJSONMixin):
65
70
  engine = DataFrameTransformerEngine()
66
71
  lv = loop_manager.run_sync(engine.to_literal, self, type(self), lt)
67
72
  sd = DataFrame(uri=lv.scalar.structured_dataset.uri)
68
- sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
73
+ sd.format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
69
74
  return {
70
75
  "uri": sd.uri,
71
- "file_format": sd.file_format,
76
+ "format": sd.format,
72
77
  }
73
78
 
74
79
  @classmethod
75
- def _deserialize(cls, value) -> "DataFrame":
80
+ def _deserialize(cls, value) -> DataFrame:
76
81
  uri = value.get("uri", None)
77
- file_format = value.get("file_format", None)
82
+ format_val = value.get("format", None)
78
83
 
79
84
  if uri is None:
80
85
  raise ValueError("DataFrame's uri and file format should not be None")
@@ -86,7 +91,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
86
91
  scalar=literals_pb2.Scalar(
87
92
  structured_dataset=literals_pb2.StructuredDataset(
88
93
  metadata=literals_pb2.StructuredDatasetMetadata(
89
- structured_dataset_type=types_pb2.StructuredDatasetType(format=file_format)
94
+ structured_dataset_type=types_pb2.StructuredDatasetType(format=format_val)
90
95
  ),
91
96
  uri=uri,
92
97
  )
@@ -102,7 +107,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
102
107
  lv = loop_manager.run_sync(sde.to_literal, self, type(self), lt)
103
108
  return {
104
109
  "uri": lv.scalar.structured_dataset.uri,
105
- "file_format": lv.scalar.structured_dataset.metadata.structured_dataset_type.format,
110
+ "format": lv.scalar.structured_dataset.metadata.structured_dataset_type.format,
106
111
  }
107
112
 
108
113
  @model_validator(mode="after")
@@ -117,7 +122,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
117
122
  scalar=literals_pb2.Scalar(
118
123
  structured_dataset=literals_pb2.StructuredDataset(
119
124
  metadata=literals_pb2.StructuredDatasetMetadata(
120
- structured_dataset_type=types_pb2.StructuredDatasetType(format=self.file_format)
125
+ structured_dataset_type=types_pb2.StructuredDatasetType(format=self.format)
121
126
  ),
122
127
  uri=self.uri,
123
128
  )
@@ -134,30 +139,46 @@ class DataFrame(SerializableType, DataClassJSONMixin):
134
139
  def column_names(cls) -> typing.List[str]:
135
140
  return [k for k, v in cls.columns().items()]
136
141
 
137
- def __init__(
138
- self,
142
+ @classmethod
143
+ def from_df(
144
+ cls,
139
145
  val: typing.Optional[typing.Any] = None,
140
146
  uri: typing.Optional[str] = None,
141
- metadata: typing.Optional[literals_pb2.StructuredDatasetMetadata] = None,
147
+ ) -> DataFrame:
148
+ """
149
+ Wrapper to create a DataFrame from a dataframe.
150
+ The reason this is implemented as a wrapper instead of a full translation invoking
151
+ the type engine and the encoders is because there's too much information in the type
152
+ signature of the task that we don't want the user to have to replicate.
153
+ """
154
+ instance = cls(uri=uri)
155
+ instance._raw_df = val
156
+ return instance
157
+
158
+ @classmethod
159
+ def from_existing_remote(
160
+ cls,
161
+ remote_path: str,
162
+ format: typing.Optional[str] = None,
142
163
  **kwargs,
143
- ):
144
- self._val = val
145
- # Make these fields public, so that the dataclass transformer can set a value for it
146
- # https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298
147
- self.uri = uri
148
- # When dataclass_json runs from_json, we need to set it here, otherwise the format will be empty string
149
- self.file_format = kwargs["file_format"] if "file_format" in kwargs else GENERIC_FORMAT
150
- # This is a special attribute that indicates if the data was either downloaded or uploaded
151
- self._metadata = metadata
152
- # This is not for users to set, the transformer will set this.
153
- self._literal_sd: Optional[literals_pb2.StructuredDataset] = None
154
- # Not meant for users to set, will be set by an open() call
155
- self._dataframe_type: Optional[DF] = None # type: ignore
156
- self._already_uploaded = False
164
+ ) -> "DataFrame":
165
+ """
166
+ Create a DataFrame reference from an existing remote dataframe.
167
+
168
+ Args:
169
+ remote_path: The remote path to the existing dataframe
170
+ format: Format of the stored dataframe
171
+
172
+ Example:
173
+ ```python
174
+ df = DataFrame.from_existing_remote("s3://bucket/data.parquet", format="parquet")
175
+ ```
176
+ """
177
+ return cls(uri=remote_path, format=format or GENERIC_FORMAT, **kwargs)
157
178
 
158
179
  @property
159
180
  def val(self) -> Optional[DF]:
160
- return self._val
181
+ return self._raw_df
161
182
 
162
183
  @property
163
184
  def metadata(self) -> Optional[literals_pb2.StructuredDatasetMetadata]:
@@ -201,7 +222,7 @@ class DataFrame(SerializableType, DataClassJSONMixin):
201
222
 
202
223
  @task
203
224
  def return_df() -> DataFrame:
204
- df = DataFrame(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet", file_format="parquet")
225
+ df = DataFrame(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet", format="parquet")
205
226
  df = df.open(pd.DataFrame).all()
206
227
  return df
207
228
 
@@ -244,6 +265,9 @@ def flatten_dict(sub_dict: dict, parent_key: str = "") -> typing.Dict:
244
265
  fields = getattr(value, "__dataclass_fields__")
245
266
  d = {k: v.type for k, v in fields.items()}
246
267
  result.update(flatten_dict(sub_dict=d, parent_key=current_key))
268
+ elif hasattr(value, "model_fields"): # Pydantic model
269
+ d = {k: v.annotation for k, v in value.model_fields.items()}
270
+ result.update(flatten_dict(sub_dict=d, parent_key=current_key))
247
271
  else:
248
272
  result[current_key] = value
249
273
  return result
@@ -708,16 +732,16 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
708
732
  # return DataFrame(uri=uri)
709
733
  if python_val.val is None:
710
734
  uri = python_val.uri
711
- file_format = python_val.file_format
735
+ format_val = python_val.format
712
736
 
713
737
  # Check the user-specified uri
714
738
  if not uri:
715
739
  raise ValueError(f"If dataframe is not specified, then the uri should be specified. {python_val}")
716
740
  if not storage.is_remote(uri):
717
- uri = await storage.put(uri)
741
+ uri = await storage.put(uri, recursive=True)
718
742
 
719
- # Check the user-specified file_format
720
- # When users specify file_format for a DataFrame, the file_format should be retained
743
+ # Check the user-specified format
744
+ # When users specify format for a DataFrame, the format should be retained
721
745
  # conditionally. For details, please refer to https://github.com/flyteorg/flyte/issues/6096.
722
746
  # Following illustrates why we can't always copy the user-specified file_format over:
723
747
  #
@@ -725,14 +749,14 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
725
749
  # def modify_format(df: Annotated[DataFrame, {}, "task-format"]) -> DataFrame:
726
750
  # return df
727
751
  #
728
- # df = DataFrame(uri="s3://my-s3-bucket/df.parquet", file_format="user-format")
752
+ # df = DataFrame(uri="s3://my-s3-bucket/df.parquet", format="user-format")
729
753
  # df2 = modify_format(df=df)
730
754
  #
731
- # In this case, we expect the df2.file_format to be task-format (as shown in Annotated),
732
- # not user-format. If we directly copy the user-specified file_format over,
755
+ # In this case, we expect the df2.format to be task-format (as shown in Annotated),
756
+ # not user-format. If we directly copy the user-specified format over,
733
757
  # the type hint information will be missing.
734
- if sdt.format == GENERIC_FORMAT and file_format != GENERIC_FORMAT:
735
- sdt.format = file_format
758
+ if sdt.format == GENERIC_FORMAT and format_val != GENERIC_FORMAT:
759
+ sdt.format = format_val
736
760
 
737
761
  sd_model = literals_pb2.StructuredDataset(
738
762
  uri=uri,
@@ -760,8 +784,9 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
760
784
  structured_dataset_type=expected.structured_dataset_type if expected else None
761
785
  )
762
786
 
763
- sd = DataFrame(val=python_val, metadata=meta)
764
- return await self.encode(sd, python_type, protocol, fmt, sdt)
787
+ fdf = DataFrame.from_df(val=python_val)
788
+ fdf._metadata = meta
789
+ return await self.encode(fdf, python_type, protocol, fmt, sdt)
765
790
 
766
791
  def _protocol_from_type_or_prefix(self, df_type: Type, uri: Optional[str] = None) -> str:
767
792
  """
@@ -782,7 +807,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
782
807
 
783
808
  async def encode(
784
809
  self,
785
- sd: DataFrame,
810
+ df: DataFrame,
786
811
  df_type: Type,
787
812
  protocol: str,
788
813
  format: str,
@@ -791,7 +816,7 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
791
816
  handler: DataFrameEncoder
792
817
  handler = self.get_encoder(df_type, protocol, format)
793
818
 
794
- sd_model = await handler.encode(sd, structured_literal_type)
819
+ sd_model = await handler.encode(df, structured_literal_type)
795
820
  # This block is here in case the encoder did not set the type information in the metadata. Since this literal
796
821
  # is special in that it carries around the type itself, we want to make sure the type info therein is at
797
822
  # least as good as the type of the interface.
@@ -807,72 +832,13 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
807
832
  lit = literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=sd_model))
808
833
 
809
834
  # Because the handler.encode may have uploaded something, and because the sd may end up living inside a
810
- # dataclass, we need to modify any uploaded flyte:// urls here.
811
- modify_literal_uris(lit) # todo: verify that this can be removed.
812
- sd._literal_sd = sd_model
813
- sd._already_uploaded = True
835
+ # dataclass, we need to modify any uploaded flyte:// urls here. Needed here even though the Type engine
836
+ # already does this because the DataframeTransformerEngine may be called directly.
837
+ modify_literal_uris(lit)
838
+ df._literal_sd = sd_model
839
+ df._already_uploaded = True
814
840
  return lit
815
841
 
816
- # pr: han-ru: can this be removed if we make DataFrame a pydantic model?
817
- def dict_to_dataframe(
818
- self, dict_obj: typing.Dict[str, str], expected_python_type: Type[T] | DataFrame
819
- ) -> T | DataFrame:
820
- uri = dict_obj.get("uri", None)
821
- file_format = dict_obj.get("file_format", None)
822
-
823
- if uri is None:
824
- raise ValueError("DataFrame's uri and file format should not be None")
825
-
826
- # Instead of using python native DataFrame, we need to build a literals.StructuredDataset
827
- # The reason is that _literal_sd of python sd is accessed when task output LiteralMap is
828
- # converted back to flyteidl. Hence, _literal_sd must have to_flyte_idl method
829
- # See https://github.com/flyteorg/flytekit/blob/f938661ff8413219d1bea77f6914a58c302d5c6c/flytekit/bin/entrypoint.py#L326
830
- # For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5956.
831
- sdt = types_pb2.StructuredDatasetType(format=file_format)
832
- metad = literals_pb2.StructuredDatasetMetadata(structured_dataset_type=sdt)
833
- sd_literal = literals_pb2.StructuredDataset(uri=uri, metadata=metad)
834
-
835
- return asyncio.run(
836
- DataFrameTransformerEngine().to_python_value(
837
- literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=sd_literal)),
838
- expected_python_type,
839
- )
840
- )
841
-
842
- def from_binary_idl(
843
- self, binary_idl_object: literals_pb2.Binary, expected_python_type: Type[T] | DataFrame
844
- ) -> T | DataFrame:
845
- """
846
- If the input is from flytekit, the Life Cycle will be as follows:
847
-
848
- Life Cycle:
849
- binary IDL -> resolved binary -> bytes -> expected Python object
850
- (flytekit customized (propeller processing) (flytekit binary IDL) (flytekit customized
851
- serialization) deserialization)
852
-
853
- Example Code:
854
- @dataclass
855
- class DC:
856
- sd: StructuredDataset
857
-
858
- @workflow
859
- def wf(dc: DC):
860
- t_sd(dc.sd)
861
-
862
- Note:
863
- - The deserialization is the same as put a structured dataset in a dataclass,
864
- which will deserialize by the mashumaro's API.
865
-
866
- Related PR:
867
- - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro
868
- - Link: https://github.com/flyteorg/flytekit/pull/2554
869
- """
870
- if binary_idl_object.tag == MESSAGEPACK:
871
- python_val = msgpack.loads(binary_idl_object.value)
872
- return self.dict_to_dataframe(dict_obj=python_val, expected_python_type=expected_python_type)
873
- else:
874
- raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")
875
-
876
842
  async def to_python_value(
877
843
  self, lv: literals_pb2.Literal, expected_python_type: Type[T] | DataFrame
878
844
  ) -> T | DataFrame:
@@ -906,9 +872,8 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
906
872
  | | the running task's signature. | |
907
873
  +-----------------------------+-----------------------------------------+--------------------------------------+
908
874
  """
909
- # Handle dataclass attribute access
910
875
  if lv.HasField("scalar") and lv.scalar.HasField("binary"):
911
- return self.from_binary_idl(lv.scalar.binary, expected_python_type)
876
+ raise TypeTransformerFailedError("Attribute access unsupported.")
912
877
 
913
878
  # Detect annotations and extract out all the relevant information that the user might supply
914
879
  expected_python_type, column_dict, storage_fmt, pa_schema = extract_cols_and_format(expected_python_type)
@@ -939,16 +904,12 @@ class DataFrameTransformerEngine(TypeTransformer[DataFrame]):
939
904
  # t1(input_a: DataFrame) # or
940
905
  # t1(input_a: Annotated[DataFrame, my_cols])
941
906
  if issubclass(expected_python_type, DataFrame):
942
- sd = expected_python_type(
943
- dataframe=None,
944
- # Note here that the type being passed in
945
- metadata=metad,
946
- )
947
- sd._literal_sd = lv.scalar.structured_dataset
948
- sd.file_format = metad.structured_dataset_type.format
949
- return sd
907
+ fdf = DataFrame(format=metad.structured_dataset_type.format)
908
+ fdf._literal_sd = lv.scalar.structured_dataset
909
+ fdf._metadata = metad
910
+ return fdf
950
911
 
951
- # If the requested type was not a StructuredDataset, then it means it was a plain dataframe type, which means
912
+ # If the requested type was not a flyte.DataFrame, then it means it was a raw dataframe type, which means
952
913
  # we should do the opening/downloading and whatever else it might entail right now. No iteration option here.
953
914
  return await self.open_as(lv.scalar.structured_dataset, df_type=expected_python_type, updated_metadata=metad)
954
915
 
flyte/remote/_task.py CHANGED
@@ -1,8 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import functools
4
5
  from dataclasses import dataclass
5
- from threading import Lock
6
6
  from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Iterator, Literal, Optional, Tuple, Union, cast
7
7
 
8
8
  import rich.repr
@@ -49,7 +49,7 @@ class LazyEntity:
49
49
  self._task: Optional[TaskDetails] = None
50
50
  self._getter = getter
51
51
  self._name = name
52
- self._mutex = Lock()
52
+ self._mutex = asyncio.Lock()
53
53
 
54
54
  @property
55
55
  def name(self) -> str:
@@ -60,11 +60,11 @@ class LazyEntity:
60
60
  """
61
61
  Forwards all other attributes to task, causing the task to be fetched!
62
62
  """
63
- with self._mutex:
63
+ async with self._mutex:
64
64
  if self._task is None:
65
65
  self._task = await self._getter()
66
- if self._task is None:
67
- raise RuntimeError(f"Error downloading the task {self._name}, (check original exception...)")
66
+ if self._task is None:
67
+ raise RuntimeError(f"Error downloading the task {self._name}, (check original exception...)")
68
68
  return self._task
69
69
 
70
70
  @syncify
@@ -73,8 +73,10 @@ class LazyEntity:
73
73
  **kwargs: Any,
74
74
  ) -> LazyEntity:
75
75
  task_details = cast(TaskDetails, await self.fetch.aio())
76
- task_details.override(**kwargs)
77
- return self
76
+ new_task_details = task_details.override(**kwargs)
77
+ new_entity = LazyEntity(self._name, self._getter)
78
+ new_entity._task = new_task_details
79
+ return new_entity
78
80
 
79
81
  async def __call__(self, *args, **kwargs):
80
82
  """
@@ -93,7 +95,7 @@ class LazyEntity:
93
95
  AutoVersioning = Literal["latest", "current"]
94
96
 
95
97
 
96
- @dataclass
98
+ @dataclass(frozen=True)
97
99
  class TaskDetails(ToJSONMixin):
98
100
  pb2: task_definition_pb2.TaskDetails
99
101
  max_inline_io_bytes: int = 10 * 1024 * 1024 # 10 MB
@@ -261,12 +263,6 @@ class TaskDetails(ToJSONMixin):
261
263
  f"Reference task {self.name} does not support positional arguments"
262
264
  f"currently. Please use keyword arguments."
263
265
  )
264
- if len(self.required_args) > 0:
265
- if len(args) + len(kwargs) < len(self.required_args):
266
- raise ValueError(
267
- f"Task {self.name} requires at least {self.required_args} arguments, "
268
- f"but only received args:{args} kwargs{kwargs}."
269
- )
270
266
 
271
267
  ctx = internal_ctx()
272
268
  if ctx.is_task_context():
@@ -276,9 +272,17 @@ class TaskDetails(ToJSONMixin):
276
272
  from flyte._internal.controllers import get_controller
277
273
 
278
274
  controller = get_controller()
275
+ if len(self.required_args) > 0:
276
+ if len(args) + len(kwargs) < len(self.required_args):
277
+ raise ValueError(
278
+ f"Task {self.name} requires at least {self.required_args} arguments, "
279
+ f"but only received args:{args} kwargs{kwargs}."
280
+ )
279
281
  if controller:
280
- return await controller.submit_task_ref(self.pb2, self.max_inline_io_bytes, *args, **kwargs)
281
- raise flyte.errors
282
+ return await controller.submit_task_ref(self, *args, **kwargs)
283
+ raise flyte.errors.ReferenceTaskError(
284
+ f"Reference tasks [{self.name}] cannot be executed locally, only remotely."
285
+ )
282
286
 
283
287
  def override(
284
288
  self,
@@ -289,6 +293,8 @@ class TaskDetails(ToJSONMixin):
289
293
  timeout: Optional[flyte.TimeoutType] = None,
290
294
  env_vars: Optional[Dict[str, str]] = None,
291
295
  secrets: Optional[flyte.SecretRequest] = None,
296
+ max_inline_io_bytes: Optional[int] = None,
297
+ cache: Optional[flyte.Cache] = None,
292
298
  **kwargs: Any,
293
299
  ) -> TaskDetails:
294
300
  if len(kwargs) > 0:
@@ -296,23 +302,47 @@ class TaskDetails(ToJSONMixin):
296
302
  f"ReferenceTasks [{self.name}] do not support overriding with kwargs: {kwargs}, "
297
303
  f"Check the parameters for override method."
298
304
  )
299
- template = self.pb2.spec.task_template
305
+ pb2 = task_definition_pb2.TaskDetails()
306
+ pb2.CopyFrom(self.pb2)
307
+
300
308
  if short_name:
301
- self.pb2.metadata.short_name = short_name
309
+ pb2.metadata.short_name = short_name
310
+
311
+ template = pb2.spec.task_template
302
312
  if secrets:
303
313
  template.security_context.CopyFrom(get_security_context(secrets))
314
+
304
315
  if template.HasField("container"):
305
316
  if env_vars:
306
317
  template.container.env.clear()
307
318
  template.container.env.extend([literals_pb2.KeyValuePair(key=k, value=v) for k, v in env_vars.items()])
308
319
  if resources:
309
320
  template.container.resources.CopyFrom(get_proto_resources(resources))
321
+
322
+ md = template.metadata
310
323
  if retries:
311
- template.metadata.retries.CopyFrom(get_proto_retry_strategy(retries))
312
- if timeout:
313
- template.metadata.timeout.CopyFrom(get_proto_timeout(timeout))
324
+ md.retries.CopyFrom(get_proto_retry_strategy(retries))
314
325
 
315
- return self
326
+ if timeout:
327
+ md.timeout.CopyFrom(get_proto_timeout(timeout))
328
+
329
+ if cache:
330
+ if cache.behavior == "disable":
331
+ md.discoverable = False
332
+ md.discovery_version = ""
333
+ elif cache.behavior == "override":
334
+ md.discoverable = True
335
+ if not cache.version_override:
336
+ raise ValueError("cache.version_override must be set when cache.behavior is 'override'")
337
+ md.discovery_version = cache.version_override
338
+ else:
339
+ if cache.behavior == "auto":
340
+ raise ValueError("cache.behavior must be 'disable' or 'override' for reference tasks")
341
+ raise ValueError(f"Invalid cache behavior: {cache.behavior}.")
342
+ md.cache_serializable = cache.serialize
343
+ md.cache_ignore_input_vars[:] = list(cache.ignored_inputs or ())
344
+
345
+ return TaskDetails(pb2, max_inline_io_bytes=max_inline_io_bytes or self.max_inline_io_bytes)
316
346
 
317
347
  def __rich_repr__(self) -> rich.repr.Result:
318
348
  """
flyte/report/_report.py CHANGED
@@ -4,7 +4,6 @@ import string
4
4
  from dataclasses import dataclass, field
5
5
  from typing import TYPE_CHECKING, Dict, List, Union
6
6
 
7
- from flyte._internal.runtime import io
8
7
  from flyte._logging import logger
9
8
  from flyte._tools import ipython_check
10
9
  from flyte.syncify import syncify
@@ -133,6 +132,7 @@ async def flush():
133
132
  """
134
133
  import flyte.storage as storage
135
134
  from flyte._context import internal_ctx
135
+ from flyte._internal.runtime import io
136
136
 
137
137
  if not internal_ctx().is_task_context():
138
138
  return
@@ -455,35 +455,6 @@ class DataclassTransformer(TypeTransformer[object]):
455
455
  2. Deserialization: The dataclass transformer converts the MessagePack Bytes back to a dataclass.
456
456
  (1) Convert MessagePack Bytes to a dataclass using mashumaro.
457
457
  (2) Handle dataclass attributes to ensure they are of the correct types.
458
-
459
- TODO: Update the example using mashumaro instead of the older library
460
-
461
- Example
462
-
463
- .. code-block:: python
464
-
465
- @dataclass
466
- class Test:
467
- a: int
468
- b: str
469
-
470
- t = Test(a=10,b="e")
471
- JSONSchema().dump(t.schema())
472
-
473
- Output will look like
474
-
475
- .. code-block:: json
476
-
477
- {'$schema': 'http://json-schema.org/draft-07/schema#',
478
- 'definitions': {'TestSchema': {'properties': {'a': {'title': 'a',
479
- 'type': 'number',
480
- 'format': 'integer'},
481
- 'b': {'title': 'b', 'type': 'string'}},
482
- 'type': 'object',
483
- 'additionalProperties': False}},
484
- '$ref': '#/definitions/TestSchema'}
485
-
486
-
487
458
  """
488
459
 
489
460
  def __init__(self) -> None:
@@ -615,7 +586,7 @@ class DataclassTransformer(TypeTransformer[object]):
615
586
  }
616
587
  )
617
588
 
618
- # The type engine used to publish a type structure for attribute access. As of v2, this is no longer needed.
589
+ # The type engine used to publish the type `structure` for attribute access. As of v2, this is no longer needed.
619
590
  return types_pb2.LiteralType(
620
591
  simple=types_pb2.SimpleType.STRUCT,
621
592
  metadata=schema,
@@ -101,7 +101,6 @@ def main(
101
101
  from flyte._logging import logger
102
102
  from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
103
103
 
104
- logger.warning(f"Flyte runtime started for action {name} with run name {run_name}")
105
104
  logger.info("Registering faulthandler for SIGUSR1")
106
105
  faulthandler.register(signal.SIGUSR1)
107
106
 
@@ -117,6 +116,8 @@ def main(
117
116
  if name.startswith("{{"):
118
117
  name = os.getenv("ACTION_NAME", "")
119
118
 
119
+ logger.warning(f"Flyte runtime started for action {name} with run name {run_name}")
120
+
120
121
  if debug and name == "a0":
121
122
  from flyte._debug.vscode import _start_vscode_server
122
123
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flyte
3
- Version: 2.0.0b18
3
+ Version: 2.0.0b20
4
4
  Summary: Add your description here
5
5
  Author-email: Ketan Umare <kumare3@users.noreply.github.com>
6
6
  Requires-Python: >=3.10
@@ -24,6 +24,7 @@ Requires-Dist: toml>=0.10.2
24
24
  Requires-Dist: async-lru>=2.0.5
25
25
  Requires-Dist: mashumaro
26
26
  Requires-Dist: dataclasses_json
27
+ Requires-Dist: aiolimiter>=1.2.1
27
28
  Dynamic: license-file
28
29
 
29
30
  # Flyte 2 SDK 🚀