flyte 0.2.0b33__py3-none-any.whl → 0.2.0b34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

@@ -212,11 +212,11 @@ def list_imported_modules_as_files(source_path: str, modules: List[ModuleType])
212
212
  from flyte._utils.lazy_module import is_imported
213
213
 
214
214
  files = []
215
- union_root = os.path.dirname(flyte.__file__)
215
+ flyte_root = os.path.dirname(flyte.__file__)
216
216
 
217
217
  # These directories contain installed packages or modules from the Python standard library.
218
218
  # If a module is from these directories, then they are not user files.
219
- invalid_directories = [union_root, sys.prefix, sys.base_prefix, site.getusersitepackages(), *site.getsitepackages()]
219
+ invalid_directories = [flyte_root, sys.prefix, sys.base_prefix, site.getusersitepackages(), *site.getsitepackages()]
220
220
 
221
221
  for mod in modules:
222
222
  # Be careful not to import a module with the .__file__ call if not yet imported.
flyte/_image.py CHANGED
@@ -545,8 +545,6 @@ class Image:
545
545
  if not script.suffix == ".py":
546
546
  raise ValueError(f"UV script {script} must have a .py extension")
547
547
  header = parse_uv_script_file(script)
548
- if registry is None:
549
- raise ValueError("registry must be specified")
550
548
 
551
549
  # todo: arch
552
550
  img = cls.from_debian_base(registry=registry, name=name, python_version=python_version, platform=platform)
@@ -10,6 +10,7 @@ from uuid import uuid4
10
10
  import click
11
11
 
12
12
  import flyte
13
+ import flyte.errors
13
14
  from flyte import Image, remote
14
15
  from flyte._image import (
15
16
  AptPackages,
@@ -53,7 +54,7 @@ class RemoteImageChecker(ImageChecker):
53
54
  )
54
55
  except Exception as e:
55
56
  msg = "remote image builder is not enabled. Please contact Union support to enable it."
56
- raise click.ClickException(msg) from e
57
+ raise flyte.errors.ImageBuildError(msg) from e
57
58
 
58
59
  image_name = f"{repository.split('/')[-1]}:{tag}"
59
60
 
@@ -115,7 +116,7 @@ class RemoteImageBuilder(ImageBuilder):
115
116
  if run_details.action_details.raw_phase == run_definition_pb2.PHASE_SUCCEEDED:
116
117
  logger.warning(click.style(f"✅ Build completed in {elapsed}!", bold=True, fg="green"))
117
118
  else:
118
- raise click.ClickException(f"❌ Build failed in {elapsed} at {click.style(run.url, fg='cyan')}")
119
+ raise flyte.errors.ImageBuildError(f"❌ Build failed in {elapsed} at {click.style(run.url, fg='cyan')}")
119
120
 
120
121
  outputs = await run_details.outputs()
121
122
  return _get_fully_qualified_image_name(outputs)
flyte/_task.py CHANGED
@@ -18,6 +18,7 @@ from typing import (
18
18
  TypeAlias,
19
19
  TypeVar,
20
20
  Union,
21
+ cast,
21
22
  )
22
23
 
23
24
  from flyte._pod import PodTemplate
@@ -90,7 +91,7 @@ class TaskTemplate(Generic[P, R]):
90
91
  cache: CacheRequest = "auto"
91
92
  interruptable: bool = False
92
93
  retries: Union[int, RetryStrategy] = 0
93
- reusable: Union[ReusePolicy, Literal["auto"], None] = None
94
+ reusable: Union[ReusePolicy, None] = None
94
95
  docs: Optional[Documentation] = None
95
96
  env: Optional[Dict[str, str]] = None
96
97
  secrets: Optional[SecretRequest] = None
