unienv 0.0.1b3__tar.gz → 0.0.1b5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. {unienv-0.0.1b3/unienv.egg-info → unienv-0.0.1b5}/PKG-INFO +1 -1
  2. {unienv-0.0.1b3 → unienv-0.0.1b5}/pyproject.toml +1 -1
  3. {unienv-0.0.1b3 → unienv-0.0.1b5/unienv.egg-info}/PKG-INFO +1 -1
  4. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv.egg-info/SOURCES.txt +5 -2
  5. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/base/common.py +16 -6
  6. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/base/storage.py +13 -3
  7. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/batches/slicestack_batch.py +1 -0
  8. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/replay_buffer/replay_buffer.py +136 -65
  9. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/replay_buffer/trajectory_replay_buffer.py +230 -163
  10. unienv-0.0.1b5/unienv_data/storages/dict_storage.py +373 -0
  11. unienv-0.0.1b3/unienv_data/storages/common.py → unienv-0.0.1b5/unienv_data/storages/flattened.py +27 -6
  12. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/storages/hdf5.py +48 -3
  13. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/storages/pytorch.py +26 -5
  14. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/storages/transformation.py +16 -3
  15. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/transformations/image_compress.py +22 -9
  16. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/func_wrapper/frame_stack.py +1 -1
  17. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/space_utils/flatten_utils.py +8 -2
  18. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/spaces/tuple.py +4 -4
  19. unienv-0.0.1b5/unienv_interface/transformations/image_resize.py +106 -0
  20. unienv-0.0.1b5/unienv_interface/transformations/iter_transform.py +92 -0
  21. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/utils/symbol_util.py +7 -1
  22. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/world/funcnode.py +1 -1
  23. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/world/node.py +2 -2
  24. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/wrapper/frame_stack.py +1 -1
  25. {unienv-0.0.1b3 → unienv-0.0.1b5}/LICENSE +0 -0
  26. {unienv-0.0.1b3 → unienv-0.0.1b5}/README.md +0 -0
  27. {unienv-0.0.1b3 → unienv-0.0.1b5}/setup.cfg +0 -0
  28. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv.egg-info/dependency_links.txt +0 -0
  29. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv.egg-info/requires.txt +0 -0
  30. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv.egg-info/top_level.txt +0 -0
  31. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/__init__.py +0 -0
  32. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/base/__init__.py +0 -0
  33. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/batches/__init__.py +0 -0
  34. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/batches/backend_compat.py +0 -0
  35. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/batches/combined_batch.py +0 -0
  36. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/batches/framestack_batch.py +0 -0
  37. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/batches/transformations.py +0 -0
  38. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/integrations/pytorch.py +0 -0
  39. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/replay_buffer/__init__.py +0 -0
  40. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/samplers/__init__.py +0 -0
  41. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/samplers/multiprocessing_sampler.py +0 -0
  42. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_data/samplers/step_sampler.py +0 -0
  43. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/__init__.py +0 -0
  44. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/backends/__init__.py +0 -0
  45. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/backends/base.py +0 -0
  46. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/backends/jax.py +0 -0
  47. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/backends/numpy.py +0 -0
  48. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/backends/pytorch.py +0 -0
  49. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/backends/serialization.py +0 -0
  50. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/env_base/__init__.py +0 -0
  51. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/env_base/env.py +0 -0
  52. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/env_base/funcenv.py +0 -0
  53. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/env_base/funcenv_wrapper.py +0 -0
  54. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/env_base/vec_env.py +0 -0
  55. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/env_base/wrapper.py +0 -0
  56. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/func_wrapper/__init__.py +0 -0
  57. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/func_wrapper/transformation.py +0 -0
  58. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/__init__.py +0 -0
  59. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/space.py +0 -0
  60. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/space_utils/__init__.py +0 -0
  61. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/space_utils/batch_utils.py +0 -0
  62. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/space_utils/construct_utils.py +0 -0
  63. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/space_utils/gym_utils.py +0 -0
  64. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/space_utils/serialization_utils.py +0 -0
  65. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/spaces/__init__.py +0 -0
  66. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/spaces/batched.py +0 -0
  67. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/spaces/binary.py +0 -0
  68. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/spaces/box.py +0 -0
  69. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/spaces/dict.py +0 -0
  70. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/spaces/dynamic_box.py +0 -0
  71. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/spaces/graph.py +0 -0
  72. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/spaces/text.py +0 -0
  73. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/space/spaces/union.py +0 -0
  74. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/transformations/__init__.py +0 -0
  75. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/transformations/batch_and_unbatch.py +0 -0
  76. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/transformations/chained_transform.py +0 -0
  77. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/transformations/dict_transform.py +0 -0
  78. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/transformations/filter_dict.py +0 -0
  79. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/transformations/rescale.py +0 -0
  80. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/transformations/transformation.py +0 -0
  81. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/utils/control_util.py +0 -0
  82. /unienv-0.0.1b3/unienv_interface/utils/data_queue.py → /unienv-0.0.1b5/unienv_interface/utils/framestack_queue.py +0 -0
  83. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/utils/seed_util.py +0 -0
  84. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/utils/stateclass.py +0 -0
  85. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/utils/vec_util.py +0 -0
  86. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/world/__init__.py +0 -0
  87. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/world/combined_funcnode.py +0 -0
  88. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/world/combined_node.py +0 -0
  89. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/world/funcworld.py +0 -0
  90. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/world/world.py +0 -0
  91. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/wrapper/__init__.py +0 -0
  92. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/wrapper/action_rescale.py +0 -0
  93. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/wrapper/backend_compat.py +0 -0
  94. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/wrapper/batch_and_unbatch.py +0 -0
  95. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/wrapper/control_frequency_limit.py +0 -0
  96. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/wrapper/flatten.py +0 -0
  97. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/wrapper/gym_compat.py +0 -0
  98. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/wrapper/time_limit.py +0 -0
  99. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/wrapper/transformation.py +0 -0
  100. {unienv-0.0.1b3 → unienv-0.0.1b5}/unienv_interface/wrapper/video_record.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: unienv
