flyte 0.2.0b14__py3-none-any.whl → 0.2.0b15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import base64
4
5
  import hashlib
5
6
  import inspect
@@ -7,12 +8,11 @@ from dataclasses import dataclass
7
8
  from types import NoneType
8
9
  from typing import Any, Dict, List, Tuple, Union, get_args
9
10
 
10
- from flyteidl.core import execution_pb2, literals_pb2
11
- from flyteidl.core.interface_pb2 import TypedInterface
11
+ from flyteidl.core import execution_pb2, interface_pb2, literals_pb2
12
12
 
13
13
  import flyte.errors
14
14
  import flyte.storage as storage
15
- from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
15
+ from flyte._protos.workflow import common_pb2, run_definition_pb2, task_definition_pb2
16
16
  from flyte.models import ActionID, NativeInterface, TaskContext
17
17
  from flyte.types import TypeEngine, TypeTransformerFailedError
18
18
 
@@ -60,6 +60,40 @@ async def convert_inputs_to_native(inputs: Inputs, python_interface: NativeInter
60
60
  return native_vals
61
61
 
62
62
 
63
+ async def convert_upload_default_inputs(interface: NativeInterface) -> List[common_pb2.NamedParameter]:
64
+ """
65
+ Converts the default inputs of a NativeInterface to a list of NamedParameters for upload.
66
+ This is used to upload default inputs to the Flyte backend.
67
+ """
68
+ if not interface.inputs:
69
+ return []
70
+
71
+ vars = []
72
+ literal_coros = []
73
+ for input_name, (input_type, default_value) in interface.inputs.items():
74
+ if default_value is not inspect.Parameter.empty:
75
+ lt = TypeEngine.to_literal_type(input_type)
76
+ literal_coros.append(TypeEngine.to_literal(default_value, input_type, lt))
77
+ vars.append((input_name, lt))
78
+
79
+ literals: List[literals_pb2.Literal] = await asyncio.gather(*literal_coros)
80
+ named_params = []
81
+ for (name, lt), literal in zip(vars, literals):
82
+ param = interface_pb2.Parameter(
83
+ var=interface_pb2.Variable(
84
+ type=lt,
85
+ ),
86
+ default=literal,
87
+ )
88
+ named_params.append(
89
+ common_pb2.NamedParameter(
90
+ name=name,
91
+ parameter=param,
92
+ ),
93
+ )
94
+ return named_params
95
+
96
+
63
97
  def is_optional_type(tp) -> bool:
64
98
  """
65
99
  True if the *annotation* `tp` is equivalent to Optional[…].
@@ -70,22 +104,42 @@ def is_optional_type(tp) -> bool:
70
104
 
71
105
  async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwargs) -> Inputs:
72
106
  kwargs = interface.convert_to_kwargs(*args, **kwargs)
73
- if len(kwargs) == 0:
107
+
108
+ if len(kwargs) < interface.num_required_inputs():
109
+ raise ValueError(
110
+ f"Received {len(kwargs)} inputs but interface has {interface.num_required_inputs()} required inputs. "
111
+ f"Please provide all required inputs. Inputs received: {kwargs}, interface: {interface}"
112
+ )
113
+
114
+ if len(interface.inputs) == 0:
74
115
  return Inputs.empty()
75
116
 
76
117
  # fill in defaults if missing
118
+ type_hints: Dict[str, type] = {}
119
+ already_converted_kwargs: Dict[str, literals_pb2.Literal] = {}
77
120
  for input_name, (input_type, default_value) in interface.inputs.items():
78
- if input_name not in kwargs:
79
- if (default_value is not None and default_value is not inspect.Signature.empty) or (
80
- default_value is None and is_optional_type(input_type)
81
- ):
121
+ if input_name in kwargs:
122
+ type_hints[input_name] = input_type
123
+ elif (default_value is not None and default_value is not inspect.Signature.empty) or (
124
+ default_value is None and is_optional_type(input_type)
125
+ ):
126
+ if default_value == NativeInterface.has_default:
127
+ if interface._remote_defaults is None or input_name not in interface._remote_defaults:
128
+ raise ValueError(f"Input '{input_name}' has a default value but it is not set in the interface.")
129
+ already_converted_kwargs[input_name] = interface._remote_defaults[input_name]
130
+ else:
82
131
  kwargs[input_name] = default_value
83
- if len(kwargs) < len(interface.inputs):
84
- raise ValueError(
85
- f"Received {len(kwargs)} inputs but interface has {len(interface.inputs)}. "
86
- f"Please provide all required inputs."
87
- )
88
- literal_map = await TypeEngine.dict_to_literal_map(kwargs, interface.get_input_types())
132
+ type_hints[input_name] = input_type
133
+
134
+ literal_map = await TypeEngine.dict_to_literal_map(kwargs, type_hints)
135
+ if len(already_converted_kwargs) > 0:
136
+ copied_literals: Dict[str, literals_pb2.Literal] = {}
137
+ for k, v in literal_map.literals.items():
138
+ copied_literals[k] = v
139
+ # Add the already converted kwargs to the literal map
140
+ for k, v in already_converted_kwargs.items():
141
+ copied_literals[k] = v
142
+ literal_map = literals_pb2.LiteralMap(literals=copied_literals)
89
143
  # Make sure we the interface, not literal_map or kwargs, because those may have a different order
90
144
  return Inputs(
91
145
  proto_inputs=run_definition_pb2.Inputs(
@@ -228,7 +282,7 @@ def generate_inputs_hash(serialized_inputs: str | bytes) -> str:
228
282
  def generate_cache_key_hash(
229
283
  task_name: str,
230
284
  inputs_hash: str,
231
- task_interface: TypedInterface,
285
+ task_interface: interface_pb2.TypedInterface,
232
286
  cache_version: str,
233
287
  ignored_input_vars: List[str],
234
288
  proto_inputs: run_definition_pb2.Inputs,
@@ -16,7 +16,7 @@ import flyte.errors
16
16
  from flyte._cache.cache import VersionParameters, cache_from_request
17
17
  from flyte._logging import logger
18
18
  from flyte._pod import _PRIMARY_CONTAINER_NAME_FIELD, PodTemplate
19
- from flyte._protos.workflow import task_definition_pb2
19
+ from flyte._protos.workflow import common_pb2, environment_pb2, task_definition_pb2
20
20
  from flyte._secret import SecretRequest, secrets_from_request
21
21
  from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
22
22
  from flyte.models import CodeBundle, SerializationContext
@@ -26,6 +26,9 @@ from ..._timeout import TimeoutType, timeout_from_request
26
26
  from .resources_serde import get_proto_extended_resources, get_proto_resources
27
27
  from .types_serde import transform_native_to_typed_interface
28
28
 
29
+ _MAX_ENV_NAME_LENGTH = 63 # Maximum length for environment names
30
+ _MAX_TASK_SHORT_NAME_LENGTH = 63 # Maximum length for task short names
31
+
29
32
 
30
33
  def load_class(qualified_name) -> Type:
31
34
  """
@@ -52,17 +55,31 @@ def load_task(resolver: str, *resolver_args: str) -> TaskTemplate:
52
55
 
53
56
 
54
57
  def translate_task_to_wire(
55
- task: TaskTemplate, serialization_context: SerializationContext
58
+ task: TaskTemplate,
59
+ serialization_context: SerializationContext,
60
+ default_inputs: Optional[typing.List[common_pb2.NamedParameter]] = None,
56
61
  ) -> task_definition_pb2.TaskSpec:
57
62
  """
58
63
  Translate a task to a wire format. This is a placeholder function.
59
64
 
60
65
  :param task: The task to translate.
61
66
  :param serialization_context: The serialization context to use for the translation.
67
+ :param default_inputs: Optional list of default inputs for the task.
62
68
 
63
69
  :return: The translated task.
64
70
  """
65
- return get_proto_task(task, serialization_context)
71
+ tt = get_proto_task(task, serialization_context)
72
+ env: environment_pb2.Environment | None = None
73
+ if task.parent_env and task.parent_env():
74
+ _env = task.parent_env()
75
+ if _env:
76
+ env = environment_pb2.Environment(name=_env.name[:_MAX_ENV_NAME_LENGTH])
77
+ return task_definition_pb2.TaskSpec(
78
+ task_template=tt,
79
+ default_inputs=default_inputs,
80
+ short_name=task.friendly_name[:_MAX_TASK_SHORT_NAME_LENGTH],
81
+ environment=env,
82
+ )
66
83
 
67
84
 
68
85
  def get_security_context(secrets: Optional[SecretRequest]) -> Optional[security_pb2.SecurityContext]:
@@ -111,7 +128,7 @@ def get_proto_timeout(timeout: TimeoutType | None) -> Optional[duration_pb2.Dura
111
128
  return duration_pb2.Duration(seconds=max_runtime_timeout.seconds)
112
129
 
113
130
 
114
- def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext) -> task_definition_pb2.TaskSpec:
131
+ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext) -> tasks_pb2.TaskTemplate:
115
132
  task_id = identifier_pb2.Identifier(
116
133
  resource_type=identifier_pb2.ResourceType.TASK,
117
134
  project=serialize_context.project,
@@ -158,7 +175,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
158
175
  else:
159
176
  logger.debug(f"Cache disabled for task {task.name}")
160
177
 
161
- tt = tasks_pb2.TaskTemplate(
178
+ return tasks_pb2.TaskTemplate(
162
179
  id=task_id,
163
180
  type=task.task_type,
164
181
  metadata=tasks_pb2.TaskMetadata(
@@ -166,12 +183,16 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
166
183
  discovery_version=cache_version,
167
184
  cache_serializable=task_cache.serialize,
168
185
  cache_ignore_input_vars=task_cache.get_ignored_inputs() if cache_enabled else None,
169
- runtime=tasks_pb2.RuntimeMetadata(),
186
+ runtime=tasks_pb2.RuntimeMetadata(
187
+ version=flyte.version(),
188
+ type=tasks_pb2.RuntimeMetadata.RuntimeType.FLYTE_SDK,
189
+ flavor="python",
190
+ ),
170
191
  retries=get_proto_retry_strategy(task.retries),
171
192
  timeout=get_proto_timeout(task.timeout),
172
193
  pod_template_name=task.pod_template if task.pod_template and isinstance(task.pod_template, str) else None,
173
194
  interruptible=task.interruptable,
174
- generates_deck=wrappers_pb2.BoolValue(value=False), # TODO add support for reports
195
+ generates_deck=wrappers_pb2.BoolValue(value=task.report),
175
196
  ),
176
197
  interface=transform_native_to_typed_interface(task.native_interface),
177
198
  custom=custom,
@@ -183,7 +204,6 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
183
204
  sql=sql,
184
205
  extended_resources=get_proto_extended_resources(task.resources),
185
206
  )
186
- return task_definition_pb2.TaskSpec(task_template=tt)
187
207
 
188
208
 
189
209
  def _get_urun_container(
flyte/_logging.py CHANGED
@@ -9,6 +9,17 @@ from ._tools import ipython_check, is_in_cluster
9
9
  DEFAULT_LOG_LEVEL = logging.WARNING
10
10
 
11
11
 
12
+ def make_hyperlink(label: str, url: str):
13
+ """
14
+ Create a hyperlink in the terminal output.
15
+ """
16
+ BLUE = "\033[94m"
17
+ RESET = "\033[0m"
18
+ OSC8_BEGIN = f"\033]8;;{url}\033\\"
19
+ OSC8_END = "\033]8;;\033\\"
20
+ return f"{BLUE}{OSC8_BEGIN}{label}{RESET}{OSC8_END}"
21
+
22
+
12
23
  def is_rich_logging_disabled() -> bool:
13
24
  """
14
25
  Check if rich logging is enabled
flyte/_run.py CHANGED
@@ -70,7 +70,7 @@ class _Runner:
70
70
  if not force_mode and client is not None:
71
71
  force_mode = "remote"
72
72
  force_mode = force_mode or "local"
73
- logger.debug(f"Effective run mode: {force_mode}, client configured: {client is not None}")
73
+ logger.debug(f"Effective run mode: `{force_mode}`, client configured: `{client is not None}`")
74
74
  self._mode = force_mode
75
75
  self._name = name
76
76
  self._service_account = service_account
flyte/_task.py CHANGED
@@ -83,6 +83,7 @@ class TaskTemplate(Generic[P, R]):
83
83
 
84
84
  name: str
85
85
  interface: NativeInterface
86
+ friendly_name: str = ""
86
87
  task_type: str = "python"
87
88
  task_type_version: int = 0
88
89
  image: Union[str, Image, Literal["auto"]] = "auto"
@@ -108,9 +109,9 @@ class TaskTemplate(Generic[P, R]):
108
109
  def __post_init__(self):
109
110
  # Auto set the image based on the image request
110
111
  if self.image == "auto":
111
- self.image = Image.auto()
112
+ self.image = Image.from_debian_base()
112
113
  elif isinstance(self.image, str):
113
- self.image = Image.from_prebuilt(str(self.image))
114
+ self.image = Image.from_base(str(self.image))
114
115
 
115
116
  # Auto set cache based on the cache request
116
117
  if isinstance(self.cache, str):
@@ -126,6 +127,10 @@ class TaskTemplate(Generic[P, R]):
126
127
  if isinstance(self.retries, int):
127
128
  self.retries = RetryStrategy(count=self.retries)
128
129
 
130
+ if self.friendly_name == "":
131
+ # If friendly_name is not set, use the name of the task
132
+ self.friendly_name = self.name
133
+
129
134
  def __getstate__(self):
130
135
  """
131
136
  This method is called when the object is pickled. We need to remove the parent_env reference
@@ -54,7 +54,7 @@ class TaskEnvironment(Environment):
54
54
  :param resources: Resources to allocate for the environment.
55
55
  :param env: Environment variables to set for the environment.
56
56
  :param secrets: Secrets to inject into the environment.
57
- :param env_dep_hints: Environment dependencies to hint, so when you deploy the environment, the dependencies are
57
+ :param depends_on: Environment dependencies to hint, so when you deploy the environment, the dependencies are
58
58
  also deployed. This is useful when you have a set of environments that depend on each other.
59
59
  :param cache: Cache policy for the environment.
60
60
  :param reusable: Reuse policy for the environment, if set, a python process may be reused for multiple tasks.
@@ -74,7 +74,7 @@ class TaskEnvironment(Environment):
74
74
  resources: Optional[Resources] = None,
75
75
  env: Optional[Dict[str, str]] = None,
76
76
  secrets: Optional[SecretRequest] = None,
77
- env_dep_hints: Optional[List[Environment]] = None,
77
+ depends_on: Optional[List[Environment]] = None,
78
78
  **kwargs: Any,
79
79
  ) -> TaskEnvironment:
80
80
  """
@@ -103,8 +103,8 @@ class TaskEnvironment(Environment):
103
103
  kwargs["reusable"] = reusable
104
104
  if secrets is not None:
105
105
  kwargs["secrets"] = secrets
106
- if env_dep_hints is not None:
107
- kwargs["env_dep_hints"] = env_dep_hints
106
+ if depends_on is not None:
107
+ kwargs["depends_on"] = depends_on
108
108
  return replace(self, **kwargs)
109
109
 
110
110
  def task(
@@ -122,7 +122,7 @@ class TaskEnvironment(Environment):
122
122
  ) -> Union[AsyncFunctionTaskTemplate, Callable[P, R]]:
123
123
  """
124
124
  :param _func: Optional The function to decorate. If not provided, the decorator will return a callable that
125
- :param name: Optional The name of the task (defaults to the function name)
125
+ :param name: Optional A friendly name for the task (defaults to the function name)
126
126
  :param cache: Optional The cache policy for the task, defaults to auto, which will cache the results of the
127
127
  task.
128
128
  :param retries: Optional The number of retries for the task, defaults to 0, which means no retries.
@@ -138,8 +138,8 @@ class TaskEnvironment(Environment):
138
138
  raise ValueError("Cannot set pod_template when environment is reusable.")
139
139
 
140
140
  def decorator(func: FunctionTypes) -> AsyncFunctionTaskTemplate[P, R]:
141
- task_name = name or func.__name__
142
- task_name = self.name + "." + task_name
141
+ friendly_name = name or func.__name__
142
+ task_name = self.name + "." + func.__name__
143
143
 
144
144
  tmpl: AsyncFunctionTaskTemplate = AsyncFunctionTaskTemplate(
145
145
  func=func,
@@ -157,6 +157,7 @@ class TaskEnvironment(Environment):
157
157
  parent_env=weakref.ref(self),
158
158
  interface=NativeInterface.from_callable(func),
159
159
  report=report,
160
+ friendly_name=friendly_name,
160
161
  )
161
162
  self._tasks[task_name] = tmpl
162
163
  return tmpl
@@ -13,6 +13,9 @@ async def run_coros(*coros: typing.Coroutine, return_when: str = asyncio.FIRST_C
13
13
  """
14
14
  tasks: typing.List[asyncio.Task[typing.Never]] = [asyncio.create_task(c) for c in coros]
15
15
  done, pending = await asyncio.wait(tasks, return_when=return_when)
16
+ # TODO we might want to handle asyncio.CancelledError here, for cases when the `action` is cancelled
17
+ # and we want to propagate it to all tasks. Though the backend will handle it anyway,
18
+ # so this is not strictly necessary.
16
19
 
17
20
  for t in pending: # type: asyncio.Task
18
21
  t.cancel() # Cancel all tasks that didn't finish first
flyte/_utils/helpers.py CHANGED
@@ -52,7 +52,33 @@ def base36_encode(byte_data: bytes) -> str:
52
52
  return "".join(reversed(base36))
53
53
 
54
54
 
55
- # does not work at all in the setuptools case. see old flytekit editable installs
55
+ def _iter_editable():
56
+ """
57
+ Yield (project_name, source_path) for every editable distribution
58
+ visible to the current interpreter
59
+ """
60
+ import json
61
+ import pathlib
62
+ from importlib.metadata import distributions
63
+
64
+ for dist in distributions():
65
+ # PEP-610 / PEP-660 (preferred, wheel-style editables)
66
+ direct = dist.read_text("direct_url.json")
67
+ if direct:
68
+ data = json.loads(direct)
69
+ if data.get("dir_info", {}).get("editable"): # spec key
70
+ # todo: will need testing on windows
71
+ yield dist.metadata["Name"], pathlib.Path(data["url"][7:]) # strip file://
72
+ continue
73
+
74
+ # Legacy setuptools-develop / pip-e (egg-link)
75
+ for file in dist.files or (): # importlib.metadata 3.8+
76
+ if file.suffix == ".egg-link":
77
+ with open(dist.locate_file(file), "r") as f:
78
+ line = f.readline()
79
+ yield dist.metadata["Name"], pathlib.Path(line.strip())
80
+
81
+
56
82
  def get_cwd_editable_install() -> typing.Optional[Path]:
57
83
  """
58
84
  This helper function is incomplete since it hasn't been tested with all the package managers out there,
@@ -65,28 +91,13 @@ def get_cwd_editable_install() -> typing.Optional[Path]:
65
91
 
66
92
  :return:
67
93
  """
68
- import site
69
94
 
70
95
  from flyte._logging import logger
71
96
 
72
- egg_links = [Path(p) for p in Path(site.getsitepackages()[0]).glob("*.egg-link")]
73
- pth_files = [Path(p) for p in Path(site.getsitepackages()[0]).glob("*.pth")]
74
-
75
- if not egg_links and not pth_files:
76
- logger.debug("No editable installs found.")
77
- return None
78
-
79
97
  editable_installs = []
80
- egg_links.extend(pth_files)
81
- for file in egg_links:
82
- with open(file, "r") as f:
83
- line = f.readline()
84
- if line:
85
- # Check if the first line is a directory
86
- p = Path(line)
87
- if p.is_dir():
88
- editable_installs.append(p)
89
- logger.debug(f"Editable installs: {editable_installs}")
98
+ for name, path in _iter_editable():
99
+ logger.debug(f"Detected editable install: {name} at {path}")
100
+ editable_installs.append(path)
90
101
 
91
102
  # check to see if the current working directory is in any of the editable installs
92
103
  # including if the current folder is the root folder, one level up from the src and contains
flyte/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.0b14'
21
- __version_tuple__ = version_tuple = (0, 2, 0, 'b14')
20
+ __version__ = version = '0.2.0b15'
21
+ __version_tuple__ = version_tuple = (0, 2, 0, 'b15')
flyte/cli/_common.py CHANGED
@@ -19,7 +19,6 @@ from rich.table import Table
19
19
  from rich.traceback import Traceback
20
20
 
21
21
  import flyte.errors
22
- from flyte._logging import logger
23
22
  from flyte.config import Config
24
23
 
25
24
  PREFERRED_BORDER_COLOR = "dim cyan"
@@ -104,8 +103,7 @@ class CLIConfig:
104
103
 
105
104
  updated_config = self.config.with_params(platform_cfg, task_cfg)
106
105
 
107
- logger.debug(f"Initializing CLI with config: {updated_config}")
108
- flyte.init_from_config(updated_config)
106
+ flyte.init_from_config(updated_config, log_level=self.log_level)
109
107
 
110
108
 
111
109
  class InvokeBaseMixin:
@@ -316,7 +314,7 @@ def get_table(title: str, vals: Iterable[Any]) -> Table:
316
314
  has_rich_repr = True
317
315
  elif not isinstance(p, (list, tuple)):
318
316
  raise ValueError("Expected a list or tuple of values, or an object with __rich_repr__ method.")
319
- o = p.__rich_repr__() if has_rich_repr else p
317
+ o = list(p.__rich_repr__()) if has_rich_repr else p
320
318
  if headers is None:
321
319
  headers = [k for k, _ in o]
322
320
  for h in headers:
flyte/cli/_deploy.py CHANGED
@@ -111,7 +111,7 @@ class EnvPerFileGroup(common.ObjectsPerFileGroup):
111
111
  name=obj_name,
112
112
  obj_name=obj_name,
113
113
  obj=obj,
114
- help=obj.description,
114
+ help=f"{obj.name}" + (f": {obj.description}" if obj.description else ""),
115
115
  deploy_args=self.deploy_args,
116
116
  )
117
117
 
flyte/cli/main.py CHANGED
@@ -144,7 +144,8 @@ def main(
144
144
  initialize_logger(log_level)
145
145
 
146
146
  cfg = config.auto(config_file=config_file)
147
- logger.debug(f"Using config file discovered at location {cfg.source}")
147
+ if cfg.source:
148
+ logger.debug(f"Using config file discovered at location `{cfg.source.absolute()}`")
148
149
 
149
150
  ctx.obj = CLIConfig(
150
151
  log_level=log_level,
@@ -154,7 +155,6 @@ def main(
154
155
  config=cfg,
155
156
  ctx=ctx,
156
157
  )
157
- logger.debug(f"Final materialized Cli config: {ctx.obj}")
158
158
 
159
159
 
160
160
  main.add_command(run)
flyte/errors.py CHANGED
@@ -161,3 +161,12 @@ class RuntimeDataValidationError(RuntimeUserError):
161
161
  super().__init__(
162
162
  "DataValiationError", f"In task {task_name} variable {var}, failed to serialize/deserialize because {e}"
163
163
  )
164
+
165
+
166
+ class DeploymentError(RuntimeUserError):
167
+ """
168
+ This error is raised when the deployment of a task fails, or some preconditions for deployment are not met.
169
+ """
170
+
171
+ def __init__(self, message: str):
172
+ super().__init__("DeploymentError", message, "user")
@@ -78,9 +78,9 @@ class ContainerTask(TaskTemplate):
78
78
  self._image = image
79
79
  if isinstance(image, str):
80
80
  if image == "auto":
81
- self._image = Image.auto()
81
+ self._image = Image.from_debian_base()
82
82
  else:
83
- self._image = Image.from_prebuilt(image)
83
+ self._image = Image.from_base(image)
84
84
  self._cmd = command
85
85
  self._args = arguments
86
86
  self._input_data_dir = input_data_dir
@@ -27,19 +27,12 @@ else:
27
27
  T = TypeVar("T")
28
28
 
29
29
 
30
- # pr: add back after storage
31
- def get_pandas_storage_options(uri: str, data_config=None, anonymous: bool = False) -> typing.Optional[typing.Dict]:
30
+ def get_pandas_storage_options(uri: str, anonymous: bool = False) -> typing.Optional[typing.Dict]:
32
31
  from pandas.io.common import is_fsspec_url # type: ignore
33
32
 
34
33
  if is_fsspec_url(uri):
35
34
  if uri.startswith("s3"):
36
- # pr: after storage, replace with real call to get_fsspec_storage_options
37
- return {
38
- "cache_regions": True,
39
- "client_kwargs": {"endpoint_url": "http://localhost:30002"},
40
- "key": "minio",
41
- "secret": "miniostorage",
42
- }
35
+ return storage.get_configured_fsspec_kwargs("s3", anonymous=anonymous)
43
36
  return {}
44
37
 
45
38
  # Pandas does not allow storage_options for non-fsspec paths e.g. local.
@@ -70,7 +63,7 @@ class PandasToCSVEncodingHandler(StructuredDatasetEncoder):
70
63
  df.to_csv(
71
64
  path,
72
65
  index=False,
73
- storage_options=get_pandas_storage_options(uri=path, data_config=None),
66
+ storage_options=get_pandas_storage_options(uri=path),
74
67
  )
75
68
  structured_dataset_type.format = CSV
76
69
  return literals_pb2.StructuredDataset(
@@ -87,20 +80,21 @@ class CSVToPandasDecodingHandler(StructuredDatasetDecoder):
87
80
  proto_value: literals_pb2.StructuredDataset,
88
81
  current_task_metadata: literals_pb2.StructuredDatasetMetadata,
89
82
  ) -> "pd.DataFrame":
90
- from botocore.exceptions import NoCredentialsError
91
-
92
83
  uri = proto_value.uri
93
84
  columns = None
94
- kwargs = get_pandas_storage_options(uri=uri, data_config=None)
85
+ kwargs = get_pandas_storage_options(uri=uri)
95
86
  path = os.path.join(uri, ".csv")
96
87
  if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
97
88
  columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
98
89
  try:
99
90
  return pd.read_csv(path, usecols=columns, storage_options=kwargs)
100
- except NoCredentialsError:
101
- logger.debug("S3 source detected, attempting anonymous S3 access")
102
- kwargs = get_pandas_storage_options(uri=uri, data_config=None, anonymous=True)
103
- return pd.read_csv(path, usecols=columns, storage_options=kwargs)
91
+ except Exception as exc:
92
+ if exc.__class__.__name__ == "NoCredentialsError":
93
+ logger.debug("S3 source detected, attempting anonymous S3 access")
94
+ kwargs = get_pandas_storage_options(uri=uri, anonymous=True)
95
+ return pd.read_csv(path, usecols=columns, storage_options=kwargs)
96
+ else:
97
+ raise
104
98
 
105
99
 
106
100
  class PandasToParquetEncodingHandler(StructuredDatasetEncoder):
@@ -128,7 +122,7 @@ class PandasToParquetEncodingHandler(StructuredDatasetEncoder):
128
122
  path,
129
123
  coerce_timestamps="us",
130
124
  allow_truncated_timestamps=False,
131
- storage_options=get_pandas_storage_options(uri=path, data_config=None),
125
+ storage_options=get_pandas_storage_options(uri=path),
132
126
  )
133
127
  structured_dataset_type.format = PARQUET
134
128
  return literals_pb2.StructuredDataset(
@@ -145,19 +139,20 @@ class ParquetToPandasDecodingHandler(StructuredDatasetDecoder):
145
139
  flyte_value: literals_pb2.StructuredDataset,
146
140
  current_task_metadata: literals_pb2.StructuredDatasetMetadata,
147
141
  ) -> "pd.DataFrame":
148
- from botocore.exceptions import NoCredentialsError
149
-
150
142
  uri = flyte_value.uri
151
143
  columns = None
152
- kwargs = get_pandas_storage_options(uri=uri, data_config=None)
144
+ kwargs = get_pandas_storage_options(uri=uri)
153
145
  if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
154
146
  columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
155
147
  try:
156
148
  return pd.read_parquet(uri, columns=columns, storage_options=kwargs)
157
- except NoCredentialsError:
158
- logger.debug("S3 source detected, attempting anonymous S3 access")
159
- kwargs = get_pandas_storage_options(uri=uri, data_config=None, anonymous=True)
160
- return pd.read_parquet(uri, columns=columns, storage_options=kwargs)
149
+ except Exception as exc:
150
+ if exc.__class__.__name__ == "NoCredentialsError":
151
+ logger.debug("S3 source detected, attempting anonymous S3 access")
152
+ kwargs = get_pandas_storage_options(uri=uri, anonymous=True)
153
+ return pd.read_parquet(uri, columns=columns, storage_options=kwargs)
154
+ else:
155
+ raise
161
156
 
162
157
 
163
158
  class ArrowToParquetEncodingHandler(StructuredDatasetEncoder):
@@ -199,7 +194,6 @@ class ParquetToArrowDecodingHandler(StructuredDatasetDecoder):
199
194
  current_task_metadata: literals_pb2.StructuredDatasetMetadata,
200
195
  ) -> "pa.Table":
201
196
  import pyarrow.parquet as pq
202
- from botocore.exceptions import NoCredentialsError
203
197
 
204
198
  uri = proto_value.uri
205
199
  if not storage.is_remote(uri):
@@ -211,9 +205,11 @@ class ParquetToArrowDecodingHandler(StructuredDatasetDecoder):
211
205
  columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
212
206
  try:
213
207
  return pq.read_table(path, columns=columns)
214
- except NoCredentialsError as e:
215
- logger.debug("S3 source detected, attempting anonymous S3 access")
216
- fs = storage.get_underlying_filesystem(path=uri, anonymous=True)
217
- if fs is not None:
218
- return pq.read_table(path, filesystem=fs, columns=columns)
219
- raise e
208
+ except Exception as exc:
209
+ if exc.__class__.__name__ == "NoCredentialsError":
210
+ logger.debug("S3 source detected, attempting anonymous S3 access")
211
+ fs = storage.get_underlying_filesystem(path=uri, anonymous=True)
212
+ if fs is not None:
213
+ return pq.read_table(path, filesystem=fs, columns=columns)
214
+ else:
215
+ raise