unienv 0.0.1b3__tar.gz → 0.0.1b4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. {unienv-0.0.1b3/unienv.egg-info → unienv-0.0.1b4}/PKG-INFO +1 -1
  2. {unienv-0.0.1b3 → unienv-0.0.1b4}/pyproject.toml +1 -1
  3. {unienv-0.0.1b3 → unienv-0.0.1b4/unienv.egg-info}/PKG-INFO +1 -1
  4. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv.egg-info/SOURCES.txt +2 -1
  5. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/base/common.py +16 -6
  6. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/base/storage.py +11 -3
  7. unienv-0.0.1b4/unienv_data/storages/dict_storage.py +341 -0
  8. unienv-0.0.1b3/unienv_data/storages/common.py → unienv-0.0.1b4/unienv_data/storages/flattened.py +19 -5
  9. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/storages/hdf5.py +42 -3
  10. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/storages/pytorch.py +26 -5
  11. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/storages/transformation.py +0 -2
  12. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/world/funcnode.py +1 -1
  13. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/world/node.py +2 -2
  14. {unienv-0.0.1b3 → unienv-0.0.1b4}/LICENSE +0 -0
  15. {unienv-0.0.1b3 → unienv-0.0.1b4}/README.md +0 -0
  16. {unienv-0.0.1b3 → unienv-0.0.1b4}/setup.cfg +0 -0
  17. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv.egg-info/dependency_links.txt +0 -0
  18. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv.egg-info/requires.txt +0 -0
  19. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv.egg-info/top_level.txt +0 -0
  20. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/__init__.py +0 -0
  21. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/base/__init__.py +0 -0
  22. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/batches/__init__.py +0 -0
  23. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/batches/backend_compat.py +0 -0
  24. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/batches/combined_batch.py +0 -0
  25. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/batches/framestack_batch.py +0 -0
  26. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/batches/slicestack_batch.py +0 -0
  27. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/batches/transformations.py +0 -0
  28. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/integrations/pytorch.py +0 -0
  29. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/replay_buffer/__init__.py +0 -0
  30. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/replay_buffer/replay_buffer.py +0 -0
  31. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/replay_buffer/trajectory_replay_buffer.py +0 -0
  32. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/samplers/__init__.py +0 -0
  33. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/samplers/multiprocessing_sampler.py +0 -0
  34. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/samplers/step_sampler.py +0 -0
  35. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/transformations/image_compress.py +0 -0
  36. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/__init__.py +0 -0
  37. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/backends/__init__.py +0 -0
  38. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/backends/base.py +0 -0
  39. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/backends/jax.py +0 -0
  40. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/backends/numpy.py +0 -0
  41. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/backends/pytorch.py +0 -0
  42. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/backends/serialization.py +0 -0
  43. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/env_base/__init__.py +0 -0
  44. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/env_base/env.py +0 -0
  45. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/env_base/funcenv.py +0 -0
  46. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/env_base/funcenv_wrapper.py +0 -0
  47. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/env_base/vec_env.py +0 -0
  48. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/env_base/wrapper.py +0 -0
  49. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/func_wrapper/__init__.py +0 -0
  50. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/func_wrapper/frame_stack.py +0 -0
  51. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/func_wrapper/transformation.py +0 -0
  52. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/__init__.py +0 -0
  53. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/space.py +0 -0
  54. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/space_utils/__init__.py +0 -0
  55. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/space_utils/batch_utils.py +0 -0
  56. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/space_utils/construct_utils.py +0 -0
  57. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/space_utils/flatten_utils.py +0 -0
  58. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/space_utils/gym_utils.py +0 -0
  59. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/space_utils/serialization_utils.py +0 -0
  60. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/__init__.py +0 -0
  61. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/batched.py +0 -0
  62. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/binary.py +0 -0
  63. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/box.py +0 -0
  64. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/dict.py +0 -0
  65. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/dynamic_box.py +0 -0
  66. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/graph.py +0 -0
  67. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/text.py +0 -0
  68. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/tuple.py +0 -0
  69. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/union.py +0 -0
  70. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/transformations/__init__.py +0 -0
  71. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/transformations/batch_and_unbatch.py +0 -0
  72. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/transformations/chained_transform.py +0 -0
  73. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/transformations/dict_transform.py +0 -0
  74. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/transformations/filter_dict.py +0 -0
  75. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/transformations/rescale.py +0 -0
  76. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/transformations/transformation.py +0 -0
  77. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/utils/control_util.py +0 -0
  78. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/utils/data_queue.py +0 -0
  79. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/utils/seed_util.py +0 -0
  80. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/utils/stateclass.py +0 -0
  81. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/utils/symbol_util.py +0 -0
  82. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/utils/vec_util.py +0 -0
  83. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/world/__init__.py +0 -0
  84. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/world/combined_funcnode.py +0 -0
  85. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/world/combined_node.py +0 -0
  86. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/world/funcworld.py +0 -0
  87. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/world/world.py +0 -0
  88. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/__init__.py +0 -0
  89. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/action_rescale.py +0 -0
  90. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/backend_compat.py +0 -0
  91. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/batch_and_unbatch.py +0 -0
  92. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/control_frequency_limit.py +0 -0
  93. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/flatten.py +0 -0
  94. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/frame_stack.py +0 -0
  95. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/gym_compat.py +0 -0
  96. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/time_limit.py +0 -0
  97. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/transformation.py +0 -0
  98. {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/video_record.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: unienv
3
- Version: 0.0.1b3
3
+ Version: 0.0.1b4
4
4
  Summary: Unified robot environment framework supporting multiple tensor and simulation backends
5
5
  License-Expression: MIT
6
6
  Project-URL: Homepage, https://github.com/UniEnvOrg/UniEnv
@@ -3,7 +3,7 @@ name = "unienv"
3
3
  description = "Unified robot environment framework supporting multiple tensor and simulation backends"
4
4
  readme = "README.md"
5
5
  license = "MIT"
6
- version = "0.0.1b3"
6
+ version = "0.0.1b4"
7
7
  requires-python = ">= 3.10"
8
8
  dependencies = [
9
9
  "numpy",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: unienv
3
- Version: 0.0.1b3
3
+ Version: 0.0.1b4
4
4
  Summary: Unified robot environment framework supporting multiple tensor and simulation backends
5
5
  License-Expression: MIT
6
6
  Project-URL: Homepage, https://github.com/UniEnvOrg/UniEnv
@@ -23,7 +23,8 @@ unienv_data/replay_buffer/trajectory_replay_buffer.py
23
23
  unienv_data/samplers/__init__.py
24
24
  unienv_data/samplers/multiprocessing_sampler.py
25
25
  unienv_data/samplers/step_sampler.py
26
- unienv_data/storages/common.py
26
+ unienv_data/storages/dict_storage.py
27
+ unienv_data/storages/flattened.py
27
28
  unienv_data/storages/hdf5.py
28
29
  unienv_data/storages/pytorch.py
29
30
  unienv_data/storages/transformation.py
@@ -135,12 +135,25 @@ class BatchBase(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BR
135
135
  flattened_data = space_flatten_utils.flatten_data(self._batched_space, value, start_dim=1)
136
136
  self.extend_flattened(flattened_data)
137
137
 
138
+ def extend_from(
139
+ self,
140
+ other : 'BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]',
141
+ chunk_size : int = 8,
142
+ tqdm : bool = False,
143
+ ) -> None:
144
+ n_total = len(other)
145
+ iterable_start = range(0, n_total, chunk_size)
146
+ if tqdm:
147
+ from tqdm import tqdm
148
+ iterable_start = tqdm(iterable_start, desc="Extending Batch")
149
+ for start_idx in range(0, n_total, chunk_size):
150
+ end_idx = min(start_idx + chunk_size, n_total)
151
+ data_chunk = other.get_at(slice(start_idx, end_idx))
152
+ self.extend(data_chunk)
153
+
138
154
  def close(self) -> None:
139
155
  pass
140
156
 
141
- def __del__(self) -> None:
142
- self.close()
143
-
144
157
  SamplerBatchT = TypeVar('SamplerBatchT')
145
158
  SamplerArrayType = TypeVar('SamplerArrayType')
146
159
  SamplerDeviceType = TypeVar('SamplerDeviceType')
@@ -273,6 +286,3 @@ class BatchSampler(
273
286
 
274
287
  def close(self) -> None:
275
288
  pass
276
-
277
- def __del__(self) -> None:
278
- self.close()
@@ -57,6 +57,17 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
57
57
  """
58
58
  cache_filename : Optional[Union[str, os.PathLike]] = None
59
59
 
60
+ """
61
+ Can the storage instance be safely used in multiprocessing environments after creation?
62
+ If True, the storage can be used in multiprocessing environments.
63
+ """
64
+ is_multiprocessing_safe : bool = False
65
+
66
+ """
67
+ Is the storage mutable? If False, the storage is read-only.
68
+ """
69
+ is_mutable : bool = True
70
+
60
71
  @property
61
72
  def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
62
73
  return self.single_instance_space.backend
@@ -128,6 +139,3 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
128
139
 
129
140
  def close(self) -> None:
130
141
  pass
131
-
132
- def __del__(self) -> None:
133
- self.close()
@@ -0,0 +1,341 @@
1
+ from importlib import metadata
2
+ from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Iterable, Type, Callable, Mapping
3
+
4
+ from unienv_interface.space import Space, DictSpace
5
+ from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
6
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
7
+ from unienv_interface.utils.symbol_util import *
8
+
9
+ from unienv_data.base import SpaceStorage, BatchT
10
+
11
+ import numpy as np
12
+ import os
13
+ import json
14
+
15
+ def map_transform(
16
+ data : Dict[str, Any],
17
+ value_map : Dict[str, Any],
18
+ fn : Callable[[str, Any, Any], Any], # (str, data, value_map) -> transformed data
19
+ prefix : str = "",
20
+ ) -> Tuple[
21
+ Dict[str, Any], # Transformed data
22
+ Dict[str, Any], # Residual data
23
+ ]:
24
+ transformed_data = {}
25
+ residual_data = {}
26
+ for key, value in data.items() if isinstance(data, Mapping) else data.spaces.items():
27
+ full_key = prefix + key
28
+ if full_key in value_map:
29
+ transformed_data[key] = fn(full_key, value, value_map[full_key])
30
+ elif isinstance(value, Mapping) or isinstance(value, DictSpace):
31
+ sub_transformed, sub_residual = map_transform(
32
+ value,
33
+ value_map,
34
+ fn,
35
+ prefix=full_key + "/",
36
+ )
37
+ if len(sub_transformed) > 0:
38
+ transformed_data[key] = sub_transformed
39
+ if len(sub_residual) > 0:
40
+ residual_data[key] = sub_residual
41
+ else:
42
+ residual_data[key] = value
43
+ if len(residual_data) > 0 and (prefix + "*") in value_map:
44
+ residual_transformed = fn(prefix + "*", residual_data, value_map[prefix + "*"])
45
+ if isinstance(residual_transformed, Mapping) or isinstance(residual_transformed, DictSpace):
46
+ for key, value in residual_transformed.items():
47
+ transformed_data[key] = value
48
+ residual_data = {}
49
+ return transformed_data, residual_data
50
+
51
+ def get_chained_residual_space(
52
+ space : DictSpace[BDeviceType, BDtypeType, BRNGType],
53
+ all_keys : List[str],
54
+ prefix : str = "",
55
+ ) -> DictSpace[BDeviceType, BDtypeType, BRNGType]:
56
+ residual_spaces = {}
57
+
58
+ if len(residual_spaces) > 0 and (prefix + "*") in all_keys:
59
+ return DictSpace(
60
+ space.backend,
61
+ {},
62
+ device=space.device,
63
+ )
64
+
65
+ for key, subspace in space.spaces.items():
66
+ full_key = prefix + key
67
+ if full_key in all_keys:
68
+ continue
69
+ elif isinstance(subspace, DictSpace):
70
+ sub_residual = get_chained_residual_space(
71
+ subspace,
72
+ all_keys,
73
+ prefix=full_key + "/",
74
+ )
75
+ if len(sub_residual.spaces) > 0:
76
+ residual_spaces[key] = sub_residual
77
+ else:
78
+ residual_spaces[key] = subspace
79
+
80
+ return DictSpace(
81
+ space.backend,
82
+ residual_spaces,
83
+ device=space.device,
84
+ )
85
+
86
+ def get_chained_space(
87
+ space : DictSpace[BDeviceType, BDtypeType, BRNGType],
88
+ key_chain : str,
89
+ all_keys : List[str],
90
+ ) -> Space[Any, BDeviceType, BDtypeType, BRNGType]:
91
+ if key_chain.endswith("*"):
92
+ prefix = key_chain[:-1]
93
+ subspace = get_chained_residual_space(
94
+ get_chained_space(
95
+ space,
96
+ prefix,
97
+ all_keys,
98
+ ) if len(prefix) > 0 else space,
99
+ [key for key in all_keys if key != key_chain],
100
+ prefix=prefix,
101
+ )
102
+ return subspace
103
+ key_chain = key_chain.split("/")
104
+ current_space : Space[Any, BDeviceType, BDtypeType, BRNGType]
105
+ current_space = space
106
+ for key in key_chain:
107
+ if len(key) == 0:
108
+ continue
109
+ assert isinstance(current_space, DictSpace), \
110
+ f"Expected DictSpace while traversing key chain, but got {type(current_space)}"
111
+ current_space = current_space.spaces[key]
112
+ return current_space
113
+
114
+ class DictStorage(SpaceStorage[
115
+ Dict[str, Any],
116
+ BArrayType,
117
+ BDeviceType,
118
+ BDtypeType,
119
+ BRNGType,
120
+ ]):
121
+ # ========== Class Attributes ==========
122
+ @classmethod
123
+ def create(
124
+ cls,
125
+ single_instance_space: Space[Any, BDeviceType, BDtypeType, BRNGType],
126
+ storage_cls_map : Dict[
127
+ str,
128
+ Type[SpaceStorage],
129
+ ],
130
+ *args,
131
+ capacity : Optional[int] = None,
132
+ cache_path : Optional[str] = None,
133
+ key_kwargs : Dict[str, Any] = {},
134
+ type_kwargs : Dict[Type[SpaceStorage[Any, BArrayType, BDeviceType, BDtypeType, BRNGType]], Dict[str, Any]] = {},
135
+ **kwargs
136
+ ) -> "DictStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
137
+ if cache_path is not None:
138
+ os.makedirs(cache_path, exist_ok=True)
139
+
140
+ storage_map = {}
141
+ all_keys = list(storage_cls_map.keys())
142
+ for key, sub_storage_cls in storage_cls_map.items():
143
+ sub_storage_path = key.replace("/", ".").replace("*", "_default") + (sub_storage_cls.single_file_ext or "")
144
+ subspace = get_chained_space(single_instance_space, key, all_keys)
145
+ sub_kwargs = kwargs.copy()
146
+ if sub_storage_cls in type_kwargs:
147
+ sub_kwargs.update(type_kwargs[sub_storage_cls])
148
+ if key in key_kwargs:
149
+ sub_kwargs.update(key_kwargs[key])
150
+ storage_map[key] = sub_storage_cls.create(
151
+ subspace,
152
+ *args,
153
+ cache_path=None if cache_path is None else os.path.join(cache_path, sub_storage_path),
154
+ capacity=capacity,
155
+ **sub_kwargs
156
+ )
157
+
158
+ return DictStorage(
159
+ single_instance_space,
160
+ storage_map,
161
+ cache_filename=cache_path,
162
+ )
163
+
164
+ @classmethod
165
+ def load_from(
166
+ cls,
167
+ path : Union[str, os.PathLike],
168
+ single_instance_space : Space[Any, BDeviceType, BDtypeType, BRNGType],
169
+ *,
170
+ capacity : Optional[int] = None,
171
+ read_only : bool = True,
172
+ key_kwargs : Dict[str, Any] = {},
173
+ type_kwargs : Dict[Type[SpaceStorage[Any, BArrayType, BDeviceType, BDtypeType, BRNGType]], Dict[str, Any]] = {},
174
+ **kwargs
175
+ ) -> "DictStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
176
+ metadata_path = os.path.join(path, "dict_storage_metadata.json")
177
+ assert os.path.exists(metadata_path), f"Metadata file {metadata_path} does not exist"
178
+ with open(metadata_path, "r") as f:
179
+ metadata = json.load(f)
180
+ assert metadata["storage_type"] == cls.__name__, \
181
+ f"Expected storage type {cls.__name__}, but found {metadata['storage_type']}"
182
+
183
+ storage_map_metadata = metadata["storage_map"]
184
+ storage_map = {}
185
+
186
+ all_keys = list(storage_map_metadata.keys())
187
+ for key, storage_meta in storage_map_metadata.items():
188
+ storage_cls : Type[SpaceStorage] = get_class_from_full_name(storage_meta["type"])
189
+ storage_path = storage_meta["path"]
190
+
191
+ subspace = get_chained_space(single_instance_space, key, all_keys)
192
+
193
+ sub_kwargs = kwargs.copy()
194
+ if storage_cls in type_kwargs:
195
+ sub_kwargs.update(type_kwargs[storage_cls])
196
+ if key in key_kwargs:
197
+ sub_kwargs.update(key_kwargs[key])
198
+ storage_map[key] = storage_cls.load_from(
199
+ os.path.join(path, storage_path),
200
+ subspace,
201
+ capacity=capacity,
202
+ read_only=read_only,
203
+ **sub_kwargs
204
+ )
205
+
206
+ return DictStorage(
207
+ single_instance_space,
208
+ storage_map,
209
+ cache_filename=path,
210
+ )
211
+
212
+ # ========== Instance Implementations ==========
213
+ single_file_ext = None
214
+
215
+ def __init__(
216
+ self,
217
+ single_instance_space: DictSpace[BDeviceType, BDtypeType, BRNGType],
218
+ storage_map : Dict[
219
+ str,
220
+ SpaceStorage[
221
+ BArrayType,
222
+ BArrayType,
223
+ BDeviceType,
224
+ BDtypeType,
225
+ BRNGType,
226
+ ],
227
+ ],
228
+ cache_filename: Optional[Union[str, os.PathLike]] = None,
229
+ ):
230
+ assert len(storage_map) > 0, "Storage map cannot be empty"
231
+ first_storage = next(iter(storage_map.values()))
232
+ init_capacity = first_storage.capacity
233
+ init_len = len(first_storage)
234
+ for key, storage in storage_map.items():
235
+ assert storage.capacity == init_capacity, \
236
+ f"All storages must have the same capacity, but storage {key} has capacity {storage.capacity} while first storage has capacity {init_capacity}"
237
+ assert len(storage) == init_len, \
238
+ f"All storages must have the same length, but storage {key} has length {len(storage)} while first storage has length {init_len}"
239
+
240
+ super().__init__(single_instance_space)
241
+ self._batched_instance_space = sbu.batch_space(single_instance_space, 1)
242
+ self.storage_map = storage_map
243
+ self._cache_filename = cache_filename if all(
244
+ storage.cache_filename is not None for storage in storage_map.values()
245
+ ) else None
246
+
247
+ @property
248
+ def cache_filename(self) -> Optional[Union[str, os.PathLike]]:
249
+ return self._cache_filename
250
+
251
+ @property
252
+ def is_mutable(self) -> bool:
253
+ return all(storage.is_mutable for storage in self.storage_map.values())
254
+
255
+ @property
256
+ def is_multiprocessing_safe(self) -> bool:
257
+ return all(storage.is_multiprocessing_safe for storage in self.storage_map.values())
258
+
259
+ @property
260
+ def capacity(self) -> Optional[int]:
261
+ return next(iter(self.storage_map.values())).capacity
262
+
263
+ def extend_length(self, length):
264
+ for storage in self.storage_map.values():
265
+ storage.extend_length(length)
266
+
267
+ def shrink_length(self, length):
268
+ for storage in self.storage_map.values():
269
+ storage.shrink_length(length)
270
+
271
+ def __len__(self):
272
+ return len(next(iter(self.storage_map.values())))
273
+
274
+ def get_flattened(self, index):
275
+ unflat_data = self.get(index)
276
+ if isinstance(index, int):
277
+ flat_data = sfu.flatten_data(self.single_instance_space, unflat_data)
278
+ else:
279
+ flat_data = sfu.flatten_data(self._batched_instance_space, unflat_data, start_dim=1)
280
+ return flat_data
281
+
282
+ def get(self, index):
283
+ result, residual = map_transform(
284
+ self.single_instance_space,
285
+ self.storage_map,
286
+ lambda key, space, storage: storage.get(index)
287
+ )
288
+ assert len(residual) == 0, f"Some spaces do not have corresponding storage: {residual}"
289
+ return result
290
+
291
+ def set_flattened(self, index, value):
292
+ if isinstance(index, int):
293
+ unflat_data = sfu.unflatten_data(self.single_instance_space, value)
294
+ else:
295
+ unflat_data = sfu.unflatten_data(self._batched_instance_space, value, start_dim=1)
296
+ self.set(index, unflat_data)
297
+
298
+ def set(self, index, value):
299
+ _, residual = map_transform(
300
+ value,
301
+ self.storage_map,
302
+ lambda key, data, storage: storage.set(index, data)
303
+ )
304
+ assert len(residual) == 0, f"Some spaces do not have corresponding storage: {residual}"
305
+
306
+ def get_subspace_by_key(
307
+ self,
308
+ key: str,
309
+ ) -> Space[Any, BDeviceType, BDtypeType, BRNGType]:
310
+ return get_chained_space(
311
+ self.single_instance_space,
312
+ key,
313
+ list(self.storage_map.keys()),
314
+ )
315
+
316
+ def clear(self):
317
+ for storage in self.storage_map.values():
318
+ storage.clear()
319
+
320
+ def dumps(self, path):
321
+ os.makedirs(path, exist_ok=True)
322
+
323
+ storage_map_metadata = {}
324
+ for key, storage in self.storage_map.items():
325
+ sub_storage_path = key.replace("/", ".").replace("*", "_default") + (storage.single_file_ext or "")
326
+ storage_map_metadata[key] = {
327
+ "type": get_full_class_name(type(storage)),
328
+ "path": sub_storage_path,
329
+ }
330
+ storage.dumps(os.path.join(path, sub_storage_path))
331
+
332
+ metadata = {
333
+ "storage_type": __class__.__name__,
334
+ "storage_map": storage_map_metadata,
335
+ }
336
+ with open(os.path.join(path, "dict_storage_metadata.json"), "w") as f:
337
+ json.dump(metadata, f)
338
+
339
+ def close(self):
340
+ for storage in self.storage_map.values():
341
+ storage.close()
@@ -3,9 +3,7 @@ from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequen
3
3
 
4
4
  from unienv_interface.space import Space, BoxSpace
5
5
  from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
6
- from unienv_interface.env_base.env import ContextType, ObsType, ActType
7
6
  from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
8
- from unienv_interface.backends.numpy import NumpyComputeBackend
9
7
  from unienv_interface.utils.symbol_util import *
10
8
 
11
9
  from unienv_data.base import SpaceStorage, BatchT
@@ -31,7 +29,7 @@ class FlattenedStorage(SpaceStorage[
31
29
  capacity : Optional[int] = None,
32
30
  cache_path : Optional[str] = None,
33
31
  **kwargs
34
- ) -> "FlattenedStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
32
+ ) -> "FlattenedStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
35
33
  flattened_space = sfu.flatten_space(single_instance_space)
36
34
  inner_storage_path = "inner_storage" + (inner_storage_cls.single_file_ext or "")
37
35
 
@@ -49,6 +47,7 @@ class FlattenedStorage(SpaceStorage[
49
47
  single_instance_space,
50
48
  inner_storage,
51
49
  inner_storage_path,
50
+ cache_filename=cache_path,
52
51
  )
53
52
 
54
53
  @classmethod
@@ -60,7 +59,7 @@ class FlattenedStorage(SpaceStorage[
60
59
  capacity : Optional[int] = None,
61
60
  read_only : bool = True,
62
61
  **kwargs
63
- ) -> "FlattenedStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
62
+ ) -> "FlattenedStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
64
63
  metadata_path = os.path.join(path, "flattened_metadata.json")
65
64
  assert os.path.exists(metadata_path), f"Metadata file {metadata_path} does not exist"
66
65
  with open(metadata_path, "r") as f:
@@ -81,6 +80,7 @@ class FlattenedStorage(SpaceStorage[
81
80
  single_instance_space,
82
81
  inner_storage,
83
82
  inner_storage_path,
83
+ cache_filename=path,
84
84
  )
85
85
 
86
86
  # ========== Instance Implementations ==========
@@ -97,6 +97,7 @@ class FlattenedStorage(SpaceStorage[
97
97
  BRNGType,
98
98
  ],
99
99
  inner_storage_path : Union[str, os.PathLike],
100
+ cache_filename : Optional[Union[str, os.PathLike]] = None,
100
101
  ):
101
102
  super().__init__(single_instance_space)
102
103
  assert inner_storage.backend == single_instance_space.backend, \
@@ -109,7 +110,20 @@ class FlattenedStorage(SpaceStorage[
109
110
  self._batched_instance_space = sbu.batch_space(single_instance_space, 1)
110
111
  self.inner_storage = inner_storage
111
112
  self.inner_storage_path = inner_storage_path
112
-
113
+ self._cache_filename = cache_filename
114
+
115
+ @property
116
+ def cache_filename(self) -> Optional[Union[str, os.PathLike]]:
117
+ return self._cache_filename if self.inner_storage.cache_filename is not None else None
118
+
119
+ @property
120
+ def is_mutable(self) -> bool:
121
+ return self.inner_storage.is_mutable
122
+
123
+ @property
124
+ def is_multiprocessing_safe(self) -> bool:
125
+ return self.inner_storage.is_multiprocessing_safe
126
+
113
127
  @property
114
128
  def capacity(self) -> Optional[int]:
115
129
  return self.inner_storage.capacity
@@ -2,7 +2,6 @@ from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequen
2
2
 
3
3
  from unienv_interface.space import Space, BoxSpace, DictSpace, TextSpace, BinarySpace
4
4
  from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
5
- from unienv_interface.env_base.env import ContextType, ObsType, ActType
6
5
  from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
7
6
  from unienv_interface.backends.numpy import NumpyComputeBackend, NumpyArrayType, NumpyDeviceType, NumpyDtypeType, NumpyRNGType
8
7
  from unienv_interface.utils.symbol_util import *
@@ -498,7 +497,7 @@ class HDF5Storage(SpaceStorage[
498
497
  capacity=capacity,
499
498
  reduce_io=reduce_io,
500
499
  )
501
-
500
+
502
501
  @classmethod
503
502
  def load_from(
504
503
  cls,
@@ -562,6 +561,20 @@ class HDF5Storage(SpaceStorage[
562
561
  assert self.capacity is None or self._len == self.capacity, \
563
562
  f"If the storage has a fixed capacity, the length must match the capacity. Expected {self.capacity}, got {self._len}"
564
563
 
564
+ @property
565
+ def is_mutable(self) -> bool:
566
+ return self.root.file.mode != 'r'
567
+
568
+ @property
569
+ def is_multiprocessing_safe(self) -> bool:
570
+ return not self.is_mutable
571
+
572
+ @property
573
+ def cache_filename(self) -> Optional[Union[str, os.PathLike]]:
574
+ if isinstance(self.root, h5py.File):
575
+ return self.root.filename
576
+ return None
577
+
565
578
  def extend_length(self, length):
566
579
  assert self.capacity is None, \
567
580
  "Cannot extend length of a storage with fixed capacity"
@@ -644,4 +657,30 @@ class HDF5Storage(SpaceStorage[
644
657
  def close(self):
645
658
  if isinstance(self.root, h5py.File):
646
659
  self.root.close()
647
- self.root = None
660
+ self.root = None
661
+
662
+ def __getstate__(self):
663
+ state = self.__dict__.copy()
664
+ if (self.root, h5py.File):
665
+ state['filename'] = self.root.filename
666
+ state['mode'] = self.root.file.mode
667
+ else:
668
+ state['filename'] = self.root.file.filename
669
+ state['mode'] = self.root.file.mode
670
+ state['full_name'] = self.root.name
671
+ del state['root']
672
+ return state
673
+
674
+ def __setstate__(self, state):
675
+ if 'filename' and 'mode' in state:
676
+ self.root = h5py.File(
677
+ state['filename'],
678
+ mode=state['mode']
679
+ )
680
+ if 'full_name' in state:
681
+ self.root = self.root[state['full_name']]
682
+ del state['full_name']
683
+
684
+ del state['filename']
685
+ del state['mode']
686
+ self.__dict__.update(state)
@@ -1,7 +1,6 @@
1
1
  import os
2
2
  import torch
3
3
  from unienv_interface.space import Space, BoxSpace
4
- from unienv_interface.env_base.env import ContextType, ObsType, ActType
5
4
  from unienv_interface.backends import ComputeBackend
6
5
  from unienv_interface.backends.pytorch import PyTorchComputeBackend, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType
7
6
  from unienv_data.base import SpaceStorage
@@ -24,6 +23,7 @@ class PytorchTensorStorage(SpaceStorage[
24
23
  is_memmap : bool = False,
25
24
  cache_path : Optional[str] = None,
26
25
  memmap_existok : bool = True,
26
+ multiprocessing : bool = False,
27
27
  ) -> "PytorchTensorStorage":
28
28
  assert single_instance_space.backend is PyTorchComputeBackend, \
29
29
  f"Single instance space must be of type PyTorchComputeBackend, got {single_instance_space.backend}"
@@ -54,8 +54,10 @@ class PytorchTensorStorage(SpaceStorage[
54
54
  dtype=single_instance_space.dtype,
55
55
  device=single_instance_space.device
56
56
  )
57
-
58
- return PytorchTensorStorage(single_instance_space, data)
57
+ if multiprocessing:
58
+ data = data.share_memory_()
59
+
60
+ return PytorchTensorStorage(single_instance_space, data, mutable=True)
59
61
 
60
62
  @classmethod
61
63
  def load_from(
@@ -66,11 +68,15 @@ class PytorchTensorStorage(SpaceStorage[
66
68
  is_memmap : bool = False,
67
69
  capacity : Optional[int] = None,
68
70
  read_only : bool = True,
71
+ multiprocessing : bool = False,
69
72
  ) -> "PytorchTensorStorage":
70
73
  assert single_instance_space.backend is PyTorchComputeBackend, "PytorchTensorStorage only supports PyTorch backend"
71
74
  assert capacity is not None, "Capacity must be specified when creating a new tensor"
72
75
  assert os.path.exists(path), "File does not exist"
73
76
 
77
+ if is_memmap and not read_only:
78
+ assert os.access(path, os.W_OK), "File is not writable, cannot open in read-write mode"
79
+
74
80
  target_shape = (capacity, *single_instance_space.shape)
75
81
  target_data = MemoryMappedTensor.from_filename(
76
82
  path,
@@ -88,11 +94,14 @@ class PytorchTensorStorage(SpaceStorage[
88
94
  dtype=single_instance_space.dtype,
89
95
  device=single_instance_space.device
90
96
  )
91
- data.copy_(target_data)
97
+ if multiprocessing:
98
+ data = data.share_memory_()
99
+ data = data.copy_(target_data)
92
100
 
93
101
  return PytorchTensorStorage(
94
102
  single_instance_space,
95
- data
103
+ data,
104
+ mutable=not read_only
96
105
  )
97
106
 
98
107
  # ========== Instance Implementations ==========
@@ -104,6 +113,7 @@ class PytorchTensorStorage(SpaceStorage[
104
113
  self,
105
114
  single_instance_space : BoxSpace[PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
106
115
  data : Union[torch.Tensor, MemoryMappedTensor],
116
+ mutable : bool = True,
107
117
  ):
108
118
  assert single_instance_space.shape == data.shape[1:], \
109
119
  f"Single instance space shape {single_instance_space.shape} does not match data shape {data.shape[1:]}"
@@ -111,6 +121,7 @@ class PytorchTensorStorage(SpaceStorage[
111
121
  single_instance_space
112
122
  )
113
123
  self.data = data
124
+ self._mutable = mutable
114
125
 
115
126
  @property
116
127
  def device(self) -> Optional[PyTorchDeviceType]:
@@ -122,6 +133,14 @@ class PytorchTensorStorage(SpaceStorage[
122
133
  return self.data.filename
123
134
  return None
124
135
 
136
+ @property
137
+ def is_mutable(self) -> bool:
138
+ return self._mutable
139
+
140
+ @property
141
+ def is_multiprocessing_safe(self) -> bool:
142
+ return self.data.is_shared()
143
+
125
144
  @property
126
145
  def capacity(self) -> int:
127
146
  return self.data.shape[0]
@@ -134,9 +153,11 @@ class PytorchTensorStorage(SpaceStorage[
134
153
  return self.data[index]
135
154
 
136
155
  def set(self, index : Union[int, slice, torch.Tensor], value : torch.Tensor) -> None:
156
+ assert self.is_mutable, "Storage is not mutable"
137
157
  self.data[index] = value
138
158
 
139
159
  def clear(self) -> None:
160
+ assert self.is_mutable, "Storage is not mutable"
140
161
  pass
141
162
 
142
163
  def dumps(self, path: Union[str, os.PathLike]) -> None:
@@ -3,9 +3,7 @@ from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequen
3
3
 
4
4
  from unienv_interface.space import Space, BoxSpace
5
5
  from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
6
- from unienv_interface.env_base.env import ContextType, ObsType, ActType
7
6
  from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
8
- from unienv_interface.backends.numpy import NumpyComputeBackend
9
7
  from unienv_interface.utils.symbol_util import *
10
8
  from unienv_interface.transformations import DataTransformation
11
9
 
@@ -21,7 +21,6 @@ class FuncWorldNode(ABC, Generic[
21
21
  """
22
22
 
23
23
  name : str
24
- world : FuncWorld[WorldStateT, BArrayType, BDeviceType, BDtypeType, BRNGType]
25
24
  control_timestep : Optional[float] = None
26
25
  context_space : Optional[Space[ContextType, BDeviceType, BDtypeType, BRNGType]] = None
27
26
  observation_space : Optional[Space[ObsType, BDeviceType, BDtypeType, BRNGType]] = None
@@ -29,6 +28,7 @@ class FuncWorldNode(ABC, Generic[
29
28
  has_reward : bool = False
30
29
  has_termination_signal : bool = False
31
30
  has_truncation_signal : bool = False
31
+ world : Optional[FuncWorld[WorldStateT, BArrayType, BDeviceType, BDtypeType, BRNGType]] = None
32
32
 
33
33
  @property
34
34
  def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
@@ -8,7 +8,7 @@ from .world import World
8
8
 
9
9
  class WorldNode(ABC, Generic[ContextType, ObsType, ActType, BArrayType, BDeviceType, BDtypeType, BRNGType]):
10
10
  """
11
- Each `WorldNode` in the simulated / real world will manage some aspect of the environment.
11
+ Each `WorldNode` in the simulated / real world will manage some aspect of the environment. This can include sensors, robots, or other entities that interact with the world.
12
12
  How the methods in this class will be called once environment resets:
13
13
  `World.reset(...)` -> `WorldNode.reset(...)` -> `WorldNode.after_reset(...)` -> `WorldNode.get_observation(...)` -> World can start stepping normally
14
14
  How the methods in this class will be called during a environment step:
@@ -16,7 +16,6 @@ class WorldNode(ABC, Generic[ContextType, ObsType, ActType, BArrayType, BDeviceT
16
16
  """
17
17
 
18
18
  name : str
19
- world : World[BArrayType, BDeviceType, BDtypeType, BRNGType]
20
19
  control_timestep : Optional[float] = None
21
20
  context_space : Optional[Space[ContextType, BDeviceType, BDtypeType, BRNGType]] = None
22
21
  observation_space : Optional[Space[ObsType, BDeviceType, BDtypeType, BRNGType]] = None
@@ -24,6 +23,7 @@ class WorldNode(ABC, Generic[ContextType, ObsType, ActType, BArrayType, BDeviceT
24
23
  has_reward : bool = False
25
24
  has_termination_signal : bool = False
26
25
  has_truncation_signal : bool = False
26
+ world : Optional[World[BArrayType, BDeviceType, BDtypeType, BRNGType]] = None
27
27
 
28
28
  @property
29
29
  def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
File without changes
File without changes
File without changes