unienv 0.0.1b4__py3-none-any.whl → 0.0.1b6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/METADATA +3 -2
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/RECORD +43 -32
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/WHEEL +1 -1
- unienv_data/base/common.py +25 -10
- unienv_data/base/storage.py +2 -0
- unienv_data/batches/backend_compat.py +1 -1
- unienv_data/batches/combined_batch.py +1 -1
- unienv_data/batches/slicestack_batch.py +1 -0
- unienv_data/replay_buffer/replay_buffer.py +179 -65
- unienv_data/replay_buffer/trajectory_replay_buffer.py +230 -163
- unienv_data/storages/_episode_storage.py +438 -0
- unienv_data/storages/_list_storage.py +136 -0
- unienv_data/storages/backend_compat.py +268 -0
- unienv_data/storages/dict_storage.py +39 -7
- unienv_data/storages/flattened.py +11 -4
- unienv_data/storages/hdf5.py +11 -0
- unienv_data/storages/image_storage.py +144 -0
- unienv_data/storages/npz_storage.py +135 -0
- unienv_data/storages/pytorch.py +17 -10
- unienv_data/storages/transformation.py +16 -1
- unienv_data/storages/video_storage.py +297 -0
- unienv_data/third_party/tensordict/memmap_tensor.py +1174 -0
- unienv_data/transformations/image_compress.py +97 -21
- unienv_interface/func_wrapper/frame_stack.py +1 -1
- unienv_interface/space/space_utils/batch_utils.py +5 -1
- unienv_interface/space/space_utils/flatten_utils.py +8 -2
- unienv_interface/space/spaces/dict.py +6 -0
- unienv_interface/space/spaces/tuple.py +4 -4
- unienv_interface/transformations/__init__.py +3 -1
- unienv_interface/transformations/batch_and_unbatch.py +42 -4
- unienv_interface/transformations/chained_transform.py +9 -8
- unienv_interface/transformations/crop.py +69 -0
- unienv_interface/transformations/dict_transform.py +8 -2
- unienv_interface/transformations/identity.py +16 -0
- unienv_interface/transformations/image_resize.py +106 -0
- unienv_interface/transformations/iter_transform.py +92 -0
- unienv_interface/transformations/rescale.py +24 -5
- unienv_interface/utils/symbol_util.py +7 -1
- unienv_interface/wrapper/backend_compat.py +1 -1
- unienv_interface/wrapper/frame_stack.py +1 -1
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/licenses/LICENSE +0 -0
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/top_level.txt +0 -0
- /unienv_interface/utils/{data_queue.py → framestack_queue.py} +0 -0
|
@@ -0,0 +1,1174 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
# From https://github.com/pytorch/tensordict/blob/main/tensordict/memmap.py
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import functools
|
|
10
|
+
|
|
11
|
+
import mmap
|
|
12
|
+
import os
|
|
13
|
+
|
|
14
|
+
import sys
|
|
15
|
+
import tempfile
|
|
16
|
+
from multiprocessing import reduction, util
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import (
|
|
19
|
+
Any,
|
|
20
|
+
Callable,
|
|
21
|
+
Iterator,
|
|
22
|
+
List,
|
|
23
|
+
Sequence,
|
|
24
|
+
Tuple,
|
|
25
|
+
TYPE_CHECKING,
|
|
26
|
+
TypeVar,
|
|
27
|
+
Union,
|
|
28
|
+
overload
|
|
29
|
+
)
|
|
30
|
+
import numpy as np
|
|
31
|
+
from pyvers import implement_for
|
|
32
|
+
|
|
33
|
+
import torch
|
|
34
|
+
from torch import Tensor
|
|
35
|
+
from torch.nn.parameter import (
|
|
36
|
+
UninitializedBuffer,
|
|
37
|
+
UninitializedParameter,
|
|
38
|
+
UninitializedTensorMixin,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
NESTED_TENSOR_ERR = (
|
|
42
|
+
"The PyTorch version isn't compatible with "
|
|
43
|
+
"nested tensors. Please upgrade to a more recent "
|
|
44
|
+
"version."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
if TYPE_CHECKING:
|
|
48
|
+
from typing import Self
|
|
49
|
+
else:
|
|
50
|
+
Self = Any
|
|
51
|
+
|
|
52
|
+
def _maybe_correct_neg_dim(
|
|
53
|
+
dim: int, shape: torch.Size | None, ndim: int | None = None
|
|
54
|
+
) -> int:
|
|
55
|
+
"""Corrects neg dim to pos."""
|
|
56
|
+
if ndim is None:
|
|
57
|
+
ndim = len(shape)
|
|
58
|
+
if dim < 0:
|
|
59
|
+
new_dim = ndim + dim
|
|
60
|
+
else:
|
|
61
|
+
new_dim = dim
|
|
62
|
+
if new_dim < 0 or new_dim >= ndim:
|
|
63
|
+
if shape is not None:
|
|
64
|
+
raise IndexError(
|
|
65
|
+
f"Incompatible dim {new_dim} for tensordict with shape {shape}."
|
|
66
|
+
)
|
|
67
|
+
raise IndexError(
|
|
68
|
+
f"Incompatible dim {new_dim} for tensordict with batch dims {ndim}."
|
|
69
|
+
)
|
|
70
|
+
return new_dim
|
|
71
|
+
|
|
72
|
+
def _shape(tensor: Tensor, nested_shape=False) -> torch.Size:
|
|
73
|
+
if isinstance(tensor, UninitializedTensorMixin):
|
|
74
|
+
return torch.Size([*getattr(tensor, "batch_size", ()), -1])
|
|
75
|
+
elif not isinstance(tensor, Tensor):
|
|
76
|
+
return tensor.shape
|
|
77
|
+
if tensor.is_nested:
|
|
78
|
+
if nested_shape:
|
|
79
|
+
return tensor._nested_tensor_size()
|
|
80
|
+
shape = []
|
|
81
|
+
for i in range(tensor.ndim):
|
|
82
|
+
try:
|
|
83
|
+
shape.append(tensor.size(i))
|
|
84
|
+
except RuntimeError:
|
|
85
|
+
shape.append(-1)
|
|
86
|
+
return torch.Size(shape)
|
|
87
|
+
return tensor.shape
|
|
88
|
+
|
|
89
|
+
if sys.version_info >= (3, 10):
|
|
90
|
+
_zip_strict = functools.partial(zip, strict=True)
|
|
91
|
+
else:
|
|
92
|
+
def _zip_strict(*iterables):
|
|
93
|
+
iterables = tuple(tuple(it) for it in iterables)
|
|
94
|
+
lengths = {len(it) for it in iterables}
|
|
95
|
+
if len(lengths) > 1:
|
|
96
|
+
raise ValueError("lengths of iterables differ.")
|
|
97
|
+
|
|
98
|
+
return zip(*iterables)
|
|
99
|
+
|
|
100
|
+
IndexType = Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]]
|
|
101
|
+
|
|
102
|
+
class MemoryMappedTensor(torch.Tensor):
|
|
103
|
+
"""A Memory-mapped Tensor.
|
|
104
|
+
|
|
105
|
+
Supports filenames or file handlers.
|
|
106
|
+
|
|
107
|
+
The main advantage of MemoryMappedTensor resides in its serialization methods,
|
|
108
|
+
which ensure that the tensor is passed through queues or RPC remote calls without
|
|
109
|
+
any copy.
|
|
110
|
+
|
|
111
|
+
.. note::
|
|
112
|
+
When used within RPC settings, the filepath should be accessible to both nodes.
|
|
113
|
+
If it isn't the behaviour of passing a MemoryMappedTensor from one worker
|
|
114
|
+
to another is undefined.
|
|
115
|
+
|
|
116
|
+
MemoryMappedTensor supports multiple construction methods.
|
|
117
|
+
|
|
118
|
+
Examples:
|
|
119
|
+
>>> # from an existing tensor
|
|
120
|
+
>>> tensor = torch.randn(3)
|
|
121
|
+
>>> with tempfile.NamedTemporaryFile() as file:
|
|
122
|
+
... memmap_tensor = MemoryMappedTensor.from_tensor(tensor, filename=file.name)
|
|
123
|
+
... assert memmap_tensor.filename is not None
|
|
124
|
+
>>> # if no filename is passed, a handler is used
|
|
125
|
+
>>> tensor = torch.randn(3)
|
|
126
|
+
>>> memmap_tensor = MemoryMappedTensor.from_tensor(tensor, filename=file.name)
|
|
127
|
+
>>> assert memmap_tensor.filename is None
|
|
128
|
+
>>> # one can create an empty tensor too
|
|
129
|
+
>>> with tempfile.NamedTemporaryFile() as file:
|
|
130
|
+
... memmap_tensor_empty = MemoryMappedTensor.empty_like(tensor, filename=file.name)
|
|
131
|
+
>>> with tempfile.NamedTemporaryFile() as file:
|
|
132
|
+
... memmap_tensor_zero = MemoryMappedTensor.zeros_like(tensor, filename=file.name)
|
|
133
|
+
>>> with tempfile.NamedTemporaryFile() as file:
|
|
134
|
+
... memmap_tensor = MemoryMappedTensor.ones_like(tensor, filename=file.name)
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
_filename: str | Path = None
|
|
138
|
+
_handler: _FileHandler = None
|
|
139
|
+
_clear: bool
|
|
140
|
+
index: Any
|
|
141
|
+
parent_shape: torch.Size
|
|
142
|
+
|
|
143
|
+
def __new__(
|
|
144
|
+
cls,
|
|
145
|
+
source,
|
|
146
|
+
*,
|
|
147
|
+
dtype=None,
|
|
148
|
+
shape=None,
|
|
149
|
+
index=None,
|
|
150
|
+
device=None,
|
|
151
|
+
handler=None,
|
|
152
|
+
filename=None,
|
|
153
|
+
):
|
|
154
|
+
if device is not None and torch.device(device).type != "cpu":
|
|
155
|
+
raise ValueError(f"{cls} device must be cpu!")
|
|
156
|
+
if isinstance(source, str):
|
|
157
|
+
if filename is not None:
|
|
158
|
+
raise TypeError("Duplicated filename argument.")
|
|
159
|
+
filename = source
|
|
160
|
+
source = None
|
|
161
|
+
if filename is not None:
|
|
162
|
+
if dtype is not None:
|
|
163
|
+
raise TypeError("Cannot pass new dtype if source is provided.")
|
|
164
|
+
result = cls.from_tensor(
|
|
165
|
+
torch.as_tensor(source),
|
|
166
|
+
filename=filename,
|
|
167
|
+
# dtype=dtype,
|
|
168
|
+
shape=shape,
|
|
169
|
+
# index=index,
|
|
170
|
+
)
|
|
171
|
+
if index is not None:
|
|
172
|
+
return result[index]
|
|
173
|
+
return result
|
|
174
|
+
elif isinstance(source, torch.StorageBase):
|
|
175
|
+
return cls.from_storage(
|
|
176
|
+
source,
|
|
177
|
+
dtype=dtype,
|
|
178
|
+
shape=shape,
|
|
179
|
+
index=index,
|
|
180
|
+
device=device,
|
|
181
|
+
handler=handler,
|
|
182
|
+
filename=filename,
|
|
183
|
+
)
|
|
184
|
+
elif handler is not None:
|
|
185
|
+
return cls.from_handler(
|
|
186
|
+
handler,
|
|
187
|
+
dtype,
|
|
188
|
+
shape,
|
|
189
|
+
index,
|
|
190
|
+
)
|
|
191
|
+
return super().__new__(cls, source)
|
|
192
|
+
|
|
193
|
+
def __init__(
|
|
194
|
+
self,
|
|
195
|
+
source,
|
|
196
|
+
*,
|
|
197
|
+
handler=None,
|
|
198
|
+
dtype=None,
|
|
199
|
+
shape=None,
|
|
200
|
+
device=None,
|
|
201
|
+
filename=None,
|
|
202
|
+
): ...
|
|
203
|
+
|
|
204
|
+
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
def from_tensor(
|
|
208
|
+
cls,
|
|
209
|
+
input,
|
|
210
|
+
*,
|
|
211
|
+
filename: Path | str = None,
|
|
212
|
+
existsok: bool = False,
|
|
213
|
+
copy_existing: bool = False,
|
|
214
|
+
copy_data: bool = True,
|
|
215
|
+
shape: torch.Size | None = None,
|
|
216
|
+
): # noqa: D417
|
|
217
|
+
"""Creates a MemoryMappedTensor with the same content as another tensor.
|
|
218
|
+
|
|
219
|
+
If the tensor is already a MemoryMappedTensor the original tensor is
|
|
220
|
+
returned if the `filename` argument is `None` or if the two paths match.
|
|
221
|
+
In all other cases, a new :class:`MemoryMappedTensor` is produced.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
input (torch.Tensor): the tensor which content must be copied onto
|
|
225
|
+
the MemoryMappedTensor.
|
|
226
|
+
|
|
227
|
+
Keyword Args:
|
|
228
|
+
filename (path to a file): the path to the file where the tensor
|
|
229
|
+
should be stored. If none is provided, a file handler is used
|
|
230
|
+
instead.
|
|
231
|
+
existsok (bool, optional): if ``True``, the file will overwrite
|
|
232
|
+
an existing file. Defaults to ``False``.
|
|
233
|
+
copy_existing (bool, optional): if ``True`` and the provided input
|
|
234
|
+
is a MemoryMappedTensor with an associated filename, copying
|
|
235
|
+
the content to the new location is permitted. Otherwise, an
|
|
236
|
+
exception is thrown. This behaviour exists to prevent
|
|
237
|
+
inadvertently duplicating data on disk.
|
|
238
|
+
copy_data (bool, optional): if ``True``, the content of the tensor
|
|
239
|
+
will be copied on the storage. Defaults to ``True``.
|
|
240
|
+
shape (torch.Size or torch.Tensor): a shape to override the tensor
|
|
241
|
+
shape. If a tensor is passed, it must represent the nested shapes of a
|
|
242
|
+
nested tensor.
|
|
243
|
+
"""
|
|
244
|
+
if isinstance(input, MemoryMappedTensor):
|
|
245
|
+
if (filename is None and input._filename is None) or (
|
|
246
|
+
input._filename is not None
|
|
247
|
+
and filename is not None
|
|
248
|
+
and Path(filename).absolute() == Path(input.filename).absolute()
|
|
249
|
+
):
|
|
250
|
+
# either location was not specified, or memmap is already in the
|
|
251
|
+
# correct location, so just return the MemmapTensor unmodified
|
|
252
|
+
return input
|
|
253
|
+
elif not copy_existing and (
|
|
254
|
+
input._filename is not None
|
|
255
|
+
and filename is not None
|
|
256
|
+
and Path(filename).absolute() != Path(input.filename).absolute()
|
|
257
|
+
):
|
|
258
|
+
raise RuntimeError(
|
|
259
|
+
f"A filename was provided but the tensor already has a file associated "
|
|
260
|
+
f"({input.filename}). "
|
|
261
|
+
f"To copy the tensor onto the new location, pass copy_existing=True."
|
|
262
|
+
)
|
|
263
|
+
elif isinstance(input, np.ndarray):
|
|
264
|
+
raise TypeError(
|
|
265
|
+
"Convert input to torch.Tensor before calling MemoryMappedTensor.from_tensor."
|
|
266
|
+
)
|
|
267
|
+
if input.requires_grad:
|
|
268
|
+
raise RuntimeError(
|
|
269
|
+
"MemoryMappedTensor.from_tensor is incompatible with tensor.requires_grad."
|
|
270
|
+
)
|
|
271
|
+
if shape is None:
|
|
272
|
+
shape = _shape(input, nested_shape=True)
|
|
273
|
+
if isinstance(shape, torch.Tensor):
|
|
274
|
+
shape_numel = shape.prod(-1).sum()
|
|
275
|
+
elif isinstance(shape, torch.Size):
|
|
276
|
+
shape_numel = shape.numel()
|
|
277
|
+
else:
|
|
278
|
+
shape_numel = torch.Size(shape).numel()
|
|
279
|
+
if filename is None:
|
|
280
|
+
if input.dtype.is_floating_point:
|
|
281
|
+
size = torch.finfo(input.dtype).bits // 8 * shape_numel
|
|
282
|
+
elif input.dtype.is_complex:
|
|
283
|
+
raise ValueError(
|
|
284
|
+
"Complex-valued tensors are not supported by MemoryMappedTensor."
|
|
285
|
+
)
|
|
286
|
+
elif input.dtype == torch.bool:
|
|
287
|
+
size = shape_numel
|
|
288
|
+
else:
|
|
289
|
+
# assume integer
|
|
290
|
+
size = torch.iinfo(input.dtype).bits // 8 * shape_numel
|
|
291
|
+
handler = _FileHandler(size)
|
|
292
|
+
if isinstance(shape, torch.Tensor):
|
|
293
|
+
func_offset_stride = getattr(
|
|
294
|
+
torch, "_nested_compute_contiguous_strides_offsets", None
|
|
295
|
+
)
|
|
296
|
+
if func_offset_stride is not None:
|
|
297
|
+
offsets_strides = func_offset_stride(shape)
|
|
298
|
+
else:
|
|
299
|
+
raise RuntimeError(NESTED_TENSOR_ERR)
|
|
300
|
+
result = torch.frombuffer(memoryview(handler.buffer), dtype=input.dtype)
|
|
301
|
+
if copy_data:
|
|
302
|
+
result.untyped_storage().copy_(input.untyped_storage())
|
|
303
|
+
result = torch._nested_view_from_buffer(
|
|
304
|
+
result,
|
|
305
|
+
shape,
|
|
306
|
+
*offsets_strides,
|
|
307
|
+
)
|
|
308
|
+
else:
|
|
309
|
+
result = torch.frombuffer(memoryview(handler.buffer), dtype=input.dtype)
|
|
310
|
+
result = result.view(shape)
|
|
311
|
+
result = cls(result)
|
|
312
|
+
else:
|
|
313
|
+
handler = None
|
|
314
|
+
if not existsok and os.path.exists(str(filename)):
|
|
315
|
+
raise RuntimeError(f"The file {filename} already exists.")
|
|
316
|
+
result = torch.from_file(
|
|
317
|
+
str(filename),
|
|
318
|
+
shared=True,
|
|
319
|
+
dtype=input.dtype,
|
|
320
|
+
size=shape_numel,
|
|
321
|
+
# needed when device ctx differs
|
|
322
|
+
device=torch.device("cpu"),
|
|
323
|
+
)
|
|
324
|
+
if isinstance(shape, torch.Tensor):
|
|
325
|
+
func_offset_stride = getattr(
|
|
326
|
+
torch, "_nested_compute_contiguous_strides_offsets", None
|
|
327
|
+
)
|
|
328
|
+
if func_offset_stride is not None:
|
|
329
|
+
offsets_strides = func_offset_stride(shape)
|
|
330
|
+
else:
|
|
331
|
+
raise RuntimeError(NESTED_TENSOR_ERR)
|
|
332
|
+
if copy_data:
|
|
333
|
+
result.untyped_storage().copy_(input.untyped_storage())
|
|
334
|
+
result = torch._nested_view_from_buffer(
|
|
335
|
+
result,
|
|
336
|
+
shape,
|
|
337
|
+
*offsets_strides,
|
|
338
|
+
)
|
|
339
|
+
else:
|
|
340
|
+
result = result.view(shape)
|
|
341
|
+
result = cls(result)
|
|
342
|
+
result._handler = handler
|
|
343
|
+
result.filename = filename
|
|
344
|
+
result.index = None
|
|
345
|
+
result.parent_shape = shape
|
|
346
|
+
if copy_data:
|
|
347
|
+
if hasattr(input, "full_tensor"):
|
|
348
|
+
# for DTensors, cheaper than importing DTensor every time
|
|
349
|
+
input = input.full_tensor()
|
|
350
|
+
if not result.is_nested:
|
|
351
|
+
result.copy_(input)
|
|
352
|
+
return result
|
|
353
|
+
|
|
354
|
+
@classmethod
|
|
355
|
+
def from_storage(
|
|
356
|
+
cls,
|
|
357
|
+
storage,
|
|
358
|
+
*,
|
|
359
|
+
shape: torch.Size | None = None,
|
|
360
|
+
dtype: torch.dtype | None = None,
|
|
361
|
+
device: torch.device | None = None,
|
|
362
|
+
index: IndexType | None = None,
|
|
363
|
+
filename: Path | str = None,
|
|
364
|
+
handler: _handler = None,
|
|
365
|
+
):
|
|
366
|
+
if getattr(storage, "filename", None) is not None:
|
|
367
|
+
if filename is None:
|
|
368
|
+
filename = storage.filename
|
|
369
|
+
elif str(storage.filename) != str(filename):
|
|
370
|
+
raise RuntimeError(
|
|
371
|
+
"Providing a storage with an associated filename that differs from the filename argument is not permitted unless filename=None. "
|
|
372
|
+
f"Got filename={str(filename)}, storage.filename={str(storage.filename)}"
|
|
373
|
+
)
|
|
374
|
+
tensor = torch.tensor(storage, dtype=dtype, device=device)
|
|
375
|
+
if shape is not None:
|
|
376
|
+
if isinstance(shape, torch.Tensor):
|
|
377
|
+
func_offset_stride = getattr(
|
|
378
|
+
torch, "_nested_compute_contiguous_strides_offsets", None
|
|
379
|
+
)
|
|
380
|
+
if func_offset_stride is not None:
|
|
381
|
+
offsets_strides = func_offset_stride(shape)
|
|
382
|
+
else:
|
|
383
|
+
raise RuntimeError(
|
|
384
|
+
"The PyTorch version isn't compatible with memmap "
|
|
385
|
+
"nested tensors. Please upgrade to a more recent "
|
|
386
|
+
"version."
|
|
387
|
+
)
|
|
388
|
+
tensor = torch._nested_view_from_buffer(
|
|
389
|
+
tensor,
|
|
390
|
+
shape,
|
|
391
|
+
*offsets_strides,
|
|
392
|
+
)
|
|
393
|
+
else:
|
|
394
|
+
tensor = tensor.view(shape)
|
|
395
|
+
|
|
396
|
+
tensor = cls(tensor)
|
|
397
|
+
if filename is not None:
|
|
398
|
+
tensor.filename = filename
|
|
399
|
+
elif handler is not None:
|
|
400
|
+
tensor._handler = handler
|
|
401
|
+
if index is not None:
|
|
402
|
+
return tensor[index]
|
|
403
|
+
return tensor
|
|
404
|
+
|
|
405
|
+
@property
|
|
406
|
+
def filename(self):
|
|
407
|
+
"""The filename of the tensor, if it has one.
|
|
408
|
+
|
|
409
|
+
Raises an exception otherwise.
|
|
410
|
+
"""
|
|
411
|
+
filename = self._filename
|
|
412
|
+
if filename is None:
|
|
413
|
+
raise RuntimeError("The MemoryMappedTensor has no file associated.")
|
|
414
|
+
return filename
|
|
415
|
+
|
|
416
|
+
@filename.setter
|
|
417
|
+
def filename(self, value):
|
|
418
|
+
if value is None and self._filename is None:
|
|
419
|
+
return
|
|
420
|
+
value = str(Path(value).absolute())
|
|
421
|
+
if self._filename is not None and value != self._filename:
|
|
422
|
+
raise RuntimeError(
|
|
423
|
+
"the MemoryMappedTensor has already a filename associated."
|
|
424
|
+
)
|
|
425
|
+
self._filename = value
|
|
426
|
+
|
|
427
|
+
@classmethod
|
|
428
|
+
def empty_like(cls, input, *, filename=None):
|
|
429
|
+
# noqa: D417
|
|
430
|
+
"""Creates a tensor with no content but the same shape and dtype as the input tensor.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
input (torch.Tensor): the tensor to use as an example.
|
|
434
|
+
|
|
435
|
+
Keyword Args:
|
|
436
|
+
filename (path or equivalent): the path to the file, if any. If none
|
|
437
|
+
is provided, a handler is used.
|
|
438
|
+
"""
|
|
439
|
+
return cls.from_tensor(
|
|
440
|
+
torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
|
|
441
|
+
filename=filename,
|
|
442
|
+
copy_data=False,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
@classmethod
|
|
446
|
+
def full_like(cls, input, fill_value, *, filename=None):
|
|
447
|
+
# noqa: D417
|
|
448
|
+
"""Creates a tensor with a single content indicated by the `fill_value` argument, but the same shape and dtype as the input tensor.
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
input (torch.Tensor): the tensor to use as an example.
|
|
452
|
+
fill_value (float or equivalent): content of the tensor.
|
|
453
|
+
|
|
454
|
+
Keyword Args:
|
|
455
|
+
filename (path or equivalent): the path to the file, if any. If none
|
|
456
|
+
is provided, a handler is used.
|
|
457
|
+
"""
|
|
458
|
+
return cls.from_tensor(
|
|
459
|
+
torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
|
|
460
|
+
filename=filename,
|
|
461
|
+
copy_data=False,
|
|
462
|
+
).fill_(fill_value)
|
|
463
|
+
|
|
464
|
+
@classmethod
|
|
465
|
+
def zeros_like(cls, input, *, filename=None):
|
|
466
|
+
# noqa: D417
|
|
467
|
+
"""Creates a tensor with a 0-filled content, but the same shape and dtype as the input tensor.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
input (torch.Tensor): the tensor to use as an example.
|
|
471
|
+
|
|
472
|
+
Keyword Args:
|
|
473
|
+
filename (path or equivalent): the path to the file, if any. If none
|
|
474
|
+
is provided, a handler is used.
|
|
475
|
+
"""
|
|
476
|
+
return cls.from_tensor(
|
|
477
|
+
torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
|
|
478
|
+
filename=filename,
|
|
479
|
+
copy_data=False,
|
|
480
|
+
).fill_(0.0)
|
|
481
|
+
|
|
482
|
+
@classmethod
|
|
483
|
+
def ones_like(cls, input, *, filename=None):
|
|
484
|
+
# noqa: D417
|
|
485
|
+
"""Creates a tensor with a 1-filled content, but the same shape and dtype as the input tensor.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
input (torch.Tensor): the tensor to use as an example.
|
|
489
|
+
|
|
490
|
+
Keyword Args:
|
|
491
|
+
filename (path or equivalent): the path to the file, if any. If none
|
|
492
|
+
is provided, a handler is used.
|
|
493
|
+
"""
|
|
494
|
+
return cls.from_tensor(
|
|
495
|
+
torch.ones((), dtype=input.dtype, device=input.device).expand_as(input),
|
|
496
|
+
filename=filename,
|
|
497
|
+
copy_data=False,
|
|
498
|
+
).fill_(1.0)
|
|
499
|
+
|
|
500
|
+
@classmethod
|
|
501
|
+
@overload
|
|
502
|
+
def ones(cls, *size, dtype=None, device=None, filename=None): ...
|
|
503
|
+
|
|
504
|
+
@classmethod
|
|
505
|
+
@overload
|
|
506
|
+
def ones(cls, shape, *, dtype=None, device=None, filename=None): ...
|
|
507
|
+
|
|
508
|
+
@classmethod
|
|
509
|
+
def ones(cls, *args, **kwargs):
|
|
510
|
+
# noqa: D417
|
|
511
|
+
"""Creates a tensor with a 1-filled content, specific shape, dtype and filename.
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
shape (integers or torch.Size): the shape of the tensor.
|
|
515
|
+
|
|
516
|
+
Keyword Args:
|
|
517
|
+
dtype (torch.dtype): the dtype of the tensor.
|
|
518
|
+
device (torch.device): the device of the tensor. Only `None` and `"cpu"`
|
|
519
|
+
are accepted, any other device will raise an exception.
|
|
520
|
+
filename (path or equivalent): the path to the file, if any. If none
|
|
521
|
+
is provided, a handler is used.
|
|
522
|
+
existsok (bool, optional): whether it is ok to overwrite an existing file.
|
|
523
|
+
Defaults to ``False``.
|
|
524
|
+
"""
|
|
525
|
+
shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
|
|
526
|
+
if device is not None:
|
|
527
|
+
device = torch.device(device)
|
|
528
|
+
if device.type != "cpu":
|
|
529
|
+
raise RuntimeError("Only CPU tensors are supported.")
|
|
530
|
+
result = torch.ones((), dtype=dtype, device=device)
|
|
531
|
+
if isinstance(shape, torch.Tensor):
|
|
532
|
+
return cls.empty(
|
|
533
|
+
shape, device=device, dtype=dtype, filename=filename
|
|
534
|
+
).fill_(1)
|
|
535
|
+
if shape:
|
|
536
|
+
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
|
|
537
|
+
shape = torch.Size(shape[0])
|
|
538
|
+
else:
|
|
539
|
+
shape = torch.Size(shape)
|
|
540
|
+
result = result.expand(shape)
|
|
541
|
+
return cls.from_tensor(
|
|
542
|
+
result,
|
|
543
|
+
filename=filename,
|
|
544
|
+
existsok=kwargs.pop("existsok", False),
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
@classmethod
|
|
548
|
+
@overload
|
|
549
|
+
def zeros(cls, *size, dtype=None, device=None, filename=None): ...
|
|
550
|
+
|
|
551
|
+
@classmethod
|
|
552
|
+
@overload
|
|
553
|
+
def zeros(cls, shape, *, dtype=None, device=None, filename=None): ...
|
|
554
|
+
|
|
555
|
+
@classmethod
|
|
556
|
+
def zeros(cls, *args, **kwargs):
|
|
557
|
+
# noqa: D417
|
|
558
|
+
"""Creates a tensor with a 0-filled content, specific shape, dtype and filename.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
shape (integers or torch.Size): the shape of the tensor.
|
|
562
|
+
|
|
563
|
+
Keyword Args:
|
|
564
|
+
dtype (torch.dtype): the dtype of the tensor.
|
|
565
|
+
device (torch.device): the device of the tensor. Only `None` and `"cpu"`
|
|
566
|
+
are accepted, any other device will raise an exception.
|
|
567
|
+
filename (path or equivalent): the path to the file, if any. If none
|
|
568
|
+
is provided, a handler is used.
|
|
569
|
+
existsok (bool, optional): whether it is ok to overwrite an existing file.
|
|
570
|
+
Defaults to ``False``.
|
|
571
|
+
"""
|
|
572
|
+
shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
|
|
573
|
+
if device is not None:
|
|
574
|
+
device = torch.device(device)
|
|
575
|
+
if device.type != "cpu":
|
|
576
|
+
raise RuntimeError("Only CPU tensors are supported.")
|
|
577
|
+
if isinstance(shape, torch.Tensor):
|
|
578
|
+
return cls.empty(
|
|
579
|
+
shape, device=device, dtype=dtype, filename=filename
|
|
580
|
+
).fill_(0)
|
|
581
|
+
result = torch.zeros((), dtype=dtype, device=device)
|
|
582
|
+
if shape:
|
|
583
|
+
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
|
|
584
|
+
shape = torch.Size(shape[0])
|
|
585
|
+
else:
|
|
586
|
+
shape = torch.Size(shape)
|
|
587
|
+
result = result.expand(shape)
|
|
588
|
+
result = cls.from_tensor(
|
|
589
|
+
result,
|
|
590
|
+
filename=filename,
|
|
591
|
+
existsok=kwargs.pop("existsok", False),
|
|
592
|
+
)
|
|
593
|
+
return result
|
|
594
|
+
|
|
595
|
+
@classmethod
|
|
596
|
+
@overload
|
|
597
|
+
def empty(cls, *size, dtype=None, device=None, filename=None): ...
|
|
598
|
+
|
|
599
|
+
@classmethod
|
|
600
|
+
@overload
|
|
601
|
+
def empty(cls, shape, *, dtype=None, device=None, filename=None): ...
|
|
602
|
+
|
|
603
|
+
@classmethod
|
|
604
|
+
def empty(cls, *args, **kwargs):
|
|
605
|
+
# noqa: D417
|
|
606
|
+
"""Creates a tensor with empty content, specific shape, dtype and filename.
|
|
607
|
+
|
|
608
|
+
Args:
|
|
609
|
+
shape (integers or torch.Size): the shape of the tensor.
|
|
610
|
+
|
|
611
|
+
Keyword Args:
|
|
612
|
+
dtype (torch.dtype): the dtype of the tensor.
|
|
613
|
+
device (torch.device): the device of the tensor. Only `None` and `"cpu"`
|
|
614
|
+
are accepted, any other device will raise an exception.
|
|
615
|
+
filename (path or equivalent): the path to the file, if any. If none
|
|
616
|
+
is provided, a handler is used.
|
|
617
|
+
existsok (bool, optional): whether it is ok to overwrite an existing file.
|
|
618
|
+
Defaults to ``False``.
|
|
619
|
+
"""
|
|
620
|
+
shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
|
|
621
|
+
if device is not None:
|
|
622
|
+
device = torch.device(device)
|
|
623
|
+
if device.type != "cpu":
|
|
624
|
+
raise RuntimeError("Only CPU tensors are supported.")
|
|
625
|
+
result = torch.zeros((), dtype=dtype, device=device)
|
|
626
|
+
if isinstance(shape, torch.Tensor):
|
|
627
|
+
# nested tensor
|
|
628
|
+
shape_numel = shape.prod(-1).sum()
|
|
629
|
+
|
|
630
|
+
if filename is None:
|
|
631
|
+
if dtype.is_floating_point:
|
|
632
|
+
size = torch.finfo(dtype).bits // 8 * shape_numel
|
|
633
|
+
elif dtype.is_complex:
|
|
634
|
+
raise ValueError(
|
|
635
|
+
"Complex-valued tensors are not supported by MemoryMappedTensor."
|
|
636
|
+
)
|
|
637
|
+
elif dtype == torch.bool:
|
|
638
|
+
size = shape_numel
|
|
639
|
+
else:
|
|
640
|
+
# assume integer
|
|
641
|
+
size = torch.iinfo(dtype).bits // 8 * shape_numel
|
|
642
|
+
handler = _FileHandler(size)
|
|
643
|
+
|
|
644
|
+
# buffer
|
|
645
|
+
func_offset_stride = getattr(
|
|
646
|
+
torch, "_nested_compute_contiguous_strides_offsets", None
|
|
647
|
+
)
|
|
648
|
+
if func_offset_stride is not None:
|
|
649
|
+
offsets_strides = func_offset_stride(shape)
|
|
650
|
+
else:
|
|
651
|
+
raise RuntimeError(NESTED_TENSOR_ERR)
|
|
652
|
+
result = torch.frombuffer(memoryview(handler.buffer), dtype=dtype)
|
|
653
|
+
result = torch._nested_view_from_buffer(
|
|
654
|
+
result,
|
|
655
|
+
shape,
|
|
656
|
+
*offsets_strides,
|
|
657
|
+
)
|
|
658
|
+
result = cls(result)
|
|
659
|
+
result._handler = handler
|
|
660
|
+
return result
|
|
661
|
+
else:
|
|
662
|
+
result = torch.from_file(
|
|
663
|
+
str(filename),
|
|
664
|
+
shared=True,
|
|
665
|
+
dtype=dtype,
|
|
666
|
+
size=shape_numel,
|
|
667
|
+
# needed when device ctx differs
|
|
668
|
+
device=torch.device("cpu"),
|
|
669
|
+
)
|
|
670
|
+
func_offset_stride = getattr(
|
|
671
|
+
torch, "_nested_compute_contiguous_strides_offsets", None
|
|
672
|
+
)
|
|
673
|
+
if func_offset_stride is not None:
|
|
674
|
+
offsets_strides = func_offset_stride(shape)
|
|
675
|
+
else:
|
|
676
|
+
raise RuntimeError(NESTED_TENSOR_ERR)
|
|
677
|
+
result = torch._nested_view_from_buffer(
|
|
678
|
+
result,
|
|
679
|
+
shape,
|
|
680
|
+
*offsets_strides,
|
|
681
|
+
)
|
|
682
|
+
result = cls(result)
|
|
683
|
+
result.filename = filename
|
|
684
|
+
return result
|
|
685
|
+
return result
|
|
686
|
+
|
|
687
|
+
if shape:
|
|
688
|
+
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
|
|
689
|
+
shape = torch.Size(shape[0])
|
|
690
|
+
else:
|
|
691
|
+
shape = torch.Size(shape)
|
|
692
|
+
result = result.expand(shape)
|
|
693
|
+
result = cls.from_tensor(
|
|
694
|
+
result,
|
|
695
|
+
filename=filename,
|
|
696
|
+
copy_data=False,
|
|
697
|
+
existsok=kwargs.pop("existsok", False),
|
|
698
|
+
)
|
|
699
|
+
return result
|
|
700
|
+
|
|
701
|
+
@classmethod
|
|
702
|
+
def empty_nested(cls, *args, **kwargs):
|
|
703
|
+
# noqa: D417
|
|
704
|
+
"""Creates a tensor with empty content, specific shape, dtype and filename.
|
|
705
|
+
|
|
706
|
+
Args:
|
|
707
|
+
shape (nested_shape): the shapes of the tensors.
|
|
708
|
+
|
|
709
|
+
Keyword Args:
|
|
710
|
+
dtype (torch.dtype): the dtype of the tensor.
|
|
711
|
+
device (torch.device): the device of the tensor. Only `None` and `"cpu"`
|
|
712
|
+
are accepted, any other device will raise an exception.
|
|
713
|
+
filename (path or equivalent): the path to the file, if any. If none
|
|
714
|
+
is provided, a handler is used.
|
|
715
|
+
existsok (bool, optional): whether it is ok to overwrite an existing file.
|
|
716
|
+
Defaults to ``False``.
|
|
717
|
+
"""
|
|
718
|
+
shape = kwargs.pop("shape", args[0])
|
|
719
|
+
args = (torch.Size([]), *args)
|
|
720
|
+
_, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
|
|
721
|
+
if device is not None:
|
|
722
|
+
device = torch.device(device)
|
|
723
|
+
if device.type != "cpu":
|
|
724
|
+
raise RuntimeError("Only CPU tensors are supported.")
|
|
725
|
+
result = torch.zeros((), dtype=dtype, device=device)
|
|
726
|
+
if shape:
|
|
727
|
+
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
|
|
728
|
+
shape = torch.Size(shape[0])
|
|
729
|
+
else:
|
|
730
|
+
shape = torch.Size(shape)
|
|
731
|
+
result = result.expand(shape)
|
|
732
|
+
result = cls.from_tensor(
|
|
733
|
+
result,
|
|
734
|
+
filename=filename,
|
|
735
|
+
copy_data=False,
|
|
736
|
+
existsok=kwargs.pop("existsok", False),
|
|
737
|
+
)
|
|
738
|
+
return result
|
|
739
|
+
|
|
740
|
+
@classmethod
|
|
741
|
+
@overload
|
|
742
|
+
def full(cls, *size, fill_value, dtype=None, device=None, filename=None): ...
|
|
743
|
+
|
|
744
|
+
@classmethod
|
|
745
|
+
@overload
|
|
746
|
+
def full(cls, shape, *, fill_value, dtype=None, device=None, filename=None): ...
|
|
747
|
+
|
|
748
|
+
@classmethod
|
|
749
|
+
def full(cls, *args, **kwargs):
|
|
750
|
+
# noqa: D417
|
|
751
|
+
"""Creates a tensor with a single content specified by `fill_value`, specific shape, dtype and filename.
|
|
752
|
+
|
|
753
|
+
Args:
|
|
754
|
+
shape (integers or torch.Size): the shape of the tensor.
|
|
755
|
+
|
|
756
|
+
Keyword Args:
|
|
757
|
+
fill_value (float or equivalent): content of the tensor.
|
|
758
|
+
dtype (torch.dtype): the dtype of the tensor.
|
|
759
|
+
device (torch.device): the device of the tensor. Only `None` and `"cpu"`
|
|
760
|
+
are accepted, any other device will raise an exception.
|
|
761
|
+
filename (path or equivalent): the path to the file, if any. If none
|
|
762
|
+
is provided, a handler is used.
|
|
763
|
+
existsok (bool, optional): whether it is ok to overwrite an existing file.
|
|
764
|
+
Defaults to ``False``.
|
|
765
|
+
"""
|
|
766
|
+
shape, device, dtype, fill_value, filename = _proc_args_const(*args, **kwargs)
|
|
767
|
+
if device is not None:
|
|
768
|
+
device = torch.device(device)
|
|
769
|
+
if device.type != "cpu":
|
|
770
|
+
raise RuntimeError("Only CPU tensors are supported.")
|
|
771
|
+
result = torch.zeros((), dtype=dtype, device=device).fill_(fill_value)
|
|
772
|
+
if shape:
|
|
773
|
+
if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
|
|
774
|
+
shape = torch.Size(shape[0])
|
|
775
|
+
else:
|
|
776
|
+
shape = torch.Size(shape)
|
|
777
|
+
result = result.expand(shape)
|
|
778
|
+
return cls.from_tensor(
|
|
779
|
+
result, filename=filename, existsok=kwargs.pop("existsok", False)
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
@classmethod
|
|
783
|
+
def from_filename(cls, filename, dtype, shape, readonly : bool = False, index=None):
|
|
784
|
+
# noqa: D417
|
|
785
|
+
"""Loads a MemoryMappedTensor from a given filename.
|
|
786
|
+
|
|
787
|
+
Args:
|
|
788
|
+
filename (path or equivalent): the path to the file.
|
|
789
|
+
dtype (torch.dtype): the dtype of the tensor.
|
|
790
|
+
shape (torch.Size or torch.Tensor): the shape of the tensor. If
|
|
791
|
+
a tensor is provided, it is assumed that the tensor is a nested_tensor
|
|
792
|
+
instance.
|
|
793
|
+
readonly (bool, optional): whether to open the file in read-only mode.
|
|
794
|
+
Defaults to ``False``.
|
|
795
|
+
index (torch-compatible index type): an index to use to build the
|
|
796
|
+
tensor.
|
|
797
|
+
|
|
798
|
+
"""
|
|
799
|
+
writable = _is_writable(filename)
|
|
800
|
+
|
|
801
|
+
if isinstance(shape, torch.Tensor):
|
|
802
|
+
func_offset_stride = getattr(
|
|
803
|
+
torch, "_nested_compute_contiguous_strides_offsets", None
|
|
804
|
+
)
|
|
805
|
+
if func_offset_stride is not None:
|
|
806
|
+
offsets_strides = func_offset_stride(shape)
|
|
807
|
+
else:
|
|
808
|
+
raise RuntimeError(
|
|
809
|
+
"The PyTorch version isn't compatible with memmap "
|
|
810
|
+
"nested tensors. Please upgrade to a more recent "
|
|
811
|
+
"version."
|
|
812
|
+
)
|
|
813
|
+
tensor = torch.from_file(
|
|
814
|
+
str(filename),
|
|
815
|
+
shared=writable and not readonly,
|
|
816
|
+
dtype=dtype,
|
|
817
|
+
size=shape.prod(-1).sum().int(),
|
|
818
|
+
# needed when device ctx differs
|
|
819
|
+
device=torch.device("cpu"),
|
|
820
|
+
)
|
|
821
|
+
tensor = torch._nested_view_from_buffer(
|
|
822
|
+
tensor,
|
|
823
|
+
shape,
|
|
824
|
+
*offsets_strides,
|
|
825
|
+
)
|
|
826
|
+
else:
|
|
827
|
+
shape = torch.Size(shape)
|
|
828
|
+
# whether the file already existed
|
|
829
|
+
tensor = torch.from_file(
|
|
830
|
+
str(filename),
|
|
831
|
+
shared=writable and not readonly,
|
|
832
|
+
dtype=dtype,
|
|
833
|
+
size=shape.numel(),
|
|
834
|
+
# needed when device ctx differs
|
|
835
|
+
device=torch.device("cpu"),
|
|
836
|
+
)
|
|
837
|
+
tensor = tensor.view(shape)
|
|
838
|
+
|
|
839
|
+
if index is not None:
|
|
840
|
+
tensor = tensor[index]
|
|
841
|
+
out = cls(tensor)
|
|
842
|
+
out.filename = filename
|
|
843
|
+
out._handler = None
|
|
844
|
+
out.index = index
|
|
845
|
+
out.parent_shape = shape
|
|
846
|
+
return out
|
|
847
|
+
|
|
848
|
+
@classmethod
|
|
849
|
+
def from_handler(cls, handler, dtype, shape, index=None):
|
|
850
|
+
# noqa: D417
|
|
851
|
+
"""Loads a MemoryMappedTensor from a given handler.
|
|
852
|
+
|
|
853
|
+
Args:
|
|
854
|
+
handler (compatible file handler): the handler for the tensor.
|
|
855
|
+
dtype (torch.dtype): the dtype of the tensor.
|
|
856
|
+
shape (torch.Size or torch.Tensor): the shape of the tensor. If
|
|
857
|
+
a tensor is provided, it is assumed that the tensor is a nested_tensor
|
|
858
|
+
instance.
|
|
859
|
+
index (torch-compatible index type, optional): an index to use to build the
|
|
860
|
+
tensor.
|
|
861
|
+
|
|
862
|
+
"""
|
|
863
|
+
out = torch.frombuffer(memoryview(handler.buffer), dtype=dtype)
|
|
864
|
+
if isinstance(shape, torch.Tensor):
|
|
865
|
+
func_offset_stride = getattr(
|
|
866
|
+
torch, "_nested_compute_contiguous_strides_offsets", None
|
|
867
|
+
)
|
|
868
|
+
if func_offset_stride is not None:
|
|
869
|
+
offsets_strides = func_offset_stride(shape)
|
|
870
|
+
else:
|
|
871
|
+
raise RuntimeError(
|
|
872
|
+
"The PyTorch version isn't compatible with memmap "
|
|
873
|
+
"nested tensors. Please upgrade to a more recent "
|
|
874
|
+
"version."
|
|
875
|
+
)
|
|
876
|
+
out = torch._nested_view_from_buffer(
|
|
877
|
+
out,
|
|
878
|
+
shape,
|
|
879
|
+
*offsets_strides,
|
|
880
|
+
)
|
|
881
|
+
else:
|
|
882
|
+
shape = torch.Size(shape)
|
|
883
|
+
out = torch.reshape(out, shape)
|
|
884
|
+
|
|
885
|
+
if index is not None:
|
|
886
|
+
out = out[index]
|
|
887
|
+
out = cls(out)
|
|
888
|
+
out.filename = None
|
|
889
|
+
out._handler = handler
|
|
890
|
+
out.index = index
|
|
891
|
+
out.parent_shape = shape
|
|
892
|
+
return out
|
|
893
|
+
|
|
894
|
+
@property
|
|
895
|
+
def _tensor(self):
|
|
896
|
+
# for bc-compatibility with MemmapTensor, to be deprecated in v0.4
|
|
897
|
+
return self
|
|
898
|
+
|
|
899
|
+
def __setstate__(self, state):
|
|
900
|
+
if "filename" in state:
|
|
901
|
+
self.__dict__ = type(self).from_filename(**state).__dict__
|
|
902
|
+
else:
|
|
903
|
+
self.__dict__ = type(self).from_handler(**state).__dict__
|
|
904
|
+
|
|
905
|
+
def __getstate__(self):
|
|
906
|
+
if getattr(self, "_handler", None) is not None:
|
|
907
|
+
return {
|
|
908
|
+
"handler": self._handler,
|
|
909
|
+
"dtype": self.dtype,
|
|
910
|
+
"shape": list(self.parent_shape),
|
|
911
|
+
"index": self.index,
|
|
912
|
+
}
|
|
913
|
+
elif getattr(self, "_filename", None) is not None:
|
|
914
|
+
return {
|
|
915
|
+
"filename": self._filename,
|
|
916
|
+
"dtype": self.dtype,
|
|
917
|
+
"shape": self.parent_shape,
|
|
918
|
+
"index": self.index,
|
|
919
|
+
}
|
|
920
|
+
else:
|
|
921
|
+
raise RuntimeError("Could not find handler or filename.")
|
|
922
|
+
|
|
923
|
+
def __reduce_ex__(self, protocol):
|
|
924
|
+
return self.__reduce__()
|
|
925
|
+
|
|
926
|
+
def __reduce__(self):
|
|
927
|
+
if getattr(self, "_handler", None) is not None:
|
|
928
|
+
return type(self).from_handler, (
|
|
929
|
+
self._handler,
|
|
930
|
+
self.dtype,
|
|
931
|
+
self.parent_shape,
|
|
932
|
+
self.index,
|
|
933
|
+
)
|
|
934
|
+
elif getattr(self, "_filename", None) is not None:
|
|
935
|
+
return type(self).from_filename, (
|
|
936
|
+
self._filename,
|
|
937
|
+
self.dtype,
|
|
938
|
+
self.parent_shape,
|
|
939
|
+
self.index,
|
|
940
|
+
)
|
|
941
|
+
else:
|
|
942
|
+
raise RuntimeError("Could not find handler or filename.")
|
|
943
|
+
|
|
944
|
+
@implement_for("torch", "2.0", None)
|
|
945
|
+
def __getitem__(self, item: IndexType) -> Self | torch.Tensor:
|
|
946
|
+
try:
|
|
947
|
+
out = super().__getitem__(item)
|
|
948
|
+
except ValueError as err:
|
|
949
|
+
if "is unbound" in str(err):
|
|
950
|
+
raise ValueError(
|
|
951
|
+
"Using first class dimension indices with MemoryMappedTensor "
|
|
952
|
+
"isn't supported at the moment."
|
|
953
|
+
) from err
|
|
954
|
+
raise
|
|
955
|
+
if out.untyped_storage().data_ptr() == self.untyped_storage().data_ptr():
|
|
956
|
+
out = self._index_wrap(out, item)
|
|
957
|
+
return out
|
|
958
|
+
|
|
959
|
+
@implement_for("torch", None, "2.0")
|
|
960
|
+
def __getitem__(self, item: IndexType) -> Self | torch.Tensor: # noqa: F811
|
|
961
|
+
try:
|
|
962
|
+
out = super().__getitem__(item)
|
|
963
|
+
except ValueError as err:
|
|
964
|
+
if "is unbound" in str(err):
|
|
965
|
+
raise ValueError(
|
|
966
|
+
"Using first class dimension indices with MemoryMappedTensor "
|
|
967
|
+
"isn't supported at the moment."
|
|
968
|
+
) from err
|
|
969
|
+
raise
|
|
970
|
+
if out.storage().data_ptr() == self.storage().data_ptr():
|
|
971
|
+
out = self._index_wrap(out, item)
|
|
972
|
+
return out
|
|
973
|
+
|
|
974
|
+
def _index_wrap(self, tensor, item, check=False):
|
|
975
|
+
if check:
|
|
976
|
+
if tensor.untyped_storage().data_ptr() == self.untyped_storage().data_ptr():
|
|
977
|
+
return self._index_wrap(tensor, item)
|
|
978
|
+
return tensor
|
|
979
|
+
tensor = MemoryMappedTensor(tensor)
|
|
980
|
+
tensor._handler = getattr(self, "_handler", None)
|
|
981
|
+
tensor.filename = getattr(self, "_filename", None)
|
|
982
|
+
tensor.index = item
|
|
983
|
+
tensor.parent_shape = getattr(self, "parent_shape", None)
|
|
984
|
+
return tensor
|
|
985
|
+
|
|
986
|
+
def unbind(self, dim):
|
|
987
|
+
out = super().unbind(dim)
|
|
988
|
+
if dim < 0:
|
|
989
|
+
dim = self.ndim + dim
|
|
990
|
+
index_base = (slice(None),) * dim
|
|
991
|
+
return tuple(
|
|
992
|
+
self._index_wrap(_out, index_base + (i,)) for i, _out in enumerate(out)
|
|
993
|
+
)
|
|
994
|
+
|
|
995
|
+
def chunk(self, chunks, dim=0):
|
|
996
|
+
dim = _maybe_correct_neg_dim(dim, self.shape)
|
|
997
|
+
out = super().chunk(chunks, dim)
|
|
998
|
+
slices = []
|
|
999
|
+
i = 0
|
|
1000
|
+
for chunk in out:
|
|
1001
|
+
slices.append(
|
|
1002
|
+
tuple(slice(None) for _ in range(dim))
|
|
1003
|
+
+ (slice(i, i + chunk.shape[dim]),)
|
|
1004
|
+
)
|
|
1005
|
+
i += chunk.shape[dim]
|
|
1006
|
+
return tuple(
|
|
1007
|
+
self._index_wrap(chunk, _slice, check=True)
|
|
1008
|
+
for chunk, _slice in _zip_strict(out, slices)
|
|
1009
|
+
)
|
|
1010
|
+
|
|
1011
|
+
def split(self, split_size, dim=0):
|
|
1012
|
+
dim = _maybe_correct_neg_dim(dim, self.shape)
|
|
1013
|
+
out = super().split(split_size, dim)
|
|
1014
|
+
slices = []
|
|
1015
|
+
i = 0
|
|
1016
|
+
for chunk in out:
|
|
1017
|
+
slices.append(
|
|
1018
|
+
tuple(slice(None) for _ in range(dim))
|
|
1019
|
+
+ (slice(i, i + chunk.shape[dim]),)
|
|
1020
|
+
)
|
|
1021
|
+
i += chunk.shape[dim]
|
|
1022
|
+
return tuple(
|
|
1023
|
+
self._index_wrap(split, _slice, check=True)
|
|
1024
|
+
for split, _slice in _zip_strict(out, slices)
|
|
1025
|
+
)
|
|
1026
|
+
|
|
1027
|
+
|
|
1028
|
+
#####################
|
|
1029
|
+
# File handler
|
|
1030
|
+
# borrowed from mp.heap
|
|
1031
|
+
|
|
1032
|
+
if sys.platform == "win32":
|
|
1033
|
+
import _winapi
|
|
1034
|
+
|
|
1035
|
+
class _FileHandler:
|
|
1036
|
+
_rand = tempfile._RandomNameSequence()
|
|
1037
|
+
|
|
1038
|
+
def __init__(self, size):
|
|
1039
|
+
self.size = size
|
|
1040
|
+
for _ in range(100):
|
|
1041
|
+
name = "pym-%d-%s" % (os.getpid(), next(self._rand))
|
|
1042
|
+
buf = mmap.mmap(-1, size, tagname=name)
|
|
1043
|
+
if _winapi.GetLastError() == 0:
|
|
1044
|
+
break
|
|
1045
|
+
# We have reopened a preexisting mmap.
|
|
1046
|
+
buf.close()
|
|
1047
|
+
else:
|
|
1048
|
+
raise FileExistsError("Cannot find name for new mmap")
|
|
1049
|
+
self.name = name
|
|
1050
|
+
self.buffer = buf
|
|
1051
|
+
self._state = (self.size, self.name)
|
|
1052
|
+
|
|
1053
|
+
def __getstate__(self):
|
|
1054
|
+
from multiprocessing.context import assert_spawning
|
|
1055
|
+
|
|
1056
|
+
assert_spawning(self)
|
|
1057
|
+
return self._state
|
|
1058
|
+
|
|
1059
|
+
def __setstate__(self, state):
|
|
1060
|
+
self.size, self.name = self._state = state
|
|
1061
|
+
# Reopen existing mmap
|
|
1062
|
+
self.buffer = mmap.mmap(-1, self.size, tagname=self.name)
|
|
1063
|
+
# XXX Temporarily preventing buildbot failures while determining
|
|
1064
|
+
# XXX the correct long-term fix. See issue 23060
|
|
1065
|
+
# assert _winapi.GetLastError() == _winapi.ERROR_ALREADY_EXISTS
|
|
1066
|
+
|
|
1067
|
+
else:
|
|
1068
|
+
|
|
1069
|
+
class _FileHandler:
|
|
1070
|
+
if sys.platform == "linux":
|
|
1071
|
+
_dir_candidates = ["/dev/shm"]
|
|
1072
|
+
else:
|
|
1073
|
+
_dir_candidates = []
|
|
1074
|
+
|
|
1075
|
+
def __init__(self, size, fd=-1):
|
|
1076
|
+
self.size = size
|
|
1077
|
+
self.fd = fd
|
|
1078
|
+
if fd == -1:
|
|
1079
|
+
self.fd, name = tempfile.mkstemp(
|
|
1080
|
+
prefix="pym-%d-" % os.getpid(), dir=self._choose_dir(size)
|
|
1081
|
+
)
|
|
1082
|
+
os.unlink(name)
|
|
1083
|
+
util.Finalize(self, os.close, (self.fd,))
|
|
1084
|
+
os.ftruncate(self.fd, size)
|
|
1085
|
+
self.buffer = mmap.mmap(self.fd, self.size)
|
|
1086
|
+
|
|
1087
|
+
def _choose_dir(self, size):
|
|
1088
|
+
# Choose a non-storage backed directory if possible,
|
|
1089
|
+
# to improve performance
|
|
1090
|
+
for d in self._dir_candidates:
|
|
1091
|
+
st = os.statvfs(d)
|
|
1092
|
+
if st.f_bavail * st.f_frsize >= size: # enough free space?
|
|
1093
|
+
return d
|
|
1094
|
+
return util.get_temp_dir()
|
|
1095
|
+
|
|
1096
|
+
def _reduce_handler(handler):
|
|
1097
|
+
if handler.fd == -1:
|
|
1098
|
+
raise ValueError(
|
|
1099
|
+
"Handler is unpicklable because "
|
|
1100
|
+
"forking was enabled when it was created"
|
|
1101
|
+
)
|
|
1102
|
+
return _rebuild_handler, (handler.size, reduction.DupFd(handler.fd))
|
|
1103
|
+
|
|
1104
|
+
def _rebuild_handler(size, dupfd):
|
|
1105
|
+
detached = dupfd.detach()
|
|
1106
|
+
return _FileHandler(size, detached)
|
|
1107
|
+
|
|
1108
|
+
reduction.register(_FileHandler, _reduce_handler)
|
|
1109
|
+
|
|
1110
|
+
|
|
1111
|
+
def _reduce_memmap(memmap_tensor):
|
|
1112
|
+
return memmap_tensor.__reduce__()
|
|
1113
|
+
|
|
1114
|
+
|
|
1115
|
+
reduction.register(MemoryMappedTensor, _reduce_memmap)
|
|
1116
|
+
|
|
1117
|
+
|
|
1118
|
+
def _proc_args_const(*args, **kwargs):
|
|
1119
|
+
if len(args) > 0:
|
|
1120
|
+
# then the first (or the N first) args are the shape
|
|
1121
|
+
if len(args) == 1 and isinstance(args[0], torch.Tensor):
|
|
1122
|
+
shape = args[0]
|
|
1123
|
+
elif len(args) == 1 and not isinstance(args[0], int):
|
|
1124
|
+
shape = torch.Size(args[0])
|
|
1125
|
+
else:
|
|
1126
|
+
shape = torch.Size(args)
|
|
1127
|
+
else:
|
|
1128
|
+
# we should have a "shape" keyword arg
|
|
1129
|
+
shape = kwargs.pop("shape", None)
|
|
1130
|
+
if shape is None:
|
|
1131
|
+
raise TypeError("Could not find the shape argument in the arguments.")
|
|
1132
|
+
if not isinstance(shape, torch.Tensor):
|
|
1133
|
+
shape = torch.Size(shape)
|
|
1134
|
+
return (
|
|
1135
|
+
shape,
|
|
1136
|
+
kwargs.pop("device", None),
|
|
1137
|
+
kwargs.pop("dtype", None),
|
|
1138
|
+
kwargs.pop("fill_value", None),
|
|
1139
|
+
kwargs.pop("filename", None),
|
|
1140
|
+
)
|
|
1141
|
+
|
|
1142
|
+
|
|
1143
|
+
# Torch functions
|
|
1144
|
+
|
|
1145
|
+
MEMMAP_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
|
|
1146
|
+
|
|
1147
|
+
|
|
1148
|
+
def implements_for_memmap(torch_function: Callable) -> Callable[[Callable], Callable]:
|
|
1149
|
+
"""Register a torch function override for MemoryMappedTensor."""
|
|
1150
|
+
|
|
1151
|
+
@functools.wraps(torch_function)
|
|
1152
|
+
def decorator(func: Callable) -> Callable:
|
|
1153
|
+
MEMMAP_HANDLED_FUNCTIONS[torch_function] = func
|
|
1154
|
+
return func
|
|
1155
|
+
|
|
1156
|
+
return decorator
|
|
1157
|
+
|
|
1158
|
+
|
|
1159
|
+
@implements_for_memmap(torch.unbind)
|
|
1160
|
+
def _unbind(tensor, dim):
|
|
1161
|
+
return tensor.unbind(dim)
|
|
1162
|
+
|
|
1163
|
+
|
|
1164
|
+
@implements_for_memmap(torch.chunk)
|
|
1165
|
+
def _chunk(input, chunks, dim=0):
|
|
1166
|
+
return input.chunk(chunks, dim=dim)
|
|
1167
|
+
|
|
1168
|
+
|
|
1169
|
+
def _is_writable(file_path):
|
|
1170
|
+
file_path = str(file_path)
|
|
1171
|
+
if os.path.exists(file_path):
|
|
1172
|
+
return os.access(file_path, os.W_OK)
|
|
1173
|
+
# Assume that the file can be written in the directory
|
|
1174
|
+
return True
|