unienv 0.0.1b4__py3-none-any.whl → 0.0.1b6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/METADATA +3 -2
  2. {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/RECORD +43 -32
  3. {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/WHEEL +1 -1
  4. unienv_data/base/common.py +25 -10
  5. unienv_data/base/storage.py +2 -0
  6. unienv_data/batches/backend_compat.py +1 -1
  7. unienv_data/batches/combined_batch.py +1 -1
  8. unienv_data/batches/slicestack_batch.py +1 -0
  9. unienv_data/replay_buffer/replay_buffer.py +179 -65
  10. unienv_data/replay_buffer/trajectory_replay_buffer.py +230 -163
  11. unienv_data/storages/_episode_storage.py +438 -0
  12. unienv_data/storages/_list_storage.py +136 -0
  13. unienv_data/storages/backend_compat.py +268 -0
  14. unienv_data/storages/dict_storage.py +39 -7
  15. unienv_data/storages/flattened.py +11 -4
  16. unienv_data/storages/hdf5.py +11 -0
  17. unienv_data/storages/image_storage.py +144 -0
  18. unienv_data/storages/npz_storage.py +135 -0
  19. unienv_data/storages/pytorch.py +17 -10
  20. unienv_data/storages/transformation.py +16 -1
  21. unienv_data/storages/video_storage.py +297 -0
  22. unienv_data/third_party/tensordict/memmap_tensor.py +1174 -0
  23. unienv_data/transformations/image_compress.py +97 -21
  24. unienv_interface/func_wrapper/frame_stack.py +1 -1
  25. unienv_interface/space/space_utils/batch_utils.py +5 -1
  26. unienv_interface/space/space_utils/flatten_utils.py +8 -2
  27. unienv_interface/space/spaces/dict.py +6 -0
  28. unienv_interface/space/spaces/tuple.py +4 -4
  29. unienv_interface/transformations/__init__.py +3 -1
  30. unienv_interface/transformations/batch_and_unbatch.py +42 -4
  31. unienv_interface/transformations/chained_transform.py +9 -8
  32. unienv_interface/transformations/crop.py +69 -0
  33. unienv_interface/transformations/dict_transform.py +8 -2
  34. unienv_interface/transformations/identity.py +16 -0
  35. unienv_interface/transformations/image_resize.py +106 -0
  36. unienv_interface/transformations/iter_transform.py +92 -0
  37. unienv_interface/transformations/rescale.py +24 -5
  38. unienv_interface/utils/symbol_util.py +7 -1
  39. unienv_interface/wrapper/backend_compat.py +1 -1
  40. unienv_interface/wrapper/frame_stack.py +1 -1
  41. {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/licenses/LICENSE +0 -0
  42. {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/top_level.txt +0 -0
  43. /unienv_interface/utils/{data_queue.py → framestack_queue.py} +0 -0
@@ -0,0 +1,1174 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # From https://github.com/pytorch/tensordict/blob/main/tensordict/memmap.py
6
+
7
+ from __future__ import annotations
8
+
9
+ import functools
10
+
11
+ import mmap
12
+ import os
13
+
14
+ import sys
15
+ import tempfile
16
+ from multiprocessing import reduction, util
17
+ from pathlib import Path
18
+ from typing import (
19
+ Any,
20
+ Callable,
21
+ Iterator,
22
+ List,
23
+ Sequence,
24
+ Tuple,
25
+ TYPE_CHECKING,
26
+ TypeVar,
27
+ Union,
28
+ overload
29
+ )
30
+ import numpy as np
31
+ from pyvers import implement_for
32
+
33
+ import torch
34
+ from torch import Tensor
35
+ from torch.nn.parameter import (
36
+ UninitializedBuffer,
37
+ UninitializedParameter,
38
+ UninitializedTensorMixin,
39
+ )
40
+
41
+ NESTED_TENSOR_ERR = (
42
+ "The PyTorch version isn't compatible with "
43
+ "nested tensors. Please upgrade to a more recent "
44
+ "version."
45
+ )
46
+
47
+ if TYPE_CHECKING:
48
+ from typing import Self
49
+ else:
50
+ Self = Any
51
+
52
+ def _maybe_correct_neg_dim(
53
+ dim: int, shape: torch.Size | None, ndim: int | None = None
54
+ ) -> int:
55
+ """Corrects neg dim to pos."""
56
+ if ndim is None:
57
+ ndim = len(shape)
58
+ if dim < 0:
59
+ new_dim = ndim + dim
60
+ else:
61
+ new_dim = dim
62
+ if new_dim < 0 or new_dim >= ndim:
63
+ if shape is not None:
64
+ raise IndexError(
65
+ f"Incompatible dim {new_dim} for tensordict with shape {shape}."
66
+ )
67
+ raise IndexError(
68
+ f"Incompatible dim {new_dim} for tensordict with batch dims {ndim}."
69
+ )
70
+ return new_dim
71
+
72
+ def _shape(tensor: Tensor, nested_shape=False) -> torch.Size:
73
+ if isinstance(tensor, UninitializedTensorMixin):
74
+ return torch.Size([*getattr(tensor, "batch_size", ()), -1])
75
+ elif not isinstance(tensor, Tensor):
76
+ return tensor.shape
77
+ if tensor.is_nested:
78
+ if nested_shape:
79
+ return tensor._nested_tensor_size()
80
+ shape = []
81
+ for i in range(tensor.ndim):
82
+ try:
83
+ shape.append(tensor.size(i))
84
+ except RuntimeError:
85
+ shape.append(-1)
86
+ return torch.Size(shape)
87
+ return tensor.shape
88
+
89
+ if sys.version_info >= (3, 10):
90
+ _zip_strict = functools.partial(zip, strict=True)
91
+ else:
92
+ def _zip_strict(*iterables):
93
+ iterables = tuple(tuple(it) for it in iterables)
94
+ lengths = {len(it) for it in iterables}
95
+ if len(lengths) > 1:
96
+ raise ValueError("lengths of iterables differ.")
97
+
98
+ return zip(*iterables)
99
+
100
+ IndexType = Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]]
101
+
102
+ class MemoryMappedTensor(torch.Tensor):
103
+ """A Memory-mapped Tensor.
104
+
105
+ Supports filenames or file handlers.
106
+
107
+ The main advantage of MemoryMappedTensor resides in its serialization methods,
108
+ which ensure that the tensor is passed through queues or RPC remote calls without
109
+ any copy.
110
+
111
+ .. note::
112
+ When used within RPC settings, the filepath should be accessible to both nodes.
113
+ If it isn't the behaviour of passing a MemoryMappedTensor from one worker
114
+ to another is undefined.
115
+
116
+ MemoryMappedTensor supports multiple construction methods.
117
+
118
+ Examples:
119
+ >>> # from an existing tensor
120
+ >>> tensor = torch.randn(3)
121
+ >>> with tempfile.NamedTemporaryFile() as file:
122
+ ... memmap_tensor = MemoryMappedTensor.from_tensor(tensor, filename=file.name)
123
+ ... assert memmap_tensor.filename is not None
124
+ >>> # if no filename is passed, a handler is used
125
+ >>> tensor = torch.randn(3)
126
+ >>> memmap_tensor = MemoryMappedTensor.from_tensor(tensor, filename=file.name)
127
+ >>> assert memmap_tensor.filename is None
128
+ >>> # one can create an empty tensor too
129
+ >>> with tempfile.NamedTemporaryFile() as file:
130
+ ... memmap_tensor_empty = MemoryMappedTensor.empty_like(tensor, filename=file.name)
131
+ >>> with tempfile.NamedTemporaryFile() as file:
132
+ ... memmap_tensor_zero = MemoryMappedTensor.zeros_like(tensor, filename=file.name)
133
+ >>> with tempfile.NamedTemporaryFile() as file:
134
+ ... memmap_tensor = MemoryMappedTensor.ones_like(tensor, filename=file.name)
135
+ """
136
+
137
+ _filename: str | Path = None
138
+ _handler: _FileHandler = None
139
+ _clear: bool
140
+ index: Any
141
+ parent_shape: torch.Size
142
+
143
+ def __new__(
144
+ cls,
145
+ source,
146
+ *,
147
+ dtype=None,
148
+ shape=None,
149
+ index=None,
150
+ device=None,
151
+ handler=None,
152
+ filename=None,
153
+ ):
154
+ if device is not None and torch.device(device).type != "cpu":
155
+ raise ValueError(f"{cls} device must be cpu!")
156
+ if isinstance(source, str):
157
+ if filename is not None:
158
+ raise TypeError("Duplicated filename argument.")
159
+ filename = source
160
+ source = None
161
+ if filename is not None:
162
+ if dtype is not None:
163
+ raise TypeError("Cannot pass new dtype if source is provided.")
164
+ result = cls.from_tensor(
165
+ torch.as_tensor(source),
166
+ filename=filename,
167
+ # dtype=dtype,
168
+ shape=shape,
169
+ # index=index,
170
+ )
171
+ if index is not None:
172
+ return result[index]
173
+ return result
174
+ elif isinstance(source, torch.StorageBase):
175
+ return cls.from_storage(
176
+ source,
177
+ dtype=dtype,
178
+ shape=shape,
179
+ index=index,
180
+ device=device,
181
+ handler=handler,
182
+ filename=filename,
183
+ )
184
+ elif handler is not None:
185
+ return cls.from_handler(
186
+ handler,
187
+ dtype,
188
+ shape,
189
+ index,
190
+ )
191
+ return super().__new__(cls, source)
192
+
193
+ def __init__(
194
+ self,
195
+ source,
196
+ *,
197
+ handler=None,
198
+ dtype=None,
199
+ shape=None,
200
+ device=None,
201
+ filename=None,
202
+ ): ...
203
+
204
+ __torch_function__ = torch._C._disabled_torch_function_impl
205
+
206
+ @classmethod
207
+ def from_tensor(
208
+ cls,
209
+ input,
210
+ *,
211
+ filename: Path | str = None,
212
+ existsok: bool = False,
213
+ copy_existing: bool = False,
214
+ copy_data: bool = True,
215
+ shape: torch.Size | None = None,
216
+ ): # noqa: D417
217
+ """Creates a MemoryMappedTensor with the same content as another tensor.
218
+
219
+ If the tensor is already a MemoryMappedTensor the original tensor is
220
+ returned if the `filename` argument is `None` or if the two paths match.
221
+ In all other cases, a new :class:`MemoryMappedTensor` is produced.
222
+
223
+ Args:
224
+ input (torch.Tensor): the tensor which content must be copied onto
225
+ the MemoryMappedTensor.
226
+
227
+ Keyword Args:
228
+ filename (path to a file): the path to the file where the tensor
229
+ should be stored. If none is provided, a file handler is used
230
+ instead.
231
+ existsok (bool, optional): if ``True``, the file will overwrite
232
+ an existing file. Defaults to ``False``.
233
+ copy_existing (bool, optional): if ``True`` and the provided input
234
+ is a MemoryMappedTensor with an associated filename, copying
235
+ the content to the new location is permitted. Otherwise, an
236
+ exception is thrown. This behaviour exists to prevent
237
+ inadvertently duplicating data on disk.
238
+ copy_data (bool, optional): if ``True``, the content of the tensor
239
+ will be copied on the storage. Defaults to ``True``.
240
+ shape (torch.Size or torch.Tensor): a shape to override the tensor
241
+ shape. If a tensor is passed, it must represent the nested shapes of a
242
+ nested tensor.
243
+ """
244
+ if isinstance(input, MemoryMappedTensor):
245
+ if (filename is None and input._filename is None) or (
246
+ input._filename is not None
247
+ and filename is not None
248
+ and Path(filename).absolute() == Path(input.filename).absolute()
249
+ ):
250
+ # either location was not specified, or memmap is already in the
251
+ # correct location, so just return the MemmapTensor unmodified
252
+ return input
253
+ elif not copy_existing and (
254
+ input._filename is not None
255
+ and filename is not None
256
+ and Path(filename).absolute() != Path(input.filename).absolute()
257
+ ):
258
+ raise RuntimeError(
259
+ f"A filename was provided but the tensor already has a file associated "
260
+ f"({input.filename}). "
261
+ f"To copy the tensor onto the new location, pass copy_existing=True."
262
+ )
263
+ elif isinstance(input, np.ndarray):
264
+ raise TypeError(
265
+ "Convert input to torch.Tensor before calling MemoryMappedTensor.from_tensor."
266
+ )
267
+ if input.requires_grad:
268
+ raise RuntimeError(
269
+ "MemoryMappedTensor.from_tensor is incompatible with tensor.requires_grad."
270
+ )
271
+ if shape is None:
272
+ shape = _shape(input, nested_shape=True)
273
+ if isinstance(shape, torch.Tensor):
274
+ shape_numel = shape.prod(-1).sum()
275
+ elif isinstance(shape, torch.Size):
276
+ shape_numel = shape.numel()
277
+ else:
278
+ shape_numel = torch.Size(shape).numel()
279
+ if filename is None:
280
+ if input.dtype.is_floating_point:
281
+ size = torch.finfo(input.dtype).bits // 8 * shape_numel
282
+ elif input.dtype.is_complex:
283
+ raise ValueError(
284
+ "Complex-valued tensors are not supported by MemoryMappedTensor."
285
+ )
286
+ elif input.dtype == torch.bool:
287
+ size = shape_numel
288
+ else:
289
+ # assume integer
290
+ size = torch.iinfo(input.dtype).bits // 8 * shape_numel
291
+ handler = _FileHandler(size)
292
+ if isinstance(shape, torch.Tensor):
293
+ func_offset_stride = getattr(
294
+ torch, "_nested_compute_contiguous_strides_offsets", None
295
+ )
296
+ if func_offset_stride is not None:
297
+ offsets_strides = func_offset_stride(shape)
298
+ else:
299
+ raise RuntimeError(NESTED_TENSOR_ERR)
300
+ result = torch.frombuffer(memoryview(handler.buffer), dtype=input.dtype)
301
+ if copy_data:
302
+ result.untyped_storage().copy_(input.untyped_storage())
303
+ result = torch._nested_view_from_buffer(
304
+ result,
305
+ shape,
306
+ *offsets_strides,
307
+ )
308
+ else:
309
+ result = torch.frombuffer(memoryview(handler.buffer), dtype=input.dtype)
310
+ result = result.view(shape)
311
+ result = cls(result)
312
+ else:
313
+ handler = None
314
+ if not existsok and os.path.exists(str(filename)):
315
+ raise RuntimeError(f"The file {filename} already exists.")
316
+ result = torch.from_file(
317
+ str(filename),
318
+ shared=True,
319
+ dtype=input.dtype,
320
+ size=shape_numel,
321
+ # needed when device ctx differs
322
+ device=torch.device("cpu"),
323
+ )
324
+ if isinstance(shape, torch.Tensor):
325
+ func_offset_stride = getattr(
326
+ torch, "_nested_compute_contiguous_strides_offsets", None
327
+ )
328
+ if func_offset_stride is not None:
329
+ offsets_strides = func_offset_stride(shape)
330
+ else:
331
+ raise RuntimeError(NESTED_TENSOR_ERR)
332
+ if copy_data:
333
+ result.untyped_storage().copy_(input.untyped_storage())
334
+ result = torch._nested_view_from_buffer(
335
+ result,
336
+ shape,
337
+ *offsets_strides,
338
+ )
339
+ else:
340
+ result = result.view(shape)
341
+ result = cls(result)
342
+ result._handler = handler
343
+ result.filename = filename
344
+ result.index = None
345
+ result.parent_shape = shape
346
+ if copy_data:
347
+ if hasattr(input, "full_tensor"):
348
+ # for DTensors, cheaper than importing DTensor every time
349
+ input = input.full_tensor()
350
+ if not result.is_nested:
351
+ result.copy_(input)
352
+ return result
353
+
354
+ @classmethod
355
+ def from_storage(
356
+ cls,
357
+ storage,
358
+ *,
359
+ shape: torch.Size | None = None,
360
+ dtype: torch.dtype | None = None,
361
+ device: torch.device | None = None,
362
+ index: IndexType | None = None,
363
+ filename: Path | str = None,
364
+ handler: _handler = None,
365
+ ):
366
+ if getattr(storage, "filename", None) is not None:
367
+ if filename is None:
368
+ filename = storage.filename
369
+ elif str(storage.filename) != str(filename):
370
+ raise RuntimeError(
371
+ "Providing a storage with an associated filename that differs from the filename argument is not permitted unless filename=None. "
372
+ f"Got filename={str(filename)}, storage.filename={str(storage.filename)}"
373
+ )
374
+ tensor = torch.tensor(storage, dtype=dtype, device=device)
375
+ if shape is not None:
376
+ if isinstance(shape, torch.Tensor):
377
+ func_offset_stride = getattr(
378
+ torch, "_nested_compute_contiguous_strides_offsets", None
379
+ )
380
+ if func_offset_stride is not None:
381
+ offsets_strides = func_offset_stride(shape)
382
+ else:
383
+ raise RuntimeError(
384
+ "The PyTorch version isn't compatible with memmap "
385
+ "nested tensors. Please upgrade to a more recent "
386
+ "version."
387
+ )
388
+ tensor = torch._nested_view_from_buffer(
389
+ tensor,
390
+ shape,
391
+ *offsets_strides,
392
+ )
393
+ else:
394
+ tensor = tensor.view(shape)
395
+
396
+ tensor = cls(tensor)
397
+ if filename is not None:
398
+ tensor.filename = filename
399
+ elif handler is not None:
400
+ tensor._handler = handler
401
+ if index is not None:
402
+ return tensor[index]
403
+ return tensor
404
+
405
+ @property
406
+ def filename(self):
407
+ """The filename of the tensor, if it has one.
408
+
409
+ Raises an exception otherwise.
410
+ """
411
+ filename = self._filename
412
+ if filename is None:
413
+ raise RuntimeError("The MemoryMappedTensor has no file associated.")
414
+ return filename
415
+
416
+ @filename.setter
417
+ def filename(self, value):
418
+ if value is None and self._filename is None:
419
+ return
420
+ value = str(Path(value).absolute())
421
+ if self._filename is not None and value != self._filename:
422
+ raise RuntimeError(
423
+ "the MemoryMappedTensor has already a filename associated."
424
+ )
425
+ self._filename = value
426
+
427
+ @classmethod
428
+ def empty_like(cls, input, *, filename=None):
429
+ # noqa: D417
430
+ """Creates a tensor with no content but the same shape and dtype as the input tensor.
431
+
432
+ Args:
433
+ input (torch.Tensor): the tensor to use as an example.
434
+
435
+ Keyword Args:
436
+ filename (path or equivalent): the path to the file, if any. If none
437
+ is provided, a handler is used.
438
+ """
439
+ return cls.from_tensor(
440
+ torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
441
+ filename=filename,
442
+ copy_data=False,
443
+ )
444
+
445
+ @classmethod
446
+ def full_like(cls, input, fill_value, *, filename=None):
447
+ # noqa: D417
448
+ """Creates a tensor with a single content indicated by the `fill_value` argument, but the same shape and dtype as the input tensor.
449
+
450
+ Args:
451
+ input (torch.Tensor): the tensor to use as an example.
452
+ fill_value (float or equivalent): content of the tensor.
453
+
454
+ Keyword Args:
455
+ filename (path or equivalent): the path to the file, if any. If none
456
+ is provided, a handler is used.
457
+ """
458
+ return cls.from_tensor(
459
+ torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
460
+ filename=filename,
461
+ copy_data=False,
462
+ ).fill_(fill_value)
463
+
464
+ @classmethod
465
+ def zeros_like(cls, input, *, filename=None):
466
+ # noqa: D417
467
+ """Creates a tensor with a 0-filled content, but the same shape and dtype as the input tensor.
468
+
469
+ Args:
470
+ input (torch.Tensor): the tensor to use as an example.
471
+
472
+ Keyword Args:
473
+ filename (path or equivalent): the path to the file, if any. If none
474
+ is provided, a handler is used.
475
+ """
476
+ return cls.from_tensor(
477
+ torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
478
+ filename=filename,
479
+ copy_data=False,
480
+ ).fill_(0.0)
481
+
482
+ @classmethod
483
+ def ones_like(cls, input, *, filename=None):
484
+ # noqa: D417
485
+ """Creates a tensor with a 1-filled content, but the same shape and dtype as the input tensor.
486
+
487
+ Args:
488
+ input (torch.Tensor): the tensor to use as an example.
489
+
490
+ Keyword Args:
491
+ filename (path or equivalent): the path to the file, if any. If none
492
+ is provided, a handler is used.
493
+ """
494
+ return cls.from_tensor(
495
+ torch.ones((), dtype=input.dtype, device=input.device).expand_as(input),
496
+ filename=filename,
497
+ copy_data=False,
498
+ ).fill_(1.0)
499
+
500
+ @classmethod
501
+ @overload
502
+ def ones(cls, *size, dtype=None, device=None, filename=None): ...
503
+
504
+ @classmethod
505
+ @overload
506
+ def ones(cls, shape, *, dtype=None, device=None, filename=None): ...
507
+
508
+ @classmethod
509
+ def ones(cls, *args, **kwargs):
510
+ # noqa: D417
511
+ """Creates a tensor with a 1-filled content, specific shape, dtype and filename.
512
+
513
+ Args:
514
+ shape (integers or torch.Size): the shape of the tensor.
515
+
516
+ Keyword Args:
517
+ dtype (torch.dtype): the dtype of the tensor.
518
+ device (torch.device): the device of the tensor. Only `None` and `"cpu"`
519
+ are accepted, any other device will raise an exception.
520
+ filename (path or equivalent): the path to the file, if any. If none
521
+ is provided, a handler is used.
522
+ existsok (bool, optional): whether it is ok to overwrite an existing file.
523
+ Defaults to ``False``.
524
+ """
525
+ shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
526
+ if device is not None:
527
+ device = torch.device(device)
528
+ if device.type != "cpu":
529
+ raise RuntimeError("Only CPU tensors are supported.")
530
+ result = torch.ones((), dtype=dtype, device=device)
531
+ if isinstance(shape, torch.Tensor):
532
+ return cls.empty(
533
+ shape, device=device, dtype=dtype, filename=filename
534
+ ).fill_(1)
535
+ if shape:
536
+ if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
537
+ shape = torch.Size(shape[0])
538
+ else:
539
+ shape = torch.Size(shape)
540
+ result = result.expand(shape)
541
+ return cls.from_tensor(
542
+ result,
543
+ filename=filename,
544
+ existsok=kwargs.pop("existsok", False),
545
+ )
546
+
547
+ @classmethod
548
+ @overload
549
+ def zeros(cls, *size, dtype=None, device=None, filename=None): ...
550
+
551
+ @classmethod
552
+ @overload
553
+ def zeros(cls, shape, *, dtype=None, device=None, filename=None): ...
554
+
555
+ @classmethod
556
+ def zeros(cls, *args, **kwargs):
557
+ # noqa: D417
558
+ """Creates a tensor with a 0-filled content, specific shape, dtype and filename.
559
+
560
+ Args:
561
+ shape (integers or torch.Size): the shape of the tensor.
562
+
563
+ Keyword Args:
564
+ dtype (torch.dtype): the dtype of the tensor.
565
+ device (torch.device): the device of the tensor. Only `None` and `"cpu"`
566
+ are accepted, any other device will raise an exception.
567
+ filename (path or equivalent): the path to the file, if any. If none
568
+ is provided, a handler is used.
569
+ existsok (bool, optional): whether it is ok to overwrite an existing file.
570
+ Defaults to ``False``.
571
+ """
572
+ shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
573
+ if device is not None:
574
+ device = torch.device(device)
575
+ if device.type != "cpu":
576
+ raise RuntimeError("Only CPU tensors are supported.")
577
+ if isinstance(shape, torch.Tensor):
578
+ return cls.empty(
579
+ shape, device=device, dtype=dtype, filename=filename
580
+ ).fill_(0)
581
+ result = torch.zeros((), dtype=dtype, device=device)
582
+ if shape:
583
+ if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
584
+ shape = torch.Size(shape[0])
585
+ else:
586
+ shape = torch.Size(shape)
587
+ result = result.expand(shape)
588
+ result = cls.from_tensor(
589
+ result,
590
+ filename=filename,
591
+ existsok=kwargs.pop("existsok", False),
592
+ )
593
+ return result
594
+
595
+ @classmethod
596
+ @overload
597
+ def empty(cls, *size, dtype=None, device=None, filename=None): ...
598
+
599
+ @classmethod
600
+ @overload
601
+ def empty(cls, shape, *, dtype=None, device=None, filename=None): ...
602
+
603
+ @classmethod
604
+ def empty(cls, *args, **kwargs):
605
+ # noqa: D417
606
+ """Creates a tensor with empty content, specific shape, dtype and filename.
607
+
608
+ Args:
609
+ shape (integers or torch.Size): the shape of the tensor.
610
+
611
+ Keyword Args:
612
+ dtype (torch.dtype): the dtype of the tensor.
613
+ device (torch.device): the device of the tensor. Only `None` and `"cpu"`
614
+ are accepted, any other device will raise an exception.
615
+ filename (path or equivalent): the path to the file, if any. If none
616
+ is provided, a handler is used.
617
+ existsok (bool, optional): whether it is ok to overwrite an existing file.
618
+ Defaults to ``False``.
619
+ """
620
+ shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
621
+ if device is not None:
622
+ device = torch.device(device)
623
+ if device.type != "cpu":
624
+ raise RuntimeError("Only CPU tensors are supported.")
625
+ result = torch.zeros((), dtype=dtype, device=device)
626
+ if isinstance(shape, torch.Tensor):
627
+ # nested tensor
628
+ shape_numel = shape.prod(-1).sum()
629
+
630
+ if filename is None:
631
+ if dtype.is_floating_point:
632
+ size = torch.finfo(dtype).bits // 8 * shape_numel
633
+ elif dtype.is_complex:
634
+ raise ValueError(
635
+ "Complex-valued tensors are not supported by MemoryMappedTensor."
636
+ )
637
+ elif dtype == torch.bool:
638
+ size = shape_numel
639
+ else:
640
+ # assume integer
641
+ size = torch.iinfo(dtype).bits // 8 * shape_numel
642
+ handler = _FileHandler(size)
643
+
644
+ # buffer
645
+ func_offset_stride = getattr(
646
+ torch, "_nested_compute_contiguous_strides_offsets", None
647
+ )
648
+ if func_offset_stride is not None:
649
+ offsets_strides = func_offset_stride(shape)
650
+ else:
651
+ raise RuntimeError(NESTED_TENSOR_ERR)
652
+ result = torch.frombuffer(memoryview(handler.buffer), dtype=dtype)
653
+ result = torch._nested_view_from_buffer(
654
+ result,
655
+ shape,
656
+ *offsets_strides,
657
+ )
658
+ result = cls(result)
659
+ result._handler = handler
660
+ return result
661
+ else:
662
+ result = torch.from_file(
663
+ str(filename),
664
+ shared=True,
665
+ dtype=dtype,
666
+ size=shape_numel,
667
+ # needed when device ctx differs
668
+ device=torch.device("cpu"),
669
+ )
670
+ func_offset_stride = getattr(
671
+ torch, "_nested_compute_contiguous_strides_offsets", None
672
+ )
673
+ if func_offset_stride is not None:
674
+ offsets_strides = func_offset_stride(shape)
675
+ else:
676
+ raise RuntimeError(NESTED_TENSOR_ERR)
677
+ result = torch._nested_view_from_buffer(
678
+ result,
679
+ shape,
680
+ *offsets_strides,
681
+ )
682
+ result = cls(result)
683
+ result.filename = filename
684
+ return result
685
+ return result
686
+
687
+ if shape:
688
+ if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
689
+ shape = torch.Size(shape[0])
690
+ else:
691
+ shape = torch.Size(shape)
692
+ result = result.expand(shape)
693
+ result = cls.from_tensor(
694
+ result,
695
+ filename=filename,
696
+ copy_data=False,
697
+ existsok=kwargs.pop("existsok", False),
698
+ )
699
+ return result
700
+
701
+ @classmethod
702
+ def empty_nested(cls, *args, **kwargs):
703
+ # noqa: D417
704
+ """Creates a tensor with empty content, specific shape, dtype and filename.
705
+
706
+ Args:
707
+ shape (nested_shape): the shapes of the tensors.
708
+
709
+ Keyword Args:
710
+ dtype (torch.dtype): the dtype of the tensor.
711
+ device (torch.device): the device of the tensor. Only `None` and `"cpu"`
712
+ are accepted, any other device will raise an exception.
713
+ filename (path or equivalent): the path to the file, if any. If none
714
+ is provided, a handler is used.
715
+ existsok (bool, optional): whether it is ok to overwrite an existing file.
716
+ Defaults to ``False``.
717
+ """
718
+ shape = kwargs.pop("shape", args[0])
719
+ args = (torch.Size([]), *args)
720
+ _, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
721
+ if device is not None:
722
+ device = torch.device(device)
723
+ if device.type != "cpu":
724
+ raise RuntimeError("Only CPU tensors are supported.")
725
+ result = torch.zeros((), dtype=dtype, device=device)
726
+ if shape:
727
+ if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
728
+ shape = torch.Size(shape[0])
729
+ else:
730
+ shape = torch.Size(shape)
731
+ result = result.expand(shape)
732
+ result = cls.from_tensor(
733
+ result,
734
+ filename=filename,
735
+ copy_data=False,
736
+ existsok=kwargs.pop("existsok", False),
737
+ )
738
+ return result
739
+
740
+ @classmethod
741
+ @overload
742
+ def full(cls, *size, fill_value, dtype=None, device=None, filename=None): ...
743
+
744
+ @classmethod
745
+ @overload
746
+ def full(cls, shape, *, fill_value, dtype=None, device=None, filename=None): ...
747
+
748
+ @classmethod
749
+ def full(cls, *args, **kwargs):
750
+ # noqa: D417
751
+ """Creates a tensor with a single content specified by `fill_value`, specific shape, dtype and filename.
752
+
753
+ Args:
754
+ shape (integers or torch.Size): the shape of the tensor.
755
+
756
+ Keyword Args:
757
+ fill_value (float or equivalent): content of the tensor.
758
+ dtype (torch.dtype): the dtype of the tensor.
759
+ device (torch.device): the device of the tensor. Only `None` and `"cpu"`
760
+ are accepted, any other device will raise an exception.
761
+ filename (path or equivalent): the path to the file, if any. If none
762
+ is provided, a handler is used.
763
+ existsok (bool, optional): whether it is ok to overwrite an existing file.
764
+ Defaults to ``False``.
765
+ """
766
+ shape, device, dtype, fill_value, filename = _proc_args_const(*args, **kwargs)
767
+ if device is not None:
768
+ device = torch.device(device)
769
+ if device.type != "cpu":
770
+ raise RuntimeError("Only CPU tensors are supported.")
771
+ result = torch.zeros((), dtype=dtype, device=device).fill_(fill_value)
772
+ if shape:
773
+ if isinstance(shape[0], (list, tuple)) and len(shape) == 1:
774
+ shape = torch.Size(shape[0])
775
+ else:
776
+ shape = torch.Size(shape)
777
+ result = result.expand(shape)
778
+ return cls.from_tensor(
779
+ result, filename=filename, existsok=kwargs.pop("existsok", False)
780
+ )
781
+
782
+ @classmethod
783
+ def from_filename(cls, filename, dtype, shape, readonly : bool = False, index=None):
784
+ # noqa: D417
785
+ """Loads a MemoryMappedTensor from a given filename.
786
+
787
+ Args:
788
+ filename (path or equivalent): the path to the file.
789
+ dtype (torch.dtype): the dtype of the tensor.
790
+ shape (torch.Size or torch.Tensor): the shape of the tensor. If
791
+ a tensor is provided, it is assumed that the tensor is a nested_tensor
792
+ instance.
793
+ readonly (bool, optional): whether to open the file in read-only mode.
794
+ Defaults to ``False``.
795
+ index (torch-compatible index type): an index to use to build the
796
+ tensor.
797
+
798
+ """
799
+ writable = _is_writable(filename)
800
+
801
+ if isinstance(shape, torch.Tensor):
802
+ func_offset_stride = getattr(
803
+ torch, "_nested_compute_contiguous_strides_offsets", None
804
+ )
805
+ if func_offset_stride is not None:
806
+ offsets_strides = func_offset_stride(shape)
807
+ else:
808
+ raise RuntimeError(
809
+ "The PyTorch version isn't compatible with memmap "
810
+ "nested tensors. Please upgrade to a more recent "
811
+ "version."
812
+ )
813
+ tensor = torch.from_file(
814
+ str(filename),
815
+ shared=writable and not readonly,
816
+ dtype=dtype,
817
+ size=shape.prod(-1).sum().int(),
818
+ # needed when device ctx differs
819
+ device=torch.device("cpu"),
820
+ )
821
+ tensor = torch._nested_view_from_buffer(
822
+ tensor,
823
+ shape,
824
+ *offsets_strides,
825
+ )
826
+ else:
827
+ shape = torch.Size(shape)
828
+ # whether the file already existed
829
+ tensor = torch.from_file(
830
+ str(filename),
831
+ shared=writable and not readonly,
832
+ dtype=dtype,
833
+ size=shape.numel(),
834
+ # needed when device ctx differs
835
+ device=torch.device("cpu"),
836
+ )
837
+ tensor = tensor.view(shape)
838
+
839
+ if index is not None:
840
+ tensor = tensor[index]
841
+ out = cls(tensor)
842
+ out.filename = filename
843
+ out._handler = None
844
+ out.index = index
845
+ out.parent_shape = shape
846
+ return out
847
+
848
+ @classmethod
849
+ def from_handler(cls, handler, dtype, shape, index=None):
850
+ # noqa: D417
851
+ """Loads a MemoryMappedTensor from a given handler.
852
+
853
+ Args:
854
+ handler (compatible file handler): the handler for the tensor.
855
+ dtype (torch.dtype): the dtype of the tensor.
856
+ shape (torch.Size or torch.Tensor): the shape of the tensor. If
857
+ a tensor is provided, it is assumed that the tensor is a nested_tensor
858
+ instance.
859
+ index (torch-compatible index type, optional): an index to use to build the
860
+ tensor.
861
+
862
+ """
863
+ out = torch.frombuffer(memoryview(handler.buffer), dtype=dtype)
864
+ if isinstance(shape, torch.Tensor):
865
+ func_offset_stride = getattr(
866
+ torch, "_nested_compute_contiguous_strides_offsets", None
867
+ )
868
+ if func_offset_stride is not None:
869
+ offsets_strides = func_offset_stride(shape)
870
+ else:
871
+ raise RuntimeError(
872
+ "The PyTorch version isn't compatible with memmap "
873
+ "nested tensors. Please upgrade to a more recent "
874
+ "version."
875
+ )
876
+ out = torch._nested_view_from_buffer(
877
+ out,
878
+ shape,
879
+ *offsets_strides,
880
+ )
881
+ else:
882
+ shape = torch.Size(shape)
883
+ out = torch.reshape(out, shape)
884
+
885
+ if index is not None:
886
+ out = out[index]
887
+ out = cls(out)
888
+ out.filename = None
889
+ out._handler = handler
890
+ out.index = index
891
+ out.parent_shape = shape
892
+ return out
893
+
894
+ @property
895
+ def _tensor(self):
896
+ # for bc-compatibility with MemmapTensor, to be deprecated in v0.4
897
+ return self
898
+
899
+ def __setstate__(self, state):
900
+ if "filename" in state:
901
+ self.__dict__ = type(self).from_filename(**state).__dict__
902
+ else:
903
+ self.__dict__ = type(self).from_handler(**state).__dict__
904
+
905
+ def __getstate__(self):
906
+ if getattr(self, "_handler", None) is not None:
907
+ return {
908
+ "handler": self._handler,
909
+ "dtype": self.dtype,
910
+ "shape": list(self.parent_shape),
911
+ "index": self.index,
912
+ }
913
+ elif getattr(self, "_filename", None) is not None:
914
+ return {
915
+ "filename": self._filename,
916
+ "dtype": self.dtype,
917
+ "shape": self.parent_shape,
918
+ "index": self.index,
919
+ }
920
+ else:
921
+ raise RuntimeError("Could not find handler or filename.")
922
+
923
+ def __reduce_ex__(self, protocol):
924
+ return self.__reduce__()
925
+
926
+ def __reduce__(self):
927
+ if getattr(self, "_handler", None) is not None:
928
+ return type(self).from_handler, (
929
+ self._handler,
930
+ self.dtype,
931
+ self.parent_shape,
932
+ self.index,
933
+ )
934
+ elif getattr(self, "_filename", None) is not None:
935
+ return type(self).from_filename, (
936
+ self._filename,
937
+ self.dtype,
938
+ self.parent_shape,
939
+ self.index,
940
+ )
941
+ else:
942
+ raise RuntimeError("Could not find handler or filename.")
943
+
944
+ @implement_for("torch", "2.0", None)
945
+ def __getitem__(self, item: IndexType) -> Self | torch.Tensor:
946
+ try:
947
+ out = super().__getitem__(item)
948
+ except ValueError as err:
949
+ if "is unbound" in str(err):
950
+ raise ValueError(
951
+ "Using first class dimension indices with MemoryMappedTensor "
952
+ "isn't supported at the moment."
953
+ ) from err
954
+ raise
955
+ if out.untyped_storage().data_ptr() == self.untyped_storage().data_ptr():
956
+ out = self._index_wrap(out, item)
957
+ return out
958
+
959
+ @implement_for("torch", None, "2.0")
960
+ def __getitem__(self, item: IndexType) -> Self | torch.Tensor: # noqa: F811
961
+ try:
962
+ out = super().__getitem__(item)
963
+ except ValueError as err:
964
+ if "is unbound" in str(err):
965
+ raise ValueError(
966
+ "Using first class dimension indices with MemoryMappedTensor "
967
+ "isn't supported at the moment."
968
+ ) from err
969
+ raise
970
+ if out.storage().data_ptr() == self.storage().data_ptr():
971
+ out = self._index_wrap(out, item)
972
+ return out
973
+
974
+ def _index_wrap(self, tensor, item, check=False):
975
+ if check:
976
+ if tensor.untyped_storage().data_ptr() == self.untyped_storage().data_ptr():
977
+ return self._index_wrap(tensor, item)
978
+ return tensor
979
+ tensor = MemoryMappedTensor(tensor)
980
+ tensor._handler = getattr(self, "_handler", None)
981
+ tensor.filename = getattr(self, "_filename", None)
982
+ tensor.index = item
983
+ tensor.parent_shape = getattr(self, "parent_shape", None)
984
+ return tensor
985
+
986
+ def unbind(self, dim):
987
+ out = super().unbind(dim)
988
+ if dim < 0:
989
+ dim = self.ndim + dim
990
+ index_base = (slice(None),) * dim
991
+ return tuple(
992
+ self._index_wrap(_out, index_base + (i,)) for i, _out in enumerate(out)
993
+ )
994
+
995
+ def chunk(self, chunks, dim=0):
996
+ dim = _maybe_correct_neg_dim(dim, self.shape)
997
+ out = super().chunk(chunks, dim)
998
+ slices = []
999
+ i = 0
1000
+ for chunk in out:
1001
+ slices.append(
1002
+ tuple(slice(None) for _ in range(dim))
1003
+ + (slice(i, i + chunk.shape[dim]),)
1004
+ )
1005
+ i += chunk.shape[dim]
1006
+ return tuple(
1007
+ self._index_wrap(chunk, _slice, check=True)
1008
+ for chunk, _slice in _zip_strict(out, slices)
1009
+ )
1010
+
1011
+ def split(self, split_size, dim=0):
1012
+ dim = _maybe_correct_neg_dim(dim, self.shape)
1013
+ out = super().split(split_size, dim)
1014
+ slices = []
1015
+ i = 0
1016
+ for chunk in out:
1017
+ slices.append(
1018
+ tuple(slice(None) for _ in range(dim))
1019
+ + (slice(i, i + chunk.shape[dim]),)
1020
+ )
1021
+ i += chunk.shape[dim]
1022
+ return tuple(
1023
+ self._index_wrap(split, _slice, check=True)
1024
+ for split, _slice in _zip_strict(out, slices)
1025
+ )
1026
+
1027
+
1028
+ #####################
1029
+ # File handler
1030
+ # borrowed from mp.heap
1031
+
1032
+ if sys.platform == "win32":
1033
+ import _winapi
1034
+
1035
+ class _FileHandler:
1036
+ _rand = tempfile._RandomNameSequence()
1037
+
1038
+ def __init__(self, size):
1039
+ self.size = size
1040
+ for _ in range(100):
1041
+ name = "pym-%d-%s" % (os.getpid(), next(self._rand))
1042
+ buf = mmap.mmap(-1, size, tagname=name)
1043
+ if _winapi.GetLastError() == 0:
1044
+ break
1045
+ # We have reopened a preexisting mmap.
1046
+ buf.close()
1047
+ else:
1048
+ raise FileExistsError("Cannot find name for new mmap")
1049
+ self.name = name
1050
+ self.buffer = buf
1051
+ self._state = (self.size, self.name)
1052
+
1053
+ def __getstate__(self):
1054
+ from multiprocessing.context import assert_spawning
1055
+
1056
+ assert_spawning(self)
1057
+ return self._state
1058
+
1059
+ def __setstate__(self, state):
1060
+ self.size, self.name = self._state = state
1061
+ # Reopen existing mmap
1062
+ self.buffer = mmap.mmap(-1, self.size, tagname=self.name)
1063
+ # XXX Temporarily preventing buildbot failures while determining
1064
+ # XXX the correct long-term fix. See issue 23060
1065
+ # assert _winapi.GetLastError() == _winapi.ERROR_ALREADY_EXISTS
1066
+
1067
+ else:
1068
+
1069
+ class _FileHandler:
1070
+ if sys.platform == "linux":
1071
+ _dir_candidates = ["/dev/shm"]
1072
+ else:
1073
+ _dir_candidates = []
1074
+
1075
+ def __init__(self, size, fd=-1):
1076
+ self.size = size
1077
+ self.fd = fd
1078
+ if fd == -1:
1079
+ self.fd, name = tempfile.mkstemp(
1080
+ prefix="pym-%d-" % os.getpid(), dir=self._choose_dir(size)
1081
+ )
1082
+ os.unlink(name)
1083
+ util.Finalize(self, os.close, (self.fd,))
1084
+ os.ftruncate(self.fd, size)
1085
+ self.buffer = mmap.mmap(self.fd, self.size)
1086
+
1087
+ def _choose_dir(self, size):
1088
+ # Choose a non-storage backed directory if possible,
1089
+ # to improve performance
1090
+ for d in self._dir_candidates:
1091
+ st = os.statvfs(d)
1092
+ if st.f_bavail * st.f_frsize >= size: # enough free space?
1093
+ return d
1094
+ return util.get_temp_dir()
1095
+
1096
+ def _reduce_handler(handler):
1097
+ if handler.fd == -1:
1098
+ raise ValueError(
1099
+ "Handler is unpicklable because "
1100
+ "forking was enabled when it was created"
1101
+ )
1102
+ return _rebuild_handler, (handler.size, reduction.DupFd(handler.fd))
1103
+
1104
+ def _rebuild_handler(size, dupfd):
1105
+ detached = dupfd.detach()
1106
+ return _FileHandler(size, detached)
1107
+
1108
+ reduction.register(_FileHandler, _reduce_handler)
1109
+
1110
+
1111
+ def _reduce_memmap(memmap_tensor):
1112
+ return memmap_tensor.__reduce__()
1113
+
1114
+
1115
+ reduction.register(MemoryMappedTensor, _reduce_memmap)
1116
+
1117
+
1118
+ def _proc_args_const(*args, **kwargs):
1119
+ if len(args) > 0:
1120
+ # then the first (or the N first) args are the shape
1121
+ if len(args) == 1 and isinstance(args[0], torch.Tensor):
1122
+ shape = args[0]
1123
+ elif len(args) == 1 and not isinstance(args[0], int):
1124
+ shape = torch.Size(args[0])
1125
+ else:
1126
+ shape = torch.Size(args)
1127
+ else:
1128
+ # we should have a "shape" keyword arg
1129
+ shape = kwargs.pop("shape", None)
1130
+ if shape is None:
1131
+ raise TypeError("Could not find the shape argument in the arguments.")
1132
+ if not isinstance(shape, torch.Tensor):
1133
+ shape = torch.Size(shape)
1134
+ return (
1135
+ shape,
1136
+ kwargs.pop("device", None),
1137
+ kwargs.pop("dtype", None),
1138
+ kwargs.pop("fill_value", None),
1139
+ kwargs.pop("filename", None),
1140
+ )
1141
+
1142
+
1143
+ # Torch functions
1144
+
1145
+ MEMMAP_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
1146
+
1147
+
1148
+ def implements_for_memmap(torch_function: Callable) -> Callable[[Callable], Callable]:
1149
+ """Register a torch function override for MemoryMappedTensor."""
1150
+
1151
+ @functools.wraps(torch_function)
1152
+ def decorator(func: Callable) -> Callable:
1153
+ MEMMAP_HANDLED_FUNCTIONS[torch_function] = func
1154
+ return func
1155
+
1156
+ return decorator
1157
+
1158
+
1159
+ @implements_for_memmap(torch.unbind)
1160
+ def _unbind(tensor, dim):
1161
+ return tensor.unbind(dim)
1162
+
1163
+
1164
+ @implements_for_memmap(torch.chunk)
1165
+ def _chunk(input, chunks, dim=0):
1166
+ return input.chunk(chunks, dim=dim)
1167
+
1168
+
1169
+ def _is_writable(file_path):
1170
+ file_path = str(file_path)
1171
+ if os.path.exists(file_path):
1172
+ return os.access(file_path, os.W_OK)
1173
+ # Assume that the file can be written in the directory
1174
+ return True