@@ -307,13 +308,11 @@ class TaskTemplate(Generic[P, R]):
307
308
  def override(
308
309
  self,
309
310
  *,
310
- local: Optional[bool] = None,
311
- ref: Optional[bool] = None,
312
311
  resources: Optional[Resources] = None,
313
312
  cache: CacheRequest = "auto",
314
313
  retries: Union[int, RetryStrategy] = 0,
315
314
  timeout: Optional[TimeoutType] = None,
316
- reusable: Union[ReusePolicy, Literal["auto"], None] = None,
315
+ reusable: Union[ReusePolicy, Literal["off"], None] = None,
317
316
  env: Optional[Dict[str, str]] = None,
318
317
  secrets: Optional[SecretRequest] = None,
319
318
  **kwargs: Any,
@@ -322,15 +321,37 @@ class TaskTemplate(Generic[P, R]):
322
321
  Override various parameters of the task template. This allows for dynamic configuration of the task
323
322
  when it is called, such as changing the image, resources, cache policy, etc.
324
323
  """
325
- resources = resources or self.resources
326
324
  cache = cache or self.cache
327
325
  retries = retries or self.retries
328
326
  timeout = timeout or self.timeout
329
327
  reusable = reusable or self.reusable
328
+ if reusable == "off":
329
+ reusable = None
330
+
331
+ if reusable is not None:
332
+ if resources is not None:
333
+ raise ValueError(
334
+ "Cannot override resources when reusable is set."
335
+ " Reusable tasks will use the parent env's resources. You can disable reusability and"
336
+ " override resources if needed. (set reusable='off')"
337
+ )
338
+ if env is not None:
339
+ raise ValueError(
340
+ "Cannot override env when reusable is set."
341
+ " Reusable tasks will use the parent env's env. You can disable reusability and "
342
+ "override env if needed. (set reusable='off')"
343
+ )
344
+ if secrets is not None:
345
+ raise ValueError(
346
+ "Cannot override secrets when reusable is set."
347
+ " Reusable tasks will use the parent env's secrets. You can disable reusability and "
348
+ "override secrets if needed. (set reusable='off')"
349
+ )
350
+
351
+ resources = resources or self.resources
330
352
  env = env or self.env
331
353
  secrets = secrets or self.secrets
332
- local = local or self.local
333
- ref = ref or self.ref
354
+
334
355
  for k, v in kwargs.items():
335
356
  if k == "name":
336
357
  raise ValueError("Name cannot be overridden")
@@ -340,13 +361,14 @@ class TaskTemplate(Generic[P, R]):
340
361
  raise ValueError("Docs cannot be overridden")
341
362
  if k == "interface":
342
363
  raise ValueError("Interface cannot be overridden")
364
+
343
365
  return replace(
344
366
  self,
345
367
  resources=resources,
346
368
  cache=cache,
347
369
  retries=retries,
348
370
  timeout=timeout,
349
- reusable=reusable,
371
+ reusable=cast(Optional[ReusePolicy], reusable),
350
372
  env=env,
351
373
  secrets=secrets,
352
374
  )
flyte/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.0b33'
21
- __version_tuple__ = version_tuple = (0, 2, 0, 'b33')
20
+ __version__ = version = '0.2.0b34'
21
+ __version_tuple__ = version_tuple = (0, 2, 0, 'b34')
flyte/cli/_common.py CHANGED
@@ -62,6 +62,28 @@ def _common_options() -> List[click.Option]:
62
62
 
63
63
 
64
64
  # This is global state for the CLI, it is manipulated by the main command
65
+ _client_secret_options = ["client-secret", "client_secret", "clientsecret", "app-credential", "app_credential"]
66
+ _device_flow_options = ["headless", "device-flow", "device_flow"]
67
+ _pkce_options = ["pkce"]
68
+ _external_command_options = ["external-command", "external_command", "externalcommand", "command", "custom"]
69
+ ALL_AUTH_OPTIONS = _client_secret_options + _device_flow_options + _pkce_options + _external_command_options
70
+
71
+
72
+ def sanitize_auth_type(auth_type: str | None) -> str:
73
+ """
74
+ Convert the auth type to the mode that is used by the Flyte backend.
75
+ """
76
+ if auth_type is None:
77
+ return "pkce"
78
+ if auth_type.lower() in _pkce_options:
79
+ return "Pkce"
80
+ if auth_type.lower() in _device_flow_options:
81
+ return "DeviceFlow"
82
+ if auth_type.lower() in _client_secret_options:
83
+ return "ClientSecret"
84
+ if auth_type.lower() in _external_command_options:
85
+ return "ExternalCommand"
86
+ raise ValueError(f"Unknown auth type: {auth_type}. Supported types are: {ALL_AUTH_OPTIONS}.")
65
87
 
66
88
 
67
89
  @rich.repr.auto
@@ -78,6 +100,7 @@ class CLIConfig:
78
100
  insecure: bool = False
79
101
  org: str | None = None
80
102
  simple: bool = False
103
+ auth_type: str | None = None
81
104
 
82
105
  def replace(self, **kwargs) -> CLIConfig:
83
106
  """
@@ -99,6 +122,8 @@ class CLIConfig:
99
122
  kwargs["endpoint"] = self.endpoint
100
123
  if self.insecure is not None:
101
124
  kwargs["insecure"] = self.insecure
125
+ if self.auth_type:
126
+ kwargs["auth_mode"] = sanitize_auth_type(self.auth_type)
102
127
  platform_cfg = self.config.platform.replace(**kwargs)
103
128
 
104
129
  updated_config = self.config.with_params(platform_cfg, task_cfg)
flyte/cli/_create.py CHANGED
@@ -95,6 +95,14 @@ def secret(
95
95
  help="Image builder to use for building images. Defaults to 'local'.",
96
96
  show_default=True,
97
97
  )
98
+ @click.option(
99
+ "--auth-type",
100
+ type=click.Choice(common.ALL_AUTH_OPTIONS, case_sensitive=False),
101
+ default=None,
102
+ help="Authentication type to use for the Flyte backend. Defaults to 'pkce'.",
103
+ show_default=True,
104
+ required=False,
105
+ )
98
106
  def config(
99
107
  output: str,
100
108
  endpoint: str | None = None,
@@ -104,6 +112,7 @@ def config(
104
112
  domain: str | None = None,
105
113
  force: bool = False,
106
114
  image_builder: str | None = None,
115
+ auth_type: str | None = None,
107
116
  ):
108
117
  """
109
118
  Creates a configuration file for Flyte CLI.
@@ -128,6 +137,8 @@ def config(
128
137
  admin["endpoint"] = endpoint
129
138
  if insecure:
130
139
  admin["insecure"] = insecure
140
+ if auth_type:
141
+ admin["authType"] = common.sanitize_auth_type(auth_type)
131
142
 
132
143
  if not org and endpoint:
133
144
  org = org_from_endpoint(endpoint)
flyte/cli/main.py CHANGED
@@ -2,6 +2,7 @@ import rich_click as click
2
2
 
3
3
  from flyte._logging import initialize_logger, logger
4
4
 
5
+ from . import _common as common
5
6
  from ._abort import abort
6
7
  from ._build import build
7
8
  from ._common import CLIConfig
@@ -74,6 +75,14 @@ def _verbosity_to_loglevel(verbosity: int) -> int | None:
74
75
  default=None,
75
76
  show_default=True,
76
77
  )
78
+ @click.option(
79
+ "--auth-type",
80
+ type=click.Choice(common.ALL_AUTH_OPTIONS, case_sensitive=False),
81
+ default=None,
82
+ help="Authentication type to use for the Flyte backend. Defaults to 'pkce'.",
83
+ show_default=True,
84
+ required=False,
85
+ )
77
86
  @click.option(
78
87
  "-v",
79
88
  "--verbose",
@@ -113,6 +122,7 @@ def main(
113
122
  org: str | None,
114
123
  config_file: str | None,
115
124
  simple: bool = False,
125
+ auth_type: str | None = None,
116
126
  ):
117
127
  """
118
128
  The Flyte CLI is the command line interface for working with the Flyte SDK and backend.
@@ -167,6 +177,7 @@ def main(
167
177
  config=cfg,
168
178
  ctx=ctx,
169
179
  simple=simple,
180
+ auth_type=auth_type,
170
181
  )
171
182
 
172
183
 
flyte/errors.py CHANGED
@@ -170,3 +170,12 @@ class DeploymentError(RuntimeUserError):
170
170
 
171
171
  def __init__(self, message: str):
172
172
  super().__init__("DeploymentError", message, "user")
173
+
174
+
175
+ class ImageBuildError(RuntimeUserError):
176
+ """
177
+ This error is raised when the image build fails.
178
+ """
179
+
180
+ def __init__(self, message: str):
181
+ super().__init__("ImageBuildError", message, "user")
flyte/io/__init__.py CHANGED
@@ -7,21 +7,21 @@ of large datasets in Union.
7
7
  """
8
8
 
9
9
  __all__ = [
10
+ "DataFrame",
11
+ "DataFrameDecoder",
12
+ "DataFrameEncoder",
13
+ "DataFrameTransformerEngine",
10
14
  "Dir",
11
15
  "File",
12
- "StructuredDataset",
13
- "StructuredDatasetDecoder",
14
- "StructuredDatasetEncoder",
15
- "StructuredDatasetTransformerEngine",
16
- "lazy_import_structured_dataset_handler",
16
+ "lazy_import_dataframe_handler",
17
17
  ]
18
18
 
19
+ from ._dataframe import (
20
+ DataFrame,
21
+ DataFrameDecoder,
22
+ DataFrameEncoder,
23
+ DataFrameTransformerEngine,
24
+ lazy_import_dataframe_handler,
25
+ )
19
26
  from ._dir import Dir
20
27
  from ._file import File
21
- from ._structured_dataset import (
22
- StructuredDataset,
23
- StructuredDatasetDecoder,
24
- StructuredDatasetEncoder,
25
- StructuredDatasetTransformerEngine,
26
- lazy_import_structured_dataset_handler,
27
- )
@@ -1,15 +1,15 @@
1
1
  """
2
- Flytekit StructuredDataset
2
+ Flytekit DataFrame
3
3
  ==========================================================
4
- .. currentmodule:: flytekit.types.structured
4
+ .. currentmodule:: flyte.io._dataframe
5
5
 
6
6
  .. autosummary::
7
7
  :template: custom.rst
8
8
  :toctree: generated/
9
9
 
10
- StructuredDataset
11
- StructuredDatasetDecoder
12
- StructuredDatasetEncoder
10
+ DataFrame
11
+ DataFrameDecoder
12
+ DataFrameEncoder
13
13
  """
14
14
 
15
15
  import functools
@@ -17,12 +17,12 @@ import functools
17
17
  from flyte._logging import logger
18
18
  from flyte._utils.lazy_module import is_imported
19
19
 
20
- from .structured_dataset import (
20
+ from .dataframe import (
21
+ DataFrame,
22
+ DataFrameDecoder,
23
+ DataFrameEncoder,
24
+ DataFrameTransformerEngine,
21
25
  DuplicateHandlerError,
22
- StructuredDataset,
23
- StructuredDatasetDecoder,
24
- StructuredDatasetEncoder,
25
- StructuredDatasetTransformerEngine,
26
26
  )
27
27
 
28
28
 
@@ -30,8 +30,8 @@ from .structured_dataset import (
30
30
  def register_csv_handlers():
31
31
  from .basic_dfs import CSVToPandasDecodingHandler, PandasToCSVEncodingHandler
32
32
 
33
- StructuredDatasetTransformerEngine.register(PandasToCSVEncodingHandler(), default_format_for_type=True)
34
- StructuredDatasetTransformerEngine.register(CSVToPandasDecodingHandler(), default_format_for_type=True)
33
+ DataFrameTransformerEngine.register(PandasToCSVEncodingHandler(), default_format_for_type=True)
34
+ DataFrameTransformerEngine.register(CSVToPandasDecodingHandler(), default_format_for_type=True)
35
35
 
36
36
 
37
37
  @functools.lru_cache(maxsize=None)
@@ -42,9 +42,9 @@ def register_pandas_handlers():
42
42
 
43
43
  from .basic_dfs import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler
44
44
 
45
- StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True)
46
- StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(), default_format_for_type=True)
47
- StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer())
45
+ DataFrameTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True)
46
+ DataFrameTransformerEngine.register(ParquetToPandasDecodingHandler(), default_format_for_type=True)
47
+ DataFrameTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer())
48
48
 
49
49
 
50
50
  @functools.lru_cache(maxsize=None)
@@ -55,9 +55,9 @@ def register_arrow_handlers():
55
55
 
56
56
  from .basic_dfs import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler
57
57
 
58
- StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(), default_format_for_type=True)
59
- StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(), default_format_for_type=True)
60
- StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer())
58
+ DataFrameTransformerEngine.register(ArrowToParquetEncodingHandler(), default_format_for_type=True)
59
+ DataFrameTransformerEngine.register(ParquetToArrowDecodingHandler(), default_format_for_type=True)
60
+ DataFrameTransformerEngine.register_renderer(pa.Table, ArrowRenderer())
61
61
 
62
62
 
63
63
  @functools.lru_cache(maxsize=None)
@@ -70,10 +70,10 @@ def register_bigquery_handlers():
70
70
  PandasToBQEncodingHandlers,
71
71
  )
72
72
 
73
- StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers())
74
- StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler())
75
- StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers())
76
- StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler())
73
+ DataFrameTransformerEngine.register(PandasToBQEncodingHandlers())
74
+ DataFrameTransformerEngine.register(BQToPandasDecodingHandler())
75
+ DataFrameTransformerEngine.register(ArrowToBQEncodingHandlers())
76
+ DataFrameTransformerEngine.register(BQToArrowDecodingHandler())
77
77
  except ImportError:
78
78
  logger.info(
79
79
  "We won't register bigquery handler for structured dataset because "
@@ -86,8 +86,8 @@ def register_snowflake_handlers():
86
86
  try:
87
87
  from .snowflake import PandasToSnowflakeEncodingHandlers, SnowflakeToPandasDecodingHandler
88
88
 
89
- StructuredDatasetTransformerEngine.register(SnowflakeToPandasDecodingHandler())
90
- StructuredDatasetTransformerEngine.register(PandasToSnowflakeEncodingHandlers())
89
+ DataFrameTransformerEngine.register(SnowflakeToPandasDecodingHandler())
90
+ DataFrameTransformerEngine.register(PandasToSnowflakeEncodingHandlers())
91
91
 
92
92
  except ImportError:
93
93
  logger.info(
@@ -96,7 +96,7 @@ def register_snowflake_handlers():
96
96
  )
97
97
 
98
98
 
99
- def lazy_import_structured_dataset_handler():
99
+ def lazy_import_dataframe_handler():
100
100
  if is_imported("pandas"):
101
101
  try:
102
102
  register_pandas_handlers()
@@ -121,9 +121,9 @@ def lazy_import_structured_dataset_handler():
121
121
 
122
122
 
123
123
  __all__ = [
124
- "StructuredDataset",
125
- "StructuredDatasetDecoder",
126
- "StructuredDatasetEncoder",
127
- "StructuredDatasetTransformerEngine",
128
- "lazy_import_structured_dataset_handler",
124
+ "DataFrame",
125
+ "DataFrameDecoder",
126
+ "DataFrameEncoder",
127
+ "DataFrameTransformerEngine",
128
+ "lazy_import_dataframe_handler",
129
129
  ]
@@ -9,12 +9,12 @@ from fsspec.core import split_protocol, strip_protocol
9
9
  import flyte.storage as storage
10
10
  from flyte._logging import logger
11
11
  from flyte._utils import lazy_module
12
- from flyte.io._structured_dataset.structured_dataset import (
12
+ from flyte.io._dataframe.dataframe import (
13
13
  CSV,
14
14
  PARQUET,
15
- StructuredDataset,
16
- StructuredDatasetDecoder,
17
- StructuredDatasetEncoder,
15
+ DataFrame,
16
+ DataFrameDecoder,
17
+ DataFrameEncoder,
18
18
  )
19
19
 
20
20
  if typing.TYPE_CHECKING:
@@ -39,27 +39,27 @@ def get_pandas_storage_options(uri: str, anonymous: bool = False) -> typing.Opti
39
39
  return None
40
40
 
41
41
 
42
- class PandasToCSVEncodingHandler(StructuredDatasetEncoder):
42
+ class PandasToCSVEncodingHandler(DataFrameEncoder):
43
43
  def __init__(self):
44
44
  super().__init__(pd.DataFrame, None, CSV)
45
45
 
46
46
  async def encode(
47
47
  self,
48
- structured_dataset: StructuredDataset,
48
+ dataframe: DataFrame,
49
49
  structured_dataset_type: types_pb2.StructuredDatasetType,
50
50
  ) -> literals_pb2.StructuredDataset:
51
- if not structured_dataset.uri:
51
+ if not dataframe.uri:
52
52
  from flyte._context import internal_ctx
53
53
 
54
54
  ctx = internal_ctx()
55
55
  uri = ctx.raw_data.get_random_remote_path()
56
56
  else:
57
- uri = typing.cast(str, structured_dataset.uri)
57
+ uri = typing.cast(str, dataframe.uri)
58
58
 
59
59
  if not storage.is_remote(uri):
60
60
  Path(uri).mkdir(parents=True, exist_ok=True)
61
61
  path = os.path.join(uri, ".csv")
62
- df = typing.cast(pd.DataFrame, structured_dataset.dataframe)
62
+ df = typing.cast(pd.DataFrame, dataframe.val)
63
63
  df.to_csv(
64
64
  path,
65
65
  index=False,
@@ -71,7 +71,7 @@ class PandasToCSVEncodingHandler(StructuredDatasetEncoder):
71
71
  )
72
72
 
73
73
 
74
- class CSVToPandasDecodingHandler(StructuredDatasetDecoder):
74
+ class CSVToPandasDecodingHandler(DataFrameDecoder):
75
75
  def __init__(self):
76
76
  super().__init__(pd.DataFrame, None, CSV)
77
77
 
@@ -97,27 +97,27 @@ class CSVToPandasDecodingHandler(StructuredDatasetDecoder):
97
97
  raise
98
98
 
99
99
 
100
- class PandasToParquetEncodingHandler(StructuredDatasetEncoder):
100
+ class PandasToParquetEncodingHandler(DataFrameEncoder):
101
101
  def __init__(self):
102
102
  super().__init__(pd.DataFrame, None, PARQUET)
103
103
 
104
104
  async def encode(
105
105
  self,
106
- structured_dataset: StructuredDataset,
106
+ dataframe: DataFrame,
107
107
  structured_dataset_type: types_pb2.StructuredDatasetType,
108
108
  ) -> literals_pb2.StructuredDataset:
109
- if not structured_dataset.uri:
109
+ if not dataframe.uri:
110
110
  from flyte._context import internal_ctx
111
111
 
112
112
  ctx = internal_ctx()
113
113
  uri = str(ctx.raw_data.get_random_remote_path())
114
114
  else:
115
- uri = typing.cast(str, structured_dataset.uri)
115
+ uri = typing.cast(str, dataframe.uri)
116
116
 
117
117
  if not storage.is_remote(uri):
118
118
  Path(uri).mkdir(parents=True, exist_ok=True)
119
119
  path = os.path.join(uri, f"{0:05}")
120
- df = typing.cast(pd.DataFrame, structured_dataset.dataframe)
120
+ df = typing.cast(pd.DataFrame, dataframe.val)
121
121
  df.to_parquet(
122
122
  path,
123
123
  coerce_timestamps="us",
@@ -130,7 +130,7 @@ class PandasToParquetEncodingHandler(StructuredDatasetEncoder):
130
130
  )
131
131
 
132
132
 
133
- class ParquetToPandasDecodingHandler(StructuredDatasetDecoder):
133
+ class ParquetToPandasDecodingHandler(DataFrameDecoder):
134
134
  def __init__(self):
135
135
  super().__init__(pd.DataFrame, None, PARQUET)
136
136
 
@@ -155,36 +155,34 @@ class ParquetToPandasDecodingHandler(StructuredDatasetDecoder):
155
155
  raise
156
156
 
157
157
 
158
- class ArrowToParquetEncodingHandler(StructuredDatasetEncoder):
158
+ class ArrowToParquetEncodingHandler(DataFrameEncoder):
159
159
  def __init__(self):
160
160
  super().__init__(pa.Table, None, PARQUET)
161
161
 
162
162
  async def encode(
163
163
  self,
164
- structured_dataset: StructuredDataset,
165
- structured_dataset_type: types_pb2.StructuredDatasetType,
164
+ dataframe: DataFrame,
165
+ dataframe_type: types_pb2.StructuredDatasetType,
166
166
  ) -> literals_pb2.StructuredDataset:
167
167
  import pyarrow.parquet as pq
168
168
 
169
- if not structured_dataset.uri:
169
+ if not dataframe.uri:
170
170
  from flyte._context import internal_ctx
171
171
 
172
172
  ctx = internal_ctx()
173
173
  uri = ctx.raw_data.get_random_remote_path()
174
174
  else:
175
- uri = typing.cast(str, structured_dataset.uri)
175
+ uri = typing.cast(str, dataframe.uri)
176
176
 
177
177
  if not storage.is_remote(uri):
178
178
  Path(uri).mkdir(parents=True, exist_ok=True)
179
179
  path = os.path.join(uri, f"{0:05}")
180
180
  filesystem = storage.get_underlying_filesystem(path=path)
181
- pq.write_table(structured_dataset.dataframe, strip_protocol(path), filesystem=filesystem)
182
- return literals_pb2.StructuredDataset(
183
- uri=uri, metadata=literals_pb2.StructuredDatasetMetadata(structured_dataset_type)
184
- )
181
+ pq.write_table(dataframe.val, strip_protocol(path), filesystem=filesystem)
182
+ return literals_pb2.StructuredDataset(uri=uri, metadata=literals_pb2.StructuredDatasetMetadata(dataframe_type))
185
183
 
186
184
 
187
- class ParquetToArrowDecodingHandler(StructuredDatasetDecoder):
185
+ class ParquetToArrowDecodingHandler(DataFrameDecoder):
188
186
  def __init__(self):
189
187
  super().__init__(pa.Table, None, PARQUET)
190
188
 
@@ -211,5 +209,6 @@ class ParquetToArrowDecodingHandler(StructuredDatasetDecoder):
211
209
  fs = storage.get_underlying_filesystem(path=uri, anonymous=True)
212
210
  if fs is not None:
213
211
  return pq.read_table(path, filesystem=fs, columns=columns)
212
+ return None
214
213
  else:
215
214
  raise