flyte 2.0.0b17__py3-none-any.whl → 2.0.0b18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

@@ -0,0 +1,342 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import inspect
5
+ from typing import Any, Iterable, Optional, Protocol, Union, runtime_checkable
6
+
7
+
8
+ @runtime_checkable
9
+ class HashMethod(Protocol):
10
+ def update(self, data: memoryview, /) -> None: ...
11
+ def result(self) -> str: ...
12
+
13
+ # Optional convenience; not required by the writers.
14
+ def reset(self) -> None: ...
15
+
16
+
17
+ class PrecomputedValue(HashMethod):
18
+ def __init__(self, value: str):
19
+ self._value = value
20
+
21
+ def update(self, data: memoryview, /) -> None: ...
22
+
23
+ def result(self) -> str:
24
+ return self._value
25
+
26
+
27
+ class HashlibAccumulator(HashMethod):
28
+ """
29
+ Wrap a hashlib-like object to the Accumulator protocol.
30
+ h = hashlib.new("sha256")
31
+ acc = HashlibAccumulator(h)
32
+ """
33
+
34
+ def __init__(self, h):
35
+ self._h = h
36
+
37
+ def update(self, data: memoryview, /) -> None:
38
+ self._h.update(data)
39
+
40
+ def result(self) -> Any:
41
+ return self._h.hexdigest()
42
+
43
+ @classmethod
44
+ def from_hash_name(cls, name: str) -> HashlibAccumulator:
45
+ """
46
+ Create an accumulator from a hashlib algorithm name.
47
+ """
48
+ h = hashlib.new(name)
49
+ return cls(h)
50
+
51
+
52
+ class HashingWriter:
53
+ """
54
+ Sync writer that updates a user-supplied accumulator on every write.
55
+
56
+ Hashing covers the exact bytes you pass in. If you write str, it is encoded
57
+ using the underlying file's .encoding if available, else UTF-8 (for hashing only).
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ fh,
63
+ accumulator: HashMethod,
64
+ *,
65
+ encoding: Optional[str] = None,
66
+ errors: str = "strict",
67
+ ):
68
+ self._fh = fh
69
+ self._acc: HashMethod = accumulator
70
+ self._encoding = encoding or getattr(fh, "encoding", None)
71
+ self._errors = errors
72
+
73
+ def result(self) -> Any:
74
+ return self._acc.result()
75
+
76
+ def _to_bytes_mv(self, data) -> memoryview:
77
+ if isinstance(data, str):
78
+ b = data.encode(self._encoding or "utf-8", self._errors)
79
+ return memoryview(b)
80
+ if isinstance(data, (bytes, bytearray, memoryview)):
81
+ return memoryview(data)
82
+ # Accept any buffer-protocol object (e.g., numpy arrays)
83
+ return memoryview(data)
84
+
85
+ def write(self, data):
86
+ mv = self._to_bytes_mv(data)
87
+ self._acc.update(mv)
88
+ return self._fh.write(data)
89
+
90
+ def writelines(self, lines: Iterable[Union[str, bytes, bytearray, memoryview]]):
91
+ for line in lines:
92
+ self.write(line)
93
+
94
+ def flush(self):
95
+ return self._fh.flush()
96
+
97
+ def close(self):
98
+ return self._fh.close()
99
+
100
+ def __getattr__(self, name):
101
+ return getattr(self._fh, name)
102
+
103
+
104
+ class AsyncHashingWriter:
105
+ """
106
+ Async version of HashingWriter with the same behavior.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ fh,
112
+ accumulator: HashMethod,
113
+ *,
114
+ encoding: Optional[str] = None,
115
+ errors: str = "strict",
116
+ ):
117
+ self._fh = fh
118
+ self._acc = accumulator
119
+ self._encoding = encoding or getattr(fh, "encoding", None)
120
+ self._errors = errors
121
+
122
+ def result(self) -> Any:
123
+ return self._acc.result()
124
+
125
+ def _to_bytes_mv(self, data) -> memoryview:
126
+ if isinstance(data, str):
127
+ b = data.encode(self._encoding or "utf-8", self._errors)
128
+ return memoryview(b)
129
+ if isinstance(data, (bytes, bytearray, memoryview)):
130
+ return memoryview(data)
131
+ return memoryview(data)
132
+
133
+ async def write(self, data):
134
+ mv = self._to_bytes_mv(data)
135
+ self._acc.update(mv)
136
+ return await self._fh.write(data)
137
+
138
+ async def writelines(self, lines: Iterable[Union[str, bytes, bytearray, memoryview]]):
139
+ for line in lines:
140
+ await self.write(line)
141
+
142
+ async def flush(self):
143
+ fn = getattr(self._fh, "flush", None)
144
+ if fn is None:
145
+ return None
146
+ res = fn()
147
+ if inspect.isawaitable(res):
148
+ return await res
149
+ return res
150
+
151
+ async def close(self):
152
+ fn = getattr(self._fh, "close", None)
153
+ if fn is None:
154
+ return None
155
+ res = fn()
156
+ if inspect.isawaitable(res):
157
+ return await res
158
+ return res
159
+
160
+ def __getattr__(self, name):
161
+ return getattr(self._fh, name)
162
+
163
+
164
+ class HashingReader:
165
+ """
166
+ Sync reader that updates a user-supplied accumulator on every read operation.
167
+
168
+ If the underlying handle returns str (text mode), we encode it for hashing only,
169
+ using the handle's .encoding if present, else the explicit 'encoding' arg, else UTF-8.
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ fh,
175
+ accumulator: HashMethod,
176
+ *,
177
+ encoding: Optional[str] = None,
178
+ errors: str = "strict",
179
+ ):
180
+ self._fh = fh
181
+ self._acc = accumulator
182
+ self._encoding = encoding or getattr(fh, "encoding", None)
183
+ self._errors = errors
184
+
185
+ def result(self) -> str:
186
+ return self._acc.result()
187
+
188
+ def _to_bytes_mv(self, data) -> Optional[memoryview]:
189
+ if data is None:
190
+ return None
191
+ if isinstance(data, str):
192
+ return memoryview(data.encode(self._encoding or "utf-8", self._errors))
193
+ if isinstance(data, (bytes, bytearray, memoryview)):
194
+ return memoryview(data)
195
+ # Accept any buffer-protocol object (rare for read paths, but safe)
196
+ return memoryview(data)
197
+
198
+ def read(self, size: int = -1):
199
+ data = self._fh.read(size)
200
+ mv = self._to_bytes_mv(data)
201
+ if mv is not None and len(mv) > 0:
202
+ self._acc.update(mv)
203
+ return data
204
+
205
+ def readline(self, size: int = -1):
206
+ line = self._fh.readline(size)
207
+ mv = self._to_bytes_mv(line)
208
+ if mv is not None and len(mv) > 0:
209
+ self._acc.update(mv)
210
+ return line
211
+
212
+ def readlines(self, hint: int = -1):
213
+ lines = self._fh.readlines(hint)
214
+ # Update in order to reflect exact concatenation
215
+ for line in lines:
216
+ mv = self._to_bytes_mv(line)
217
+ if mv is not None and len(mv) > 0:
218
+ self._acc.update(mv)
219
+ return lines
220
+
221
+ def __iter__(self):
222
+ return self
223
+
224
+ def __next__(self):
225
+ # Delegate to the underlying iterator to preserve semantics (including buffering),
226
+ # but intercept the produced line to hash it.
227
+ line = next(self._fh)
228
+ mv = self._to_bytes_mv(line)
229
+ if mv is not None and len(mv) > 0:
230
+ self._acc.update(mv)
231
+ return line
232
+
233
+ # ---- passthrough ----
234
+ def __getattr__(self, name):
235
+ # Avoid leaking private lookups to the underlying object if we're missing something internal
236
+ if name.startswith("_"):
237
+ raise AttributeError(name)
238
+ return getattr(self._fh, name)
239
+
240
+
241
+ class AsyncHashingReader:
242
+ """
243
+ Async reader that updates a user-supplied accumulator on every read operation.
244
+
245
+ Works with aiofiles/fsspec async handles. `flush`/`close` may be awaitable or sync.
246
+ """
247
+
248
+ def __init__(
249
+ self,
250
+ fh,
251
+ accumulator: HashMethod,
252
+ *,
253
+ encoding: Optional[str] = None,
254
+ errors: str = "strict",
255
+ ):
256
+ self._fh = fh
257
+ self._acc = accumulator
258
+ self._encoding = encoding or getattr(fh, "encoding", None)
259
+ self._errors = errors
260
+
261
+ def result(self) -> str:
262
+ return self._acc.result()
263
+
264
+ def _to_bytes_mv(self, data) -> Optional[memoryview]:
265
+ if data is None:
266
+ return None
267
+ if isinstance(data, str):
268
+ return memoryview(data.encode(self._encoding or "utf-8", self._errors))
269
+ if isinstance(data, (bytes, bytearray, memoryview)):
270
+ return memoryview(data)
271
+ return memoryview(data)
272
+
273
+ async def read(self, size: int = -1):
274
+ data = await self._fh.read(size)
275
+ mv = self._to_bytes_mv(data)
276
+ if mv is not None and len(mv) > 0:
277
+ self._acc.update(mv)
278
+ return data
279
+
280
+ async def readline(self, size: int = -1):
281
+ line = await self._fh.readline(size)
282
+ mv = self._to_bytes_mv(line)
283
+ if mv is not None and len(mv) > 0:
284
+ self._acc.update(mv)
285
+ return line
286
+
287
+ async def readlines(self, hint: int = -1):
288
+ # Some async filehandles implement readlines(); if not, fall back to manual loop.
289
+ if hasattr(self._fh, "readlines"):
290
+ lines = await self._fh.readlines(hint)
291
+ for line in lines:
292
+ mv = self._to_bytes_mv(line)
293
+ if mv is not None and len(mv) > 0:
294
+ self._acc.update(mv)
295
+ return lines
296
+ # Fallback: read all via iteration
297
+ lines = []
298
+ async for line in self:
299
+ lines.append(line)
300
+ return lines
301
+
302
+ def __aiter__(self):
303
+ return self
304
+
305
+ async def __anext__(self):
306
+ # Prefer the underlying async iterator if present
307
+ anext_fn = getattr(self._fh, "__anext__", None)
308
+ if anext_fn is not None:
309
+ try:
310
+ line = await anext_fn()
311
+ except StopAsyncIteration:
312
+ raise
313
+ mv = self._to_bytes_mv(line)
314
+ if mv is not None and len(mv) > 0:
315
+ self._acc.update(mv)
316
+ return line
317
+
318
+ # Fallback to readline-based iteration
319
+ line = await self.readline()
320
+ if line == "" or line == b"":
321
+ raise StopAsyncIteration
322
+ return line
323
+
324
+ async def flush(self):
325
+ fn = getattr(self._fh, "flush", None)
326
+ if fn is None:
327
+ return None
328
+ res = fn()
329
+ return await res if inspect.isawaitable(res) else res
330
+
331
+ async def close(self):
332
+ fn = getattr(self._fh, "close", None)
333
+ if fn is None:
334
+ return None
335
+ res = fn()
336
+ return await res if inspect.isawaitable(res) else res
337
+
338
+ # ---- passthrough ----
339
+ def __getattr__(self, name):
340
+ if name.startswith("_"):
341
+ raise AttributeError(name)
342
+ return getattr(self._fh, name)
flyte/models.py CHANGED
@@ -342,9 +342,15 @@ class NativeInterface:
342
342
  """
