unienv 0.0.1b2__py3-none-any.whl → 0.0.1b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. unienv-0.0.1b4.dist-info/METADATA +74 -0
  2. unienv-0.0.1b4.dist-info/RECORD +93 -0
  3. {unienv-0.0.1b2.dist-info → unienv-0.0.1b4.dist-info}/licenses/LICENSE +1 -3
  4. unienv-0.0.1b4.dist-info/top_level.txt +2 -0
  5. unienv_data/base/__init__.py +0 -1
  6. unienv_data/base/common.py +111 -51
  7. unienv_data/base/storage.py +12 -3
  8. unienv_data/batches/__init__.py +2 -1
  9. unienv_data/batches/backend_compat.py +47 -1
  10. unienv_data/batches/combined_batch.py +2 -4
  11. unienv_data/{base → batches}/transformations.py +3 -2
  12. unienv_data/replay_buffer/replay_buffer.py +4 -0
  13. unienv_data/samplers/__init__.py +0 -1
  14. unienv_data/samplers/multiprocessing_sampler.py +26 -22
  15. unienv_data/samplers/step_sampler.py +9 -18
  16. unienv_data/storages/dict_storage.py +341 -0
  17. unienv_data/storages/{common.py → flattened.py} +24 -5
  18. unienv_data/storages/hdf5.py +333 -23
  19. unienv_data/storages/pytorch.py +27 -5
  20. unienv_data/storages/transformation.py +189 -0
  21. unienv_data/transformations/image_compress.py +213 -0
  22. unienv_interface/backends/jax.py +4 -1
  23. unienv_interface/backends/numpy.py +4 -1
  24. unienv_interface/backends/pytorch.py +4 -1
  25. unienv_interface/env_base/__init__.py +1 -0
  26. unienv_interface/env_base/env.py +5 -0
  27. unienv_interface/env_base/funcenv.py +32 -1
  28. unienv_interface/env_base/funcenv_wrapper.py +2 -2
  29. unienv_interface/env_base/vec_env.py +474 -0
  30. unienv_interface/func_wrapper/__init__.py +2 -1
  31. unienv_interface/func_wrapper/frame_stack.py +150 -0
  32. unienv_interface/space/space_utils/__init__.py +1 -0
  33. unienv_interface/space/space_utils/batch_utils.py +83 -0
  34. unienv_interface/space/space_utils/construct_utils.py +216 -0
  35. unienv_interface/space/space_utils/serialization_utils.py +16 -1
  36. unienv_interface/space/spaces/__init__.py +3 -1
  37. unienv_interface/space/spaces/batched.py +90 -0
  38. unienv_interface/space/spaces/binary.py +0 -1
  39. unienv_interface/space/spaces/box.py +13 -24
  40. unienv_interface/space/spaces/text.py +1 -3
  41. unienv_interface/transformations/dict_transform.py +31 -5
  42. unienv_interface/utils/control_util.py +68 -0
  43. unienv_interface/utils/data_queue.py +184 -0
  44. unienv_interface/utils/stateclass.py +46 -0
  45. unienv_interface/utils/vec_util.py +15 -0
  46. unienv_interface/world/__init__.py +3 -1
  47. unienv_interface/world/combined_funcnode.py +336 -0
  48. unienv_interface/world/combined_node.py +232 -0
  49. unienv_interface/world/funcnode.py +1 -1
  50. unienv_interface/world/node.py +2 -2
  51. unienv_interface/wrapper/backend_compat.py +2 -2
  52. unienv_interface/wrapper/frame_stack.py +19 -114
  53. unienv_interface/wrapper/video_record.py +11 -2
  54. unienv-0.0.1b2.dist-info/METADATA +0 -73
  55. unienv-0.0.1b2.dist-info/RECORD +0 -85
  56. unienv-0.0.1b2.dist-info/top_level.txt +0 -4
  57. unienv_data/samplers/slice_sampler.py +0 -266
  58. unienv_maniskill/__init__.py +0 -1
  59. unienv_maniskill/wrapper/maniskill_compat.py +0 -235
  60. unienv_mjxplayground/__init__.py +0 -1
  61. unienv_mjxplayground/wrapper/playground_compat.py +0 -256
  62. {unienv-0.0.1b2.dist-info → unienv-0.0.1b4.dist-info}/WHEEL +0 -0
