unienv 0.0.1b3__tar.gz → 0.0.1b4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {unienv-0.0.1b3/unienv.egg-info → unienv-0.0.1b4}/PKG-INFO +1 -1
- {unienv-0.0.1b3 → unienv-0.0.1b4}/pyproject.toml +1 -1
- {unienv-0.0.1b3 → unienv-0.0.1b4/unienv.egg-info}/PKG-INFO +1 -1
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv.egg-info/SOURCES.txt +2 -1
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/base/common.py +16 -6
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/base/storage.py +11 -3
- unienv-0.0.1b4/unienv_data/storages/dict_storage.py +341 -0
- unienv-0.0.1b3/unienv_data/storages/common.py → unienv-0.0.1b4/unienv_data/storages/flattened.py +19 -5
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/storages/hdf5.py +42 -3
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/storages/pytorch.py +26 -5
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/storages/transformation.py +0 -2
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/world/funcnode.py +1 -1
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/world/node.py +2 -2
- {unienv-0.0.1b3 → unienv-0.0.1b4}/LICENSE +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/README.md +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/setup.cfg +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv.egg-info/dependency_links.txt +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv.egg-info/requires.txt +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv.egg-info/top_level.txt +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/base/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/batches/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/batches/backend_compat.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/batches/combined_batch.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/batches/framestack_batch.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/batches/slicestack_batch.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/batches/transformations.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/integrations/pytorch.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/replay_buffer/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/replay_buffer/replay_buffer.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/replay_buffer/trajectory_replay_buffer.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/samplers/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/samplers/multiprocessing_sampler.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/samplers/step_sampler.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_data/transformations/image_compress.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/backends/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/backends/base.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/backends/jax.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/backends/numpy.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/backends/pytorch.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/backends/serialization.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/env_base/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/env_base/env.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/env_base/funcenv.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/env_base/funcenv_wrapper.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/env_base/vec_env.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/env_base/wrapper.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/func_wrapper/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/func_wrapper/frame_stack.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/func_wrapper/transformation.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/space.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/space_utils/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/space_utils/batch_utils.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/space_utils/construct_utils.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/space_utils/flatten_utils.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/space_utils/gym_utils.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/space_utils/serialization_utils.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/batched.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/binary.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/box.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/dict.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/dynamic_box.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/graph.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/text.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/tuple.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/space/spaces/union.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/transformations/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/transformations/batch_and_unbatch.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/transformations/chained_transform.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/transformations/dict_transform.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/transformations/filter_dict.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/transformations/rescale.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/transformations/transformation.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/utils/control_util.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/utils/data_queue.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/utils/seed_util.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/utils/stateclass.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/utils/symbol_util.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/utils/vec_util.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/world/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/world/combined_funcnode.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/world/combined_node.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/world/funcworld.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/world/world.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/__init__.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/action_rescale.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/backend_compat.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/batch_and_unbatch.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/control_frequency_limit.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/flatten.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/frame_stack.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/gym_compat.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/time_limit.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/transformation.py +0 -0
- {unienv-0.0.1b3 → unienv-0.0.1b4}/unienv_interface/wrapper/video_record.py +0 -0
|
@@ -23,7 +23,8 @@ unienv_data/replay_buffer/trajectory_replay_buffer.py
|
|
|
23
23
|
unienv_data/samplers/__init__.py
|
|
24
24
|
unienv_data/samplers/multiprocessing_sampler.py
|
|
25
25
|
unienv_data/samplers/step_sampler.py
|
|
26
|
-
unienv_data/storages/
|
|
26
|
+
unienv_data/storages/dict_storage.py
|
|
27
|
+
unienv_data/storages/flattened.py
|
|
27
28
|
unienv_data/storages/hdf5.py
|
|
28
29
|
unienv_data/storages/pytorch.py
|
|
29
30
|
unienv_data/storages/transformation.py
|
|
@@ -135,12 +135,25 @@ class BatchBase(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BR
|
|
|
135
135
|
flattened_data = space_flatten_utils.flatten_data(self._batched_space, value, start_dim=1)
|
|
136
136
|
self.extend_flattened(flattened_data)
|
|
137
137
|
|
|
138
|
+
def extend_from(
|
|
139
|
+
self,
|
|
140
|
+
other : 'BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]',
|
|
141
|
+
chunk_size : int = 8,
|
|
142
|
+
tqdm : bool = False,
|
|
143
|
+
) -> None:
|
|
144
|
+
n_total = len(other)
|
|
145
|
+
iterable_start = range(0, n_total, chunk_size)
|
|
146
|
+
if tqdm:
|
|
147
|
+
from tqdm import tqdm
|
|
148
|
+
iterable_start = tqdm(iterable_start, desc="Extending Batch")
|
|
149
|
+
for start_idx in range(0, n_total, chunk_size):
|
|
150
|
+
end_idx = min(start_idx + chunk_size, n_total)
|
|
151
|
+
data_chunk = other.get_at(slice(start_idx, end_idx))
|
|
152
|
+
self.extend(data_chunk)
|
|
153
|
+
|
|
138
154
|
def close(self) -> None:
|
|
139
155
|
pass
|
|
140
156
|
|
|
141
|
-
def __del__(self) -> None:
|
|
142
|
-
self.close()
|
|
143
|
-
|
|
144
157
|
SamplerBatchT = TypeVar('SamplerBatchT')
|
|
145
158
|
SamplerArrayType = TypeVar('SamplerArrayType')
|
|
146
159
|
SamplerDeviceType = TypeVar('SamplerDeviceType')
|
|
@@ -273,6 +286,3 @@ class BatchSampler(
|
|
|
273
286
|
|
|
274
287
|
def close(self) -> None:
|
|
275
288
|
pass
|
|
276
|
-
|
|
277
|
-
def __del__(self) -> None:
|
|
278
|
-
self.close()
|
|
@@ -57,6 +57,17 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
|
|
|
57
57
|
"""
|
|
58
58
|
cache_filename : Optional[Union[str, os.PathLike]] = None
|
|
59
59
|
|
|
60
|
+
"""
|
|
61
|
+
Can the storage instance be safely used in multiprocessing environments after creation?
|
|
62
|
+
If True, the storage can be used in multiprocessing environments.
|
|
63
|
+
"""
|
|
64
|
+
is_multiprocessing_safe : bool = False
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
Is the storage mutable? If False, the storage is read-only.
|
|
68
|
+
"""
|
|
69
|
+
is_mutable : bool = True
|
|
70
|
+
|
|
60
71
|
@property
|
|
61
72
|
def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
|
|
62
73
|
return self.single_instance_space.backend
|
|
@@ -128,6 +139,3 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
|
|
|
128
139
|
|
|
129
140
|
def close(self) -> None:
|
|
130
141
|
pass
|
|
131
|
-
|
|
132
|
-
def __del__(self) -> None:
|
|
133
|
-
self.close()
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
from importlib import metadata
|
|
2
|
+
from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Iterable, Type, Callable, Mapping
|
|
3
|
+
|
|
4
|
+
from unienv_interface.space import Space, DictSpace
|
|
5
|
+
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
6
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
7
|
+
from unienv_interface.utils.symbol_util import *
|
|
8
|
+
|
|
9
|
+
from unienv_data.base import SpaceStorage, BatchT
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import os
|
|
13
|
+
import json
|
|
14
|
+
|
|
15
|
+
def map_transform(
|
|
16
|
+
data : Dict[str, Any],
|
|
17
|
+
value_map : Dict[str, Any],
|
|
18
|
+
fn : Callable[[str, Any, Any], Any], # (str, data, value_map) -> transformed data
|
|
19
|
+
prefix : str = "",
|
|
20
|
+
) -> Tuple[
|
|
21
|
+
Dict[str, Any], # Transformed data
|
|
22
|
+
Dict[str, Any], # Residual data
|
|
23
|
+
]:
|
|
24
|
+
transformed_data = {}
|
|
25
|
+
residual_data = {}
|
|
26
|
+
for key, value in data.items() if isinstance(data, Mapping) else data.spaces.items():
|
|
27
|
+
full_key = prefix + key
|
|
28
|
+
if full_key in value_map:
|
|
29
|
+
transformed_data[key] = fn(full_key, value, value_map[full_key])
|
|
30
|
+
elif isinstance(value, Mapping) or isinstance(value, DictSpace):
|
|
31
|
+
sub_transformed, sub_residual = map_transform(
|
|
32
|
+
value,
|
|
33
|
+
value_map,
|
|
34
|
+
fn,
|
|
35
|
+
prefix=full_key + "/",
|
|
36
|
+
)
|
|
37
|
+
if len(sub_transformed) > 0:
|
|
38
|
+
transformed_data[key] = sub_transformed
|
|
39
|
+
if len(sub_residual) > 0:
|
|
40
|
+
residual_data[key] = sub_residual
|
|
41
|
+
else:
|
|
42
|
+
residual_data[key] = value
|
|
43
|
+
if len(residual_data) > 0 and (prefix + "*") in value_map:
|
|
44
|
+
residual_transformed = fn(prefix + "*", residual_data, value_map[prefix + "*"])
|
|
45
|
+
if isinstance(residual_transformed, Mapping) or isinstance(residual_transformed, DictSpace):
|
|
46
|
+
for key, value in residual_transformed.items():
|
|
47
|
+
transformed_data[key] = value
|
|
48
|
+
residual_data = {}
|
|
49
|
+
return transformed_data, residual_data
|
|
50
|
+
|
|
51
|
+
def get_chained_residual_space(
|
|
52
|
+
space : DictSpace[BDeviceType, BDtypeType, BRNGType],
|
|
53
|
+
all_keys : List[str],
|
|
54
|
+
prefix : str = "",
|
|
55
|
+
) -> DictSpace[BDeviceType, BDtypeType, BRNGType]:
|
|
56
|
+
residual_spaces = {}
|
|
57
|
+
|
|
58
|
+
if len(residual_spaces) > 0 and (prefix + "*") in all_keys:
|
|
59
|
+
return DictSpace(
|
|
60
|
+
space.backend,
|
|
61
|
+
{},
|
|
62
|
+
device=space.device,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
for key, subspace in space.spaces.items():
|
|
66
|
+
full_key = prefix + key
|
|
67
|
+
if full_key in all_keys:
|
|
68
|
+
continue
|
|
69
|
+
elif isinstance(subspace, DictSpace):
|
|
70
|
+
sub_residual = get_chained_residual_space(
|
|
71
|
+
subspace,
|
|
72
|
+
all_keys,
|
|
73
|
+
prefix=full_key + "/",
|
|
74
|
+
)
|
|
75
|
+
if len(sub_residual.spaces) > 0:
|
|
76
|
+
residual_spaces[key] = sub_residual
|
|
77
|
+
else:
|
|
78
|
+
residual_spaces[key] = subspace
|
|
79
|
+
|
|
80
|
+
return DictSpace(
|
|
81
|
+
space.backend,
|
|
82
|
+
residual_spaces,
|
|
83
|
+
device=space.device,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def get_chained_space(
|
|
87
|
+
space : DictSpace[BDeviceType, BDtypeType, BRNGType],
|
|
88
|
+
key_chain : str,
|
|
89
|
+
all_keys : List[str],
|
|
90
|
+
) -> Space[Any, BDeviceType, BDtypeType, BRNGType]:
|
|
91
|
+
if key_chain.endswith("*"):
|
|
92
|
+
prefix = key_chain[:-1]
|
|
93
|
+
subspace = get_chained_residual_space(
|
|
94
|
+
get_chained_space(
|
|
95
|
+
space,
|
|
96
|
+
prefix,
|
|
97
|
+
all_keys,
|
|
98
|
+
) if len(prefix) > 0 else space,
|
|
99
|
+
[key for key in all_keys if key != key_chain],
|
|
100
|
+
prefix=prefix,
|
|
101
|
+
)
|
|
102
|
+
return subspace
|
|
103
|
+
key_chain = key_chain.split("/")
|
|
104
|
+
current_space : Space[Any, BDeviceType, BDtypeType, BRNGType]
|
|
105
|
+
current_space = space
|
|
106
|
+
for key in key_chain:
|
|
107
|
+
if len(key) == 0:
|
|
108
|
+
continue
|
|
109
|
+
assert isinstance(current_space, DictSpace), \
|
|
110
|
+
f"Expected DictSpace while traversing key chain, but got {type(current_space)}"
|
|
111
|
+
current_space = current_space.spaces[key]
|
|
112
|
+
return current_space
|
|
113
|
+
|
|
114
|
+
class DictStorage(SpaceStorage[
|
|
115
|
+
Dict[str, Any],
|
|
116
|
+
BArrayType,
|
|
117
|
+
BDeviceType,
|
|
118
|
+
BDtypeType,
|
|
119
|
+
BRNGType,
|
|
120
|
+
]):
|
|
121
|
+
# ========== Class Attributes ==========
|
|
122
|
+
@classmethod
|
|
123
|
+
def create(
|
|
124
|
+
cls,
|
|
125
|
+
single_instance_space: Space[Any, BDeviceType, BDtypeType, BRNGType],
|
|
126
|
+
storage_cls_map : Dict[
|
|
127
|
+
str,
|
|
128
|
+
Type[SpaceStorage],
|
|
129
|
+
],
|
|
130
|
+
*args,
|
|
131
|
+
capacity : Optional[int] = None,
|
|
132
|
+
cache_path : Optional[str] = None,
|
|
133
|
+
key_kwargs : Dict[str, Any] = {},
|
|
134
|
+
type_kwargs : Dict[Type[SpaceStorage[Any, BArrayType, BDeviceType, BDtypeType, BRNGType]], Dict[str, Any]] = {},
|
|
135
|
+
**kwargs
|
|
136
|
+
) -> "DictStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
137
|
+
if cache_path is not None:
|
|
138
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
139
|
+
|
|
140
|
+
storage_map = {}
|
|
141
|
+
all_keys = list(storage_cls_map.keys())
|
|
142
|
+
for key, sub_storage_cls in storage_cls_map.items():
|
|
143
|
+
sub_storage_path = key.replace("/", ".").replace("*", "_default") + (sub_storage_cls.single_file_ext or "")
|
|
144
|
+
subspace = get_chained_space(single_instance_space, key, all_keys)
|
|
145
|
+
sub_kwargs = kwargs.copy()
|
|
146
|
+
if sub_storage_cls in type_kwargs:
|
|
147
|
+
sub_kwargs.update(type_kwargs[sub_storage_cls])
|
|
148
|
+
if key in key_kwargs:
|
|
149
|
+
sub_kwargs.update(key_kwargs[key])
|
|
150
|
+
storage_map[key] = sub_storage_cls.create(
|
|
151
|
+
subspace,
|
|
152
|
+
*args,
|
|
153
|
+
cache_path=None if cache_path is None else os.path.join(cache_path, sub_storage_path),
|
|
154
|
+
capacity=capacity,
|
|
155
|
+
**sub_kwargs
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return DictStorage(
|
|
159
|
+
single_instance_space,
|
|
160
|
+
storage_map,
|
|
161
|
+
cache_filename=cache_path,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
@classmethod
|
|
165
|
+
def load_from(
|
|
166
|
+
cls,
|
|
167
|
+
path : Union[str, os.PathLike],
|
|
168
|
+
single_instance_space : Space[Any, BDeviceType, BDtypeType, BRNGType],
|
|
169
|
+
*,
|
|
170
|
+
capacity : Optional[int] = None,
|
|
171
|
+
read_only : bool = True,
|
|
172
|
+
key_kwargs : Dict[str, Any] = {},
|
|
173
|
+
type_kwargs : Dict[Type[SpaceStorage[Any, BArrayType, BDeviceType, BDtypeType, BRNGType]], Dict[str, Any]] = {},
|
|
174
|
+
**kwargs
|
|
175
|
+
) -> "DictStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
176
|
+
metadata_path = os.path.join(path, "dict_storage_metadata.json")
|
|
177
|
+
assert os.path.exists(metadata_path), f"Metadata file {metadata_path} does not exist"
|
|
178
|
+
with open(metadata_path, "r") as f:
|
|
179
|
+
metadata = json.load(f)
|
|
180
|
+
assert metadata["storage_type"] == cls.__name__, \
|
|
181
|
+
f"Expected storage type {cls.__name__}, but found {metadata['storage_type']}"
|
|
182
|
+
|
|
183
|
+
storage_map_metadata = metadata["storage_map"]
|
|
184
|
+
storage_map = {}
|
|
185
|
+
|
|
186
|
+
all_keys = list(storage_map_metadata.keys())
|
|
187
|
+
for key, storage_meta in storage_map_metadata.items():
|
|
188
|
+
storage_cls : Type[SpaceStorage] = get_class_from_full_name(storage_meta["type"])
|
|
189
|
+
storage_path = storage_meta["path"]
|
|
190
|
+
|
|
191
|
+
subspace = get_chained_space(single_instance_space, key, all_keys)
|
|
192
|
+
|
|
193
|
+
sub_kwargs = kwargs.copy()
|
|
194
|
+
if storage_cls in type_kwargs:
|
|
195
|
+
sub_kwargs.update(type_kwargs[storage_cls])
|
|
196
|
+
if key in key_kwargs:
|
|
197
|
+
sub_kwargs.update(key_kwargs[key])
|
|
198
|
+
storage_map[key] = storage_cls.load_from(
|
|
199
|
+
os.path.join(path, storage_path),
|
|
200
|
+
subspace,
|
|
201
|
+
capacity=capacity,
|
|
202
|
+
read_only=read_only,
|
|
203
|
+
**sub_kwargs
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
return DictStorage(
|
|
207
|
+
single_instance_space,
|
|
208
|
+
storage_map,
|
|
209
|
+
cache_filename=path,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# ========== Instance Implementations ==========
|
|
213
|
+
single_file_ext = None
|
|
214
|
+
|
|
215
|
+
def __init__(
|
|
216
|
+
self,
|
|
217
|
+
single_instance_space: DictSpace[BDeviceType, BDtypeType, BRNGType],
|
|
218
|
+
storage_map : Dict[
|
|
219
|
+
str,
|
|
220
|
+
SpaceStorage[
|
|
221
|
+
BArrayType,
|
|
222
|
+
BArrayType,
|
|
223
|
+
BDeviceType,
|
|
224
|
+
BDtypeType,
|
|
225
|
+
BRNGType,
|
|
226
|
+
],
|
|
227
|
+
],
|
|
228
|
+
cache_filename: Optional[Union[str, os.PathLike]] = None,
|
|
229
|
+
):
|
|
230
|
+
assert len(storage_map) > 0, "Storage map cannot be empty"
|
|
231
|
+
first_storage = next(iter(storage_map.values()))
|
|
232
|
+
init_capacity = first_storage.capacity
|
|
233
|
+
init_len = len(first_storage)
|
|
234
|
+
for key, storage in storage_map.items():
|
|
235
|
+
assert storage.capacity == init_capacity, \
|
|
236
|
+
f"All storages must have the same capacity, but storage {key} has capacity {storage.capacity} while first storage has capacity {init_capacity}"
|
|
237
|
+
assert len(storage) == init_len, \
|
|
238
|
+
f"All storages must have the same length, but storage {key} has length {len(storage)} while first storage has length {init_len}"
|
|
239
|
+
|
|
240
|
+
super().__init__(single_instance_space)
|
|
241
|
+
self._batched_instance_space = sbu.batch_space(single_instance_space, 1)
|
|
242
|
+
self.storage_map = storage_map
|
|
243
|
+
self._cache_filename = cache_filename if all(
|
|
244
|
+
storage.cache_filename is not None for storage in storage_map.values()
|
|
245
|
+
) else None
|
|
246
|
+
|
|
247
|
+
@property
|
|
248
|
+
def cache_filename(self) -> Optional[Union[str, os.PathLike]]:
|
|
249
|
+
return self._cache_filename
|
|
250
|
+
|
|
251
|
+
@property
|
|
252
|
+
def is_mutable(self) -> bool:
|
|
253
|
+
return all(storage.is_mutable for storage in self.storage_map.values())
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def is_multiprocessing_safe(self) -> bool:
|
|
257
|
+
return all(storage.is_multiprocessing_safe for storage in self.storage_map.values())
|
|
258
|
+
|
|
259
|
+
@property
|
|
260
|
+
def capacity(self) -> Optional[int]:
|
|
261
|
+
return next(iter(self.storage_map.values())).capacity
|
|
262
|
+
|
|
263
|
+
def extend_length(self, length):
|
|
264
|
+
for storage in self.storage_map.values():
|
|
265
|
+
storage.extend_length(length)
|
|
266
|
+
|
|
267
|
+
def shrink_length(self, length):
|
|
268
|
+
for storage in self.storage_map.values():
|
|
269
|
+
storage.shrink_length(length)
|
|
270
|
+
|
|
271
|
+
def __len__(self):
|
|
272
|
+
return len(next(iter(self.storage_map.values())))
|
|
273
|
+
|
|
274
|
+
def get_flattened(self, index):
|
|
275
|
+
unflat_data = self.get(index)
|
|
276
|
+
if isinstance(index, int):
|
|
277
|
+
flat_data = sfu.flatten_data(self.single_instance_space, unflat_data)
|
|
278
|
+
else:
|
|
279
|
+
flat_data = sfu.flatten_data(self._batched_instance_space, unflat_data, start_dim=1)
|
|
280
|
+
return flat_data
|
|
281
|
+
|
|
282
|
+
def get(self, index):
|
|
283
|
+
result, residual = map_transform(
|
|
284
|
+
self.single_instance_space,
|
|
285
|
+
self.storage_map,
|
|
286
|
+
lambda key, space, storage: storage.get(index)
|
|
287
|
+
)
|
|
288
|
+
assert len(residual) == 0, f"Some spaces do not have corresponding storage: {residual}"
|
|
289
|
+
return result
|
|
290
|
+
|
|
291
|
+
def set_flattened(self, index, value):
|
|
292
|
+
if isinstance(index, int):
|
|
293
|
+
unflat_data = sfu.unflatten_data(self.single_instance_space, value)
|
|
294
|
+
else:
|
|
295
|
+
unflat_data = sfu.unflatten_data(self._batched_instance_space, value, start_dim=1)
|
|
296
|
+
self.set(index, unflat_data)
|
|
297
|
+
|
|
298
|
+
def set(self, index, value):
|
|
299
|
+
_, residual = map_transform(
|
|
300
|
+
value,
|
|
301
|
+
self.storage_map,
|
|
302
|
+
lambda key, data, storage: storage.set(index, data)
|
|
303
|
+
)
|
|
304
|
+
assert len(residual) == 0, f"Some spaces do not have corresponding storage: {residual}"
|
|
305
|
+
|
|
306
|
+
def get_subspace_by_key(
|
|
307
|
+
self,
|
|
308
|
+
key: str,
|
|
309
|
+
) -> Space[Any, BDeviceType, BDtypeType, BRNGType]:
|
|
310
|
+
return get_chained_space(
|
|
311
|
+
self.single_instance_space,
|
|
312
|
+
key,
|
|
313
|
+
list(self.storage_map.keys()),
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
def clear(self):
|
|
317
|
+
for storage in self.storage_map.values():
|
|
318
|
+
storage.clear()
|
|
319
|
+
|
|
320
|
+
def dumps(self, path):
|
|
321
|
+
os.makedirs(path, exist_ok=True)
|
|
322
|
+
|
|
323
|
+
storage_map_metadata = {}
|
|
324
|
+
for key, storage in self.storage_map.items():
|
|
325
|
+
sub_storage_path = key.replace("/", ".").replace("*", "_default") + (storage.single_file_ext or "")
|
|
326
|
+
storage_map_metadata[key] = {
|
|
327
|
+
"type": get_full_class_name(type(storage)),
|
|
328
|
+
"path": sub_storage_path,
|
|
329
|
+
}
|
|
330
|
+
storage.dumps(os.path.join(path, sub_storage_path))
|
|
331
|
+
|
|
332
|
+
metadata = {
|
|
333
|
+
"storage_type": __class__.__name__,
|
|
334
|
+
"storage_map": storage_map_metadata,
|
|
335
|
+
}
|
|
336
|
+
with open(os.path.join(path, "dict_storage_metadata.json"), "w") as f:
|
|
337
|
+
json.dump(metadata, f)
|
|
338
|
+
|
|
339
|
+
def close(self):
|
|
340
|
+
for storage in self.storage_map.values():
|
|
341
|
+
storage.close()
|
unienv-0.0.1b3/unienv_data/storages/common.py → unienv-0.0.1b4/unienv_data/storages/flattened.py
RENAMED
|
@@ -3,9 +3,7 @@ from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequen
|
|
|
3
3
|
|
|
4
4
|
from unienv_interface.space import Space, BoxSpace
|
|
5
5
|
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
6
|
-
from unienv_interface.env_base.env import ContextType, ObsType, ActType
|
|
7
6
|
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
8
|
-
from unienv_interface.backends.numpy import NumpyComputeBackend
|
|
9
7
|
from unienv_interface.utils.symbol_util import *
|
|
10
8
|
|
|
11
9
|
from unienv_data.base import SpaceStorage, BatchT
|
|
@@ -31,7 +29,7 @@ class FlattenedStorage(SpaceStorage[
|
|
|
31
29
|
capacity : Optional[int] = None,
|
|
32
30
|
cache_path : Optional[str] = None,
|
|
33
31
|
**kwargs
|
|
34
|
-
) -> "FlattenedStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
32
|
+
) -> "FlattenedStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
35
33
|
flattened_space = sfu.flatten_space(single_instance_space)
|
|
36
34
|
inner_storage_path = "inner_storage" + (inner_storage_cls.single_file_ext or "")
|
|
37
35
|
|
|
@@ -49,6 +47,7 @@ class FlattenedStorage(SpaceStorage[
|
|
|
49
47
|
single_instance_space,
|
|
50
48
|
inner_storage,
|
|
51
49
|
inner_storage_path,
|
|
50
|
+
cache_filename=cache_path,
|
|
52
51
|
)
|
|
53
52
|
|
|
54
53
|
@classmethod
|
|
@@ -60,7 +59,7 @@ class FlattenedStorage(SpaceStorage[
|
|
|
60
59
|
capacity : Optional[int] = None,
|
|
61
60
|
read_only : bool = True,
|
|
62
61
|
**kwargs
|
|
63
|
-
) -> "FlattenedStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
62
|
+
) -> "FlattenedStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
64
63
|
metadata_path = os.path.join(path, "flattened_metadata.json")
|
|
65
64
|
assert os.path.exists(metadata_path), f"Metadata file {metadata_path} does not exist"
|
|
66
65
|
with open(metadata_path, "r") as f:
|
|
@@ -81,6 +80,7 @@ class FlattenedStorage(SpaceStorage[
|
|
|
81
80
|
single_instance_space,
|
|
82
81
|
inner_storage,
|
|
83
82
|
inner_storage_path,
|
|
83
|
+
cache_filename=path,
|
|
84
84
|
)
|
|
85
85
|
|
|
86
86
|
# ========== Instance Implementations ==========
|
|
@@ -97,6 +97,7 @@ class FlattenedStorage(SpaceStorage[
|
|
|
97
97
|
BRNGType,
|
|
98
98
|
],
|
|
99
99
|
inner_storage_path : Union[str, os.PathLike],
|
|
100
|
+
cache_filename : Optional[Union[str, os.PathLike]] = None,
|
|
100
101
|
):
|
|
101
102
|
super().__init__(single_instance_space)
|
|
102
103
|
assert inner_storage.backend == single_instance_space.backend, \
|
|
@@ -109,7 +110,20 @@ class FlattenedStorage(SpaceStorage[
|
|
|
109
110
|
self._batched_instance_space = sbu.batch_space(single_instance_space, 1)
|
|
110
111
|
self.inner_storage = inner_storage
|
|
111
112
|
self.inner_storage_path = inner_storage_path
|
|
112
|
-
|
|
113
|
+
self._cache_filename = cache_filename
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def cache_filename(self) -> Optional[Union[str, os.PathLike]]:
|
|
117
|
+
return self._cache_filename if self.inner_storage.cache_filename is not None else None
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def is_mutable(self) -> bool:
|
|
121
|
+
return self.inner_storage.is_mutable
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def is_multiprocessing_safe(self) -> bool:
|
|
125
|
+
return self.inner_storage.is_multiprocessing_safe
|
|
126
|
+
|
|
113
127
|
@property
|
|
114
128
|
def capacity(self) -> Optional[int]:
|
|
115
129
|
return self.inner_storage.capacity
|
|
@@ -2,7 +2,6 @@ from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequen
|
|
|
2
2
|
|
|
3
3
|
from unienv_interface.space import Space, BoxSpace, DictSpace, TextSpace, BinarySpace
|
|
4
4
|
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
5
|
-
from unienv_interface.env_base.env import ContextType, ObsType, ActType
|
|
6
5
|
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
7
6
|
from unienv_interface.backends.numpy import NumpyComputeBackend, NumpyArrayType, NumpyDeviceType, NumpyDtypeType, NumpyRNGType
|
|
8
7
|
from unienv_interface.utils.symbol_util import *
|
|
@@ -498,7 +497,7 @@ class HDF5Storage(SpaceStorage[
|
|
|
498
497
|
capacity=capacity,
|
|
499
498
|
reduce_io=reduce_io,
|
|
500
499
|
)
|
|
501
|
-
|
|
500
|
+
|
|
502
501
|
@classmethod
|
|
503
502
|
def load_from(
|
|
504
503
|
cls,
|
|
@@ -562,6 +561,20 @@ class HDF5Storage(SpaceStorage[
|
|
|
562
561
|
assert self.capacity is None or self._len == self.capacity, \
|
|
563
562
|
f"If the storage has a fixed capacity, the length must match the capacity. Expected {self.capacity}, got {self._len}"
|
|
564
563
|
|
|
564
|
+
@property
|
|
565
|
+
def is_mutable(self) -> bool:
|
|
566
|
+
return self.root.file.mode != 'r'
|
|
567
|
+
|
|
568
|
+
@property
|
|
569
|
+
def is_multiprocessing_safe(self) -> bool:
|
|
570
|
+
return not self.is_mutable
|
|
571
|
+
|
|
572
|
+
@property
|
|
573
|
+
def cache_filename(self) -> Optional[Union[str, os.PathLike]]:
|
|
574
|
+
if isinstance(self.root, h5py.File):
|
|
575
|
+
return self.root.filename
|
|
576
|
+
return None
|
|
577
|
+
|
|
565
578
|
def extend_length(self, length):
|
|
566
579
|
assert self.capacity is None, \
|
|
567
580
|
"Cannot extend length of a storage with fixed capacity"
|
|
@@ -644,4 +657,30 @@ class HDF5Storage(SpaceStorage[
|
|
|
644
657
|
def close(self):
|
|
645
658
|
if isinstance(self.root, h5py.File):
|
|
646
659
|
self.root.close()
|
|
647
|
-
self.root = None
|
|
660
|
+
self.root = None
|
|
661
|
+
|
|
662
|
+
def __getstate__(self):
|
|
663
|
+
state = self.__dict__.copy()
|
|
664
|
+
if (self.root, h5py.File):
|
|
665
|
+
state['filename'] = self.root.filename
|
|
666
|
+
state['mode'] = self.root.file.mode
|
|
667
|
+
else:
|
|
668
|
+
state['filename'] = self.root.file.filename
|
|
669
|
+
state['mode'] = self.root.file.mode
|
|
670
|
+
state['full_name'] = self.root.name
|
|
671
|
+
del state['root']
|
|
672
|
+
return state
|
|
673
|
+
|
|
674
|
+
def __setstate__(self, state):
|
|
675
|
+
if 'filename' and 'mode' in state:
|
|
676
|
+
self.root = h5py.File(
|
|
677
|
+
state['filename'],
|
|
678
|
+
mode=state['mode']
|
|
679
|
+
)
|
|
680
|
+
if 'full_name' in state:
|
|
681
|
+
self.root = self.root[state['full_name']]
|
|
682
|
+
del state['full_name']
|
|
683
|
+
|
|
684
|
+
del state['filename']
|
|
685
|
+
del state['mode']
|
|
686
|
+
self.__dict__.update(state)
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import torch
|
|
3
3
|
from unienv_interface.space import Space, BoxSpace
|
|
4
|
-
from unienv_interface.env_base.env import ContextType, ObsType, ActType
|
|
5
4
|
from unienv_interface.backends import ComputeBackend
|
|
6
5
|
from unienv_interface.backends.pytorch import PyTorchComputeBackend, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType
|
|
7
6
|
from unienv_data.base import SpaceStorage
|
|
@@ -24,6 +23,7 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
24
23
|
is_memmap : bool = False,
|
|
25
24
|
cache_path : Optional[str] = None,
|
|
26
25
|
memmap_existok : bool = True,
|
|
26
|
+
multiprocessing : bool = False,
|
|
27
27
|
) -> "PytorchTensorStorage":
|
|
28
28
|
assert single_instance_space.backend is PyTorchComputeBackend, \
|
|
29
29
|
f"Single instance space must be of type PyTorchComputeBackend, got {single_instance_space.backend}"
|
|
@@ -54,8 +54,10 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
54
54
|
dtype=single_instance_space.dtype,
|
|
55
55
|
device=single_instance_space.device
|
|
56
56
|
)
|
|
57
|
-
|
|
58
|
-
|
|
57
|
+
if multiprocessing:
|
|
58
|
+
data = data.share_memory_()
|
|
59
|
+
|
|
60
|
+
return PytorchTensorStorage(single_instance_space, data, mutable=True)
|
|
59
61
|
|
|
60
62
|
@classmethod
|
|
61
63
|
def load_from(
|
|
@@ -66,11 +68,15 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
66
68
|
is_memmap : bool = False,
|
|
67
69
|
capacity : Optional[int] = None,
|
|
68
70
|
read_only : bool = True,
|
|
71
|
+
multiprocessing : bool = False,
|
|
69
72
|
) -> "PytorchTensorStorage":
|
|
70
73
|
assert single_instance_space.backend is PyTorchComputeBackend, "PytorchTensorStorage only supports PyTorch backend"
|
|
71
74
|
assert capacity is not None, "Capacity must be specified when creating a new tensor"
|
|
72
75
|
assert os.path.exists(path), "File does not exist"
|
|
73
76
|
|
|
77
|
+
if is_memmap and not read_only:
|
|
78
|
+
assert os.access(path, os.W_OK), "File is not writable, cannot open in read-write mode"
|
|
79
|
+
|
|
74
80
|
target_shape = (capacity, *single_instance_space.shape)
|
|
75
81
|
target_data = MemoryMappedTensor.from_filename(
|
|
76
82
|
path,
|
|
@@ -88,11 +94,14 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
88
94
|
dtype=single_instance_space.dtype,
|
|
89
95
|
device=single_instance_space.device
|
|
90
96
|
)
|
|
91
|
-
|
|
97
|
+
if multiprocessing:
|
|
98
|
+
data = data.share_memory_()
|
|
99
|
+
data = data.copy_(target_data)
|
|
92
100
|
|
|
93
101
|
return PytorchTensorStorage(
|
|
94
102
|
single_instance_space,
|
|
95
|
-
data
|
|
103
|
+
data,
|
|
104
|
+
mutable=not read_only
|
|
96
105
|
)
|
|
97
106
|
|
|
98
107
|
# ========== Instance Implementations ==========
|
|
@@ -104,6 +113,7 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
104
113
|
self,
|
|
105
114
|
single_instance_space : BoxSpace[PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
|
|
106
115
|
data : Union[torch.Tensor, MemoryMappedTensor],
|
|
116
|
+
mutable : bool = True,
|
|
107
117
|
):
|
|
108
118
|
assert single_instance_space.shape == data.shape[1:], \
|
|
109
119
|
f"Single instance space shape {single_instance_space.shape} does not match data shape {data.shape[1:]}"
|
|
@@ -111,6 +121,7 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
111
121
|
single_instance_space
|
|
112
122
|
)
|
|
113
123
|
self.data = data
|
|
124
|
+
self._mutable = mutable
|
|
114
125
|
|
|
115
126
|
@property
|
|
116
127
|
def device(self) -> Optional[PyTorchDeviceType]:
|
|
@@ -122,6 +133,14 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
122
133
|
return self.data.filename
|
|
123
134
|
return None
|
|
124
135
|
|
|
136
|
+
@property
|
|
137
|
+
def is_mutable(self) -> bool:
|
|
138
|
+
return self._mutable
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def is_multiprocessing_safe(self) -> bool:
|
|
142
|
+
return self.data.is_shared()
|
|
143
|
+
|
|
125
144
|
@property
|
|
126
145
|
def capacity(self) -> int:
|
|
127
146
|
return self.data.shape[0]
|
|
@@ -134,9 +153,11 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
134
153
|
return self.data[index]
|
|
135
154
|
|
|
136
155
|
def set(self, index : Union[int, slice, torch.Tensor], value : torch.Tensor) -> None:
|
|
156
|
+
assert self.is_mutable, "Storage is not mutable"
|
|
137
157
|
self.data[index] = value
|
|
138
158
|
|
|
139
159
|
def clear(self) -> None:
|
|
160
|
+
assert self.is_mutable, "Storage is not mutable"
|
|
140
161
|
pass
|
|
141
162
|
|
|
142
163
|
def dumps(self, path: Union[str, os.PathLike]) -> None:
|
|
@@ -3,9 +3,7 @@ from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequen
|
|
|
3
3
|
|
|
4
4
|
from unienv_interface.space import Space, BoxSpace
|
|
5
5
|
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
6
|
-
from unienv_interface.env_base.env import ContextType, ObsType, ActType
|
|
7
6
|
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
8
|
-
from unienv_interface.backends.numpy import NumpyComputeBackend
|
|
9
7
|
from unienv_interface.utils.symbol_util import *
|
|
10
8
|
from unienv_interface.transformations import DataTransformation
|
|
11
9
|
|
|
@@ -21,7 +21,6 @@ class FuncWorldNode(ABC, Generic[
|
|
|
21
21
|
"""
|
|
22
22
|
|
|
23
23
|
name : str
|
|
24
|
-
world : FuncWorld[WorldStateT, BArrayType, BDeviceType, BDtypeType, BRNGType]
|
|
25
24
|
control_timestep : Optional[float] = None
|
|
26
25
|
context_space : Optional[Space[ContextType, BDeviceType, BDtypeType, BRNGType]] = None
|
|
27
26
|
observation_space : Optional[Space[ObsType, BDeviceType, BDtypeType, BRNGType]] = None
|
|
@@ -29,6 +28,7 @@ class FuncWorldNode(ABC, Generic[
|
|
|
29
28
|
has_reward : bool = False
|
|
30
29
|
has_termination_signal : bool = False
|
|
31
30
|
has_truncation_signal : bool = False
|
|
31
|
+
world : Optional[FuncWorld[WorldStateT, BArrayType, BDeviceType, BDtypeType, BRNGType]] = None
|
|
32
32
|
|
|
33
33
|
@property
|
|
34
34
|
def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
|
|
@@ -8,7 +8,7 @@ from .world import World
|
|
|
8
8
|
|
|
9
9
|
class WorldNode(ABC, Generic[ContextType, ObsType, ActType, BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
10
10
|
"""
|
|
11
|
-
Each `WorldNode` in the simulated / real world will manage some aspect of the environment.
|
|
11
|
+
Each `WorldNode` in the simulated / real world will manage some aspect of the environment. This can include sensors, robots, or other entities that interact with the world.
|
|
12
12
|
How the methods in this class will be called once environment resets:
|
|
13
13
|
`World.reset(...)` -> `WorldNode.reset(...)` -> `WorldNode.after_reset(...)` -> `WorldNode.get_observation(...)` -> World can start stepping normally
|
|
14
14
|
How the methods in this class will be called during a environment step:
|
|
@@ -16,7 +16,6 @@ class WorldNode(ABC, Generic[ContextType, ObsType, ActType, BArrayType, BDeviceT
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
name : str
|
|
19
|
-
world : World[BArrayType, BDeviceType, BDtypeType, BRNGType]
|
|
20
19
|
control_timestep : Optional[float] = None
|
|
21
20
|
context_space : Optional[Space[ContextType, BDeviceType, BDtypeType, BRNGType]] = None
|
|
22
21
|
observation_space : Optional[Space[ObsType, BDeviceType, BDtypeType, BRNGType]] = None
|
|
@@ -24,6 +23,7 @@ class WorldNode(ABC, Generic[ContextType, ObsType, ActType, BArrayType, BDeviceT
|
|
|
24
23
|
has_reward : bool = False
|
|
25
24
|
has_termination_signal : bool = False
|
|
26
25
|
has_truncation_signal : bool = False
|
|
26
|
+
world : Optional[World[BArrayType, BDeviceType, BDtypeType, BRNGType]] = None
|
|
27
27
|
|
|
28
28
|
@property
|
|
29
29
|
def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|