343
343
  # Convert positional arguments to keyword arguments
344
344
  if len(args) > len(self.inputs):
345
- raise ValueError(f"Too many positional arguments provided, inputs {self.inputs.keys()}, args {len(args)}")
345
+ raise ValueError(
346
+ f"Too many positional arguments provided, expected inputs {self.inputs.keys()}, args {len(args)}"
347
+ )
346
348
  for arg, input_name in zip(args, self.inputs.keys()):
347
349
  kwargs[input_name] = arg
350
+ if len(kwargs) > len(self.inputs):
351
+ raise ValueError(
352
+ f"Too many keyword arguments provided, expected inputs {self.inputs.keys()}, args {kwargs.keys()}"
353
+ )
348
354
  return kwargs
349
355
 
350
356
  def get_input_types(self) -> Dict[str, Type]:
@@ -410,13 +416,15 @@ class SerializationContext:
410
416
  code_bundle: Optional[CodeBundle] = None
411
417
  input_path: str = "{{.input}}"
412
418
  output_path: str = "{{.outputPrefix}}"
413
- _entrypoint_path: str = field(default="_bin/runtime.py", init=False)
419
+ interpreter_path: str = "/opt/venv/bin/python"
414
420
  image_cache: ImageCache | None = None
415
421
  root_dir: Optional[pathlib.Path] = None
