flyte 2.0.0b17__py3-none-any.whl → 2.0.0b18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/_bin/runtime.py +2 -0
- flyte/_debug/vscode.py +4 -2
- flyte/_deploy.py +3 -1
- flyte/_environment.py +15 -6
- flyte/_hash.py +1 -16
- flyte/_image.py +6 -1
- flyte/_initialize.py +14 -15
- flyte/_internal/controllers/remote/_controller.py +5 -8
- flyte/_internal/controllers/remote/_core.py +1 -1
- flyte/_internal/imagebuild/docker_builder.py +31 -23
- flyte/_internal/imagebuild/remote_builder.py +37 -10
- flyte/_internal/imagebuild/utils.py +2 -1
- flyte/_internal/runtime/convert.py +69 -2
- flyte/_internal/runtime/taskrunner.py +4 -1
- flyte/_logging.py +110 -26
- flyte/_pod.py +2 -1
- flyte/_run.py +6 -1
- flyte/_task_environment.py +5 -1
- flyte/_trace.py +5 -0
- flyte/_version.py +3 -3
- flyte/io/_dir.py +35 -4
- flyte/io/_file.py +61 -15
- flyte/io/_hashing_io.py +342 -0
- flyte/models.py +12 -4
- flyte/remote/_action.py +4 -2
- flyte/storage/_storage.py +16 -1
- flyte/types/_type_engine.py +0 -21
- {flyte-2.0.0b17.data → flyte-2.0.0b18.data}/scripts/runtime.py +2 -0
- {flyte-2.0.0b17.dist-info → flyte-2.0.0b18.dist-info}/METADATA +1 -1
- {flyte-2.0.0b17.dist-info → flyte-2.0.0b18.dist-info}/RECORD +35 -34
- {flyte-2.0.0b17.data → flyte-2.0.0b18.data}/scripts/debug.py +0 -0
- {flyte-2.0.0b17.dist-info → flyte-2.0.0b18.dist-info}/WHEEL +0 -0
- {flyte-2.0.0b17.dist-info → flyte-2.0.0b18.dist-info}/entry_points.txt +0 -0
- {flyte-2.0.0b17.dist-info → flyte-2.0.0b18.dist-info}/licenses/LICENSE +0 -0
- {flyte-2.0.0b17.dist-info → flyte-2.0.0b18.dist-info}/top_level.txt +0 -0
flyte/_logging.py
CHANGED
|
@@ -75,39 +75,52 @@ def get_rich_handler(log_level: int) -> Optional[logging.Handler]:
|
|
|
75
75
|
return handler
|
|
76
76
|
|
|
77
77
|
|
|
78
|
-
def get_default_handler(log_level: int) -> logging.Handler:
|
|
79
|
-
handler = logging.StreamHandler()
|
|
80
|
-
handler.setLevel(log_level)
|
|
81
|
-
formatter = logging.Formatter(fmt="[%(name)s] %(message)s")
|
|
82
|
-
if log_format_from_env() == "json":
|
|
83
|
-
pass
|
|
84
|
-
# formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s")
|
|
85
|
-
handler.setFormatter(formatter)
|
|
86
|
-
return handler
|
|
87
|
-
|
|
88
|
-
|
|
89
78
|
def initialize_logger(log_level: int = get_env_log_level(), enable_rich: bool = False):
|
|
90
79
|
"""
|
|
91
80
|
Initializes the global loggers to the default configuration.
|
|
81
|
+
When enable_rich=True, upgrades to Rich handler for local CLI usage.
|
|
92
82
|
"""
|
|
93
83
|
global logger # noqa: PLW0603
|
|
94
|
-
logger = _create_logger("flyte", log_level, enable_rich)
|
|
95
84
|
|
|
85
|
+
# Clear existing handlers to reconfigure
|
|
86
|
+
root = logging.getLogger()
|
|
87
|
+
root.handlers.clear()
|
|
96
88
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
89
|
+
flyte_logger = logging.getLogger("flyte")
|
|
90
|
+
flyte_logger.handlers.clear()
|
|
91
|
+
|
|
92
|
+
# Set up root logger handler
|
|
93
|
+
root_handler = None
|
|
94
|
+
if enable_rich:
|
|
95
|
+
root_handler = get_rich_handler(log_level)
|
|
96
|
+
|
|
97
|
+
if root_handler is None:
|
|
98
|
+
root_handler = logging.StreamHandler()
|
|
99
|
+
|
|
100
|
+
# Add context filter to root handler for all logging
|
|
101
|
+
root_handler.addFilter(ContextFilter())
|
|
102
|
+
root.addHandler(root_handler)
|
|
103
|
+
|
|
104
|
+
# Set up Flyte logger handler
|
|
105
|
+
flyte_handler = None
|
|
105
106
|
if enable_rich:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
107
|
+
flyte_handler = get_rich_handler(log_level)
|
|
108
|
+
|
|
109
|
+
if flyte_handler is None:
|
|
110
|
+
flyte_handler = logging.StreamHandler()
|
|
111
|
+
flyte_handler.setLevel(log_level)
|
|
112
|
+
formatter = logging.Formatter(fmt="%(message)s")
|
|
113
|
+
flyte_handler.setFormatter(formatter)
|
|
114
|
+
|
|
115
|
+
# Add both filters to Flyte handler
|
|
116
|
+
flyte_handler.addFilter(FlyteInternalFilter())
|
|
117
|
+
flyte_handler.addFilter(ContextFilter())
|
|
118
|
+
|
|
119
|
+
flyte_logger.addHandler(flyte_handler)
|
|
120
|
+
flyte_logger.setLevel(log_level)
|
|
121
|
+
flyte_logger.propagate = False # Prevent double logging
|
|
122
|
+
|
|
123
|
+
logger = flyte_logger
|
|
111
124
|
|
|
112
125
|
|
|
113
126
|
def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
|
|
@@ -135,4 +148,75 @@ def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
|
|
|
135
148
|
return decorator(fn)
|
|
136
149
|
|
|
137
150
|
|
|
138
|
-
|
|
151
|
+
class ContextFilter(logging.Filter):
|
|
152
|
+
"""
|
|
153
|
+
A logging filter that adds the current action's run name and name to all log records.
|
|
154
|
+
Applied globally to capture context for both user and Flyte internal logging.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
def filter(self, record):
|
|
158
|
+
from flyte._context import ctx
|
|
159
|
+
|
|
160
|
+
c = ctx()
|
|
161
|
+
if c:
|
|
162
|
+
action = c.action
|
|
163
|
+
record.msg = f"[{action.run_name}][{action.name}] {record.msg}"
|
|
164
|
+
return True
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class FlyteInternalFilter(logging.Filter):
|
|
168
|
+
"""
|
|
169
|
+
A logging filter that adds [flyte] prefix to internal Flyte logging only.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def filter(self, record):
|
|
173
|
+
if record.name.startswith("flyte"):
|
|
174
|
+
record.msg = f"[flyte] {record.msg}"
|
|
175
|
+
return True
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _setup_root_logger():
|
|
179
|
+
"""
|
|
180
|
+
Configure the root logger to capture all logging with context information.
|
|
181
|
+
This ensures both user code and Flyte internal logging get the context.
|
|
182
|
+
"""
|
|
183
|
+
root = logging.getLogger()
|
|
184
|
+
root.handlers.clear() # Remove any existing handlers to prevent double logging
|
|
185
|
+
|
|
186
|
+
# Create a basic handler for the root logger
|
|
187
|
+
handler = logging.StreamHandler()
|
|
188
|
+
# Add context filter to ALL logging
|
|
189
|
+
handler.addFilter(ContextFilter())
|
|
190
|
+
|
|
191
|
+
# Simple formatter since filters handle prefixes
|
|
192
|
+
root.addHandler(handler)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _create_flyte_logger() -> logging.Logger:
|
|
196
|
+
"""
|
|
197
|
+
Create the internal Flyte logger with [flyte] prefix.
|
|
198
|
+
"""
|
|
199
|
+
flyte_logger = logging.getLogger("flyte")
|
|
200
|
+
flyte_logger.setLevel(get_env_log_level())
|
|
201
|
+
|
|
202
|
+
# Add a handler specifically for flyte logging with the prefix filter
|
|
203
|
+
handler = logging.StreamHandler()
|
|
204
|
+
handler.setLevel(get_env_log_level())
|
|
205
|
+
handler.addFilter(FlyteInternalFilter())
|
|
206
|
+
handler.addFilter(ContextFilter())
|
|
207
|
+
|
|
208
|
+
formatter = logging.Formatter(fmt="%(message)s")
|
|
209
|
+
handler.setFormatter(formatter)
|
|
210
|
+
|
|
211
|
+
# Prevent propagation to root to avoid double logging
|
|
212
|
+
flyte_logger.propagate = False
|
|
213
|
+
flyte_logger.addHandler(handler)
|
|
214
|
+
|
|
215
|
+
return flyte_logger
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# Initialize root logger for global context
|
|
219
|
+
_setup_root_logger()
|
|
220
|
+
|
|
221
|
+
# Create the Flyte internal logger
|
|
222
|
+
logger = _create_flyte_logger()
|
flyte/_pod.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict, Optional
|
|
|
3
3
|
|
|
4
4
|
if TYPE_CHECKING:
|
|
5
5
|
from flyteidl.core.tasks_pb2 import K8sPod
|
|
6
|
-
from kubernetes.client import
|
|
6
|
+
from kubernetes.client import V1PodSpec
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
|
|
@@ -21,6 +21,7 @@ class PodTemplate(object):
|
|
|
21
21
|
|
|
22
22
|
def to_k8s_pod(self) -> "K8sPod":
|
|
23
23
|
from flyteidl.core.tasks_pb2 import K8sObjectMetadata, K8sPod
|
|
24
|
+
from kubernetes.client import ApiClient
|
|
24
25
|
|
|
25
26
|
return K8sPod(
|
|
26
27
|
metadata=K8sObjectMetadata(labels=self.labels, annotations=self.annotations),
|
flyte/_run.py
CHANGED
|
@@ -161,7 +161,12 @@ class _Runner:
|
|
|
161
161
|
code_bundle = cached_value.code_bundle
|
|
162
162
|
image_cache = cached_value.image_cache
|
|
163
163
|
else:
|
|
164
|
-
|
|
164
|
+
if not self._dry_run:
|
|
165
|
+
image_cache = await build_images.aio(cast(Environment, obj.parent_env()))
|
|
166
|
+
else:
|
|
167
|
+
from ._internal.imagebuild.image_builder import ImageCache
|
|
168
|
+
|
|
169
|
+
image_cache = ImageCache(image_lookup={})
|
|
165
170
|
|
|
166
171
|
if self._interactive_mode:
|
|
167
172
|
code_bundle = await build_pkl_bundle(
|
flyte/_task_environment.py
CHANGED
|
@@ -18,7 +18,7 @@ from typing import (
|
|
|
18
18
|
|
|
19
19
|
import rich.repr
|
|
20
20
|
|
|
21
|
-
from ._cache import CacheRequest
|
|
21
|
+
from ._cache import Cache, CacheRequest
|
|
22
22
|
from ._doc import Documentation
|
|
23
23
|
from ._environment import Environment
|
|
24
24
|
from ._image import Image
|
|
@@ -74,6 +74,10 @@ class TaskEnvironment(Environment):
|
|
|
74
74
|
super().__post_init__()
|
|
75
75
|
if self.reusable is not None and self.plugin_config is not None:
|
|
76
76
|
raise ValueError("Cannot set plugin_config when environment is reusable.")
|
|
77
|
+
if self.reusable and not isinstance(self.reusable, ReusePolicy):
|
|
78
|
+
raise TypeError(f"Expected reusable to be of type ReusePolicy, got {type(self.reusable)}")
|
|
79
|
+
if self.cache and not isinstance(self.cache, (str, Cache)):
|
|
80
|
+
raise TypeError(f"Expected cache to be of type str or Cache, got {type(self.cache)}")
|
|
77
81
|
|
|
78
82
|
def clone_with(
|
|
79
83
|
self,
|
flyte/_trace.py
CHANGED
|
@@ -3,6 +3,7 @@ import inspect
|
|
|
3
3
|
import time
|
|
4
4
|
from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Callable, TypeGuard, TypeVar, Union, cast
|
|
5
5
|
|
|
6
|
+
from flyte._logging import logger
|
|
6
7
|
from flyte.models import NativeInterface
|
|
7
8
|
|
|
8
9
|
T = TypeVar("T")
|
|
@@ -33,10 +34,13 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
|
|
|
33
34
|
iface = NativeInterface.from_callable(func)
|
|
34
35
|
info, ok = await controller.get_action_outputs(iface, func, *args, **kwargs)
|
|
35
36
|
if ok:
|
|
37
|
+
logger.info(f"Found existing trace info for {func}, {info}")
|
|
36
38
|
if info.output:
|
|
37
39
|
return info.output
|
|
38
40
|
elif info.error:
|
|
39
41
|
raise info.error
|
|
42
|
+
else:
|
|
43
|
+
logger.debug(f"No existing trace info found for {func}, proceeding to execute.")
|
|
40
44
|
start_time = time.time()
|
|
41
45
|
try:
|
|
42
46
|
# Cast to Awaitable to satisfy mypy
|
|
@@ -44,6 +48,7 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
|
|
|
44
48
|
results = await coroutine_result
|
|
45
49
|
info.add_outputs(results, start_time=start_time, end_time=time.time())
|
|
46
50
|
await controller.record_trace(info)
|
|
51
|
+
logger.debug(f"Finished trace for {func}, {info}")
|
|
47
52
|
return results
|
|
48
53
|
except Exception as e:
|
|
49
54
|
# If there is an error, we need to record it
|
flyte/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '2.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (2, 0, 0, '
|
|
31
|
+
__version__ = version = '2.0.0b18'
|
|
32
|
+
__version_tuple__ = version_tuple = (2, 0, 0, 'b18')
|
|
33
33
|
|
|
34
|
-
__commit_id__ = commit_id = '
|
|
34
|
+
__commit_id__ = commit_id = 'g930faeaea'
|
flyte/io/_dir.py
CHANGED
|
@@ -48,6 +48,7 @@ class Dir(BaseModel, Generic[T], SerializableType):
|
|
|
48
48
|
path: str
|
|
49
49
|
name: Optional[str] = None
|
|
50
50
|
format: str = ""
|
|
51
|
+
hash: Optional[str] = None
|
|
51
52
|
|
|
52
53
|
class Config:
|
|
53
54
|
arbitrary_types_allowed = True
|
|
@@ -248,13 +249,20 @@ class Dir(BaseModel, Generic[T], SerializableType):
|
|
|
248
249
|
raise NotImplementedError("Sync download is not implemented for remote paths")
|
|
249
250
|
|
|
250
251
|
@classmethod
|
|
251
|
-
async def from_local(
|
|
252
|
+
async def from_local(
|
|
253
|
+
cls,
|
|
254
|
+
local_path: Union[str, Path],
|
|
255
|
+
remote_path: Optional[str] = None,
|
|
256
|
+
dir_cache_key: Optional[str] = None,
|
|
257
|
+
) -> Dir[T]:
|
|
252
258
|
"""
|
|
253
259
|
Asynchronously create a new Dir by uploading a local directory to the configured remote store.
|
|
254
260
|
|
|
255
261
|
Args:
|
|
256
262
|
local_path: Path to the local directory
|
|
257
263
|
remote_path: Optional path to store the directory remotely. If None, a path will be generated.
|
|
264
|
+
dir_cache_key: If you have a precomputed hash value you want to use when computing cache keys for
|
|
265
|
+
discoverable tasks that this File is an input to.
|
|
258
266
|
|
|
259
267
|
Returns:
|
|
260
268
|
A new Dir instance pointing to the uploaded directory
|
|
@@ -262,13 +270,34 @@ class Dir(BaseModel, Generic[T], SerializableType):
|
|
|
262
270
|
Example:
|
|
263
271
|
```python
|
|
264
272
|
remote_dir = await Dir[DataFrame].from_local('/tmp/data_dir/', 's3://bucket/data/')
|
|
273
|
+
# With a known hash value you want to use for cache key calculation
|
|
274
|
+
remote_dir = await Dir[DataFrame].from_local('/tmp/data_dir/', 's3://bucket/data/', dir_cache_key='abc123')
|
|
265
275
|
```
|
|
266
276
|
"""
|
|
267
277
|
local_path_str = str(local_path)
|
|
268
278
|
dirname = os.path.basename(os.path.normpath(local_path_str))
|
|
269
279
|
|
|
270
280
|
output_path = await storage.put(from_path=local_path_str, to_path=remote_path, recursive=True)
|
|
271
|
-
return cls(path=output_path, name=dirname)
|
|
281
|
+
return cls(path=output_path, name=dirname, hash=dir_cache_key)
|
|
282
|
+
|
|
283
|
+
@classmethod
|
|
284
|
+
def from_existing_remote(cls, remote_path: str, dir_cache_key: Optional[str] = None) -> Dir[T]:
|
|
285
|
+
"""
|
|
286
|
+
Create a Dir reference from an existing remote directory.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
remote_path: The remote path to the existing directory
|
|
290
|
+
dir_cache_key: Optional hash value to use for cache key computation. If not specified,
|
|
291
|
+
the cache key will be computed based on this object's attributes.
|
|
292
|
+
|
|
293
|
+
Example:
|
|
294
|
+
```python
|
|
295
|
+
remote_dir = Dir.from_existing_remote("s3://bucket/data/")
|
|
296
|
+
# With a known hash
|
|
297
|
+
remote_dir = Dir.from_existing_remote("s3://bucket/data/", dir_cache_key="abc123")
|
|
298
|
+
```
|
|
299
|
+
"""
|
|
300
|
+
return cls(path=remote_path, hash=dir_cache_key)
|
|
272
301
|
|
|
273
302
|
@classmethod
|
|
274
303
|
def from_local_sync(cls, local_path: Union[str, Path], remote_path: Optional[str] = None) -> Dir[T]:
|
|
@@ -414,7 +443,8 @@ class DirTransformer(TypeTransformer[Dir]):
|
|
|
414
443
|
),
|
|
415
444
|
uri=python_val.path,
|
|
416
445
|
)
|
|
417
|
-
)
|
|
446
|
+
),
|
|
447
|
+
hash=python_val.hash if python_val.hash else None,
|
|
418
448
|
)
|
|
419
449
|
|
|
420
450
|
async def to_python_value(
|
|
@@ -432,7 +462,8 @@ class DirTransformer(TypeTransformer[Dir]):
|
|
|
432
462
|
|
|
433
463
|
uri = lv.scalar.blob.uri
|
|
434
464
|
filename = Path(uri).name
|
|
435
|
-
|
|
465
|
+
hash_value = lv.hash if lv.hash else None
|
|
466
|
+
f: Dir = Dir(path=uri, name=filename, format=lv.scalar.blob.metadata.type.format, hash=hash_value)
|
|
436
467
|
return f
|
|
437
468
|
|
|
438
469
|
def guess_python_type(self, literal_type: types_pb2.LiteralType) -> Type[Dir]:
|
flyte/io/_file.py
CHANGED
|
@@ -5,6 +5,7 @@ from contextlib import asynccontextmanager, contextmanager
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import (
|
|
7
7
|
IO,
|
|
8
|
+
Annotated,
|
|
8
9
|
Any,
|
|
9
10
|
AsyncGenerator,
|
|
10
11
|
Dict,
|
|
@@ -21,12 +22,14 @@ from flyteidl.core import literals_pb2, types_pb2
|
|
|
21
22
|
from fsspec.asyn import AsyncFileSystem
|
|
22
23
|
from fsspec.utils import get_protocol
|
|
23
24
|
from mashumaro.types import SerializableType
|
|
24
|
-
from pydantic import BaseModel, model_validator
|
|
25
|
+
from pydantic import BaseModel, Field, model_validator
|
|
26
|
+
from pydantic.json_schema import SkipJsonSchema
|
|
25
27
|
|
|
26
28
|
import flyte.storage as storage
|
|
27
29
|
from flyte._context import internal_ctx
|
|
28
30
|
from flyte._initialize import requires_initialization
|
|
29
31
|
from flyte._logging import logger
|
|
32
|
+
from flyte.io._hashing_io import AsyncHashingReader, HashingWriter, HashMethod, PrecomputedValue
|
|
30
33
|
from flyte.types import TypeEngine, TypeTransformer, TypeTransformerFailedError
|
|
31
34
|
|
|
32
35
|
# Type variable for the file format
|
|
@@ -104,6 +107,8 @@ class File(BaseModel, Generic[T], SerializableType):
|
|
|
104
107
|
path: str
|
|
105
108
|
name: Optional[str] = None
|
|
106
109
|
format: str = ""
|
|
110
|
+
hash: Optional[str] = None
|
|
111
|
+
hash_method: Annotated[Optional[HashMethod], Field(default=None, exclude=True), SkipJsonSchema()] = None
|
|
107
112
|
|
|
108
113
|
class Config:
|
|
109
114
|
arbitrary_types_allowed = True
|
|
@@ -139,7 +144,7 @@ class File(BaseModel, Generic[T], SerializableType):
|
|
|
139
144
|
|
|
140
145
|
@classmethod
|
|
141
146
|
@requires_initialization
|
|
142
|
-
def new_remote(cls) -> File[T]:
|
|
147
|
+
def new_remote(cls, hash_method: Optional[HashMethod | str] = None) -> File[T]:
|
|
143
148
|
"""
|
|
144
149
|
Create a new File reference for a remote file that will be written to.
|
|
145
150
|
|
|
@@ -155,11 +160,13 @@ class File(BaseModel, Generic[T], SerializableType):
|
|
|
155
160
|
```
|
|
156
161
|
"""
|
|
157
162
|
ctx = internal_ctx()
|
|
163
|
+
known_cache_key = hash_method if isinstance(hash_method, str) else None
|
|
164
|
+
method = hash_method if isinstance(hash_method, HashMethod) else None
|
|
158
165
|
|
|
159
|
-
return cls(path=ctx.raw_data.get_random_remote_path())
|
|
166
|
+
return cls(path=ctx.raw_data.get_random_remote_path(), hash=known_cache_key, hash_method=method)
|
|
160
167
|
|
|
161
168
|
@classmethod
|
|
162
|
-
def from_existing_remote(cls, remote_path: str) -> File[T]:
|
|
169
|
+
def from_existing_remote(cls, remote_path: str, file_cache_key: Optional[str] = None) -> File[T]:
|
|
163
170
|
"""
|
|
164
171
|
Create a File reference from an existing remote file.
|
|
165
172
|
|
|
@@ -172,8 +179,10 @@ class File(BaseModel, Generic[T], SerializableType):
|
|
|
172
179
|
|
|
173
180
|
Args:
|
|
174
181
|
remote_path: The remote path to the existing file
|
|
182
|
+
file_cache_key: Optional hash value to use for discovery purposes. If not specified, the value of this
|
|
183
|
+
File object will be hashed (basically the path, not the contents).
|
|
175
184
|
"""
|
|
176
|
-
return cls(path=remote_path)
|
|
185
|
+
return cls(path=remote_path, hash=file_cache_key)
|
|
177
186
|
|
|
178
187
|
@asynccontextmanager
|
|
179
188
|
async def open(
|
|
@@ -184,7 +193,7 @@ class File(BaseModel, Generic[T], SerializableType):
|
|
|
184
193
|
cache_options: Optional[dict] = None,
|
|
185
194
|
compression: Optional[str] = None,
|
|
186
195
|
**kwargs,
|
|
187
|
-
) -> AsyncGenerator[IO[Any]]:
|
|
196
|
+
) -> AsyncGenerator[Union[IO[Any], "HashingWriter"], None]:
|
|
188
197
|
"""
|
|
189
198
|
Asynchronously open the file and return a file-like object.
|
|
190
199
|
|
|
@@ -245,7 +254,15 @@ class File(BaseModel, Generic[T], SerializableType):
|
|
|
245
254
|
file_handle.close()
|
|
246
255
|
|
|
247
256
|
with fs.open(self.path, mode) as file_handle:
|
|
248
|
-
|
|
257
|
+
if self.hash_method and self.hash is None:
|
|
258
|
+
logger.debug(f"Wrapping file handle with hashing writer using {self.hash_method}")
|
|
259
|
+
fh = HashingWriter(file_handle, accumulator=self.hash_method)
|
|
260
|
+
yield fh
|
|
261
|
+
self.hash = fh.result()
|
|
262
|
+
fh.close()
|
|
263
|
+
else:
|
|
264
|
+
yield file_handle
|
|
265
|
+
file_handle.close()
|
|
249
266
|
|
|
250
267
|
def exists_sync(self) -> bool:
|
|
251
268
|
"""
|
|
@@ -351,13 +368,22 @@ class File(BaseModel, Generic[T], SerializableType):
|
|
|
351
368
|
|
|
352
369
|
@classmethod
|
|
353
370
|
@requires_initialization
|
|
354
|
-
async def from_local(
|
|
371
|
+
async def from_local(
|
|
372
|
+
cls,
|
|
373
|
+
local_path: Union[str, Path],
|
|
374
|
+
remote_destination: Optional[str] = None,
|
|
375
|
+
hash_method: Optional[HashMethod | str] = None,
|
|
376
|
+
) -> File[T]:
|
|
355
377
|
"""
|
|
356
378
|
Create a new File object from a local file that will be uploaded to the configured remote store.
|
|
357
379
|
|
|
358
380
|
Args:
|
|
359
381
|
local_path: Path to the local file
|
|
360
382
|
remote_destination: Optional path to store the file remotely. If None, a path will be generated.
|
|
383
|
+
hash_method: Pass this argument either as a set string or a HashMethod to use for
|
|
384
|
+
determining a task's cache key if this File object is used as an input to said task. If not specified,
|
|
385
|
+
the cache key will just be computed based on this object's attributes (i.e. path, name, format, etc.).
|
|
386
|
+
If there is a set value you want to use, please pass an instance of the PrecomputedValue HashMethod.
|
|
361
387
|
|
|
362
388
|
Returns:
|
|
363
389
|
A new File instance pointing to the uploaded file
|
|
@@ -376,20 +402,38 @@ class File(BaseModel, Generic[T], SerializableType):
|
|
|
376
402
|
|
|
377
403
|
# If remote_destination was not set by the user, and the configured raw data path is also local,
|
|
378
404
|
# then let's optimize by not uploading.
|
|
405
|
+
hash_value = hash_method if isinstance(hash_method, str) else None
|
|
406
|
+
hash_method = hash_method if isinstance(hash_method, HashMethod) else None
|
|
379
407
|
if "file" in protocol:
|
|
380
408
|
if remote_destination is None:
|
|
381
409
|
path = str(Path(local_path).absolute())
|
|
382
410
|
else:
|
|
383
411
|
# Otherwise, actually make a copy of the file
|
|
384
|
-
async with aiofiles.open(
|
|
385
|
-
async with aiofiles.open(
|
|
386
|
-
|
|
412
|
+
async with aiofiles.open(local_path, "rb") as src:
|
|
413
|
+
async with aiofiles.open(remote_path, "wb") as dst:
|
|
414
|
+
if hash_method:
|
|
415
|
+
dst_wrapper = HashingWriter(dst, accumulator=hash_method)
|
|
416
|
+
await dst_wrapper.write(await src.read())
|
|
417
|
+
hash_value = dst_wrapper.result()
|
|
418
|
+
else:
|
|
419
|
+
await dst.write(await src.read())
|
|
387
420
|
path = str(Path(remote_path).absolute())
|
|
388
421
|
else:
|
|
389
422
|
# Otherwise upload to remote using async storage layer
|
|
390
|
-
|
|
423
|
+
if hash_method:
|
|
424
|
+
# We can skip the wrapper if the hash method is just a precomputed value
|
|
425
|
+
if not isinstance(hash_method, PrecomputedValue):
|
|
426
|
+
async with aiofiles.open(local_path, "rb") as src:
|
|
427
|
+
src_wrapper = AsyncHashingReader(src, accumulator=hash_method)
|
|
428
|
+
path = await storage.put_stream(src_wrapper, to_path=remote_path)
|
|
429
|
+
hash_value = src_wrapper.result()
|
|
430
|
+
else:
|
|
431
|
+
path = await storage.put(str(local_path), remote_path)
|
|
432
|
+
hash_value = hash_method.result()
|
|
433
|
+
else:
|
|
434
|
+
path = await storage.put(str(local_path), remote_path)
|
|
391
435
|
|
|
392
|
-
f = cls(path=path, name=filename)
|
|
436
|
+
f = cls(path=path, name=filename, hash_method=hash_method, hash=hash_value)
|
|
393
437
|
return f
|
|
394
438
|
|
|
395
439
|
|
|
@@ -432,7 +476,8 @@ class FileTransformer(TypeTransformer[File]):
|
|
|
432
476
|
),
|
|
433
477
|
uri=python_val.path,
|
|
434
478
|
)
|
|
435
|
-
)
|
|
479
|
+
),
|
|
480
|
+
hash=python_val.hash if python_val.hash else None,
|
|
436
481
|
)
|
|
437
482
|
|
|
438
483
|
async def to_python_value(
|
|
@@ -450,7 +495,8 @@ class FileTransformer(TypeTransformer[File]):
|
|
|
450
495
|
|
|
451
496
|
uri = lv.scalar.blob.uri
|
|
452
497
|
filename = Path(uri).name
|
|
453
|
-
|
|
498
|
+
hash_value = lv.hash if lv.hash else None
|
|
499
|
+
f: File = File(path=uri, name=filename, format=lv.scalar.blob.metadata.type.format, hash=hash_value)
|
|
454
500
|
return f
|
|
455
501
|
|
|
456
502
|
def guess_python_type(self, literal_type: types_pb2.LiteralType) -> Type[File]:
|