@@ -0,0 +1,74 @@
1
+ Metadata-Version: 2.4
2
+ Name: unienv
3
+ Version: 0.0.1b4
4
+ Summary: Unified robot environment framework supporting multiple tensor and simulation backends
5
+ License-Expression: MIT
6
+ Project-URL: Homepage, https://github.com/UniEnvOrg/UniEnv
7
+ Project-URL: Documentation, https://github.com/UniEnvOrg/UniEnv
8
+ Project-URL: Repository, https://github.com/UniEnvOrg/UniEnv
9
+ Project-URL: Issues, https://github.com/UniEnvOrg/UniEnv/issues
10
+ Project-URL: Changelog, https://github.com/UniEnvOrg/UniEnv/blob/main/CHANGELOG.md
11
+ Requires-Python: >=3.10
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ Requires-Dist: numpy
15
+ Requires-Dist: xbarray>=0.0.1a8
16
+ Requires-Dist: pillow
17
+ Requires-Dist: cloudpickle
18
+ Provides-Extra: dev
19
+ Requires-Dist: pytest; extra == "dev"
20
+ Provides-Extra: gymnasium
21
+ Requires-Dist: gymnasium>=0.29.0; extra == "gymnasium"
22
+ Provides-Extra: video
23
+ Requires-Dist: moviepy>=2.1; extra == "video"
24
+ Dynamic: license-file
25
+
26
+ # UniEnv
27
+
28
+ Framework unifying robot environments and data APIs. UniEnv provides an universal interface for robot actors, sensors, environments, and data.
29
+
30
+ ## Tensor library cross-backend Support
31
+
32
+ UniEnv supports multiple tensor backends with zero-copy translation layers through the DLPack protocol, and allows you to use the same abstract compute backend interface to write custom data transformation layers, environment wrappers and other utilities. This is powered by the [XBArray](https://github.com/UniEnvOrg/XBArray) package.
33
+
34
+ ## Universal Robot Environment Interface
35
+
36
+ UniEnv supports diverse simulation environments and real robots, built on top of the abstract environment / world interface. This allows you to reuse code across different sim and real robots.
37
+
38
+ ## Universal Robot Data Interface
39
+
40
+ UniEnv provides a universal data interface for accessing robot data through the abstract `BatchBase` interface. We also provide a utility `ReplayBuffer` for saving data from various environments with diverse data format support, including `hdf5`, memory-mapped torch tensors, and others.
41
+
42
+ ## Installation
43
+
44
+ Install the package with pip
45
+
46
+ ```bash
47
+ pip install unienv
48
+ ```
49
+
50
+ You can install optional dependencies such as `gymnasium` (for Gymnasium-compatible environments), `dev`, or `video` by running
51
+
52
+ ```bash
53
+ pip install unienv[gymnasium,video]
54
+ ```
55
+
56
+ ## Cite
57
+
58
+ If you use UniEnv in your research, please cite it as follows:
59
+
60
+ ```bibtex
61
+ @software{cao_unienv,
62
+ author = {Cao, Yunhao AND Fang, Kuan},
63
+ title = {{UniEnv: Unifying Robot Environments and Data APIs}},
64
+ year = {2025},
65
+ month = oct,
66
+ url = {https://github.com/UniEnvOrg/UniEnv},
67
+ license = {MIT}
68
+ }
69
+ ```
70
+
71
+ ## Acknowledgements
72
+
73
+ The idea of this project is inspired by [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) and its predecessor [OpenAI Gym](https://github.com/openai/gym).
74
+ This library is impossible without the great work of DataAPIs Consortium and their work on the [Array API Standard](https://data-apis.org/array-api/latest/). The zero-copy translation layers are powered by the [DLPack](https://github.com/dmlc/dlpack) project.
@@ -0,0 +1,93 @@
1
+ unienv-0.0.1b4.dist-info/licenses/LICENSE,sha256=nkklvEaJUR4QDBygz7tkEe1FMVKV1JSjnGzJNLhdIWM,1091
2
+ unienv_data/__init__.py,sha256=zFxbe7aM5JvYXIK0FGnOPwWQJMN-8l_l8prB85CkcA8,95
3
+ unienv_data/base/__init__.py,sha256=w-I8A-z7YYArkHc2ZOVGrfzfThsaDBg7aD7qMFprNM8,186
4
+ unienv_data/base/common.py,sha256=EYOzuYmvsCy1uJftsw6cXeycPIr8P7GWZ3_q4wgoNeo,12879
5
+ unienv_data/base/storage.py,sha256=s99PYEZGa76kf-Enx57kOyVkwjb-tpU-vTHcGc5Dcew,5415
6
+ unienv_data/batches/__init__.py,sha256=Vi92f8ddgFYCqwv7xO2Pi3oJePnioJ4XrJbQVV7eIvk,234
7
+ unienv_data/batches/backend_compat.py,sha256=7Juf7nU2jYHohRzNzmGfqMMpJtFM-3oTzzLu6EbC77E,8168
8
+ unienv_data/batches/combined_batch.py,sha256=aua1H86sa_qWrCtAAp5JIZMGtFiiKFPFkU0y5JoyElM,15325
9
+ unienv_data/batches/framestack_batch.py,sha256=pdURqZeksOlbf21Nhx8kkm0gtFt6rjt2OiNWgZPdFCM,2312
10
+ unienv_data/batches/slicestack_batch.py,sha256=J2EhARcPA-zz6EBnV7OLzm4yyvnZ06vrdUoPD5RkJ-o,16672
11
+ unienv_data/batches/transformations.py,sha256=b4HqX3wZ6TuRgQ2q81Jv43PmeHGmP8cwURK_ULjGNgs,5647
12
+ unienv_data/integrations/pytorch.py,sha256=pW5rXBXagfzwJjM_VGgg8CPXEs3e2fKgg4nY7M3dpOc,2350
13
+ unienv_data/replay_buffer/__init__.py,sha256=uVebYruIYlj8OjTYVi8UYI4gWp3S3XIdgFlHbwO260o,100
14
+ unienv_data/replay_buffer/replay_buffer.py,sha256=nhbC-7aHGIYhcCdmaaDdhB2U9ODAZrbKMq8dP8ffOv0,10344
15
+ unienv_data/replay_buffer/trajectory_replay_buffer.py,sha256=fxV6FIqAHhN8opYs2WjAJMPqNRWD3iIku-4WlaydyG4,20737
16
+ unienv_data/samplers/__init__.py,sha256=e7uunWN3r-g_2fDaMsYMe8cZcF4N-okCxqBPweQnE0s,97
17
+ unienv_data/samplers/multiprocessing_sampler.py,sha256=FEBK8pMTnkpA0xuMkbvlv4aIdVTTubeT8BjL60BJL5o,13254
18
+ unienv_data/samplers/step_sampler.py,sha256=ZCcrx9WbILtaR6izhIP3DhtmFcP7KQBdaYaSZ7vWwRk,3010
19
+ unienv_data/storages/dict_storage.py,sha256=SqCGcGT9Y4l0thdmx23XSxRMzIEIuldA6m8Cd9HrpnA,12588
20
+ unienv_data/storages/flattened.py,sha256=Fu01TjrzvmyNhXEGtC4FiBTb7cqXDtVkErc1QNwLvcI,6704
21
+ unienv_data/storages/hdf5.py,sha256=F_mkrmX6SGT2HamJAyYopBmj_Nf5NzJiyvVN9irtiiM,26260
22
+ unienv_data/storages/pytorch.py,sha256=ftO8cND7PFV0J1B1o2YOWqj4U_pyWsJvWv9lC9A7LJg,6953
23
+ unienv_data/storages/transformation.py,sha256=9BIwrvdruiTRduqC03e5UbSjBT1jLSxLCkNfrsVDP7I,7577
24
+ unienv_data/transformations/image_compress.py,sha256=dINrvmpTWy3sbqruHk0kPZG2XNyJI90ERgErXV7GamE,9131
25
+ unienv_interface/__init__.py,sha256=pAWqfm4l7NAssuyXCugIjekSIh05aBbOjNhwsNXcJbE,100
26
+ unienv_interface/backends/__init__.py,sha256=L7CFwCChHVL-2Dpz34pTGC37WgodfJEeDQwXscyM7FM,198
27
+ unienv_interface/backends/base.py,sha256=1_hji1qwNAhcEtFQdAuzaNey9g5bWYj38t1sQxjnggc,132
28
+ unienv_interface/backends/jax.py,sha256=26Wu5OQ4EEjolyZoELhlWMPNSZ7LsVoKEGpd09L80Ck,533
29
+ unienv_interface/backends/numpy.py,sha256=6dMB2Vq7mrWukobyyGvuccluZUgjVkxr7x0hrUc_pe8,542
30
+ unienv_interface/backends/pytorch.py,sha256=BddHmZAngsaedFlvj1mKdXpNe6AWvNwEXq_eTEUoFWA,592
31
+ unienv_interface/backends/serialization.py,sha256=0TZlpfbP1DRB4FkM8ysDVQmn6RlYtIPisyeHjvHr7bE,2289
32
+ unienv_interface/env_base/__init__.py,sha256=JuaVgWlg313LZpflt4LSErY94nUrfvUp0LbIPUle0MA,226
33
+ unienv_interface/env_base/env.py,sha256=PV-AEmKwSjnFDjZFYtBW-At9w4fpm_I5C7GhfxPPrs4,4833
34
+ unienv_interface/env_base/funcenv.py,sha256=Qwm9BP4NrsVHOr7X0l3-mbsn5IhaO3-ZVW48dLg08-k,10609
35
+ unienv_interface/env_base/funcenv_wrapper.py,sha256=chw1iJ1RhAFMv4JAk67cttJvI9agdSm1QxNxZq0-hgM,7757
36
+ unienv_interface/env_base/vec_env.py,sha256=bcv6NdOxt0Xp1fRMXqzFtmVw6LQ-pDj_Jvj-qaW6otQ,16116
37
+ unienv_interface/env_base/wrapper.py,sha256=7hf4Rr2wouS0igPoahhvb2tzYY3bCaWL0NlgwpYZwQs,9734
38
+ unienv_interface/func_wrapper/__init__.py,sha256=6BPF8O25WkIBpODVTwnOE9HGSm3KRKX6iPwFGWESlxA,123
39
+ unienv_interface/func_wrapper/frame_stack.py,sha256=52CqAHDqwgHtOwMwxzB3Syup9kA19zdlvXCH4mI7MNU,6819
40
+ unienv_interface/func_wrapper/transformation.py,sha256=7mdzcpjLjqtpbtXoqbkGtTMPQxoMmMsqzDWHcZLbrhs,5939
41
+ unienv_interface/space/__init__.py,sha256=6-wLoD9mKDAfz7IuQs_Rn9DMDfDwTZ0tEhQ924libpg,99
42
+ unienv_interface/space/space.py,sha256=mFlCcDvMgEPTXlwo_iwBlm6Eg4Bn2rrecgsfIVstdq0,4067
43
+ unienv_interface/space/space_utils/__init__.py,sha256=GAsPoZC8YNabx3Gw5m2o4zsnG8zmA3mcuM9_lNKhiGo,121
44
+ unienv_interface/space/space_utils/batch_utils.py,sha256=qXK7kERPXKGIYozz7lpjzVz56S9GkH6ZASfIRzCYXHY,36993
45
+ unienv_interface/space/space_utils/construct_utils.py,sha256=Y4RpV9obY8XQ85O3r_NC1HrBk-Nm941ffRNXNL7nHgA,8323
46
+ unienv_interface/space/space_utils/flatten_utils.py,sha256=kkHkjrsk43NDbg3Q5VAhVoIXStuRayYFO-7knsDzx4A,12289
47
+ unienv_interface/space/space_utils/gym_utils.py,sha256=nH8EKruOKCXNrIMPUd9F4XGKCfFkhxsTmx4I1BeSgn0,15079
48
+ unienv_interface/space/space_utils/serialization_utils.py,sha256=LWYSFN7E6tEFe8ULWm42LkFUxP_0dfTGkCcx0yl4Y8s,9530
49
+ unienv_interface/space/spaces/__init__.py,sha256=Jap768TlwHFDDiTzHZ0qaHEFEVC1cKA2QzLlSZVQnjI,535
50
+ unienv_interface/space/spaces/batched.py,sha256=RA8aLUSS14zBSCTm_ud18TTa-ntbIZ074xwJ0xls1Jk,3691
51
+ unienv_interface/space/spaces/binary.py,sha256=0iQUbO37dhkznVpjhsJdwlD-KdWgCEx2H7KrybuZ_PM,3570
52
+ unienv_interface/space/spaces/box.py,sha256=NCmileEZOKz-L3WNzZ-uwydrRFsIMdNZBwTn1vWgeI0,13316
53
+ unienv_interface/space/spaces/dict.py,sha256=G5_iYC1Bj5DqeJ7aFlq6eRJbnpATbIRIyRu1jF_UUvk,7022
54
+ unienv_interface/space/spaces/dynamic_box.py,sha256=HvMNgzfYwIVc5VVgCtq-8lQbNI1V1dZMI-w60AwYss4,19591
55
+ unienv_interface/space/spaces/graph.py,sha256=KocRFLtYP5VWYpwbP6HybXH5R4jTIYJdNePKb6vhnYE,15163
56
+ unienv_interface/space/spaces/text.py,sha256=ePGGJdiD3q-BAX6IHLO7HMe0OH4VrzF043K02eb0zXI,4443
57
+ unienv_interface/space/spaces/tuple.py,sha256=rgZQz3EB3CLbIk9UlHBQbM6w9gssbA1izm-Qq-_Chqs,4267
58
+ unienv_interface/space/spaces/union.py,sha256=Qisd-DdmPcGRmdhZFGiQw8_AOjYWqkuQ4Hwd-I8tdSI,4375
59
+ unienv_interface/transformations/__init__.py,sha256=g19uGnDHMywvDAXRaqFgoWAF1vCPrbJENEpaEgtIrOw,353
60
+ unienv_interface/transformations/batch_and_unbatch.py,sha256=ELCnNtwmgA5wpTBJZasfNSHmtf4vzydzLPmO6IGbT9o,1164
61
+ unienv_interface/transformations/chained_transform.py,sha256=TDnUvxUKK6bXGc_sfr6ZCvvVWw7P5KX2sA9i7i2lx14,2075
62
+ unienv_interface/transformations/dict_transform.py,sha256=ynrJrloVUix2I27Ir1mL86crT0vY5DvpiBAVxPBJup4,5357
63
+ unienv_interface/transformations/filter_dict.py,sha256=DzR-hgHoHJObTipxwB2UrKVlTxbfIrJohaOgqdAICLY,5871
64
+ unienv_interface/transformations/rescale.py,sha256=fM5ukWUvNvPeDO48_PRU0KyyvGhBIDxaN9XZyQ1VaQQ,4364
65
+ unienv_interface/transformations/transformation.py,sha256=u4_9H1tvophhgG0p0F3xfkMMsRuaKY2TQmVeGoeQsJ0,1652
66
+ unienv_interface/utils/control_util.py,sha256=lY_1EknglY3cNekWX9rYWt0ZUglaPMtIt4M5D9y0WfE,2351
67
+ unienv_interface/utils/data_queue.py,sha256=UZiuQDOn39DB9Heu6xinrwuzAL3X8jHlDkFoSC5Phtc,5707
68
+ unienv_interface/utils/seed_util.py,sha256=Up3nBXj7L8w-S9W5Q1U2d9accMhMf0TmHPaN6JXDVWs,677
69
+ unienv_interface/utils/stateclass.py,sha256=xjzicPGX1UuI7q3ZAxhBCCoouKfNtLywUzQtLaT0yS4,1390
70
+ unienv_interface/utils/symbol_util.py,sha256=NAERK-D_2MaTg2eYW-L75tbzPQN5YJIiKtM9zuQ89Sw,383
71
+ unienv_interface/utils/vec_util.py,sha256=EIK680ReCl_rr-qKP8co5hwz8Dx-gks8SHf-CLOZSOA,373
72
+ unienv_interface/world/__init__.py,sha256=aGuYTz8XFzW32RGkdi2b2LJ1sa0kgFrQyOR3JXDEwLQ,230
73
+ unienv_interface/world/combined_funcnode.py,sha256=O9qWxhtMJkDVtWuGyaeEj3nKMgIyRAPqF9-5LU6yna8,10853
74
+ unienv_interface/world/combined_node.py,sha256=tG7I9uWVxDDN6M6KeC1D14MV7YUnXYMUK9L9KXHnViA,9090
75
+ unienv_interface/world/funcnode.py,sha256=WvTNisOwPTwWlxC5NwQRxi-gh6MxLohh7ulctj-2YXY,7846
76
+ unienv_interface/world/funcworld.py,sha256=GLp8nS0Ym1gaj7FWvD5FPkQElCgZMbpyuLsIMU0w-sw,2020
77
+ unienv_interface/world/node.py,sha256=EAvHnx0u7IudmWQDbAUIRVEqB4kh2Xsm1aXdS3CeloY,6095
78
+ unienv_interface/world/world.py,sha256=Kl7wbNbs2YR3CjFrCLFhDB3DQUAWM6LjBwSADQtBTII,5740
79
+ unienv_interface/wrapper/__init__.py,sha256=ZNqr-WjVRqgvIxkLkeABxpYZ6tRgJNZOzmluDeJ6W_w,614
80
+ unienv_interface/wrapper/action_rescale.py,sha256=rTJlEHvWSuwGVX83cjfLWvszBk7B2iExX_K37vH8Wic,1231
81
+ unienv_interface/wrapper/backend_compat.py,sha256=T6hosgu2hrZvg3xtnyELmR6Exlz-ztqdj9vdyiz7bhI,7081
82
+ unienv_interface/wrapper/batch_and_unbatch.py,sha256=HpmnppgOKmshNlfmJYkGQYtEU7_U7q3mEdY5n4UaqEY,3457
83
+ unienv_interface/wrapper/control_frequency_limit.py,sha256=B0E2aUbaUr2p2yIN6wT3q4rAbPYsVroioqma2qKMoC0,2322
84
+ unienv_interface/wrapper/flatten.py,sha256=NWA5xne5j_L34oq_wT85wGvp6iHwdCSeGsk1DMugvRw,5837
85
+ unienv_interface/wrapper/frame_stack.py,sha256=lZZh_T_AmxsRWeYSLsTU321lVgIt12MX1eWl_yRNlWg,6002
86
+ unienv_interface/wrapper/gym_compat.py,sha256=JhLxDsO1NsJnKzKhO0MqMw9i5_1FLxoxKilWaQQyBkw,9789
87
+ unienv_interface/wrapper/time_limit.py,sha256=VRvB00BK7deI2QtdGatqwDWmPgjgjg1E7MTvEyaW5rg,2904
88
+ unienv_interface/wrapper/transformation.py,sha256=pQ-_YVU8WWDqSk2sONUUgQY1iigOD092KNcp1DYxoxk,10043
89
+ unienv_interface/wrapper/video_record.py,sha256=y_nJRYgo1SeLeO_Ymg9xbbGPKm48AbU3BxZK2wd0gzk,8679
90
+ unienv-0.0.1b4.dist-info/METADATA,sha256=R_70XnKo1K6ObRxMmSlW1W_lxfD_rGR6txa3wBHGPOM,3033
91
+ unienv-0.0.1b4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
92
+ unienv-0.0.1b4.dist-info/top_level.txt,sha256=wfcJ5_DruUtOEUZjEyfadaKn7B90hWqz2aw-eM3wX5g,29
93
+ unienv-0.0.1b4.dist-info/RECORD,,
@@ -1,8 +1,6 @@
1
1
  MIT License
2
2
 
3
- Copyright (c) 2016 OpenAI
4
- Copyright (c) 2022 Farama Foundation
5
- Copyright (c) 2024 Yunhao Cao
3
+ Copyright (c) 2025 Yunhao Cao and UniEnv Contributors
6
4
 
7
5
  Permission is hereby granted, free of charge, to any person obtaining a copy
8
6
  of this software and associated documentation files (the "Software"), to deal
@@ -0,0 +1,2 @@
1
+ unienv_data
2
+ unienv_interface
@@ -1,3 +1,2 @@
1
1
  from .common import BatchT, SamplerBatchT, SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType, BatchBase, BatchSampler, IndexableType
2
- from .transformations import TransformedBatch
3
2
  from .storage import SpaceStorage
@@ -9,26 +9,44 @@ import dataclasses
9
9
 
10
10
  from unienv_interface.space.space_utils import batch_utils as space_batch_utils, flatten_utils as space_flatten_utils
11
11
 
12
+ __all__ = [
13
+ "BatchT",
14
+ "BatchBase",
15
+ "BatchSampler",
16
+ "IndexableType",
17
+ "convert_index_to_backendarray",
18
+ ]
19
+
12
20
  IndexableType = Union[int, slice, EllipsisType]
13
21
 
22
+ def convert_index_to_backendarray(
23
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
24
+ index : IndexableType,
25
+ length : int,
26
+ device : Optional[BDeviceType] = None,
27
+ ) -> BArrayType:
28
+ if isinstance(index, int):
29
+ return backend.asarray([index], dtype=backend.default_integer_dtype, device=device)
30
+ elif isinstance(index, slice):
31
+ return backend.arange(*index.indices(length), dtype=backend.default_integer_dtype, device=device)
32
+ elif index is Ellipsis:
33
+ return backend.arange(length, dtype=backend.default_integer_dtype, device=device)
34
+ else:
35
+ raise ValueError("Index must be an integer, slice, or Ellipsis.")
36
+
14
37
  BatchT = TypeVar('BatchT')
15
38
  class BatchBase(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
16
- backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]
17
- device: Optional[BDeviceType] = None
18
-
19
39
  # If the batch is mutable, then the data can be changed (extend_*, set_*, remove_*, etc.)
20
40
  is_mutable: bool = True
21
41
 
22
42
  def __init__(
23
43
  self,
24
- single_space : Space[Any, BDeviceType, BDtypeType, BRNGType],
44
+ single_space : Space[BatchT, BDeviceType, BDtypeType, BRNGType],
25
45
  single_metadata_space : Optional[DictSpace[BDeviceType, BDtypeType, BRNGType]] = None,
26
46
  ):
27
47
  self.single_space = single_space
28
48
  self.single_metadata_space = single_metadata_space
29
- self._batched_space : Space[
30
- BatchT, Any, BDeviceType, BDtypeType, BRNGType
31
- ] = space_batch_utils.batch_space(single_space, 1)
49
+ self._batched_space : Space[BatchT, BDeviceType, BDtypeType, BRNGType] = space_batch_utils.batch_space(single_space, 1)
32
50
  if single_metadata_space is not None:
33
51
  self._batched_metadata_space : DictSpace[
34
52
  BDeviceType, BDtypeType, BRNGType
@@ -36,24 +54,43 @@ class BatchBase(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BR
36
54
  else:
37
55
  self._batched_metadata_space = None
38
56
 
57
+ @property
58
+ def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
59
+ return self.single_space.backend
60
+
61
+ @property
62
+ def device(self) -> Optional[BDeviceType]:
63
+ return self.single_space.device
64
+
39
65
  @abc.abstractmethod
40
66
  def __len__(self) -> int:
41
67
  raise NotImplementedError
42
68
 
43
69
  def get_flattened_at(self, idx : Union[IndexableType, BArrayType]) -> BArrayType:
44
- return self.get_flattened_at_with_metadata(idx)[0]
70
+ unflattened_data = self.get_at(idx)
71
+ if isinstance(idx, int):
72
+ return space_flatten_utils.flatten_data(self.single_space, unflattened_data)
73
+ else:
74
+ return space_flatten_utils.flatten_data(self._batched_space, unflattened_data, start_dim=1)
45
75
 
46
- @abc.abstractmethod
47
76
  def get_flattened_at_with_metadata(
48
77
  self, idx : Union[IndexableType, BArrayType]
49
78
  ) -> Tuple[BArrayType, Optional[Dict[str, Any]]]:
50
- raise NotImplementedError
79
+ unflattened_data, metadata = self.get_at_with_metadata(idx)
80
+ if isinstance(idx, int):
81
+ return space_flatten_utils.flatten_data(self.single_space, unflattened_data), metadata
82
+ else:
83
+ return space_flatten_utils.flatten_data(self._batched_space, unflattened_data, start_dim=1), metadata
51
84
 
52
85
  def set_flattened_at(self, idx : Union[IndexableType, BArrayType], value : BArrayType) -> None:
53
86
  raise NotImplementedError
54
87
 
88
+ def append_flattened(self, value : BArrayType) -> None:
89
+ return self.extend_flattened(value[None])
90
+
55
91
  def extend_flattened(self, value : BArrayType) -> None:
56
- raise NotImplementedError
92
+ unflat_data = space_flatten_utils.unflatten_data(self._batched_space, value, start_dim=1)
93
+ self.extend(unflat_data)
57
94
 
58
95
  def get_at(self, idx : Union[IndexableType, BArrayType]) -> BatchT:
59
96
  flattened_data = self.get_flattened_at(idx)
@@ -90,55 +127,81 @@ class BatchBase(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BR
90
127
  def __delitem__(self, idx : Union[IndexableType, BArrayType]) -> None:
91
128
  self.remove_at(idx)
92
129
 
130
+ def append(self, value : BatchT) -> None:
131
+ batched_data = space_batch_utils.concatenate(self._batched_space, [value])
132
+ self.extend(batched_data)
133
+
93
134
  def extend(self, value : BatchT) -> None:
94
135
  flattened_data = space_flatten_utils.flatten_data(self._batched_space, value, start_dim=1)
95
136
  self.extend_flattened(flattened_data)
96
137
 
138
+ def extend_from(
139
+ self,
140
+ other : 'BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]',
141
+ chunk_size : int = 8,
142
+ tqdm : bool = False,
143
+ ) -> None:
144
+ n_total = len(other)
145
+ iterable_start = range(0, n_total, chunk_size)
146
+ if tqdm:
147
+ from tqdm import tqdm
148
+ iterable_start = tqdm(iterable_start, desc="Extending Batch")
149
+ for start_idx in range(0, n_total, chunk_size):
150
+ end_idx = min(start_idx + chunk_size, n_total)
151
+ data_chunk = other.get_at(slice(start_idx, end_idx))
152
+ self.extend(data_chunk)
153
+
97
154
  def close(self) -> None:
98
155
  pass
99
156
 
100
- def __del__(self) -> None:
101
- self.close()
102
-
103
157
  SamplerBatchT = TypeVar('SamplerBatchT')
104
158
  SamplerArrayType = TypeVar('SamplerArrayType')
105
159
  SamplerDeviceType = TypeVar('SamplerDeviceType')
106
160
  SamplerDtypeType = TypeVar('SamplerDtypeType')
107
161
  SamplerRNGType = TypeVar('SamplerRNGType')
108
- class BatchSampler(abc.ABC, Generic[
109
- SamplerBatchT, SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType,
110
- BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType,
111
- ]):
112
- batch_size : int
113
- sampled_space : Space[SamplerBatchT, SamplerDeviceType, SamplerDtypeType, SamplerRNGType]
114
- sampled_space_flat : BoxSpace[SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType]
115
- sampled_metadata_space : Optional[DictSpace[SamplerDeviceType, SamplerDtypeType, SamplerRNGType]] = None
116
-
117
- backend : ComputeBackend[SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType]
118
- device : Optional[SamplerDeviceType] = None
119
-
162
+ class BatchSampler(
163
+ Generic[
164
+ SamplerBatchT, SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType,
165
+ BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType,
166
+ ],
167
+ BatchBase[
168
+ SamplerBatchT, SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType
169
+ ]
170
+ ):
120
171
  data : BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]
121
172
 
122
173
  rng : Optional[SamplerRNGType] = None
123
174
  data_rng : Optional[BRNGType] = None
124
-
125
- def get_flat_at(self, idx : SamplerArrayType) -> SamplerArrayType:
126
- return self.get_flat_at_with_metadata(idx)[0]
127
-
128
- @abc.abstractmethod
129
- def get_flat_at_with_metadata(
130
- self, idx : SamplerArrayType
131
- ) -> Tuple[SamplerArrayType, Optional[Dict[str, Any]]]:
132
- raise NotImplementedError
133
175
 
134
- def get_at(self, idx : SamplerArrayType) -> SamplerBatchT:
135
- return space_flatten_utils.unflatten_data(self.sampled_space, self.get_flat_at(idx), start_dim=1)
176
+ is_mutable : bool = False
177
+
178
+ def __init__(
179
+ self,
180
+ single_space : Space[BatchT, BDeviceType, BDtypeType, BRNGType],
181
+ single_metadata_space : Optional[DictSpace[BDeviceType, BDtypeType, BRNGType]] = None,
182
+ batch_size : int = 1,
183
+ ) -> None:
184
+ super().__init__(single_space=single_space, single_metadata_space=single_metadata_space)
185
+ self.batch_size = batch_size
186
+ self._batched_space : Space[SamplerBatchT, SamplerDeviceType, SamplerDtypeType, SamplerRNGType] = space_batch_utils.batch_space(self.single_space, batch_size)
187
+ self._batched_metadata_space : Optional[DictSpace[SamplerDeviceType, SamplerDtypeType, SamplerRNGType]] = space_batch_utils.batch_space(self.single_metadata_space, batch_size) if self.single_metadata_space is not None else None
188
+
189
+ def manual_seed(self, seed : int) -> None:
190
+ if self.rng is not None:
191
+ self.rng = self.backend.random.random_number_generator(seed, device=self.device)
192
+ if self.data_rng is not None:
193
+ self.data_rng = self.backend.random.random_number_generator(seed, device=self.data.device)
194
+
195
+ @property
196
+ def sampled_space(self) -> Space[SamplerBatchT, SamplerDeviceType, SamplerDtypeType, SamplerRNGType]:
197
+ return self._batched_space
136
198
 
137
- def get_at_with_metadata(
138
- self, idx : SamplerArrayType
139
- ) -> Tuple[SamplerBatchT, Optional[Dict[str, Any]]]:
140
- flat_data, metadata = self.get_flat_at_with_metadata(idx)
141
- return space_flatten_utils.unflatten_data(self.sampled_space, flat_data, start_dim=1), metadata
199
+ @property
200
+ def sampled_metadata_space(self) -> Optional[DictSpace[SamplerDeviceType, SamplerDtypeType, SamplerRNGType]]:
201
+ return self._batched_metadata_space
202
+
203
+ def __len__(self):
204
+ return len(self.data)
142
205
 
143
206
  def sample_index(self) -> SamplerArrayType:
144
207
  new_rng, indices = self.backend.random.random_discrete_uniform( # (B, )
@@ -156,11 +219,11 @@ class BatchSampler(abc.ABC, Generic[
156
219
 
157
220
  def sample_flat(self) -> SamplerArrayType:
158
221
  idx = self.sample_index()
159
- return self.get_flat_at(idx)
222
+ return self.get_flattened_at(idx)
160
223
 
161
224
  def sample_flat_with_metadata(self) -> Tuple[SamplerArrayType, Optional[Dict[str, Any]]]:
162
225
  idx = self.sample_index()
163
- return self.get_flat_at_with_metadata(idx)
226
+ return self.get_flattened_at_with_metadata(idx)
164
227
 
165
228
  def sample(self) -> SamplerBatchT:
166
229
  idx = self.sample_index()
@@ -205,9 +268,9 @@ class BatchSampler(abc.ABC, Generic[
205
268
  n_batches = len(self.data) // self.batch_size
206
269
  num_left = len(self.data) % self.batch_size
207
270
  for i in range(n_batches):
208
- yield self.get_flat_at(idx[i*self.batch_size:(i+1)*self.batch_size])
271
+ yield self.get_flattened_at(idx[i*self.batch_size:(i+1)*self.batch_size])
209
272
  if num_left > 0:
210
- yield self.get_flat_at(idx[-num_left:])
273
+ yield self.get_flattened_at(idx[-num_left:])
211
274
 
212
275
  def epoch_flat_iter_with_metadata(self) -> Iterator[Tuple[SamplerArrayType, Optional[Dict[str, Any]]]]:
213
276
  if self.data_rng is not None:
@@ -217,12 +280,9 @@ class BatchSampler(abc.ABC, Generic[
217
280
  n_batches = len(self.data) // self.batch_size
218
281
  num_left = len(self.data) % self.batch_size
219
282
  for i in range(n_batches):
220
- yield self.get_flat_at_with_metadata(idx[i*self.batch_size:(i+1)*self.batch_size])
283
+ yield self.get_flattened_at_with_metadata(idx[i*self.batch_size:(i+1)*self.batch_size])
221
284
  if num_left > 0:
222
- yield self.get_flat_at_with_metadata(idx[-num_left:])
285
+ yield self.get_flattened_at_with_metadata(idx[-num_left:])
223
286
 
224
287
  def close(self) -> None:
225
288
  pass
226
-
227
- def __del__(self) -> None:
228
- self.close()
@@ -31,6 +31,7 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
31
31
  single_instance_space: Space[BatchT, BDeviceType, BDtypeType, BRNGType],
32
32
  *,
33
33
  capacity : Optional[int] = None,
34
+ read_only : bool = True,
34
35
  **kwargs
35
36
  ) -> "SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
36
37
  raise NotImplementedError
@@ -56,6 +57,17 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
56
57
  """
57
58
  cache_filename : Optional[Union[str, os.PathLike]] = None
58
59
 
60
+ """
61
+ Can the storage instance be safely used in multiprocessing environments after creation?
62
+ If True, the storage can be used in multiprocessing environments.
63
+ """
64
+ is_multiprocessing_safe : bool = False
65
+
66
+ """
67
+ Is the storage mutable? If False, the storage is read-only.
68
+ """
69
+ is_mutable : bool = True
70
+
59
71
  @property
60
72
  def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
61
73
  return self.single_instance_space.backend
@@ -127,6 +139,3 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
127
139
 
128
140
  def close(self) -> None:
129
141
  pass
130
-
131
- def __del__(self) -> None:
132
- self.close()
@@ -1,4 +1,5 @@
1
1
  from .backend_compat import ToBackendOrDeviceBatch
2
2
  from .combined_batch import CombinedBatch
3
3
  from .slicestack_batch import SliceStackedBatch
4
- from .framestack_batch import FrameStackedBatch
4
+ from .framestack_batch import FrameStackedBatch
5
+ from .transformations import TransformedBatch
@@ -66,7 +66,7 @@ class ToBackendOrDeviceBatch(
66
66
  )
67
67
  self.batch = batch
68
68
  self.target_backend = backend
69
- self.device = device
69
+ self.target_device = device
70
70
 
71
71
  def __len__(self) -> int:
72
72
  return len(self.batch)
@@ -79,7 +79,18 @@ class ToBackendOrDeviceBatch(
79
79
  def backend(self) -> ComputeBackend[WrapperBArrayT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT]:
80
80
  return self.target_backend if self.target_backend is not None else self.batch.backend
81
81
 
82
+ @property
83
+ def device(self) -> Optional[WrapperBDeviceT]:
84
+ return self.target_device if self.target_device is not None else self.batch.device
85
+
82
86
  def get_flattened_at(self, idx):
87
+ if self.target_backend.is_backendarray(idx):
88
+ idx = data_to(
89
+ idx,
90
+ source_backend=self.target_backend,
91
+ target_backend=self.batch.backend,
92
+ target_device=self.batch.device
93
+ )
83
94
  o_data = self.batch.get_flattened_at(idx)
84
95
  return data_to(
85
96
  o_data,
@@ -89,6 +100,13 @@ class ToBackendOrDeviceBatch(
89
100
  )
90
101
 
91
102
  def get_flattened_at_with_metadata(self, idx):
103
+ if self.target_backend.is_backendarray(idx):
104
+ idx = data_to(
105
+ idx,
106
+ source_backend=self.target_backend,
107
+ target_backend=self.batch.backend,
108
+ target_device=self.batch.device
109
+ )
92
110
  o_data, o_metadata = self.batch.get_flattened_at_with_metadata(idx)
93
111
  return (
94
112
  data_to(
@@ -107,6 +125,13 @@ class ToBackendOrDeviceBatch(
107
125
 
108
126
  def set_flattened_at(self, idx, value):
109
127
  assert self.is_mutable, "Batch is not mutable"
128
+ if self.target_backend.is_backendarray(idx):
129
+ idx = data_to(
130
+ idx,
131
+ source_backend=self.target_backend,
132
+ target_backend=self.batch.backend,
133
+ target_device=self.batch.device
134
+ )
110
135
  value = data_to(
111
136
  value,
112
137
  source_backend=self.target_backend,
@@ -126,6 +151,13 @@ class ToBackendOrDeviceBatch(
126
151
  self.batch.extend_flattened(value)
127
152
 
128
153
  def get_at(self, idx):
154
+ if self.target_backend.is_backendarray(idx):
155
+ idx = data_to(
156
+ idx,
157
+ source_backend=self.target_backend,
158
+ target_backend=self.batch.backend,
159
+ target_device=self.batch.device
160
+ )
129
161
  o_data = self.batch.get_at(idx)
130
162
  return (
131
163
  data_to(
@@ -137,6 +169,13 @@ class ToBackendOrDeviceBatch(
137
169
  )
138
170
 
139
171
  def get_at_with_metadata(self, idx):
172
+ if self.target_backend.is_backendarray(idx):
173
+ idx = data_to(
174
+ idx,
175
+ source_backend=self.target_backend,
176
+ target_backend=self.batch.backend,
177
+ target_device=self.batch.device
178
+ )
140
179
  o_data, o_metadata = self.batch.get_at_with_metadata(idx)
141
180
  return (
142
181
  data_to(
@@ -155,6 +194,13 @@ class ToBackendOrDeviceBatch(
155
194
 
156
195
  def set_at(self, idx, value):
157
196
  assert self.is_mutable, "Batch is not mutable"
197
+ if self.target_backend.is_backendarray(idx):
198
+ idx = data_to(
199
+ idx,
200
+ source_backend=self.target_backend,
201
+ target_backend=self.batch.backend,
202
+ target_device=self.batch.device
203
+ )
158
204
  o_value = data_to(
159
205
  value,
160
206
  source_backend=self.target_backend,
@@ -56,8 +56,6 @@ class CombinedBatch(BatchBase[
56
56
  )
57
57
  super().__init__(single_space, new_metadata_space)
58
58
 
59
- self.backend = backend
60
- self.device = device
61
59
  self.is_mutable = is_mutable
62
60
  self.batches = batches
63
61
  self._build_index_cache()
@@ -248,7 +246,7 @@ class CombinedBatch(BatchBase[
248
246
  result = result_space.create_empty()
249
247
  for batch_index, index_into_batch, mask in batch_list:
250
248
  result = sbu.set_at(
251
- self.single_space,
249
+ result_space,
252
250
  result,
253
251
  mask,
254
252
  self.batches[batch_index].get_at(index_into_batch),
@@ -295,7 +293,7 @@ class CombinedBatch(BatchBase[
295
293
  for batch_index, index_into_batch, mask in batch_list:
296
294
  batch_result, metadata_result = self.batches[batch_index].get_at_with_metadata(index_into_batch)
297
295
  result = sbu.set_at(
298
- self.single_space,
296
+ result_space,
299
297
  result,
300
298
  mask,
301
299
  batch_result,
@@ -1,11 +1,12 @@
1
- from typing import Optional, Any, Union
1
+ from typing import Optional, Any, Union, Tuple, Dict
2
2
  from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
3
3
 
4
4
  from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
5
- from .common import *
6
5
  from unienv_interface.transformations.transformation import DataTransformation, TargetDataT, SourceDataT, SourceBArrT, SourceBDeviceT, SourceBDTypeT, SourceBDRNGT
7
6
  from unienv_interface.space import Space
8
7
 
8
+ from ..base.common import BatchBase, BatchT, IndexableType
9
+
9
10
  class TransformedBatch(
10
11
  BatchBase[
11
12
  BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType
@@ -63,6 +63,8 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
63
63
  **kwargs
64
64
  ) -> "ReplayBuffer[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
65
65
  storage_path_relative = "storage" + (storage_cls.single_file_ext or "")
66
+ if cache_path is not None:
67
+ os.makedirs(cache_path, exist_ok=True)
66
68
  storage = storage_cls.create(
67
69
  single_instance_space,
68
70
  *args,
@@ -94,6 +96,7 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
94
96
  *,
95
97
  backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
96
98
  device: Optional[BDeviceType] = None,
99
+ read_only : bool = True,
97
100
  **storage_kwargs
98
101
  ) -> "ReplayBuffer[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
99
102
  with open(os.path.join(path, "metadata.json"), "r") as f:
@@ -114,6 +117,7 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
114
117
  storage_path,
115
118
  single_instance_space,
116
119
  capacity=capacity,
120
+ read_only=read_only,
117
121
  **storage_kwargs
118
122
  )
119
123
  return ReplayBuffer(storage, metadata["storage_path_relative"], count, offset, cache_path=path)
@@ -1,3 +1,2 @@
1
1
  from .step_sampler import StepSampler
2
- from .slice_sampler import SliceSampler
3
2
  from .multiprocessing_sampler import MultiprocessingSampler