416
422
 
417
- def get_entrypoint_path(self, interpreter_path: str) -> str:
423
+ def get_entrypoint_path(self, interpreter_path: Optional[str] = None) -> str:
418
424
  """
419
425
  Get the entrypoint path for the task. This is used to determine the entrypoint for the task execution.
420
426
  :param interpreter_path: The path to the interpreter (python)
421
427
  """
422
- return os.path.join(os.path.dirname(os.path.dirname(interpreter_path)), self._entrypoint_path)
428
+ if interpreter_path is None:
429
+ interpreter_path = self.interpreter_path
430
+ return os.path.join(os.path.dirname(interpreter_path), "runtime.py")
flyte/remote/_action.py CHANGED
@@ -327,9 +327,10 @@ class Action(ToJSONMixin):
327
327
  )
328
328
  else:
329
329
  details = await self.details()
330
+ error_message = details.error_info.message if details.error_info else ""
330
331
  console.print(
331
332
  f"[bold red]Action '{self.name}' in Run '{self.run_name}'"
332
- f" exited unsuccessfully in state {self.phase} with error: {details.error_info}[/bold red]"
333
+ f" exited unsuccessfully in state {self.phase} with error: {error_message}[/bold red]"
333
334
  )
334
335
  return
335
336
 
@@ -371,9 +372,10 @@ class Action(ToJSONMixin):
371
372
  if ad.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