3
- Version: 0.0.1b3
3
+ Version: 0.0.1b5
4
4
  Summary: Unified robot environment framework supporting multiple tensor and simulation backends
5
5
  License-Expression: MIT
6
6
  Project-URL: Homepage, https://github.com/UniEnvOrg/UniEnv
@@ -3,7 +3,7 @@ name = "unienv"
3
3
  description = "Unified robot environment framework supporting multiple tensor and simulation backends"
4
4
  readme = "README.md"
5
5
  license = "MIT"
6
- version = "0.0.1b3"
6
+ version = "0.0.1b5"
7
7
  requires-python = ">= 3.10"
8
8
  dependencies = [
9
9
  "numpy",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: unienv
3
- Version: 0.0.1b3
3
+ Version: 0.0.1b5
4
4
  Summary: Unified robot environment framework supporting multiple tensor and simulation backends
5
5
  License-Expression: MIT
6
6
  Project-URL: Homepage, https://github.com/UniEnvOrg/UniEnv
@@ -23,7 +23,8 @@ unienv_data/replay_buffer/trajectory_replay_buffer.py
23
23
  unienv_data/samplers/__init__.py
24
24
  unienv_data/samplers/multiprocessing_sampler.py
25
25
  unienv_data/samplers/step_sampler.py
26
- unienv_data/storages/common.py
26
+ unienv_data/storages/dict_storage.py
27
+ unienv_data/storages/flattened.py
27
28
  unienv_data/storages/hdf5.py
28
29
  unienv_data/storages/pytorch.py
29
30
  unienv_data/storages/transformation.py
@@ -67,10 +68,12 @@ unienv_interface/transformations/batch_and_unbatch.py
67
68
  unienv_interface/transformations/chained_transform.py
68
69
  unienv_interface/transformations/dict_transform.py
69
70
  unienv_interface/transformations/filter_dict.py
71
+ unienv_interface/transformations/image_resize.py
72
+ unienv_interface/transformations/iter_transform.py
70
73
  unienv_interface/transformations/rescale.py
71
74
  unienv_interface/transformations/transformation.py
72
75
  unienv_interface/utils/control_util.py
73
- unienv_interface/utils/data_queue.py
76
+ unienv_interface/utils/framestack_queue.py
74
77
  unienv_interface/utils/seed_util.py
75
78
  unienv_interface/utils/stateclass.py
76
79
  unienv_interface/utils/symbol_util.py
@@ -135,12 +135,25 @@ class BatchBase(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BR
135
135
  flattened_data = space_flatten_utils.flatten_data(self._batched_space, value, start_dim=1)
136
136
  self.extend_flattened(flattened_data)
137
137
 
138
+ def extend_from(
139
+ self,
140
+ other : 'BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]',
141
+ chunk_size : int = 8,
142
+ tqdm : bool = False,
143
+ ) -> None:
144
+ n_total = len(other)
145
+ iterable_start = range(0, n_total, chunk_size)
146
+ if tqdm:
147
+ from tqdm import tqdm
148
+ iterable_start = tqdm(iterable_start, desc="Extending Batch")
149
+ for start_idx in range(0, n_total, chunk_size):
150
+ end_idx = min(start_idx + chunk_size, n_total)
151
+ data_chunk = other.get_at(slice(start_idx, end_idx))
152
+ self.extend(data_chunk)
153
+
138
154
  def close(self) -> None:
139
155
  pass
140
156
 
141
- def __del__(self) -> None:
142
- self.close()
143
-
144
157
  SamplerBatchT = TypeVar('SamplerBatchT')
145
158
  SamplerArrayType = TypeVar('SamplerArrayType')
146
159
  SamplerDeviceType = TypeVar('SamplerDeviceType')
@@ -273,6 +286,3 @@ class BatchSampler(
273
286
 
274
287
  def close(self) -> None:
275
288
  pass
276
-
277
- def __del__(self) -> None:
278
- self.close()
@@ -20,6 +20,7 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
20
20
  *args,
21
21
  capacity : Optional[int],
22
22
  cache_path : Optional[Union[str, os.PathLike]] = None,
23
+ multiprocessing : bool = False,
23
24
  **kwargs
24
25
  ) -> "SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
25
26
  raise NotImplementedError
@@ -32,6 +33,7 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
32
33
  *,
33
34
  capacity : Optional[int] = None,
34
35
  read_only : bool = True,
36
+ multiprocessing : bool = False,
35
37
  **kwargs
36
38
  ) -> "SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
37
39
  raise NotImplementedError
@@ -57,6 +59,17 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
57
59
  """
58
60
  cache_filename : Optional[Union[str, os.PathLike]] = None
59
61
 
62
+ """
63
+ Can the storage instance be safely used in multiprocessing environments after creation?
64
+ If True, the storage can be used in multiprocessing environments.
65
+ """
66
+ is_multiprocessing_safe : bool = False
67
+
68
+ """
69
+ Is the storage mutable? If False, the storage is read-only.
70
+ """
71
+ is_mutable : bool = True
72
+
60
73
  @property
61
74
  def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
62
75
  return self.single_instance_space.backend
@@ -128,6 +141,3 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
128
141
 
129
142
  def close(self) -> None:
130
143
  pass
131
-
132
- def __del__(self) -> None:
133
- self.close()
@@ -33,6 +33,7 @@ class SliceStackedBatch(BatchBase[
33
33
  fill_invalid_data : bool = True,
34
34
  stack_metadata : bool = False,
35
35
  ):
36
+ assert batch.backend.dtype_is_real_integer(fixed_offset.dtype), "Fixed offset must be an integer tensor"
36
37
  assert len(fixed_offset.shape) == 1, "Fixed offset must be a 1D tensor"
37
38
  assert fixed_offset.shape[0] > 0, "Fixed offset must have a positive length"
38
39
  assert batch.backend.any(fixed_offset == 0), "There should be at least one zero in the fixed offset"
@@ -1,6 +1,9 @@
1
1
  import abc
2
2
  import os
3
3
  import dataclasses
4
+ import multiprocessing as mp
5
+ from contextlib import nullcontext
6
+
4
7
  from typing import Generic, TypeVar, Optional, Any, Dict, Union, Tuple, Sequence, Callable, Type
5
8
  from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
6
9
 
@@ -51,7 +54,6 @@ def index_with_offset(
51
54
  return data_index
52
55
 
53
56
  class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
54
- is_mutable = True
55
57
  # =========== Class Attributes ==========
56
58
  @staticmethod
57
59
  def create(
@@ -60,6 +62,7 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
60
62
  *args,
61
63
  cache_path : Optional[Union[str, os.PathLike]] = None,
62
64
  capacity : Optional[int] = None,
65
+ multiprocessing : bool = False,
63
66
  **kwargs
64
67
  ) -> "ReplayBuffer[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
65
68
  storage_path_relative = "storage" + (storage_cls.single_file_ext or "")
@@ -70,6 +73,7 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
70
73
  *args,
71
74
  cache_path=None if cache_path is None else os.path.join(cache_path, storage_path_relative),
72
75
  capacity=capacity,
76
+ multiprocessing=multiprocessing,
73
77
  **kwargs
74
78
  )
75
79
  return ReplayBuffer(
@@ -77,7 +81,8 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
77
81
  storage_path_relative,
78
82
  0,
79
83
  0,
80
- cache_path=cache_path
84
+ cache_path=cache_path,
85
+ multiprocessing=multiprocessing
81
86
  )
82
87
 
83
88
  @staticmethod
@@ -97,6 +102,7 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
97
102
  backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
98
103
  device: Optional[BDeviceType] = None,
99
104
  read_only : bool = True,
105
+ multiprocessing : bool = False,
100
106
  **storage_kwargs
101
107
  ) -> "ReplayBuffer[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
102
108
  with open(os.path.join(path, "metadata.json"), "r") as f:
@@ -118,52 +124,103 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
118
124
  single_instance_space,
119
125
  capacity=capacity,
120
126
  read_only=read_only,
127
+ multiprocessing=multiprocessing,
121
128
  **storage_kwargs
122
129
  )
123
- return ReplayBuffer(storage, metadata["storage_path_relative"], count, offset, cache_path=path)
130
+ return ReplayBuffer(
131
+ storage,
132
+ metadata["storage_path_relative"],
133
+ count,
134
+ offset,
135
+ cache_path=path,
136
+ multiprocessing=multiprocessing
137
+ )
124
138
 
125
139
  # =========== Instance Attributes and Methods ==========
126
140
  def dumps(self, path : Union[str, os.PathLike]):
127
- os.makedirs(path, exist_ok=True)
128
- storage_path = os.path.join(path, self.storage_path_relative)
129
- self.storage.dumps(storage_path)
130
- metadata = {
131
- "type": __class__.__name__,
132
- "count": self.count,
133
- "offset": self.offset,
134
- "capacity": self.storage.capacity,
135
- "storage_cls": get_full_class_name(type(self.storage)),
136
- "storage_path_relative": self.storage_path_relative,
137
- "single_instance_space": bsu.space_to_json(self.storage.single_instance_space),
138
- }
139
- with open(os.path.join(path, "metadata.json"), "w") as f:
140
- json.dump(metadata, f)
141
+ with self._lock_scope():
142
+ os.makedirs(path, exist_ok=True)
143
+ storage_path = os.path.join(path, self.storage_path_relative)
144
+ self.storage.dumps(storage_path)
145
+ metadata = {
146
+ "type": __class__.__name__,
147
+ "count": self.count,
148
+ "offset": self.offset,
149
+ "capacity": self.storage.capacity,
150
+ "storage_cls": get_full_class_name(type(self.storage)),
151
+ "storage_path_relative": self.storage_path_relative,
152
+ "single_instance_space": bsu.space_to_json(self.storage.single_instance_space),
153
+ }
154
+ with open(os.path.join(path, "metadata.json"), "w") as f:
155
+ json.dump(metadata, f)
141
156
 
142
157
  def __init__(
143
158
  self,
144
159
  storage : SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType],
145
- storage_path_relative : Union[str, os.PathLike],
160
+ storage_path_relative : str,
146
161
  count : int = 0,
147
162
  offset : int = 0,
148
163
  cache_path : Optional[Union[str, os.PathLike]] = None,
164
+ multiprocessing : bool = False,
149
165
  ):
150
166
  self.storage = storage
151
- self.count = count
152
- self.offset = offset
153
- self.storage_path_relative = storage_path_relative
167
+ self._storage_path_relative = storage_path_relative
154
168
  self._cache_path = cache_path
169
+ self._multiprocessing = multiprocessing
170
+ if multiprocessing:
171
+ assert storage.is_multiprocessing_safe, "Storage is not multiprocessing safe"
172
+ self._lock = mp.Lock()
173
+ self._count_value = mp.Value("q", int(count))
174
+ self._offset_value = mp.Value("q", int(offset))
175
+ else:
176
+ self._lock = None
177
+ self._count_value = int(count)
178
+ self._offset_value = int(offset)
179
+
155
180
  super().__init__(
156
181
  storage.single_instance_space,
157
182
  None
158
183
  )
159
184
 
185
+ def _lock_scope(self):
186
+ if self._lock is not None:
187
+ return self._lock
188
+ else:
189
+ return nullcontext()
190
+
160
191
  @property
161
192
  def cache_path(self) -> Optional[Union[str, os.PathLike]]:
162
193
  return self._cache_path
163
194
 
195
+ @property
196
+ def storage_path_relative(self) -> str:
197
+ return self._storage_path_relative
198
+
164
199
  def __len__(self) -> int:
165
200
  return self.count
201
+
202
+ @property
203
+ def count(self) -> int:
204
+ return self._count_value.value if self._multiprocessing else self._count_value
205
+
206
+ @count.setter
207
+ def count(self, value: int) -> None:
208
+ if self._multiprocessing:
209
+ self._count_value.value = int(value)
210
+ else:
211
+ self._count_value = int(value)
166
212
 
213
+ @property
214
+ def offset(self) -> int:
215
+ return self._offset_value.value if self._multiprocessing else self._offset_value
216
+
217
+ @offset.setter
218
+ def offset(self, value: int) -> None:
219
+ if self._multiprocessing:
220
+ self._offset_value.value = int(value)
221
+ else:
222
+ self._offset_value = int(value)
223
+
167
224
  @property
168
225
  def capacity(self) -> Optional[int]:
169
226
  return self.storage.capacity
@@ -176,12 +233,21 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
176
233
  def device(self) -> Optional[BDeviceType]:
177
234
  return self.storage.device
178
235
 
236
+ @property
237
+ def is_mutable(self) -> bool:
238
+ return self.storage.is_mutable
239
+
240
+ @property
241
+ def is_multiprocessing_safe(self) -> bool:
242
+ return self._multiprocessing
243
+
179
244
  def get_flattened_at(self, idx):
180
245
  return self.get_flattened_at_with_metadata(idx)[0]
181
246
 
182
247
  def get_flattened_at_with_metadata(self, idx: Union[IndexableType, BArrayType]) -> BArrayType:
183
248
  if hasattr(self.storage, "get_flattened"):
184
- data = self.storage.get_flattened(idx)
249
+ with self._lock_scope():
250
+ data = self.storage.get_flattened(idx)
185
251
  return data, None
186
252
 
187
253
  data, metadata = self.get_at_with_metadata(idx)
@@ -195,19 +261,21 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
195
261
  return self.get_at_with_metadata(idx)[0]
196
262
 
197
263
  def get_at_with_metadata(self, idx):
198
- data_index = index_with_offset(
199
- self.backend,
200
- idx,
201
- self.count,
202
- self.offset,
203
- self.device
204
- )
205
- data = self.storage.get(data_index)
264
+ with self._lock_scope():
265
+ data_index = index_with_offset(
266
+ self.backend,
267
+ idx,
268
+ self.count,
269
+ self.offset,
270
+ self.device
271
+ )
272
+ data = self.storage.get(data_index)
206
273
  return data, None
207
274
 
208
275
  def set_flattened_at(self, idx: Union[IndexableType, BArrayType], value: BArrayType) -> None:
209
276
  if hasattr(self.storage, "set_flattened"):
210
- self.storage.set_flattened(idx, value)
277
+ with self._lock_scope():
278
+ self.storage.set_flattened(idx, value)
211
279
  return
212
280
 
213
281
  if isinstance(idx, int):
@@ -217,13 +285,14 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
217
285
  self.set_at(idx, value)
218
286
 
219
287
  def set_at(self, idx, value):
220
- self.storage.set(index_with_offset(
221
- self.backend,
222
- idx,
223
- self.count,
224
- self.offset,
225
- self.device
226
- ), value)
288
+ with self._lock_scope():
289
+ self.storage.set(index_with_offset(
290
+ self.backend,
291
+ idx,
292
+ self.count,
293
+ self.offset,
294
+ self.device
295
+ ), value)
227
296
 
228
297
  def extend_flattened(
229
298
  self,
@@ -233,35 +302,37 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
233
302
  self.extend(unflattened_data)
234
303
 
235
304
  def extend(self, value):
236
- B = sbu.batch_size_data(value)
237
- if B == 0:
238
- return
239
- if self.capacity is None:
240
- assert self.offset == 0, "Offset must be 0 when capacity is None"
241
- self.storage.extend_length(B)
242
- self.storage.set(slice(self.count, self.count + B), value)
243
- self.count += B
244
- return
245
-
246
- # We have a fixed capacity, only keep the last `capacity` elements
247
- if B >= self.capacity:
248
- self.storage.set(Ellipsis, sbu.get_at(self._batched_space, value, slice(-self.capacity, None)))
249
- self.count = self.capacity
250
- self.offset = 0
251
- return
252
-
253
- # Otherwise, perform round-robin writes
254
- indexes = (self.backend.arange(B, device=self.device) + self.offset + self.count) % self.capacity
255
- self.storage.set(indexes, value)
256
- outflow = max(0, self.count + B - self.capacity)
257
- if outflow > 0:
258
- self.offset = (self.offset + outflow) % self.capacity
259
- self.count = min(self.count + B, self.capacity)
305
+ with self._lock_scope():
306
+ B = sbu.batch_size_data(value)
307
+ if B == 0:
308
+ return
309
+ if self.capacity is None:
310
+ assert self.offset == 0, "Offset must be 0 when capacity is None"
311
+ self.storage.extend_length(B)
312
+ self.storage.set(slice(self.count, self.count + B), value)
313
+ self.count += B
314
+ return
315
+
316
+ # We have a fixed capacity, only keep the last `capacity` elements
317
+ if B >= self.capacity:
318
+ self.storage.set(Ellipsis, sbu.get_at(self._batched_space, value, slice(-self.capacity, None)))
319
+ self.count = self.capacity
320
+ self.offset = 0
321
+ return
322
+
323
+ # Otherwise, perform round-robin writes
324
+ indexes = (self.backend.arange(B, device=self.device) + self.offset + self.count) % self.capacity
325
+ self.storage.set(indexes, value)
326
+ outflow = max(0, self.count + B - self.capacity)
327
+ if outflow > 0:
328
+ self.offset = (self.offset + outflow) % self.capacity
329
+ self.count = min(self.count + B, self.capacity)
260
330
 
261
331
  def clear(self):
262
- self.count = 0
263
- self.offset = 0
264
- self.storage.clear()
332
+ with self._lock_scope():
333
+ self.count = 0
334
+ self.offset = 0
335
+ self.storage.clear()
265
336
 
266
337
  def close(self) -> None:
267
338
  self.storage.close()