unienv 0.0.1b2__py3-none-any.whl → 0.0.1b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unienv-0.0.1b4.dist-info/METADATA +74 -0
- unienv-0.0.1b4.dist-info/RECORD +93 -0
- {unienv-0.0.1b2.dist-info → unienv-0.0.1b4.dist-info}/licenses/LICENSE +1 -3
- unienv-0.0.1b4.dist-info/top_level.txt +2 -0
- unienv_data/base/__init__.py +0 -1
- unienv_data/base/common.py +111 -51
- unienv_data/base/storage.py +12 -3
- unienv_data/batches/__init__.py +2 -1
- unienv_data/batches/backend_compat.py +47 -1
- unienv_data/batches/combined_batch.py +2 -4
- unienv_data/{base → batches}/transformations.py +3 -2
- unienv_data/replay_buffer/replay_buffer.py +4 -0
- unienv_data/samplers/__init__.py +0 -1
- unienv_data/samplers/multiprocessing_sampler.py +26 -22
- unienv_data/samplers/step_sampler.py +9 -18
- unienv_data/storages/dict_storage.py +341 -0
- unienv_data/storages/{common.py → flattened.py} +24 -5
- unienv_data/storages/hdf5.py +333 -23
- unienv_data/storages/pytorch.py +27 -5
- unienv_data/storages/transformation.py +189 -0
- unienv_data/transformations/image_compress.py +213 -0
- unienv_interface/backends/jax.py +4 -1
- unienv_interface/backends/numpy.py +4 -1
- unienv_interface/backends/pytorch.py +4 -1
- unienv_interface/env_base/__init__.py +1 -0
- unienv_interface/env_base/env.py +5 -0
- unienv_interface/env_base/funcenv.py +32 -1
- unienv_interface/env_base/funcenv_wrapper.py +2 -2
- unienv_interface/env_base/vec_env.py +474 -0
- unienv_interface/func_wrapper/__init__.py +2 -1
- unienv_interface/func_wrapper/frame_stack.py +150 -0
- unienv_interface/space/space_utils/__init__.py +1 -0
- unienv_interface/space/space_utils/batch_utils.py +83 -0
- unienv_interface/space/space_utils/construct_utils.py +216 -0
- unienv_interface/space/space_utils/serialization_utils.py +16 -1
- unienv_interface/space/spaces/__init__.py +3 -1
- unienv_interface/space/spaces/batched.py +90 -0
- unienv_interface/space/spaces/binary.py +0 -1
- unienv_interface/space/spaces/box.py +13 -24
- unienv_interface/space/spaces/text.py +1 -3
- unienv_interface/transformations/dict_transform.py +31 -5
- unienv_interface/utils/control_util.py +68 -0
- unienv_interface/utils/data_queue.py +184 -0
- unienv_interface/utils/stateclass.py +46 -0
- unienv_interface/utils/vec_util.py +15 -0
- unienv_interface/world/__init__.py +3 -1
- unienv_interface/world/combined_funcnode.py +336 -0
- unienv_interface/world/combined_node.py +232 -0
- unienv_interface/world/funcnode.py +1 -1
- unienv_interface/world/node.py +2 -2
- unienv_interface/wrapper/backend_compat.py +2 -2
- unienv_interface/wrapper/frame_stack.py +19 -114
- unienv_interface/wrapper/video_record.py +11 -2
- unienv-0.0.1b2.dist-info/METADATA +0 -73
- unienv-0.0.1b2.dist-info/RECORD +0 -85
- unienv-0.0.1b2.dist-info/top_level.txt +0 -4
- unienv_data/samplers/slice_sampler.py +0 -266
- unienv_maniskill/__init__.py +0 -1
- unienv_maniskill/wrapper/maniskill_compat.py +0 -235
- unienv_mjxplayground/__init__.py +0 -1
- unienv_mjxplayground/wrapper/playground_compat.py +0 -256
- {unienv-0.0.1b2.dist-info → unienv-0.0.1b4.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: unienv
|
|
3
|
+
Version: 0.0.1b4
|
|
4
|
+
Summary: Unified robot environment framework supporting multiple tensor and simulation backends
|
|
5
|
+
License-Expression: MIT
|
|
6
|
+
Project-URL: Homepage, https://github.com/UniEnvOrg/UniEnv
|
|
7
|
+
Project-URL: Documentation, https://github.com/UniEnvOrg/UniEnv
|
|
8
|
+
Project-URL: Repository, https://github.com/UniEnvOrg/UniEnv
|
|
9
|
+
Project-URL: Issues, https://github.com/UniEnvOrg/UniEnv/issues
|
|
10
|
+
Project-URL: Changelog, https://github.com/UniEnvOrg/UniEnv/blob/main/CHANGELOG.md
|
|
11
|
+
Requires-Python: >=3.10
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
License-File: LICENSE
|
|
14
|
+
Requires-Dist: numpy
|
|
15
|
+
Requires-Dist: xbarray>=0.0.1a8
|
|
16
|
+
Requires-Dist: pillow
|
|
17
|
+
Requires-Dist: cloudpickle
|
|
18
|
+
Provides-Extra: dev
|
|
19
|
+
Requires-Dist: pytest; extra == "dev"
|
|
20
|
+
Provides-Extra: gymnasium
|
|
21
|
+
Requires-Dist: gymnasium>=0.29.0; extra == "gymnasium"
|
|
22
|
+
Provides-Extra: video
|
|
23
|
+
Requires-Dist: moviepy>=2.1; extra == "video"
|
|
24
|
+
Dynamic: license-file
|
|
25
|
+
|
|
26
|
+
# UniEnv
|
|
27
|
+
|
|
28
|
+
Framework unifying robot environments and data APIs. UniEnv provides an universal interface for robot actors, sensors, environments, and data.
|
|
29
|
+
|
|
30
|
+
## Tensor library cross-backend Support
|
|
31
|
+
|
|
32
|
+
UniEnv supports multiple tensor backends with zero-copy translation layers through the DLPack protocol, and allows you to use the same abstract compute backend interface to write custom data transformation layers, environment wrappers and other utilities. This is powered by the [XBArray](https://github.com/UniEnvOrg/XBArray) package.
|
|
33
|
+
|
|
34
|
+
## Universal Robot Environment Interface
|
|
35
|
+
|
|
36
|
+
UniEnv supports diverse simulation environments and real robots, built on top of the abstract environment / world interface. This allows you to reuse code across different sim and real robots.
|
|
37
|
+
|
|
38
|
+
## Universal Robot Data Interface
|
|
39
|
+
|
|
40
|
+
UniEnv provides a universal data interface for accessing robot data through the abstract `BatchBase` interface. We also provide a utility `ReplayBuffer` for saving data from various environments with diverse data format support, including `hdf5`, memory-mapped torch tensors, and others.
|
|
41
|
+
|
|
42
|
+
## Installation
|
|
43
|
+
|
|
44
|
+
Install the package with pip
|
|
45
|
+
|
|
46
|
+
```bash
|
|
47
|
+
pip install unienv
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
You can install optional dependencies such as `gymnasium` (for Gymnasium-compatible environments), `dev`, or `video` by running
|
|
51
|
+
|
|
52
|
+
```bash
|
|
53
|
+
pip install unienv[gymnasium,video]
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
## Cite
|
|
57
|
+
|
|
58
|
+
If you use UniEnv in your research, please cite it as follows:
|
|
59
|
+
|
|
60
|
+
```bibtex
|
|
61
|
+
@software{cao_unienv,
|
|
62
|
+
author = {Cao, Yunhao AND Fang, Kuan},
|
|
63
|
+
title = {{UniEnv: Unifying Robot Environments and Data APIs}},
|
|
64
|
+
year = {2025},
|
|
65
|
+
month = oct,
|
|
66
|
+
url = {https://github.com/UniEnvOrg/UniEnv},
|
|
67
|
+
license = {MIT}
|
|
68
|
+
}
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
## Acknowledgements
|
|
72
|
+
|
|
73
|
+
The idea of this project is inspired by [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) and its predecessor [OpenAI Gym](https://github.com/openai/gym).
|
|
74
|
+
This library is impossible without the great work of DataAPIs Consortium and their work on the [Array API Standard](https://data-apis.org/array-api/latest/). The zero-copy translation layers are powered by the [DLPack](https://github.com/dmlc/dlpack) project.
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
unienv-0.0.1b4.dist-info/licenses/LICENSE,sha256=nkklvEaJUR4QDBygz7tkEe1FMVKV1JSjnGzJNLhdIWM,1091
|
|
2
|
+
unienv_data/__init__.py,sha256=zFxbe7aM5JvYXIK0FGnOPwWQJMN-8l_l8prB85CkcA8,95
|
|
3
|
+
unienv_data/base/__init__.py,sha256=w-I8A-z7YYArkHc2ZOVGrfzfThsaDBg7aD7qMFprNM8,186
|
|
4
|
+
unienv_data/base/common.py,sha256=EYOzuYmvsCy1uJftsw6cXeycPIr8P7GWZ3_q4wgoNeo,12879
|
|
5
|
+
unienv_data/base/storage.py,sha256=s99PYEZGa76kf-Enx57kOyVkwjb-tpU-vTHcGc5Dcew,5415
|
|
6
|
+
unienv_data/batches/__init__.py,sha256=Vi92f8ddgFYCqwv7xO2Pi3oJePnioJ4XrJbQVV7eIvk,234
|
|
7
|
+
unienv_data/batches/backend_compat.py,sha256=7Juf7nU2jYHohRzNzmGfqMMpJtFM-3oTzzLu6EbC77E,8168
|
|
8
|
+
unienv_data/batches/combined_batch.py,sha256=aua1H86sa_qWrCtAAp5JIZMGtFiiKFPFkU0y5JoyElM,15325
|
|
9
|
+
unienv_data/batches/framestack_batch.py,sha256=pdURqZeksOlbf21Nhx8kkm0gtFt6rjt2OiNWgZPdFCM,2312
|
|
10
|
+
unienv_data/batches/slicestack_batch.py,sha256=J2EhARcPA-zz6EBnV7OLzm4yyvnZ06vrdUoPD5RkJ-o,16672
|
|
11
|
+
unienv_data/batches/transformations.py,sha256=b4HqX3wZ6TuRgQ2q81Jv43PmeHGmP8cwURK_ULjGNgs,5647
|
|
12
|
+
unienv_data/integrations/pytorch.py,sha256=pW5rXBXagfzwJjM_VGgg8CPXEs3e2fKgg4nY7M3dpOc,2350
|
|
13
|
+
unienv_data/replay_buffer/__init__.py,sha256=uVebYruIYlj8OjTYVi8UYI4gWp3S3XIdgFlHbwO260o,100
|
|
14
|
+
unienv_data/replay_buffer/replay_buffer.py,sha256=nhbC-7aHGIYhcCdmaaDdhB2U9ODAZrbKMq8dP8ffOv0,10344
|
|
15
|
+
unienv_data/replay_buffer/trajectory_replay_buffer.py,sha256=fxV6FIqAHhN8opYs2WjAJMPqNRWD3iIku-4WlaydyG4,20737
|
|
16
|
+
unienv_data/samplers/__init__.py,sha256=e7uunWN3r-g_2fDaMsYMe8cZcF4N-okCxqBPweQnE0s,97
|
|
17
|
+
unienv_data/samplers/multiprocessing_sampler.py,sha256=FEBK8pMTnkpA0xuMkbvlv4aIdVTTubeT8BjL60BJL5o,13254
|
|
18
|
+
unienv_data/samplers/step_sampler.py,sha256=ZCcrx9WbILtaR6izhIP3DhtmFcP7KQBdaYaSZ7vWwRk,3010
|
|
19
|
+
unienv_data/storages/dict_storage.py,sha256=SqCGcGT9Y4l0thdmx23XSxRMzIEIuldA6m8Cd9HrpnA,12588
|
|
20
|
+
unienv_data/storages/flattened.py,sha256=Fu01TjrzvmyNhXEGtC4FiBTb7cqXDtVkErc1QNwLvcI,6704
|
|
21
|
+
unienv_data/storages/hdf5.py,sha256=F_mkrmX6SGT2HamJAyYopBmj_Nf5NzJiyvVN9irtiiM,26260
|
|
22
|
+
unienv_data/storages/pytorch.py,sha256=ftO8cND7PFV0J1B1o2YOWqj4U_pyWsJvWv9lC9A7LJg,6953
|
|
23
|
+
unienv_data/storages/transformation.py,sha256=9BIwrvdruiTRduqC03e5UbSjBT1jLSxLCkNfrsVDP7I,7577
|
|
24
|
+
unienv_data/transformations/image_compress.py,sha256=dINrvmpTWy3sbqruHk0kPZG2XNyJI90ERgErXV7GamE,9131
|
|
25
|
+
unienv_interface/__init__.py,sha256=pAWqfm4l7NAssuyXCugIjekSIh05aBbOjNhwsNXcJbE,100
|
|
26
|
+
unienv_interface/backends/__init__.py,sha256=L7CFwCChHVL-2Dpz34pTGC37WgodfJEeDQwXscyM7FM,198
|
|
27
|
+
unienv_interface/backends/base.py,sha256=1_hji1qwNAhcEtFQdAuzaNey9g5bWYj38t1sQxjnggc,132
|
|
28
|
+
unienv_interface/backends/jax.py,sha256=26Wu5OQ4EEjolyZoELhlWMPNSZ7LsVoKEGpd09L80Ck,533
|
|
29
|
+
unienv_interface/backends/numpy.py,sha256=6dMB2Vq7mrWukobyyGvuccluZUgjVkxr7x0hrUc_pe8,542
|
|
30
|
+
unienv_interface/backends/pytorch.py,sha256=BddHmZAngsaedFlvj1mKdXpNe6AWvNwEXq_eTEUoFWA,592
|
|
31
|
+
unienv_interface/backends/serialization.py,sha256=0TZlpfbP1DRB4FkM8ysDVQmn6RlYtIPisyeHjvHr7bE,2289
|
|
32
|
+
unienv_interface/env_base/__init__.py,sha256=JuaVgWlg313LZpflt4LSErY94nUrfvUp0LbIPUle0MA,226
|
|
33
|
+
unienv_interface/env_base/env.py,sha256=PV-AEmKwSjnFDjZFYtBW-At9w4fpm_I5C7GhfxPPrs4,4833
|
|
34
|
+
unienv_interface/env_base/funcenv.py,sha256=Qwm9BP4NrsVHOr7X0l3-mbsn5IhaO3-ZVW48dLg08-k,10609
|
|
35
|
+
unienv_interface/env_base/funcenv_wrapper.py,sha256=chw1iJ1RhAFMv4JAk67cttJvI9agdSm1QxNxZq0-hgM,7757
|
|
36
|
+
unienv_interface/env_base/vec_env.py,sha256=bcv6NdOxt0Xp1fRMXqzFtmVw6LQ-pDj_Jvj-qaW6otQ,16116
|
|
37
|
+
unienv_interface/env_base/wrapper.py,sha256=7hf4Rr2wouS0igPoahhvb2tzYY3bCaWL0NlgwpYZwQs,9734
|
|
38
|
+
unienv_interface/func_wrapper/__init__.py,sha256=6BPF8O25WkIBpODVTwnOE9HGSm3KRKX6iPwFGWESlxA,123
|
|
39
|
+
unienv_interface/func_wrapper/frame_stack.py,sha256=52CqAHDqwgHtOwMwxzB3Syup9kA19zdlvXCH4mI7MNU,6819
|
|
40
|
+
unienv_interface/func_wrapper/transformation.py,sha256=7mdzcpjLjqtpbtXoqbkGtTMPQxoMmMsqzDWHcZLbrhs,5939
|
|
41
|
+
unienv_interface/space/__init__.py,sha256=6-wLoD9mKDAfz7IuQs_Rn9DMDfDwTZ0tEhQ924libpg,99
|
|
42
|
+
unienv_interface/space/space.py,sha256=mFlCcDvMgEPTXlwo_iwBlm6Eg4Bn2rrecgsfIVstdq0,4067
|
|
43
|
+
unienv_interface/space/space_utils/__init__.py,sha256=GAsPoZC8YNabx3Gw5m2o4zsnG8zmA3mcuM9_lNKhiGo,121
|
|
44
|
+
unienv_interface/space/space_utils/batch_utils.py,sha256=qXK7kERPXKGIYozz7lpjzVz56S9GkH6ZASfIRzCYXHY,36993
|
|
45
|
+
unienv_interface/space/space_utils/construct_utils.py,sha256=Y4RpV9obY8XQ85O3r_NC1HrBk-Nm941ffRNXNL7nHgA,8323
|
|
46
|
+
unienv_interface/space/space_utils/flatten_utils.py,sha256=kkHkjrsk43NDbg3Q5VAhVoIXStuRayYFO-7knsDzx4A,12289
|
|
47
|
+
unienv_interface/space/space_utils/gym_utils.py,sha256=nH8EKruOKCXNrIMPUd9F4XGKCfFkhxsTmx4I1BeSgn0,15079
|
|
48
|
+
unienv_interface/space/space_utils/serialization_utils.py,sha256=LWYSFN7E6tEFe8ULWm42LkFUxP_0dfTGkCcx0yl4Y8s,9530
|
|
49
|
+
unienv_interface/space/spaces/__init__.py,sha256=Jap768TlwHFDDiTzHZ0qaHEFEVC1cKA2QzLlSZVQnjI,535
|
|
50
|
+
unienv_interface/space/spaces/batched.py,sha256=RA8aLUSS14zBSCTm_ud18TTa-ntbIZ074xwJ0xls1Jk,3691
|
|
51
|
+
unienv_interface/space/spaces/binary.py,sha256=0iQUbO37dhkznVpjhsJdwlD-KdWgCEx2H7KrybuZ_PM,3570
|
|
52
|
+
unienv_interface/space/spaces/box.py,sha256=NCmileEZOKz-L3WNzZ-uwydrRFsIMdNZBwTn1vWgeI0,13316
|
|
53
|
+
unienv_interface/space/spaces/dict.py,sha256=G5_iYC1Bj5DqeJ7aFlq6eRJbnpATbIRIyRu1jF_UUvk,7022
|
|
54
|
+
unienv_interface/space/spaces/dynamic_box.py,sha256=HvMNgzfYwIVc5VVgCtq-8lQbNI1V1dZMI-w60AwYss4,19591
|
|
55
|
+
unienv_interface/space/spaces/graph.py,sha256=KocRFLtYP5VWYpwbP6HybXH5R4jTIYJdNePKb6vhnYE,15163
|
|
56
|
+
unienv_interface/space/spaces/text.py,sha256=ePGGJdiD3q-BAX6IHLO7HMe0OH4VrzF043K02eb0zXI,4443
|
|
57
|
+
unienv_interface/space/spaces/tuple.py,sha256=rgZQz3EB3CLbIk9UlHBQbM6w9gssbA1izm-Qq-_Chqs,4267
|
|
58
|
+
unienv_interface/space/spaces/union.py,sha256=Qisd-DdmPcGRmdhZFGiQw8_AOjYWqkuQ4Hwd-I8tdSI,4375
|
|
59
|
+
unienv_interface/transformations/__init__.py,sha256=g19uGnDHMywvDAXRaqFgoWAF1vCPrbJENEpaEgtIrOw,353
|
|
60
|
+
unienv_interface/transformations/batch_and_unbatch.py,sha256=ELCnNtwmgA5wpTBJZasfNSHmtf4vzydzLPmO6IGbT9o,1164
|
|
61
|
+
unienv_interface/transformations/chained_transform.py,sha256=TDnUvxUKK6bXGc_sfr6ZCvvVWw7P5KX2sA9i7i2lx14,2075
|
|
62
|
+
unienv_interface/transformations/dict_transform.py,sha256=ynrJrloVUix2I27Ir1mL86crT0vY5DvpiBAVxPBJup4,5357
|
|
63
|
+
unienv_interface/transformations/filter_dict.py,sha256=DzR-hgHoHJObTipxwB2UrKVlTxbfIrJohaOgqdAICLY,5871
|
|
64
|
+
unienv_interface/transformations/rescale.py,sha256=fM5ukWUvNvPeDO48_PRU0KyyvGhBIDxaN9XZyQ1VaQQ,4364
|
|
65
|
+
unienv_interface/transformations/transformation.py,sha256=u4_9H1tvophhgG0p0F3xfkMMsRuaKY2TQmVeGoeQsJ0,1652
|
|
66
|
+
unienv_interface/utils/control_util.py,sha256=lY_1EknglY3cNekWX9rYWt0ZUglaPMtIt4M5D9y0WfE,2351
|
|
67
|
+
unienv_interface/utils/data_queue.py,sha256=UZiuQDOn39DB9Heu6xinrwuzAL3X8jHlDkFoSC5Phtc,5707
|
|
68
|
+
unienv_interface/utils/seed_util.py,sha256=Up3nBXj7L8w-S9W5Q1U2d9accMhMf0TmHPaN6JXDVWs,677
|
|
69
|
+
unienv_interface/utils/stateclass.py,sha256=xjzicPGX1UuI7q3ZAxhBCCoouKfNtLywUzQtLaT0yS4,1390
|
|
70
|
+
unienv_interface/utils/symbol_util.py,sha256=NAERK-D_2MaTg2eYW-L75tbzPQN5YJIiKtM9zuQ89Sw,383
|
|
71
|
+
unienv_interface/utils/vec_util.py,sha256=EIK680ReCl_rr-qKP8co5hwz8Dx-gks8SHf-CLOZSOA,373
|
|
72
|
+
unienv_interface/world/__init__.py,sha256=aGuYTz8XFzW32RGkdi2b2LJ1sa0kgFrQyOR3JXDEwLQ,230
|
|
73
|
+
unienv_interface/world/combined_funcnode.py,sha256=O9qWxhtMJkDVtWuGyaeEj3nKMgIyRAPqF9-5LU6yna8,10853
|
|
74
|
+
unienv_interface/world/combined_node.py,sha256=tG7I9uWVxDDN6M6KeC1D14MV7YUnXYMUK9L9KXHnViA,9090
|
|
75
|
+
unienv_interface/world/funcnode.py,sha256=WvTNisOwPTwWlxC5NwQRxi-gh6MxLohh7ulctj-2YXY,7846
|
|
76
|
+
unienv_interface/world/funcworld.py,sha256=GLp8nS0Ym1gaj7FWvD5FPkQElCgZMbpyuLsIMU0w-sw,2020
|
|
77
|
+
unienv_interface/world/node.py,sha256=EAvHnx0u7IudmWQDbAUIRVEqB4kh2Xsm1aXdS3CeloY,6095
|
|
78
|
+
unienv_interface/world/world.py,sha256=Kl7wbNbs2YR3CjFrCLFhDB3DQUAWM6LjBwSADQtBTII,5740
|
|
79
|
+
unienv_interface/wrapper/__init__.py,sha256=ZNqr-WjVRqgvIxkLkeABxpYZ6tRgJNZOzmluDeJ6W_w,614
|
|
80
|
+
unienv_interface/wrapper/action_rescale.py,sha256=rTJlEHvWSuwGVX83cjfLWvszBk7B2iExX_K37vH8Wic,1231
|
|
81
|
+
unienv_interface/wrapper/backend_compat.py,sha256=T6hosgu2hrZvg3xtnyELmR6Exlz-ztqdj9vdyiz7bhI,7081
|
|
82
|
+
unienv_interface/wrapper/batch_and_unbatch.py,sha256=HpmnppgOKmshNlfmJYkGQYtEU7_U7q3mEdY5n4UaqEY,3457
|
|
83
|
+
unienv_interface/wrapper/control_frequency_limit.py,sha256=B0E2aUbaUr2p2yIN6wT3q4rAbPYsVroioqma2qKMoC0,2322
|
|
84
|
+
unienv_interface/wrapper/flatten.py,sha256=NWA5xne5j_L34oq_wT85wGvp6iHwdCSeGsk1DMugvRw,5837
|
|
85
|
+
unienv_interface/wrapper/frame_stack.py,sha256=lZZh_T_AmxsRWeYSLsTU321lVgIt12MX1eWl_yRNlWg,6002
|
|
86
|
+
unienv_interface/wrapper/gym_compat.py,sha256=JhLxDsO1NsJnKzKhO0MqMw9i5_1FLxoxKilWaQQyBkw,9789
|
|
87
|
+
unienv_interface/wrapper/time_limit.py,sha256=VRvB00BK7deI2QtdGatqwDWmPgjgjg1E7MTvEyaW5rg,2904
|
|
88
|
+
unienv_interface/wrapper/transformation.py,sha256=pQ-_YVU8WWDqSk2sONUUgQY1iigOD092KNcp1DYxoxk,10043
|
|
89
|
+
unienv_interface/wrapper/video_record.py,sha256=y_nJRYgo1SeLeO_Ymg9xbbGPKm48AbU3BxZK2wd0gzk,8679
|
|
90
|
+
unienv-0.0.1b4.dist-info/METADATA,sha256=R_70XnKo1K6ObRxMmSlW1W_lxfD_rGR6txa3wBHGPOM,3033
|
|
91
|
+
unienv-0.0.1b4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
92
|
+
unienv-0.0.1b4.dist-info/top_level.txt,sha256=wfcJ5_DruUtOEUZjEyfadaKn7B90hWqz2aw-eM3wX5g,29
|
|
93
|
+
unienv-0.0.1b4.dist-info/RECORD,,
|
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
MIT License
|
|
2
2
|
|
|
3
|
-
Copyright (c)
|
|
4
|
-
Copyright (c) 2022 Farama Foundation
|
|
5
|
-
Copyright (c) 2024 Yunhao Cao
|
|
3
|
+
Copyright (c) 2025 Yunhao Cao and UniEnv Contributors
|
|
6
4
|
|
|
7
5
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
6
|
of this software and associated documentation files (the "Software"), to deal
|
unienv_data/base/__init__.py
CHANGED
unienv_data/base/common.py
CHANGED
|
@@ -9,26 +9,44 @@ import dataclasses
|
|
|
9
9
|
|
|
10
10
|
from unienv_interface.space.space_utils import batch_utils as space_batch_utils, flatten_utils as space_flatten_utils
|
|
11
11
|
|
|
12
|
+
__all__ = [
|
|
13
|
+
"BatchT",
|
|
14
|
+
"BatchBase",
|
|
15
|
+
"BatchSampler",
|
|
16
|
+
"IndexableType",
|
|
17
|
+
"convert_index_to_backendarray",
|
|
18
|
+
]
|
|
19
|
+
|
|
12
20
|
IndexableType = Union[int, slice, EllipsisType]
|
|
13
21
|
|
|
22
|
+
def convert_index_to_backendarray(
|
|
23
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
24
|
+
index : IndexableType,
|
|
25
|
+
length : int,
|
|
26
|
+
device : Optional[BDeviceType] = None,
|
|
27
|
+
) -> BArrayType:
|
|
28
|
+
if isinstance(index, int):
|
|
29
|
+
return backend.asarray([index], dtype=backend.default_integer_dtype, device=device)
|
|
30
|
+
elif isinstance(index, slice):
|
|
31
|
+
return backend.arange(*index.indices(length), dtype=backend.default_integer_dtype, device=device)
|
|
32
|
+
elif index is Ellipsis:
|
|
33
|
+
return backend.arange(length, dtype=backend.default_integer_dtype, device=device)
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError("Index must be an integer, slice, or Ellipsis.")
|
|
36
|
+
|
|
14
37
|
BatchT = TypeVar('BatchT')
|
|
15
38
|
class BatchBase(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
16
|
-
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]
|
|
17
|
-
device: Optional[BDeviceType] = None
|
|
18
|
-
|
|
19
39
|
# If the batch is mutable, then the data can be changed (extend_*, set_*, remove_*, etc.)
|
|
20
40
|
is_mutable: bool = True
|
|
21
41
|
|
|
22
42
|
def __init__(
|
|
23
43
|
self,
|
|
24
|
-
single_space : Space[
|
|
44
|
+
single_space : Space[BatchT, BDeviceType, BDtypeType, BRNGType],
|
|
25
45
|
single_metadata_space : Optional[DictSpace[BDeviceType, BDtypeType, BRNGType]] = None,
|
|
26
46
|
):
|
|
27
47
|
self.single_space = single_space
|
|
28
48
|
self.single_metadata_space = single_metadata_space
|
|
29
|
-
self._batched_space : Space[
|
|
30
|
-
BatchT, Any, BDeviceType, BDtypeType, BRNGType
|
|
31
|
-
] = space_batch_utils.batch_space(single_space, 1)
|
|
49
|
+
self._batched_space : Space[BatchT, BDeviceType, BDtypeType, BRNGType] = space_batch_utils.batch_space(single_space, 1)
|
|
32
50
|
if single_metadata_space is not None:
|
|
33
51
|
self._batched_metadata_space : DictSpace[
|
|
34
52
|
BDeviceType, BDtypeType, BRNGType
|
|
@@ -36,24 +54,43 @@ class BatchBase(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BR
|
|
|
36
54
|
else:
|
|
37
55
|
self._batched_metadata_space = None
|
|
38
56
|
|
|
57
|
+
@property
|
|
58
|
+
def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
|
|
59
|
+
return self.single_space.backend
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def device(self) -> Optional[BDeviceType]:
|
|
63
|
+
return self.single_space.device
|
|
64
|
+
|
|
39
65
|
@abc.abstractmethod
|
|
40
66
|
def __len__(self) -> int:
|
|
41
67
|
raise NotImplementedError
|
|
42
68
|
|
|
43
69
|
def get_flattened_at(self, idx : Union[IndexableType, BArrayType]) -> BArrayType:
|
|
44
|
-
|
|
70
|
+
unflattened_data = self.get_at(idx)
|
|
71
|
+
if isinstance(idx, int):
|
|
72
|
+
return space_flatten_utils.flatten_data(self.single_space, unflattened_data)
|
|
73
|
+
else:
|
|
74
|
+
return space_flatten_utils.flatten_data(self._batched_space, unflattened_data, start_dim=1)
|
|
45
75
|
|
|
46
|
-
@abc.abstractmethod
|
|
47
76
|
def get_flattened_at_with_metadata(
|
|
48
77
|
self, idx : Union[IndexableType, BArrayType]
|
|
49
78
|
) -> Tuple[BArrayType, Optional[Dict[str, Any]]]:
|
|
50
|
-
|
|
79
|
+
unflattened_data, metadata = self.get_at_with_metadata(idx)
|
|
80
|
+
if isinstance(idx, int):
|
|
81
|
+
return space_flatten_utils.flatten_data(self.single_space, unflattened_data), metadata
|
|
82
|
+
else:
|
|
83
|
+
return space_flatten_utils.flatten_data(self._batched_space, unflattened_data, start_dim=1), metadata
|
|
51
84
|
|
|
52
85
|
def set_flattened_at(self, idx : Union[IndexableType, BArrayType], value : BArrayType) -> None:
|
|
53
86
|
raise NotImplementedError
|
|
54
87
|
|
|
88
|
+
def append_flattened(self, value : BArrayType) -> None:
|
|
89
|
+
return self.extend_flattened(value[None])
|
|
90
|
+
|
|
55
91
|
def extend_flattened(self, value : BArrayType) -> None:
|
|
56
|
-
|
|
92
|
+
unflat_data = space_flatten_utils.unflatten_data(self._batched_space, value, start_dim=1)
|
|
93
|
+
self.extend(unflat_data)
|
|
57
94
|
|
|
58
95
|
def get_at(self, idx : Union[IndexableType, BArrayType]) -> BatchT:
|
|
59
96
|
flattened_data = self.get_flattened_at(idx)
|
|
@@ -90,55 +127,81 @@ class BatchBase(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BR
|
|
|
90
127
|
def __delitem__(self, idx : Union[IndexableType, BArrayType]) -> None:
|
|
91
128
|
self.remove_at(idx)
|
|
92
129
|
|
|
130
|
+
def append(self, value : BatchT) -> None:
|
|
131
|
+
batched_data = space_batch_utils.concatenate(self._batched_space, [value])
|
|
132
|
+
self.extend(batched_data)
|
|
133
|
+
|
|
93
134
|
def extend(self, value : BatchT) -> None:
|
|
94
135
|
flattened_data = space_flatten_utils.flatten_data(self._batched_space, value, start_dim=1)
|
|
95
136
|
self.extend_flattened(flattened_data)
|
|
96
137
|
|
|
138
|
+
def extend_from(
|
|
139
|
+
self,
|
|
140
|
+
other : 'BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]',
|
|
141
|
+
chunk_size : int = 8,
|
|
142
|
+
tqdm : bool = False,
|
|
143
|
+
) -> None:
|
|
144
|
+
n_total = len(other)
|
|
145
|
+
iterable_start = range(0, n_total, chunk_size)
|
|
146
|
+
if tqdm:
|
|
147
|
+
from tqdm import tqdm
|
|
148
|
+
iterable_start = tqdm(iterable_start, desc="Extending Batch")
|
|
149
|
+
for start_idx in range(0, n_total, chunk_size):
|
|
150
|
+
end_idx = min(start_idx + chunk_size, n_total)
|
|
151
|
+
data_chunk = other.get_at(slice(start_idx, end_idx))
|
|
152
|
+
self.extend(data_chunk)
|
|
153
|
+
|
|
97
154
|
def close(self) -> None:
|
|
98
155
|
pass
|
|
99
156
|
|
|
100
|
-
def __del__(self) -> None:
|
|
101
|
-
self.close()
|
|
102
|
-
|
|
103
157
|
SamplerBatchT = TypeVar('SamplerBatchT')
|
|
104
158
|
SamplerArrayType = TypeVar('SamplerArrayType')
|
|
105
159
|
SamplerDeviceType = TypeVar('SamplerDeviceType')
|
|
106
160
|
SamplerDtypeType = TypeVar('SamplerDtypeType')
|
|
107
161
|
SamplerRNGType = TypeVar('SamplerRNGType')
|
|
108
|
-
class BatchSampler(
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
backend : ComputeBackend[SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType]
|
|
118
|
-
device : Optional[SamplerDeviceType] = None
|
|
119
|
-
|
|
162
|
+
class BatchSampler(
|
|
163
|
+
Generic[
|
|
164
|
+
SamplerBatchT, SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType,
|
|
165
|
+
BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType,
|
|
166
|
+
],
|
|
167
|
+
BatchBase[
|
|
168
|
+
SamplerBatchT, SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType
|
|
169
|
+
]
|
|
170
|
+
):
|
|
120
171
|
data : BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]
|
|
121
172
|
|
|
122
173
|
rng : Optional[SamplerRNGType] = None
|
|
123
174
|
data_rng : Optional[BRNGType] = None
|
|
124
|
-
|
|
125
|
-
def get_flat_at(self, idx : SamplerArrayType) -> SamplerArrayType:
|
|
126
|
-
return self.get_flat_at_with_metadata(idx)[0]
|
|
127
|
-
|
|
128
|
-
@abc.abstractmethod
|
|
129
|
-
def get_flat_at_with_metadata(
|
|
130
|
-
self, idx : SamplerArrayType
|
|
131
|
-
) -> Tuple[SamplerArrayType, Optional[Dict[str, Any]]]:
|
|
132
|
-
raise NotImplementedError
|
|
133
175
|
|
|
134
|
-
|
|
135
|
-
|
|
176
|
+
is_mutable : bool = False
|
|
177
|
+
|
|
178
|
+
def __init__(
|
|
179
|
+
self,
|
|
180
|
+
single_space : Space[BatchT, BDeviceType, BDtypeType, BRNGType],
|
|
181
|
+
single_metadata_space : Optional[DictSpace[BDeviceType, BDtypeType, BRNGType]] = None,
|
|
182
|
+
batch_size : int = 1,
|
|
183
|
+
) -> None:
|
|
184
|
+
super().__init__(single_space=single_space, single_metadata_space=single_metadata_space)
|
|
185
|
+
self.batch_size = batch_size
|
|
186
|
+
self._batched_space : Space[SamplerBatchT, SamplerDeviceType, SamplerDtypeType, SamplerRNGType] = space_batch_utils.batch_space(self.single_space, batch_size)
|
|
187
|
+
self._batched_metadata_space : Optional[DictSpace[SamplerDeviceType, SamplerDtypeType, SamplerRNGType]] = space_batch_utils.batch_space(self.single_metadata_space, batch_size) if self.single_metadata_space is not None else None
|
|
188
|
+
|
|
189
|
+
def manual_seed(self, seed : int) -> None:
|
|
190
|
+
if self.rng is not None:
|
|
191
|
+
self.rng = self.backend.random.random_number_generator(seed, device=self.device)
|
|
192
|
+
if self.data_rng is not None:
|
|
193
|
+
self.data_rng = self.backend.random.random_number_generator(seed, device=self.data.device)
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def sampled_space(self) -> Space[SamplerBatchT, SamplerDeviceType, SamplerDtypeType, SamplerRNGType]:
|
|
197
|
+
return self._batched_space
|
|
136
198
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
199
|
+
@property
|
|
200
|
+
def sampled_metadata_space(self) -> Optional[DictSpace[SamplerDeviceType, SamplerDtypeType, SamplerRNGType]]:
|
|
201
|
+
return self._batched_metadata_space
|
|
202
|
+
|
|
203
|
+
def __len__(self):
|
|
204
|
+
return len(self.data)
|
|
142
205
|
|
|
143
206
|
def sample_index(self) -> SamplerArrayType:
|
|
144
207
|
new_rng, indices = self.backend.random.random_discrete_uniform( # (B, )
|
|
@@ -156,11 +219,11 @@ class BatchSampler(abc.ABC, Generic[
|
|
|
156
219
|
|
|
157
220
|
def sample_flat(self) -> SamplerArrayType:
|
|
158
221
|
idx = self.sample_index()
|
|
159
|
-
return self.
|
|
222
|
+
return self.get_flattened_at(idx)
|
|
160
223
|
|
|
161
224
|
def sample_flat_with_metadata(self) -> Tuple[SamplerArrayType, Optional[Dict[str, Any]]]:
|
|
162
225
|
idx = self.sample_index()
|
|
163
|
-
return self.
|
|
226
|
+
return self.get_flattened_at_with_metadata(idx)
|
|
164
227
|
|
|
165
228
|
def sample(self) -> SamplerBatchT:
|
|
166
229
|
idx = self.sample_index()
|
|
@@ -205,9 +268,9 @@ class BatchSampler(abc.ABC, Generic[
|
|
|
205
268
|
n_batches = len(self.data) // self.batch_size
|
|
206
269
|
num_left = len(self.data) % self.batch_size
|
|
207
270
|
for i in range(n_batches):
|
|
208
|
-
yield self.
|
|
271
|
+
yield self.get_flattened_at(idx[i*self.batch_size:(i+1)*self.batch_size])
|
|
209
272
|
if num_left > 0:
|
|
210
|
-
yield self.
|
|
273
|
+
yield self.get_flattened_at(idx[-num_left:])
|
|
211
274
|
|
|
212
275
|
def epoch_flat_iter_with_metadata(self) -> Iterator[Tuple[SamplerArrayType, Optional[Dict[str, Any]]]]:
|
|
213
276
|
if self.data_rng is not None:
|
|
@@ -217,12 +280,9 @@ class BatchSampler(abc.ABC, Generic[
|
|
|
217
280
|
n_batches = len(self.data) // self.batch_size
|
|
218
281
|
num_left = len(self.data) % self.batch_size
|
|
219
282
|
for i in range(n_batches):
|
|
220
|
-
yield self.
|
|
283
|
+
yield self.get_flattened_at_with_metadata(idx[i*self.batch_size:(i+1)*self.batch_size])
|
|
221
284
|
if num_left > 0:
|
|
222
|
-
yield self.
|
|
285
|
+
yield self.get_flattened_at_with_metadata(idx[-num_left:])
|
|
223
286
|
|
|
224
287
|
def close(self) -> None:
|
|
225
288
|
pass
|
|
226
|
-
|
|
227
|
-
def __del__(self) -> None:
|
|
228
|
-
self.close()
|
unienv_data/base/storage.py
CHANGED
|
@@ -31,6 +31,7 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
|
|
|
31
31
|
single_instance_space: Space[BatchT, BDeviceType, BDtypeType, BRNGType],
|
|
32
32
|
*,
|
|
33
33
|
capacity : Optional[int] = None,
|
|
34
|
+
read_only : bool = True,
|
|
34
35
|
**kwargs
|
|
35
36
|
) -> "SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
36
37
|
raise NotImplementedError
|
|
@@ -56,6 +57,17 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
|
|
|
56
57
|
"""
|
|
57
58
|
cache_filename : Optional[Union[str, os.PathLike]] = None
|
|
58
59
|
|
|
60
|
+
"""
|
|
61
|
+
Can the storage instance be safely used in multiprocessing environments after creation?
|
|
62
|
+
If True, the storage can be used in multiprocessing environments.
|
|
63
|
+
"""
|
|
64
|
+
is_multiprocessing_safe : bool = False
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
Is the storage mutable? If False, the storage is read-only.
|
|
68
|
+
"""
|
|
69
|
+
is_mutable : bool = True
|
|
70
|
+
|
|
59
71
|
@property
|
|
60
72
|
def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
|
|
61
73
|
return self.single_instance_space.backend
|
|
@@ -127,6 +139,3 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
|
|
|
127
139
|
|
|
128
140
|
def close(self) -> None:
|
|
129
141
|
pass
|
|
130
|
-
|
|
131
|
-
def __del__(self) -> None:
|
|
132
|
-
self.close()
|
unienv_data/batches/__init__.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from .backend_compat import ToBackendOrDeviceBatch
|
|
2
2
|
from .combined_batch import CombinedBatch
|
|
3
3
|
from .slicestack_batch import SliceStackedBatch
|
|
4
|
-
from .framestack_batch import FrameStackedBatch
|
|
4
|
+
from .framestack_batch import FrameStackedBatch
|
|
5
|
+
from .transformations import TransformedBatch
|
|
@@ -66,7 +66,7 @@ class ToBackendOrDeviceBatch(
|
|
|
66
66
|
)
|
|
67
67
|
self.batch = batch
|
|
68
68
|
self.target_backend = backend
|
|
69
|
-
self.
|
|
69
|
+
self.target_device = device
|
|
70
70
|
|
|
71
71
|
def __len__(self) -> int:
|
|
72
72
|
return len(self.batch)
|
|
@@ -79,7 +79,18 @@ class ToBackendOrDeviceBatch(
|
|
|
79
79
|
def backend(self) -> ComputeBackend[WrapperBArrayT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT]:
|
|
80
80
|
return self.target_backend if self.target_backend is not None else self.batch.backend
|
|
81
81
|
|
|
82
|
+
@property
|
|
83
|
+
def device(self) -> Optional[WrapperBDeviceT]:
|
|
84
|
+
return self.target_device if self.target_device is not None else self.batch.device
|
|
85
|
+
|
|
82
86
|
def get_flattened_at(self, idx):
|
|
87
|
+
if self.target_backend.is_backendarray(idx):
|
|
88
|
+
idx = data_to(
|
|
89
|
+
idx,
|
|
90
|
+
source_backend=self.target_backend,
|
|
91
|
+
target_backend=self.batch.backend,
|
|
92
|
+
target_device=self.batch.device
|
|
93
|
+
)
|
|
83
94
|
o_data = self.batch.get_flattened_at(idx)
|
|
84
95
|
return data_to(
|
|
85
96
|
o_data,
|
|
@@ -89,6 +100,13 @@ class ToBackendOrDeviceBatch(
|
|
|
89
100
|
)
|
|
90
101
|
|
|
91
102
|
def get_flattened_at_with_metadata(self, idx):
|
|
103
|
+
if self.target_backend.is_backendarray(idx):
|
|
104
|
+
idx = data_to(
|
|
105
|
+
idx,
|
|
106
|
+
source_backend=self.target_backend,
|
|
107
|
+
target_backend=self.batch.backend,
|
|
108
|
+
target_device=self.batch.device
|
|
109
|
+
)
|
|
92
110
|
o_data, o_metadata = self.batch.get_flattened_at_with_metadata(idx)
|
|
93
111
|
return (
|
|
94
112
|
data_to(
|
|
@@ -107,6 +125,13 @@ class ToBackendOrDeviceBatch(
|
|
|
107
125
|
|
|
108
126
|
def set_flattened_at(self, idx, value):
|
|
109
127
|
assert self.is_mutable, "Batch is not mutable"
|
|
128
|
+
if self.target_backend.is_backendarray(idx):
|
|
129
|
+
idx = data_to(
|
|
130
|
+
idx,
|
|
131
|
+
source_backend=self.target_backend,
|
|
132
|
+
target_backend=self.batch.backend,
|
|
133
|
+
target_device=self.batch.device
|
|
134
|
+
)
|
|
110
135
|
value = data_to(
|
|
111
136
|
value,
|
|
112
137
|
source_backend=self.target_backend,
|
|
@@ -126,6 +151,13 @@ class ToBackendOrDeviceBatch(
|
|
|
126
151
|
self.batch.extend_flattened(value)
|
|
127
152
|
|
|
128
153
|
def get_at(self, idx):
|
|
154
|
+
if self.target_backend.is_backendarray(idx):
|
|
155
|
+
idx = data_to(
|
|
156
|
+
idx,
|
|
157
|
+
source_backend=self.target_backend,
|
|
158
|
+
target_backend=self.batch.backend,
|
|
159
|
+
target_device=self.batch.device
|
|
160
|
+
)
|
|
129
161
|
o_data = self.batch.get_at(idx)
|
|
130
162
|
return (
|
|
131
163
|
data_to(
|
|
@@ -137,6 +169,13 @@ class ToBackendOrDeviceBatch(
|
|
|
137
169
|
)
|
|
138
170
|
|
|
139
171
|
def get_at_with_metadata(self, idx):
|
|
172
|
+
if self.target_backend.is_backendarray(idx):
|
|
173
|
+
idx = data_to(
|
|
174
|
+
idx,
|
|
175
|
+
source_backend=self.target_backend,
|
|
176
|
+
target_backend=self.batch.backend,
|
|
177
|
+
target_device=self.batch.device
|
|
178
|
+
)
|
|
140
179
|
o_data, o_metadata = self.batch.get_at_with_metadata(idx)
|
|
141
180
|
return (
|
|
142
181
|
data_to(
|
|
@@ -155,6 +194,13 @@ class ToBackendOrDeviceBatch(
|
|
|
155
194
|
|
|
156
195
|
def set_at(self, idx, value):
|
|
157
196
|
assert self.is_mutable, "Batch is not mutable"
|
|
197
|
+
if self.target_backend.is_backendarray(idx):
|
|
198
|
+
idx = data_to(
|
|
199
|
+
idx,
|
|
200
|
+
source_backend=self.target_backend,
|
|
201
|
+
target_backend=self.batch.backend,
|
|
202
|
+
target_device=self.batch.device
|
|
203
|
+
)
|
|
158
204
|
o_value = data_to(
|
|
159
205
|
value,
|
|
160
206
|
source_backend=self.target_backend,
|
|
@@ -56,8 +56,6 @@ class CombinedBatch(BatchBase[
|
|
|
56
56
|
)
|
|
57
57
|
super().__init__(single_space, new_metadata_space)
|
|
58
58
|
|
|
59
|
-
self.backend = backend
|
|
60
|
-
self.device = device
|
|
61
59
|
self.is_mutable = is_mutable
|
|
62
60
|
self.batches = batches
|
|
63
61
|
self._build_index_cache()
|
|
@@ -248,7 +246,7 @@ class CombinedBatch(BatchBase[
|
|
|
248
246
|
result = result_space.create_empty()
|
|
249
247
|
for batch_index, index_into_batch, mask in batch_list:
|
|
250
248
|
result = sbu.set_at(
|
|
251
|
-
|
|
249
|
+
result_space,
|
|
252
250
|
result,
|
|
253
251
|
mask,
|
|
254
252
|
self.batches[batch_index].get_at(index_into_batch),
|
|
@@ -295,7 +293,7 @@ class CombinedBatch(BatchBase[
|
|
|
295
293
|
for batch_index, index_into_batch, mask in batch_list:
|
|
296
294
|
batch_result, metadata_result = self.batches[batch_index].get_at_with_metadata(index_into_batch)
|
|
297
295
|
result = sbu.set_at(
|
|
298
|
-
|
|
296
|
+
result_space,
|
|
299
297
|
result,
|
|
300
298
|
mask,
|
|
301
299
|
batch_result,
|
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
from typing import Optional, Any, Union
|
|
1
|
+
from typing import Optional, Any, Union, Tuple, Dict
|
|
2
2
|
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
3
3
|
|
|
4
4
|
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
5
|
-
from .common import *
|
|
6
5
|
from unienv_interface.transformations.transformation import DataTransformation, TargetDataT, SourceDataT, SourceBArrT, SourceBDeviceT, SourceBDTypeT, SourceBDRNGT
|
|
7
6
|
from unienv_interface.space import Space
|
|
8
7
|
|
|
8
|
+
from ..base.common import BatchBase, BatchT, IndexableType
|
|
9
|
+
|
|
9
10
|
class TransformedBatch(
|
|
10
11
|
BatchBase[
|
|
11
12
|
BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
@@ -63,6 +63,8 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
63
63
|
**kwargs
|
|
64
64
|
) -> "ReplayBuffer[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
65
65
|
storage_path_relative = "storage" + (storage_cls.single_file_ext or "")
|
|
66
|
+
if cache_path is not None:
|
|
67
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
66
68
|
storage = storage_cls.create(
|
|
67
69
|
single_instance_space,
|
|
68
70
|
*args,
|
|
@@ -94,6 +96,7 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
94
96
|
*,
|
|
95
97
|
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
96
98
|
device: Optional[BDeviceType] = None,
|
|
99
|
+
read_only : bool = True,
|
|
97
100
|
**storage_kwargs
|
|
98
101
|
) -> "ReplayBuffer[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
99
102
|
with open(os.path.join(path, "metadata.json"), "r") as f:
|
|
@@ -114,6 +117,7 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
114
117
|
storage_path,
|
|
115
118
|
single_instance_space,
|
|
116
119
|
capacity=capacity,
|
|
120
|
+
read_only=read_only,
|
|
117
121
|
**storage_kwargs
|
|
118
122
|
)
|
|
119
123
|
return ReplayBuffer(storage, metadata["storage_path_relative"], count, offset, cache_path=path)
|
unienv_data/samplers/__init__.py
CHANGED