372
373
  console.print(f"[bold green]Run '{self.run_name}' completed successfully.[/bold green]")
373
374
  else:
375
+ error_message = ad.error_info.message if ad.error_info else ""
374
376
  console.print(
375
377
  f"[bold red]Run '{self.run_name}' exited unsuccessfully in state {ad.phase}"
376
- f" with error: {ad.error_info}[/bold red]"
378
+ f" with error: {error_message}[/bold red]"
377
379
  )
378
380
  break
379
381
  except asyncio.CancelledError:
flyte/storage/_storage.py CHANGED
@@ -14,6 +14,7 @@ from obstore.fsspec import register
14
14
 
15
15
  from flyte._initialize import get_storage
16
16
  from flyte._logging import logger
17
+ from flyte.errors import InitializationError
17
18
 
18
19
  _OBSTORE_SUPPORTED_PROTOCOLS = ["s3", "gs", "abfs", "abfss"]
19
20
 
@@ -77,21 +78,36 @@ def get_configured_fsspec_kwargs(
77
78
  protocol: typing.Optional[str] = None, anonymous: bool = False
78
79
  ) -> typing.Dict[str, typing.Any]:
79
80
  if protocol:
81
+ # Try to get storage config safely - may not be initialized for local operations
82
+ try:
83
+ storage_config = get_storage()
84
+ except InitializationError:
85
+ storage_config = None
86
+
80
87
  match protocol:
81
88
  case "s3":
82
89
  # If the protocol is s3, we can use the s3 filesystem
83
90
  from flyte.storage import S3
84
91
 
92
+ if storage_config and isinstance(storage_config, S3):
93
+ return storage_config.get_fsspec_kwargs(anonymous=anonymous)
94
+
85
95
  return S3.auto().get_fsspec_kwargs(anonymous=anonymous)
86
96
  case "gs":
87
97
  # If the protocol is gs, we can use the gs filesystem
88
98
  from flyte.storage import GCS
89
99
 
100
+ if storage_config and isinstance(storage_config, GCS):
101
+ return storage_config.get_fsspec_kwargs(anonymous=anonymous)
102
+
90
103
  return GCS.auto().get_fsspec_kwargs(anonymous=anonymous)
91
104
  case "abfs" | "abfss":
92
105
  # If the protocol is abfs or abfss, we can use the abfs filesystem
93
106
  from flyte.storage import ABFS
94
107
 
108
+ if storage_config and isinstance(storage_config, ABFS):
109
+ return storage_config.get_fsspec_kwargs(anonymous=anonymous)
110
+
95
111
  return ABFS.auto().get_fsspec_kwargs(anonymous=anonymous)
96
112
  case _:
97
113
  return {}
@@ -145,7 +161,6 @@ async def get(from_path: str, to_path: Optional[str | pathlib.Path] = None, recu
145
161
  else:
146
162
  exists = file_system.exists(from_path)
147
163
  if not exists:
148
- # TODO: update exception to be more specific
149
164
  raise AssertionError(f"Unable to load data from {from_path}")
150
165
  file_system = _get_anonymous_filesystem(from_path)
151
166
  logger.debug(f"Attempting anonymous get with {file_system}")
@@ -39,7 +39,6 @@ from pydantic import BaseModel
39
39
  from typing_extensions import Annotated, get_args, get_origin
40
40
 
41
41
  import flyte.storage as storage
42
- from flyte._hash import HashMethod
43
42
  from flyte._logging import logger
44
43
  from flyte._utils.helpers import load_proto_from_file
45
44
  from flyte.models import NativeInterface
@@ -1096,23 +1095,6 @@ class TypeEngine(typing.Generic[T]):
1096
1095
  ):
1097
1096
  raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}")
1098
1097
 
1099
- @classmethod
1100
- def calculate_hash(cls, python_val: typing.Any, python_type: Type[T]) -> Optional[str]:
1101
- # In case the value is an annotated type we inspect the annotations and look for hash-related annotations.
1102
- hsh = None
1103
- if is_annotated(python_type):
1104
- # We are now dealing with one of two cases:
1105
- # 1. The annotated type is a `HashMethod`, which indicates that we should produce the hash using
1106
- # the method indicated in the annotation.
1107
- # 2. The annotated type is being used for a different purpose other than calculating hash values,
1108
- # in which case we should just continue.
1109
- for annotation in get_args(python_type)[1:]:
1110
- if not isinstance(annotation, HashMethod):
1111
- continue
1112
- hsh = annotation.calculate(python_val)
1113
- break
1114
- return hsh
1115
-
1116
1098
  @classmethod
1117
1099
  async def to_literal(
1118
1100
  cls, python_val: typing.Any, python_type: Type[T], expected: types_pb2.LiteralType
@@ -1125,8 +1107,6 @@ class TypeEngine(typing.Generic[T]):
1125
1107
  lv = await transformer.to_literal(python_val, python_type, expected)
1126
1108
 
1127
1109
  modify_literal_uris(lv)
1128
- calculated_hash = cls.calculate_hash(python_val, python_type) or ""
1129
- lv.hash = calculated_hash
1130
1110
  return lv
1131
1111
 
1132
1112
  @classmethod
@@ -1795,7 +1775,6 @@ class DictTransformer(TypeTransformer[dict]):
1795
1775
  for k, v in python_val.items():
1796
1776
  if type(k) is not str:
1797
1777
  raise ValueError("Flyte MapType expects all keys to be strings")
1798
- # TODO: log a warning for Annotated objects that contain HashMethod
1799
1778
 
1800
1779
  _, v_type = self.extract_types(python_type)
1801
1780
  lit_map[k] = TypeEngine.to_literal(v, cast(type, v_type), expected.map_value_type)
@@ -101,6 +101,7 @@ def main(
101
101
  from flyte._logging import logger
102
102
  from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
103
103
 
104
+ logger.warning(f"Flyte runtime started for action {name} with run name {run_name}")
104
105
  logger.info("Registering faulthandler for SIGUSR1")
105
106
  faulthandler.register(signal.SIGUSR1)
106
107
 
@@ -168,6 +169,7 @@ def main(
168
169
  await controller.stop()
169
170
 
170
171
  asyncio.run(_run_and_stop())
172
+ logger.warning(f"Flyte runtime completed for action {name} with run name {run_name}")
171
173
 
172
174
 
173
175
  if __name__ == "__main__":
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flyte
3
- Version: 2.0.0b17
3
+ Version: 2.0.0b18
4
4
  Summary: Add your description here
5
5
  Author-email: Ketan Umare <kumare3@users.noreply.github.com>
6
6
  Requires-Python: >=3.10