torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,2239 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import collections
9
+ import importlib
10
+ import warnings
11
+ from contextlib import nullcontext
12
+ from copy import copy
13
+ from functools import partial
14
+ from types import ModuleType
15
+ from warnings import warn
16
+
17
+ import numpy as np
18
+ import torch
19
+ from packaging import version
20
+ from tensordict import TensorDict, TensorDictBase
21
+ from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
22
+
23
+ TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
24
+
25
+ from torchrl._utils import implement_for, logger as torchrl_logger
26
+ from torchrl.data.tensor_specs import (
27
+ _minmax_dtype,
28
+ Binary,
29
+ Bounded,
30
+ Categorical,
31
+ Composite,
32
+ MultiCategorical,
33
+ MultiOneHot,
34
+ NonTensor,
35
+ OneHot,
36
+ TensorSpec,
37
+ Unbounded,
38
+ )
39
+ from torchrl.data.utils import numpy_to_torch_dtype_dict, torch_to_numpy_dtype_dict
40
+ from torchrl.envs.common import _EnvPostInit
41
+ from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv
42
+ from torchrl.envs.utils import _classproperty
43
+
44
+ try:
45
+ from torch.utils._contextlib import _DecoratorContextManager
46
+ except ModuleNotFoundError:
47
+ from torchrl._utils import _DecoratorContextManager
48
+
49
+ DEFAULT_GYM = None
50
+ IMPORT_ERROR = None
51
+ # check gym presence without importing it
52
+ _has_gym = importlib.util.find_spec("gym") is not None
53
+ if not _has_gym:
54
+ _has_gym = importlib.util.find_spec("gymnasium") is not None
55
+
56
+ _has_mo = importlib.util.find_spec("mo_gymnasium") is not None
57
+ _has_sb3 = importlib.util.find_spec("stable_baselines3") is not None
58
+ _has_isaaclab = importlib.util.find_spec("isaaclab") is not None
59
+ _has_minigrid = importlib.util.find_spec("minigrid") is not None
60
+
61
+
62
+ GYMNASIUM_1_ERROR = """RuntimeError: TorchRL does not support gymnasium 1.0 versions due to incompatible
63
+ changes in the Gym API.
64
+ Using gymnasium 1.0 with TorchRL would require significant modifications to your code and may result in:
65
+ * Inaccurate step counting, as the auto-reset feature can cause unpredictable numbers of steps to be executed.
66
+ * Potential data corruption, as the environment may require/produce garbage data during reset steps.
67
+ * Trajectory overlap during data collection.
68
+ * Increased computational overhead, as the library would need to handle the additional complexity of auto-resets.
69
+ * Manual filtering and boilerplate code to mitigate these issues, which would compromise the modularity and ease of
70
+ use of TorchRL.
71
+ To maintain the integrity and efficiency of our library, we cannot support this version of gymnasium at this time.
72
+ If you need to use gymnasium 1.0, we recommend exploring alternative solutions or waiting for future updates
73
+ to TorchRL and gymnasium that may address this compatibility issue.
74
+ For more information, please refer to discussion https://github.com/pytorch/rl/discussions/2483 in torchrl.
75
+ """
76
+
77
+
78
+ def _minigrid_lib():
79
+ assert _has_minigrid, "minigrid not found"
80
+ import minigrid
81
+
82
+ return minigrid
83
+
84
+
85
+ class set_gym_backend(_DecoratorContextManager):
86
+ """Sets the gym-backend to a certain value.
87
+
88
+ Args:
89
+ backend (python module, string or callable returning a module): the
90
+ gym backend to use. Use a string or callable whenever you wish to
91
+ avoid importing gym at loading time.
92
+
93
+ Examples:
94
+ >>> import gym
95
+ >>> import gymnasium
96
+ >>> with set_gym_backend("gym"):
97
+ ... assert gym_backend() == gym
98
+ >>> with set_gym_backend(lambda: gym):
99
+ ... assert gym_backend() == gym
100
+ >>> with set_gym_backend(gym):
101
+ ... assert gym_backend() == gym
102
+ >>> with set_gym_backend("gymnasium"):
103
+ ... assert gym_backend() == gymnasium
104
+ >>> with set_gym_backend(lambda: gymnasium):
105
+ ... assert gym_backend() == gymnasium
106
+ >>> with set_gym_backend(gymnasium):
107
+ ... assert gym_backend() == gymnasium
108
+
109
+ This class can also be used as a function decorator.
110
+
111
+ Examples:
112
+ >>> @set_gym_backend("gym")
113
+ ... def fun():
114
+ ... gym = gym_backend()
115
+ ... print(gym)
116
+ >>> fun()
117
+ <module 'gym' from '/path/to/env/site-packages/gym/__init__.py'>
118
+ >>> @set_gym_backend("gymnasium")
119
+ ... def fun():
120
+ ... gym = gym_backend()
121
+ ... print(gym)
122
+ >>> fun()
123
+ <module 'gymnasium' from '/path/to/env/site-packages/gymnasium/__init__.py'>
124
+
125
+
126
+ """
127
+
128
+ def __init__(self, backend):
129
+ self.backend = backend
130
+
131
+ def _call(self):
132
+ """Sets the backend as default."""
133
+ global DEFAULT_GYM
134
+ DEFAULT_GYM = self.backend
135
+ found_setters = collections.defaultdict(bool)
136
+ for setter in copy(implement_for._setters):
137
+ check_module = (
138
+ callable(setter.module_name)
139
+ and setter.module_name.__name__ == self.backend.__name__
140
+ ) or setter.module_name == self.backend.__name__
141
+ check_version = setter.check_version(
142
+ self.backend.__version__, setter.from_version, setter.to_version
143
+ )
144
+ if check_module and check_version:
145
+ setter.module_set()
146
+ found_setter = True
147
+ elif check_module:
148
+ found_setter = False
149
+ else:
150
+ found_setter = None
151
+ if found_setter is not None:
152
+ found_setters[setter.func_name] = (
153
+ found_setters[setter.func_name] or found_setter
154
+ )
155
+ # we keep only the setters we need. This is safe because a copy is saved under self._setters_saved
156
+ for func_name, found_setter in found_setters.items():
157
+ if not found_setter:
158
+ raise ImportError(
159
+ f"could not set anything related to gym backend "
160
+ f"{self.backend.__name__} with version={self.backend.__version__} for the function with name {func_name}. "
161
+ f"Check that the gym versions match!"
162
+ )
163
+
164
+ def set(self):
165
+ """Irreversibly sets the gym backend in the script."""
166
+ self._call()
167
+
168
+ def __enter__(self):
169
+ global DEFAULT_GYM
170
+ # Save the current DEFAULT_GYM so we can restore it on exit
171
+ self._default_gym_saved = DEFAULT_GYM
172
+ self._call()
173
+
174
+ def __exit__(self, exc_type, exc_val, exc_tb):
175
+ global DEFAULT_GYM
176
+ # Restore the previous DEFAULT_GYM
177
+ saved_gym = self._default_gym_saved
178
+ DEFAULT_GYM = saved_gym
179
+ delattr(self, "_default_gym_saved")
180
+ # Re-activate the implementations for the original backend
181
+ # If saved_gym was None, we need to determine the default backend
182
+ # by calling gym_backend() which will initialize DEFAULT_GYM
183
+ if saved_gym is None:
184
+ # Initialize DEFAULT_GYM with the default backend (gymnasium first, then gym)
185
+ saved_gym = gym_backend()
186
+ # Re-apply the original backend's implementations
187
+ for setter in copy(implement_for._setters):
188
+ check_module = (
189
+ callable(setter.module_name)
190
+ and setter.module_name.__name__ == saved_gym.__name__
191
+ ) or setter.module_name == saved_gym.__name__
192
+ check_version = setter.check_version(
193
+ saved_gym.__version__, setter.from_version, setter.to_version
194
+ )
195
+ if check_module and check_version:
196
+ setter.module_set()
197
+
198
+ def clone(self):
199
+ # override this method if your children class takes __init__ parameters
200
+ return self.__class__(self.backend)
201
+
202
+ @property
203
+ def backend(self):
204
+ if isinstance(self._backend, str):
205
+ return importlib.import_module(self._backend)
206
+ elif callable(self._backend):
207
+ return self._backend()
208
+ return self._backend
209
+
210
+ @backend.setter
211
+ def backend(self, value):
212
+ self._backend = value
213
+
214
+
215
+ def gym_backend(submodule=None):
216
+ """Returns the gym backend, or a sumbodule of it.
217
+
218
+ Args:
219
+ submodule (str): the submodule to import. If ``None``, the backend
220
+ itself is returned.
221
+
222
+ Examples:
223
+ >>> import mo_gymnasium
224
+ >>> with set_gym_backend("gym"):
225
+ ... wrappers = gym_backend('wrappers')
226
+ ... print(wrappers)
227
+ >>> with set_gym_backend("gymnasium"):
228
+ ... wrappers = gym_backend('wrappers')
229
+ ... print(wrappers)
230
+ """
231
+ global IMPORT_ERROR
232
+ global DEFAULT_GYM
233
+ if DEFAULT_GYM is None:
234
+ try:
235
+ # rule of thumbs: gymnasium precedes
236
+ import gymnasium as gym
237
+ except ImportError as err:
238
+ IMPORT_ERROR = err
239
+ try:
240
+ import gym as gym
241
+ except ImportError as err:
242
+ IMPORT_ERROR = err
243
+ gym = None
244
+ DEFAULT_GYM = gym
245
+ if submodule is not None:
246
+ if not submodule.startswith("."):
247
+ submodule = "." + submodule
248
+ submodule = importlib.import_module(submodule, package=DEFAULT_GYM.__name__)
249
+ return submodule
250
+ return DEFAULT_GYM
251
+
252
+
253
+ __all__ = ["GymWrapper", "GymEnv"]
254
+
255
+
256
+ # Define a dictionary to store conversion functions for each spec type
257
+ class _ConversionRegistry(collections.UserDict):
258
+ def __getitem__(self, cls):
259
+ if cls not in super().keys():
260
+ # We want to find the closest parent
261
+ parents = {}
262
+ for k in self.keys():
263
+ if not isinstance(k, str):
264
+ parents[k] = k
265
+ continue
266
+ try:
267
+ space_cls = gym_backend("spaces")
268
+ for sbsp in k.split("."):
269
+ space_cls = getattr(space_cls, sbsp)
270
+ except AttributeError:
271
+ # Some specs may be too recent
272
+ continue
273
+ parents[space_cls] = k
274
+ mro = cls.mro()
275
+ for base in mro:
276
+ for p in parents:
277
+ if issubclass(base, p):
278
+ return self[parents[p]]
279
+ else:
280
+ raise KeyError(
281
+ f"No conversion tool could be found with the gym space {cls}. "
282
+ f"You can register your own with `torchrl.envs.libs.register_gym_spec_conversion.`"
283
+ )
284
+ return super().__getitem__(cls)
285
+
286
+
287
+ _conversion_registry = _ConversionRegistry()
288
+
289
+
290
+ def register_gym_spec_conversion(spec_type):
291
+ """Decorator to register a conversion function for a specific spec type.
292
+
293
+ The method must have the following signature:
294
+
295
+ >>> @register_gym_spec_conversion("spec.name")
296
+ ... def convert_specname(
297
+ ... spec,
298
+ ... dtype=None,
299
+ ... device=None,
300
+ ... categorical_action_encoding=None,
301
+ ... remap_state_to_observation=None,
302
+ ... batch_size=None,
303
+ ... ):
304
+
305
+ where `gym(nasium).spaces.spec.name` is the location of the spec in gym.
306
+
307
+ If the spec type is accessible, this will also work:
308
+
309
+ >>> @register_gym_spec_conversion(SpecType)
310
+ ... def convert_specname(
311
+ ... spec,
312
+ ... dtype=None,
313
+ ... device=None,
314
+ ... categorical_action_encoding=None,
315
+ ... remap_state_to_observation=None,
316
+ ... batch_size=None,
317
+ ... ):
318
+
319
+ ..note:: The wrapped function can be simplified, and unused kwargs can be wrapped in `**kwargs`.
320
+
321
+ """
322
+
323
+ def decorator(conversion_func):
324
+ _conversion_registry[spec_type] = conversion_func
325
+ return conversion_func
326
+
327
+ return decorator
328
+
329
+
330
+ def _gym_to_torchrl_spec_transform(
331
+ spec,
332
+ dtype=None,
333
+ device=None,
334
+ categorical_action_encoding=False,
335
+ remap_state_to_observation: bool = True,
336
+ batch_size: tuple = (),
337
+ ) -> TensorSpec:
338
+ """Maps the gym specs to the TorchRL specs.
339
+
340
+ Args:
341
+ spec (gym.spaces member): the gym space to transform.
342
+ dtype (torch.dtype): a dtype to use for the spec.
343
+ Defaults to`spec.dtype`.
344
+ device (torch.device): the device for the spec.
345
+ Defaults to ``None`` (no device for composite and default device for specs).
346
+ categorical_action_encoding (bool): whether discrete spaces should be mapped to categorical or one-hot.
347
+ Defaults to ``False`` (one-hot).
348
+ remap_state_to_observation (bool): whether to rename the 'state' key of
349
+ Dict specs to "observation". Default is true.
350
+ batch_size (torch.Size): batch size to which expand the spec. Defaults to
351
+ ``torch.Size([])``.
352
+ """
353
+ if batch_size:
354
+ return _gym_to_torchrl_spec_transform(
355
+ spec,
356
+ dtype=dtype,
357
+ device=device,
358
+ categorical_action_encoding=categorical_action_encoding,
359
+ remap_state_to_observation=remap_state_to_observation,
360
+ batch_size=None,
361
+ ).expand(batch_size)
362
+
363
+ # Get the conversion function from the registry
364
+ conversion_func = _conversion_registry[type(spec)]
365
+ # Call the conversion function with the provided arguments
366
+ return conversion_func(
367
+ spec,
368
+ dtype=dtype,
369
+ device=device,
370
+ categorical_action_encoding=categorical_action_encoding,
371
+ remap_state_to_observation=remap_state_to_observation,
372
+ batch_size=batch_size,
373
+ )
374
+
375
+
376
+ # Register conversion functions for each spec type
377
+ @register_gym_spec_conversion("tuple.Tuple")
378
+ def convert_tuple_spec(
379
+ spec,
380
+ dtype=None,
381
+ device=None,
382
+ categorical_action_encoding=None,
383
+ remap_state_to_observation=None,
384
+ batch_size=None,
385
+ ):
386
+ # Implementation for Tuple spec type
387
+ result = torch.stack(
388
+ [
389
+ _gym_to_torchrl_spec_transform(
390
+ s,
391
+ device=device,
392
+ categorical_action_encoding=categorical_action_encoding,
393
+ remap_state_to_observation=remap_state_to_observation,
394
+ )
395
+ for s in spec
396
+ ],
397
+ dim=0,
398
+ )
399
+ return result
400
+
401
+
402
+ @register_gym_spec_conversion("discrete.Discrete")
403
+ def convert_discrete_spec(
404
+ spec,
405
+ dtype=None,
406
+ device=None,
407
+ categorical_action_encoding=None,
408
+ remap_state_to_observation=None,
409
+ batch_size=None,
410
+ ):
411
+ # Implementation for Discrete spec type
412
+ action_space_cls = Categorical if categorical_action_encoding else OneHot
413
+ dtype = (
414
+ numpy_to_torch_dtype_dict[spec.dtype]
415
+ if categorical_action_encoding
416
+ else torch.long
417
+ )
418
+ return action_space_cls(spec.n, device=device, dtype=dtype)
419
+
420
+
421
+ @register_gym_spec_conversion("multi_binary.MultiBinary")
422
+ def convert_multi_binary_spec(
423
+ spec,
424
+ dtype=None,
425
+ device=None,
426
+ categorical_action_encoding=None,
427
+ remap_state_to_observation=None,
428
+ batch_size=None,
429
+ ):
430
+ # Implementation for MultiBinary spec type
431
+ return Binary(spec.n, device=device, dtype=numpy_to_torch_dtype_dict[spec.dtype])
432
+
433
+
434
+ @register_gym_spec_conversion("multi_discrete.MultiDiscrete")
435
+ def convert_multidiscrete_spec(
436
+ spec,
437
+ dtype=None,
438
+ device=None,
439
+ categorical_action_encoding=None,
440
+ remap_state_to_observation=None,
441
+ batch_size=None,
442
+ ):
443
+ # Only use MultiCategorical/MultiOneHot for heterogeneous nvec (e.g., [3, 5, 7]).
444
+ # Homogeneous nvec like [2, 2] typically represents independent actions
445
+ # (e.g., vectorized envs with same Discrete(n) per env) and should use stacking.
446
+ if len(spec.nvec.shape) == 1 and len(np.unique(spec.nvec)) > 1:
447
+ dtype = (
448
+ numpy_to_torch_dtype_dict[spec.dtype]
449
+ if categorical_action_encoding
450
+ else torch.long
451
+ )
452
+ return (
453
+ MultiCategorical(spec.nvec, device=device, dtype=dtype)
454
+ if categorical_action_encoding
455
+ else MultiOneHot(spec.nvec, device=device, dtype=dtype)
456
+ )
457
+
458
+ return torch.stack(
459
+ [
460
+ _gym_to_torchrl_spec_transform(
461
+ spec[i],
462
+ device=device,
463
+ categorical_action_encoding=categorical_action_encoding,
464
+ remap_state_to_observation=remap_state_to_observation,
465
+ )
466
+ for i in range(len(spec.nvec))
467
+ ],
468
+ 0,
469
+ )
470
+
471
+
472
+ @register_gym_spec_conversion("Box")
473
+ def convert_box_spec(
474
+ spec,
475
+ dtype=None,
476
+ device=None,
477
+ categorical_action_encoding=None,
478
+ remap_state_to_observation=None,
479
+ batch_size=None,
480
+ ):
481
+ shape = spec.shape
482
+ if not len(shape):
483
+ shape = torch.Size([1])
484
+ if dtype is None:
485
+ dtype = numpy_to_torch_dtype_dict[spec.dtype]
486
+ low = torch.as_tensor(spec.low, device=device, dtype=dtype)
487
+ high = torch.as_tensor(spec.high, device=device, dtype=dtype)
488
+ is_unbounded = low.isinf().all() and high.isinf().all()
489
+
490
+ minval, maxval = _minmax_dtype(dtype)
491
+ minval = torch.as_tensor(minval).to(low.device, dtype)
492
+ maxval = torch.as_tensor(maxval).to(low.device, dtype)
493
+ is_unbounded = is_unbounded or (
494
+ torch.isclose(low, torch.as_tensor(minval, dtype=dtype)).all()
495
+ and torch.isclose(high, torch.as_tensor(maxval, dtype=dtype)).all()
496
+ )
497
+ return (
498
+ Unbounded(shape, device=device, dtype=dtype)
499
+ if is_unbounded
500
+ else Bounded(
501
+ low,
502
+ high,
503
+ shape,
504
+ dtype=dtype,
505
+ device=device,
506
+ )
507
+ )
508
+
509
+
510
+ @register_gym_spec_conversion("Sequence")
511
+ def convert_sequence_spec(
512
+ spec,
513
+ dtype=None,
514
+ device=None,
515
+ categorical_action_encoding=None,
516
+ remap_state_to_observation=None,
517
+ batch_size=None,
518
+ ):
519
+ if not hasattr(spec, "stack"):
520
+ # gym does not have a stack attribute in sequence
521
+ raise ValueError(
522
+ "gymnasium should be used whenever a Sequence is present, as it needs to be stacked. "
523
+ "If you need the gym backend at all price, please raise an issue on the TorchRL GitHub repository."
524
+ )
525
+ if not getattr(spec, "stack", False):
526
+ raise ValueError(
527
+ "Sequence spaces must have the stack argument set to ``True``. "
528
+ )
529
+ space = spec.feature_space
530
+ out = _gym_to_torchrl_spec_transform(space, device=device, dtype=dtype)
531
+ out = out.unsqueeze(0)
532
+ out.make_neg_dim(0)
533
+ return out
534
+
535
+
536
+ @register_gym_spec_conversion(dict)
537
+ def convert_dict_spec(
538
+ spec,
539
+ dtype=None,
540
+ device=None,
541
+ categorical_action_encoding=None,
542
+ remap_state_to_observation=None,
543
+ batch_size=None,
544
+ ):
545
+ spec_out = {}
546
+ for k in spec.keys():
547
+ key = k
548
+ if (
549
+ remap_state_to_observation
550
+ and k == "state"
551
+ and "observation" not in spec.keys()
552
+ ):
553
+ # we rename "state" in "observation" as "observation" is the conventional name
554
+ # for single observation in torchrl.
555
+ # naming it 'state' will result in envs that have a different name for the state vector
556
+ # when queried with and without pixels
557
+ key = "observation"
558
+ spec_out[key] = _gym_to_torchrl_spec_transform(
559
+ spec[k],
560
+ device=device,
561
+ categorical_action_encoding=categorical_action_encoding,
562
+ remap_state_to_observation=remap_state_to_observation,
563
+ batch_size=batch_size,
564
+ )
565
+ # the batch-size must be set later
566
+ return Composite(spec_out, device=device)
567
+
568
+
569
+ @register_gym_spec_conversion("Text")
570
+ def convert_text_soec(
571
+ spec,
572
+ dtype=None,
573
+ device=None,
574
+ categorical_action_encoding=None,
575
+ remap_state_to_observation=None,
576
+ batch_size=None,
577
+ ):
578
+ return NonTensor((), device=device, example_data="a string")
579
+
580
+
581
+ @register_gym_spec_conversion("dict.Dict")
582
+ def convert_dict_spec2(
583
+ spec,
584
+ dtype=None,
585
+ device=None,
586
+ categorical_action_encoding=None,
587
+ remap_state_to_observation=None,
588
+ batch_size=None,
589
+ ):
590
+ return _gym_to_torchrl_spec_transform(
591
+ spec.spaces,
592
+ device=device,
593
+ categorical_action_encoding=categorical_action_encoding,
594
+ remap_state_to_observation=remap_state_to_observation,
595
+ batch_size=batch_size,
596
+ )
597
+
598
+
599
+ @implement_for("gym", None, "0.18")
600
+ def _box_convert(spec, gym_spaces, shape):
601
+ low = spec.low.detach().unique().cpu().item()
602
+ high = spec.high.detach().unique().cpu().item()
603
+ return gym_spaces.Box(low=low, high=high, shape=shape)
604
+
605
+
606
+ @implement_for("gym", "0.18")
607
+ def _box_convert(spec, gym_spaces, shape): # noqa: F811
608
+ low = spec.low.detach().cpu().numpy()
609
+ high = spec.high.detach().cpu().numpy()
610
+ return gym_spaces.Box(low=low, high=high, shape=shape)
611
+
612
+
613
+ @implement_for("gymnasium", None, "1.0.0")
614
+ def _box_convert(spec, gym_spaces, shape): # noqa: F811
615
+ low = spec.low.detach().cpu().numpy()
616
+ high = spec.high.detach().cpu().numpy()
617
+ return gym_spaces.Box(low=low, high=high, shape=shape)
618
+
619
+
620
+ @implement_for("gymnasium", "1.0.0", "1.1.0")
621
+ def _box_convert(spec, gym_spaces, shape): # noqa: F811
622
+ raise ImportError(GYMNASIUM_1_ERROR)
623
+
624
+
625
+ @implement_for("gymnasium", "1.1.0")
626
+ def _box_convert(spec, gym_spaces, shape): # noqa: F811
627
+ low = spec.low.detach().cpu().numpy()
628
+ high = spec.high.detach().cpu().numpy()
629
+ return gym_spaces.Box(low=low, high=high, shape=shape)
630
+
631
+
632
+ @implement_for("gym", "0.21", None)
633
+ def _multidiscrete_convert(gym_spaces, spec):
634
+ return gym_spaces.multi_discrete.MultiDiscrete(
635
+ spec.nvec, dtype=torch_to_numpy_dtype_dict[spec.dtype]
636
+ )
637
+
638
+
639
+ @implement_for("gymnasium", None, "1.0.0")
640
+ def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
641
+ return gym_spaces.multi_discrete.MultiDiscrete(
642
+ spec.nvec, dtype=torch_to_numpy_dtype_dict[spec.dtype]
643
+ )
644
+
645
+
646
+ @implement_for("gymnasium", "1.0.0", "1.1.0")
647
+ def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
648
+ raise ImportError(GYMNASIUM_1_ERROR)
649
+
650
+
651
+ @implement_for("gymnasium", "1.1.0")
652
+ def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
653
+ return gym_spaces.multi_discrete.MultiDiscrete(
654
+ spec.nvec, dtype=torch_to_numpy_dtype_dict[spec.dtype]
655
+ )
656
+
657
+
658
+ @implement_for("gym", None, "0.21")
659
+ def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
660
+ return gym_spaces.multi_discrete.MultiDiscrete(spec.nvec)
661
+
662
+
663
+ def _torchrl_to_gym_spec_transform(
664
+ spec,
665
+ categorical_action_encoding=False,
666
+ ) -> TensorSpec:
667
+ """Maps TorchRL specs to gym spaces.
668
+
669
+ Args:
670
+ spec: the torchrl spec to transform.
671
+ categorical_action_encoding: whether discrete spaces should be mapped to categorical or one-hot.
672
+ Defaults to one-hot.
673
+
674
+ """
675
+ gym_spaces = gym_backend("spaces")
676
+ shape = spec.shape
677
+ if any(s == -1 for s in spec.shape):
678
+ if spec.shape[0] == -1:
679
+ spec = spec.clone()
680
+ spec = spec[0]
681
+ return gym_spaces.Sequence(_torchrl_to_gym_spec_transform(spec), stack=True)
682
+ else:
683
+ return gym_spaces.Tuple(
684
+ tuple(_torchrl_to_gym_spec_transform(spec) for spec in spec.unbind(0))
685
+ )
686
+ if isinstance(spec, MultiCategorical):
687
+ return _multidiscrete_convert(gym_spaces, spec)
688
+ if isinstance(spec, MultiOneHot):
689
+ return gym_spaces.multi_discrete.MultiDiscrete(spec.nvec)
690
+ if isinstance(spec, Binary):
691
+ return gym_spaces.multi_binary.MultiBinary(spec.shape[-1])
692
+ if isinstance(spec, Categorical):
693
+ return gym_spaces.discrete.Discrete(
694
+ spec.n
695
+ ) # dtype=torch_to_numpy_dtype_dict[spec.dtype])
696
+ if isinstance(spec, OneHot):
697
+ return gym_spaces.discrete.Discrete(spec.n)
698
+ if isinstance(spec, Unbounded):
699
+ minval, maxval = _minmax_dtype(spec.dtype)
700
+ return gym_spaces.Box(
701
+ low=minval,
702
+ high=maxval,
703
+ shape=shape,
704
+ dtype=torch_to_numpy_dtype_dict[spec.dtype],
705
+ )
706
+ if isinstance(spec, Unbounded):
707
+ minval, maxval = _minmax_dtype(spec.dtype)
708
+ return gym_spaces.Box(
709
+ low=minval,
710
+ high=maxval,
711
+ shape=shape,
712
+ dtype=torch_to_numpy_dtype_dict[spec.dtype],
713
+ )
714
+ if isinstance(spec, Bounded):
715
+ return _box_convert(spec, gym_spaces, shape)
716
+ if isinstance(spec, Composite):
717
+ # remove batch size
718
+ while spec.shape:
719
+ spec = spec[0]
720
+ return gym_spaces.Dict(
721
+ **{
722
+ key: _torchrl_to_gym_spec_transform(
723
+ val,
724
+ categorical_action_encoding=categorical_action_encoding,
725
+ )
726
+ for key, val in spec.items()
727
+ }
728
+ )
729
+ else:
730
+ raise NotImplementedError(
731
+ f"spec of type {type(spec).__name__} is currently unaccounted for"
732
+ )
733
+
734
+
735
+ def _get_envs(to_dict=False) -> list:
736
+ if not _has_gym:
737
+ raise ImportError("Gym(nasium) could not be found in your virtual environment.")
738
+ envs = _get_gym_envs()
739
+ envs = list(envs)
740
+ envs = sorted(envs)
741
+ return envs
742
+
743
+
744
+ @implement_for("gym", None, "0.26.0")
745
+ def _get_gym_envs(): # noqa: F811
746
+ gym = gym_backend()
747
+ return gym.envs.registration.registry.env_specs.keys()
748
+
749
+
750
+ @implement_for("gym", "0.26.0", None)
751
+ def _get_gym_envs(): # noqa: F811
752
+ gym = gym_backend()
753
+ return gym.envs.registration.registry.keys()
754
+
755
+
756
+ @implement_for("gymnasium", None, "1.0.0")
757
+ def _get_gym_envs(): # noqa: F811
758
+ gym = gym_backend()
759
+ return gym.envs.registration.registry.keys()
760
+
761
+
762
+ @implement_for("gymnasium", "1.0.0", "1.1.0")
763
+ def _get_gym_envs(): # noqa: F811
764
+ raise ImportError(GYMNASIUM_1_ERROR)
765
+
766
+
767
+ @implement_for("gymnasium", "1.1.0")
768
+ def _get_gym_envs(): # noqa: F811
769
+ gym = gym_backend()
770
+ return gym.envs.registration.registry.keys()
771
+
772
+
773
+ def _is_from_pixels(env):
774
+ observation_spec = env.observation_space
775
+ try:
776
+ PixelObservationWrapper = gym_backend(
777
+ "wrappers.pixel_observation"
778
+ ).PixelObservationWrapper
779
+ except ModuleNotFoundError:
780
+
781
+ class PixelObservationWrapper:
782
+ pass
783
+
784
+ from torchrl.envs.libs.utils import (
785
+ GymPixelObservationWrapper as LegacyPixelObservationWrapper,
786
+ )
787
+
788
+ gDict = gym_backend("spaces").dict.Dict
789
+ Box = gym_backend("spaces").Box
790
+
791
+ if isinstance(observation_spec, (dict,)):
792
+ if "pixels" in set(observation_spec.keys()):
793
+ return True
794
+ if isinstance(observation_spec, (gDict,)):
795
+ if "pixels" in set(observation_spec.spaces.keys()):
796
+ return True
797
+ elif (
798
+ isinstance(observation_spec, Box)
799
+ and (observation_spec.low == 0).all()
800
+ and (observation_spec.high == 255).all()
801
+ and observation_spec.low.shape[-1] == 3
802
+ and observation_spec.low.ndim == 3
803
+ ):
804
+ return True
805
+ else:
806
+ while True:
807
+ if isinstance(
808
+ env, (LegacyPixelObservationWrapper, PixelObservationWrapper)
809
+ ):
810
+ return True
811
+ if hasattr(env, "env"):
812
+ env = env.env
813
+ else:
814
+ break
815
+ return False
816
+
817
+
818
+ class _GymAsyncMeta(_EnvPostInit):
819
+ def __call__(cls, *args, **kwargs):
820
+ missing_obs_value = kwargs.pop("missing_obs_value", None)
821
+ num_workers = kwargs.pop("num_workers", 1)
822
+
823
+ if cls.__name__ == "GymEnv" and num_workers > 1:
824
+ from torchrl.envs import EnvCreator, ParallelEnv
825
+
826
+ env_name = args[0] if args else kwargs.get("env_name")
827
+ env_kwargs = kwargs.copy()
828
+ env_kwargs.pop("env_name", None)
829
+ make_env = partial(cls, env_name, **env_kwargs)
830
+ return ParallelEnv(num_workers, EnvCreator(make_env))
831
+
832
+ instance: GymWrapper = super().__call__(*args, **kwargs)
833
+
834
+ # before gym 0.22, there was no final_observation
835
+ if instance._is_batched:
836
+ gym_backend = instance.get_library_name(instance._env)
837
+ from torchrl.envs.transforms.transforms import (
838
+ TransformedEnv,
839
+ VecGymEnvTransform,
840
+ )
841
+
842
+ if _has_isaaclab:
843
+ from isaaclab.envs import ManagerBasedRLEnv
844
+
845
+ kwargs = {}
846
+ if missing_obs_value is not None:
847
+ kwargs["missing_obs_value"] = missing_obs_value
848
+ if isinstance(instance._env.unwrapped, ManagerBasedRLEnv):
849
+ return TransformedEnv(instance, VecGymEnvTransform(**kwargs))
850
+
851
+ if _has_sb3:
852
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv
853
+
854
+ if isinstance(instance._env, VecEnv):
855
+ backend = "sb3"
856
+ else:
857
+ backend = gym_backend
858
+ else:
859
+ backend = gym_backend
860
+
861
+ # we need 3 checks: the backend is not sb3 (if so, gymnasium is used),
862
+ # it is gym and not gymnasium and the version is before 0.22.0
863
+ add_info_dict = True
864
+ if backend == "gym" and gym_backend == "gym": # check gym against gymnasium
865
+ import gym
866
+
867
+ if version.parse(gym.__version__) < version.parse("0.22.0"):
868
+ warn(
869
+ "A batched gym environment is being wrapped in a GymWrapper with gym version < 0.22. "
870
+ "This implies that the next-observation is wrongly tracked (as the batched environment auto-resets "
871
+ "and discards the true next observation to return the result of the step). "
872
+ "This isn't compatible with TorchRL API and should be used with caution.",
873
+ category=UserWarning,
874
+ )
875
+ add_info_dict = False
876
+ if gym_backend == "gymnasium":
877
+ import gymnasium
878
+
879
+ if version.parse(gymnasium.__version__) >= version.parse("1.1.0"):
880
+ add_info_dict = (
881
+ instance._env.autoreset_mode
882
+ != gymnasium.vector.AutoresetMode.DISABLED
883
+ )
884
+ if not add_info_dict:
885
+ return instance
886
+ if add_info_dict:
887
+ # register terminal_obs_reader
888
+ instance.auto_register_info_dict(
889
+ info_dict_reader=terminal_obs_reader(
890
+ instance.observation_spec, backend=backend
891
+ )
892
+ )
893
+ kwargs = {}
894
+ if missing_obs_value is not None:
895
+ kwargs["missing_obs_value"] = missing_obs_value
896
+ return TransformedEnv(instance, VecGymEnvTransform(**kwargs))
897
+ return instance
898
+
899
+
900
+ class GymWrapper(GymLikeEnv, metaclass=_GymAsyncMeta):
901
+ """OpenAI Gym environment wrapper.
902
+
903
+ Works across `gymnasium <https://gymnasium.farama.org/>`_ and `OpenAI/gym <https://github.com/openai/gym>`_.
904
+
905
+ Args:
906
+ env (gym.Env): the environment to wrap. Batched environments (:class:`~stable_baselines3.common.vec_env.base_vec_env.VecEnv`
907
+ or :class:`gym.VectorEnv`) are supported and the environment batch-size
908
+ will reflect the number of environments executed in parallel.
909
+ categorical_action_encoding (bool, optional): if ``True``, categorical
910
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
911
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
912
+ Defaults to ``False``.
913
+
914
+ Keyword Args:
915
+ from_pixels (bool, optional): if ``True``, an attempt to return the pixel
916
+ observations from the env will be performed. By default, these observations
917
+ will be written under the ``"pixels"`` entry.
918
+ The method being used varies
919
+ depending on the gym version and may involve a ``wrappers.pixel_observation.PixelObservationWrapper``.
920
+ Defaults to ``False``.
921
+ pixels_only (bool, optional): if ``True``, only the pixel observations will
922
+ be returned (by default under the ``"pixels"`` entry in the output tensordict).
923
+ If ``False``, observations (eg, states) and pixels will be returned
924
+ whenever ``from_pixels=True``. Defaults to ``True``.
925
+ frame_skip (int, optional): if provided, indicates for how many steps the
926
+ same action is to be repeated. The observation returned will be the
927
+ last observation of the sequence, whereas the reward will be the sum
928
+ of rewards across steps.
929
+ device (torch.device, optional): if provided, the device on which the data
930
+ is to be cast. Defaults to ``torch.device("cpu")``.
931
+ batch_size (torch.Size, optional): the batch size of the environment.
932
+ Should match the leading dimensions of all observations, done states,
933
+ rewards, actions and infos.
934
+ Defaults to ``torch.Size([])``.
935
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
936
+ for envs to be ``done`` just after :meth:`reset` is called.
937
+ Defaults to ``False``.
938
+ convert_actions_to_numpy (bool, optional): if ``True``, actions will be
939
+ converted from tensors to numpy arrays and moved to CPU before being passed to the
940
+ env step function. Set this to ``False`` if the environment is evaluated
941
+ on GPU, such as IsaacLab.
942
+ Defaults to ``True``.
943
+ missing_obs_value (Any, optional): default value to use as placeholder for missing observations, when
944
+ the environment is auto-resetting and missing observations cannot be found in the info dictionary
945
+ (e.g., with IsaacLab). This argument is passed to :class:`~torchrl.envs.VecGymEnvTransform` by
946
+ the metaclass.
947
+
948
+ Attributes:
949
+ available_envs (List[str]): a list of environments to build.
950
+
951
+ .. note::
952
+ If an attribute cannot be found, this class will attempt to retrieve it from
953
+ the nested env:
954
+
955
+ >>> from torchrl.envs import GymWrapper
956
+ >>> import gymnasium as gym
957
+ >>> env = GymWrapper(gym.make("Pendulum-v1"))
958
+ >>> print(env.spec.max_episode_steps)
959
+ 200
960
+
961
+ Examples:
962
+ >>> import gymnasium as gym
963
+ >>> from torchrl.envs import GymWrapper
964
+ >>> base_env = gym.make("Pendulum-v1")
965
+ >>> env = GymWrapper(base_env)
966
+ >>> td = env.rand_step()
967
+ >>> print(td)
968
+ TensorDict(
969
+ fields={
970
+ action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
971
+ next: TensorDict(
972
+ fields={
973
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
974
+ observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
975
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
976
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
977
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
978
+ batch_size=torch.Size([]),
979
+ device=cpu,
980
+ is_shared=False)},
981
+ batch_size=torch.Size([]),
982
+ device=cpu,
983
+ is_shared=False)
984
+ >>> print(env.available_envs)
985
+ ['ALE/Adventure-ram-v5', 'ALE/Adventure-v5', 'ALE/AirRaid-ram-v5', 'ALE/AirRaid-v5', 'ALE/Alien-ram-v5', 'ALE/Alien-v5',
986
+
987
+ .. note::
988
+ info dictionaries will be read using :class:`~torchrl.envs.gym_like.default_info_dict_reader`
989
+ if no other reader is provided. To provide another reader, refer to
990
+ :meth:`set_info_dict_reader`. To automatically register the info_dict
991
+ content, refer to :meth:`torchrl.envs.GymLikeEnv.auto_register_info_dict`.
992
+ For parallel (Vectorized) environments, the info dictionary reader is automatically set and should
993
+ not be set manually.
994
+
995
+ .. note:: Gym spaces are not completely covered.
996
+ The following spaces are accounted for provided that they can be represented by a torch.Tensor, a nested tensor
997
+ and/or within a tensordict:
998
+
999
+ - spaces.Box
1000
+ - spaces.Sequence
1001
+ - spaces.Tuple
1002
+ - spaces.Discrete
1003
+ - spaces.MultiBinary
1004
+ - spaces.MultiDiscrete
1005
+ - spaces.Dict
1006
+
1007
+ Some considerations should be made when working with gym spaces. For instance, a tuple of spaces
1008
+ can only be supported if the spaces are semantically identical (same dtype and same number of dimensions).
1009
+ Ragged dimension can be supported through :func:`~torch.nested.nested_tensor`, but then there should be only
1010
+ one level of tuple and data should be stacked along the first dimension (as nested_tensors can only be
1011
+ stacked along the first dimension).
1012
+
1013
+ Check the example in examples/envs/gym_conversion_examples.py to know more!
1014
+
1015
+ """
1016
+
1017
+ git_url = "https://github.com/openai/gym"
1018
+ libname = "gym"
1019
+
1020
+ @_classproperty
1021
+ def available_envs(cls):
1022
+ if not _has_gym:
1023
+ return []
1024
+ return list(_get_envs())
1025
+
1026
+ @staticmethod
1027
+ def get_library_name(env) -> str:
1028
+ """Given a gym environment, returns the backend name (either gym or gymnasium).
1029
+
1030
+ This can be used to set the appropriate backend when needed:
1031
+
1032
+ Examples:
1033
+ >>> env = gymnasium.make("Pendulum-v1")
1034
+ >>> with set_gym_backend(env):
1035
+ ... env = GymWrapper(env)
1036
+
1037
+ :class:`~GymWrapper` and similar use this method to set their method
1038
+ to the right backend during instantiation.
1039
+
1040
+ """
1041
+ try:
1042
+ import gym
1043
+
1044
+ if isinstance(env.action_space, gym.spaces.space.Space):
1045
+ return "gym"
1046
+ except ImportError:
1047
+ pass
1048
+ try:
1049
+ import gymnasium
1050
+
1051
+ if isinstance(env.action_space, gymnasium.spaces.space.Space):
1052
+ return "gymnasium"
1053
+ except ImportError:
1054
+ pass
1055
+ raise ImportError(
1056
+ f"Could not find the library of env {env}. Please file an issue on torchrl github repo."
1057
+ )
1058
+
1059
+ def __init__(self, env=None, categorical_action_encoding=False, **kwargs):
1060
+ self._seed_calls_reset = None
1061
+ self._categorical_action_encoding = categorical_action_encoding
1062
+ if env is not None:
1063
+ try:
1064
+ env_str = str(env)
1065
+ except TypeError:
1066
+ # MiniGrid has a bug where the __str__ method fails
1067
+ pass
1068
+ else:
1069
+ if (
1070
+ "EnvCompatibility" in env_str
1071
+ ): # a hacky way of knowing if EnvCompatibility is part of the wrappers of env
1072
+ raise ValueError(
1073
+ "GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility wrapper. "
1074
+ "If this feature is needed, detail your use case in an issue of "
1075
+ "https://github.com/pytorch/rl/issues."
1076
+ )
1077
+ libname = self.get_library_name(env)
1078
+ self._validate_env(env)
1079
+ with set_gym_backend(libname):
1080
+ kwargs["env"] = env
1081
+ super().__init__(**kwargs)
1082
+ else:
1083
+ super().__init__(**kwargs)
1084
+ self._post_init()
1085
+
1086
+ @implement_for("gymnasium", "1.1.0")
1087
+ def _validate_env(self, env):
1088
+ autoreset_mode = getattr(env, "autoreset_mode", None)
1089
+ if autoreset_mode is not None:
1090
+ from gymnasium.vector import AutoresetMode
1091
+
1092
+ if autoreset_mode not in (AutoresetMode.DISABLED, AutoresetMode.SAME_STEP):
1093
+ raise RuntimeError(
1094
+ "The auto-reset mode must be one of SAME_STEP or DISABLED (which is preferred). Got "
1095
+ f"autoreset_mode={autoreset_mode}."
1096
+ )
1097
+
1098
+ @implement_for("gym", None, "1.1.0")
1099
+ def _validate_env(self, env): # noqa
1100
+ pass
1101
+
1102
+ @implement_for("gymnasium", None, "1.1.0")
1103
+ def _validate_env(self, env): # noqa
1104
+ pass
1105
+
1106
+ def _post_init(self):
1107
+ # writes the functions that are gym-version specific to the instance
1108
+ # once and for all. This is aimed at avoiding the need of decorating code
1109
+ # with set_gym_backend + allowing for parallel execution (which would
1110
+ # be troublesome when both an old version of gym and recent gymnasium
1111
+ # are present within the same virtual env).
1112
+ #
1113
+ # These calls seemingly do nothing but they actually get rid of the @implement_for decorator.
1114
+ # We execute them within the set_gym_backend context manager to make sure we get
1115
+ # the right implementation.
1116
+ #
1117
+ # This method is executed by the metaclass of GymWrapper.
1118
+ with set_gym_backend(self.get_library_name(self._env)):
1119
+ self._reset_output_transform = self._reset_output_transform
1120
+ self._output_transform = self._output_transform
1121
+
1122
+ @property
1123
+ def _is_batched(self):
1124
+ tuple_of_classes = ()
1125
+ if _has_sb3:
1126
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv
1127
+
1128
+ tuple_of_classes = tuple_of_classes + (VecEnv,)
1129
+ if _has_isaaclab:
1130
+ from isaaclab.envs import ManagerBasedRLEnv
1131
+
1132
+ tuple_of_classes = tuple_of_classes + (ManagerBasedRLEnv,)
1133
+ return isinstance(
1134
+ self._env.unwrapped, tuple_of_classes + (gym_backend("vector").VectorEnv,)
1135
+ )
1136
+
1137
+ @implement_for("gym")
1138
+ def _get_batch_size(self, env):
1139
+ if hasattr(env, "num_envs"):
1140
+ batch_size = torch.Size([env.num_envs, *self.batch_size])
1141
+ else:
1142
+ batch_size = self.batch_size
1143
+ return batch_size
1144
+
1145
+ @implement_for("gymnasium", None, "1.0.0") # gymnasium wants the unwrapped env
1146
+ def _get_batch_size(self, env): # noqa: F811
1147
+ env_unwrapped = env.unwrapped
1148
+ if hasattr(env_unwrapped, "num_envs"):
1149
+ batch_size = torch.Size([env_unwrapped.num_envs, *self.batch_size])
1150
+ else:
1151
+ batch_size = self.batch_size
1152
+ return batch_size
1153
+
1154
+ @implement_for("gymnasium", "1.0.0", "1.1.0")
1155
+ def _get_batch_size(self, env): # noqa: F811
1156
+ raise ImportError(GYMNASIUM_1_ERROR)
1157
+
1158
+ @implement_for("gymnasium", "1.1.0") # gymnasium wants the unwrapped env
1159
+ def _get_batch_size(self, env): # noqa: F811
1160
+ env_unwrapped = env.unwrapped
1161
+ if hasattr(env_unwrapped, "num_envs"):
1162
+ batch_size = torch.Size([env_unwrapped.num_envs, *self.batch_size])
1163
+ else:
1164
+ batch_size = self.batch_size
1165
+ return batch_size
1166
+
1167
+ def _check_kwargs(self, kwargs: dict):
1168
+ if "env" not in kwargs:
1169
+ raise TypeError("Could not find environment key 'env' in kwargs.")
1170
+ env = kwargs["env"]
1171
+ if not (hasattr(env, "action_space") and hasattr(env, "observation_space")):
1172
+ raise TypeError("env is not of type 'gym.Env'.")
1173
+
1174
+ def _build_env(
1175
+ self,
1176
+ env,
1177
+ from_pixels: bool = False,
1178
+ pixels_only: bool = False,
1179
+ ) -> gym.core.Env: # noqa: F821
1180
+ self.batch_size = self._get_batch_size(env)
1181
+
1182
+ env_from_pixels = _is_from_pixels(env)
1183
+ from_pixels = from_pixels or env_from_pixels
1184
+ self.from_pixels = from_pixels
1185
+ self.pixels_only = pixels_only
1186
+ if from_pixels and not env_from_pixels:
1187
+ try:
1188
+ PixelObservationWrapper = gym_backend(
1189
+ "wrappers.pixel_observation.PixelObservationWrapper"
1190
+ )
1191
+ if isinstance(env, PixelObservationWrapper):
1192
+ raise TypeError(
1193
+ "PixelObservationWrapper cannot be used to wrap an environment "
1194
+ "that is already a PixelObservationWrapper instance."
1195
+ )
1196
+ except ModuleNotFoundError:
1197
+ pass
1198
+ env = self._build_gym_env(env, pixels_only)
1199
+ return env
1200
+
1201
+ def read_action(self, action):
1202
+ action = super().read_action(action)
1203
+ if isinstance(self.action_spec, (OneHot, Categorical)) and action.size == 1:
1204
+ # some envs require an integer for indexing
1205
+ action = int(action)
1206
+ return action
1207
+
1208
+ @implement_for("gym", None, "0.19.0")
1209
+ def _build_gym_env(self, env, pixels_only): # noqa: F811
1210
+ from .utils import GymPixelObservationWrapper as PixelObservationWrapper
1211
+
1212
+ return PixelObservationWrapper(env, pixels_only=pixels_only)
1213
+
1214
+ @implement_for("gym", "0.19.0", "0.26.0")
1215
+ def _build_gym_env(self, env, pixels_only): # noqa: F811
1216
+ pixel_observation = gym_backend("wrappers.pixel_observation")
1217
+ return pixel_observation.PixelObservationWrapper(env, pixels_only=pixels_only)
1218
+
1219
+ @implement_for("gym", "0.26.0", None)
1220
+ def _build_gym_env(self, env, pixels_only): # noqa: F811
1221
+ compatibility = gym_backend("wrappers.compatibility")
1222
+ pixel_observation = gym_backend("wrappers.pixel_observation")
1223
+
1224
+ if env.render_mode:
1225
+ return pixel_observation.PixelObservationWrapper(
1226
+ env, pixels_only=pixels_only
1227
+ )
1228
+
1229
+ warnings.warn(
1230
+ "Environments provided to GymWrapper that need to be wrapped in PixelObservationWrapper "
1231
+ "should be created with `gym.make(env_name, render_mode=mode)` where possible,"
1232
+ 'where mode is either "rgb_array" or any other supported mode.'
1233
+ )
1234
+ # resetting as 0.26 comes with a very 'nice' OrderEnforcing wrapper
1235
+ env = compatibility.EnvCompatibility(env)
1236
+ env.reset()
1237
+ from torchrl.envs.libs.utils import (
1238
+ GymPixelObservationWrapper as LegacyPixelObservationWrapper,
1239
+ )
1240
+
1241
+ return LegacyPixelObservationWrapper(env, pixels_only=pixels_only)
1242
+
1243
+ @implement_for("gymnasium", "1.0.0", "1.1.0")
1244
+ def _build_gym_env(self, env, pixels_only): # noqa: F811
1245
+ raise ImportError(GYMNASIUM_1_ERROR)
1246
+
1247
+ @implement_for("gymnasium", None, "1.0.0")
1248
+ def _build_gym_env(self, env, pixels_only): # noqa: F811
1249
+ compatibility = gym_backend("wrappers.compatibility")
1250
+ pixel_observation = gym_backend("wrappers.pixel_observation")
1251
+
1252
+ if env.render_mode:
1253
+ return pixel_observation.PixelObservationWrapper(
1254
+ env, pixels_only=pixels_only
1255
+ )
1256
+
1257
+ warnings.warn(
1258
+ "Environments provided to GymWrapper that need to be wrapped in PixelObservationWrapper "
1259
+ "should be created with `gym.make(env_name, render_mode=mode)` where possible,"
1260
+ 'where mode is either "rgb_array" or any other supported mode.'
1261
+ )
1262
+ # resetting as 0.26 comes with a very 'nice' OrderEnforcing wrapper
1263
+ env = compatibility.EnvCompatibility(env)
1264
+ env.reset()
1265
+ from torchrl.envs.libs.utils import (
1266
+ GymPixelObservationWrapper as LegacyPixelObservationWrapper,
1267
+ )
1268
+
1269
+ return LegacyPixelObservationWrapper(env, pixels_only=pixels_only)
1270
+
1271
+ @implement_for("gymnasium", "1.1.0")
1272
+ def _build_gym_env(self, env, pixels_only): # noqa: F811
1273
+ wrappers = gym_backend("wrappers")
1274
+
1275
+ if env.render_mode:
1276
+ return wrappers.AddRenderObservation(env, render_only=pixels_only)
1277
+
1278
+ warnings.warn(
1279
+ "Environments provided to GymWrapper that need to be wrapped in PixelObservationWrapper "
1280
+ "should be created with `gym.make(env_name, render_mode=mode)` where possible,"
1281
+ 'where mode is either "rgb_array" or any other supported mode.'
1282
+ )
1283
+ env.reset()
1284
+ from torchrl.envs.libs.utils import (
1285
+ GymPixelObservationWrapper as LegacyPixelObservationWrapper,
1286
+ )
1287
+
1288
+ return LegacyPixelObservationWrapper(env, pixels_only=pixels_only)
1289
+
1290
+ @property
1291
+ def lib(self) -> ModuleType:
1292
+ gym = gym_backend()
1293
+ if gym is None:
1294
+ raise RuntimeError(
1295
+ "Gym backend is not available. Please install gym or gymnasium."
1296
+ )
1297
+ return gym
1298
+
1299
+ def _set_seed(self, seed: int | None) -> None: # noqa: F811
1300
+ if self._seed_calls_reset is None:
1301
+ # Determine basing on gym version whether `reset` is called when setting seed.
1302
+ self._set_seed_initial(seed)
1303
+ elif self._seed_calls_reset:
1304
+ self.reset(seed=seed)
1305
+ else:
1306
+ self._env.seed(seed=seed)
1307
+
1308
+ @implement_for("gym", None, "0.15.0")
1309
+ def _set_seed_initial(self, seed: int) -> None: # noqa: F811
1310
+ self._seed_calls_reset = False
1311
+ self._env.seed(seed)
1312
+
1313
+ @implement_for("gym", "0.15.0", "0.19.0")
1314
+ def _set_seed_initial(self, seed: int) -> None: # noqa: F811
1315
+ self._seed_calls_reset = False
1316
+ self._env.seed(seed=seed)
1317
+
1318
+ @implement_for("gym", "0.19.0", "0.21.0")
1319
+ def _set_seed_initial(self, seed: int) -> None: # noqa: F811
1320
+ # In gym 0.19-0.21, reset() doesn't accept seed kwarg yet,
1321
+ # and VectorEnv.seed uses seeds= (plural) instead of seed=
1322
+ self._seed_calls_reset = False
1323
+ if hasattr(self._env, "num_envs"):
1324
+ # Vector environment uses seeds= (plural)
1325
+ self._env.seed(seeds=seed)
1326
+ else:
1327
+ self._env.seed(seed=seed)
1328
+
1329
+ @implement_for("gym", "0.21.0", None)
1330
+ def _set_seed_initial(self, seed: int) -> None: # noqa: F811
1331
+ try:
1332
+ self.reset(seed=seed)
1333
+ self._seed_calls_reset = True
1334
+ except TypeError as err:
1335
+ warnings.warn(
1336
+ f"reset with seed kwarg returned an exception: {err}.\n"
1337
+ f"Calling env.seed from now on."
1338
+ )
1339
+ self._seed_calls_reset = False
1340
+ try:
1341
+ self._env.seed(seed=seed)
1342
+ except AttributeError as err2:
1343
+ raise err from err2
1344
+
1345
+ @implement_for("gymnasium", "1.0.0", "1.1.0")
1346
+ def _set_seed_initial(self, seed: int) -> None: # noqa: F811
1347
+ raise ImportError(GYMNASIUM_1_ERROR)
1348
+
1349
+ @implement_for("gymnasium", None, "1.0.0")
1350
+ def _set_seed_initial(self, seed: int) -> None: # noqa: F811
1351
+ try:
1352
+ self.reset(seed=seed)
1353
+ self._seed_calls_reset = True
1354
+ except TypeError as err:
1355
+ warnings.warn(
1356
+ f"reset with seed kwarg returned an exception: {err}.\n"
1357
+ f"Calling env.seed from now on."
1358
+ )
1359
+ self._seed_calls_reset = False
1360
+ self._env.seed(seed=seed)
1361
+
1362
+ @implement_for("gymnasium", "1.1.0")
1363
+ def _set_seed_initial(self, seed: int) -> None: # noqa: F811
1364
+ try:
1365
+ self.reset(seed=seed)
1366
+ self._seed_calls_reset = True
1367
+ except TypeError as err:
1368
+ warnings.warn(
1369
+ f"reset with seed kwarg returned an exception: {err}.\n"
1370
+ f"Calling env.seed from now on."
1371
+ )
1372
+ self._seed_calls_reset = False
1373
+ self._env.seed(seed=seed)
1374
+
1375
+ @implement_for("gym")
1376
+ def _reward_space(self, env):
1377
+ if hasattr(env, "reward_space") and env.reward_space is not None:
1378
+ return env.reward_space
1379
+
1380
+ @implement_for("gymnasium", "1.0.0", "1.1.0")
1381
+ def _reward_space(self, env): # noqa: F811
1382
+ raise ImportError(GYMNASIUM_1_ERROR)
1383
+
1384
+ @implement_for("gymnasium", None, "1.0.0")
1385
+ def _reward_space(self, env): # noqa: F811
1386
+ env = env.unwrapped
1387
+ if hasattr(env, "reward_space") and env.reward_space is not None:
1388
+ rs = env.reward_space
1389
+ return rs
1390
+
1391
+ @implement_for("gymnasium", "1.1.0")
1392
+ def _reward_space(self, env): # noqa: F811
1393
+ env = env.unwrapped
1394
+ if hasattr(env, "reward_space") and env.reward_space is not None:
1395
+ rs = env.reward_space
1396
+ return rs
1397
+
1398
+ def _make_specs(self, env: gym.Env, batch_size=None) -> None: # noqa: F821
1399
+ # If batch_size is provided, we set it to tell what batch size must be used
1400
+ # instead of self.batch_size
1401
+ cur_batch_size = self.batch_size if batch_size is None else torch.Size([])
1402
+ observation_spec = _gym_to_torchrl_spec_transform(
1403
+ env.observation_space,
1404
+ device=self.device,
1405
+ categorical_action_encoding=self._categorical_action_encoding,
1406
+ )
1407
+ action_spec = _gym_to_torchrl_spec_transform(
1408
+ env.action_space,
1409
+ device=self.device,
1410
+ categorical_action_encoding=self._categorical_action_encoding,
1411
+ )
1412
+ # When the action space is MultiDiscrete and an action_mask is present in the
1413
+ # observation with shape matching nvec, we convert to a flattened Categorical/OneHot
1414
+ # so that the mask can be applied directly to all possible action combinations.
1415
+ # This is useful for grid-based games where the mask indicates valid (row, col) positions.
1416
+ gym_spaces = gym_backend("spaces")
1417
+ MultiDiscrete = getattr(gym_spaces, "MultiDiscrete", None)
1418
+ if MultiDiscrete is None:
1419
+ # Fallback for gym versions where MultiDiscrete is in a submodule
1420
+ multi_discrete_module = getattr(gym_spaces, "multi_discrete", None)
1421
+ if multi_discrete_module is not None:
1422
+ MultiDiscrete = getattr(multi_discrete_module, "MultiDiscrete", None)
1423
+ if MultiDiscrete is not None and isinstance(env.action_space, MultiDiscrete):
1424
+ nvec = np.asarray(env.action_space.nvec)
1425
+ if (
1426
+ nvec.ndim == 1
1427
+ and isinstance(observation_spec, Composite)
1428
+ and "action_mask" in observation_spec
1429
+ ):
1430
+ mask_spec = observation_spec["action_mask"]
1431
+ if tuple(mask_spec.shape) == tuple(nvec):
1432
+ prod_n = int(np.prod(nvec))
1433
+ dtype = (
1434
+ numpy_to_torch_dtype_dict[env.action_space.dtype]
1435
+ if self._categorical_action_encoding
1436
+ else torch.long
1437
+ )
1438
+ # Flattened action: single choice from prod(nvec) options.
1439
+ # The mask (which has shape matching nvec) will be reshaped
1440
+ # by Categorical/OneHot.update_mask when applied.
1441
+ if self._categorical_action_encoding:
1442
+ action_spec = Categorical(
1443
+ prod_n,
1444
+ shape=(),
1445
+ device=self.device,
1446
+ dtype=dtype,
1447
+ )
1448
+ else:
1449
+ action_spec = OneHot(
1450
+ prod_n,
1451
+ shape=(prod_n,),
1452
+ device=self.device,
1453
+ dtype=torch.bool,
1454
+ )
1455
+ if not isinstance(observation_spec, Composite):
1456
+ if self.from_pixels:
1457
+ observation_spec = Composite(
1458
+ pixels=observation_spec, shape=cur_batch_size
1459
+ )
1460
+ else:
1461
+ observation_spec = Composite(
1462
+ observation=observation_spec, shape=cur_batch_size
1463
+ )
1464
+ elif observation_spec.shape[: len(cur_batch_size)] != cur_batch_size:
1465
+ observation_spec.shape = cur_batch_size
1466
+
1467
+ reward_space = self._reward_space(env)
1468
+ if reward_space is not None:
1469
+ reward_spec = _gym_to_torchrl_spec_transform(
1470
+ reward_space,
1471
+ device=self.device,
1472
+ categorical_action_encoding=self._categorical_action_encoding,
1473
+ )
1474
+ else:
1475
+ reward_spec = Unbounded(
1476
+ shape=[1],
1477
+ device=self.device,
1478
+ )
1479
+ if batch_size is not None:
1480
+ action_spec = action_spec.expand(*batch_size, *action_spec.shape)
1481
+ reward_spec = reward_spec.expand(*batch_size, *reward_spec.shape)
1482
+ observation_spec = observation_spec.expand(
1483
+ *batch_size, *observation_spec.shape
1484
+ )
1485
+
1486
+ self.done_spec = self._make_done_spec()
1487
+ self.action_spec = action_spec
1488
+ if reward_spec.shape[: len(cur_batch_size)] != cur_batch_size:
1489
+ self.reward_spec = reward_spec.expand(*cur_batch_size, *reward_spec.shape)
1490
+ else:
1491
+ self.reward_spec = reward_spec
1492
+ self.observation_spec = observation_spec
1493
+
1494
+ @implement_for("gym", None, "0.26")
1495
+ def _make_done_spec(self): # noqa: F811
1496
+ return Composite(
1497
+ {
1498
+ "done": Categorical(
1499
+ 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
1500
+ ),
1501
+ "terminated": Categorical(
1502
+ 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
1503
+ ),
1504
+ "truncated": Categorical(
1505
+ 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
1506
+ ),
1507
+ },
1508
+ shape=self.batch_size,
1509
+ )
1510
+
1511
+ @implement_for("gym", "0.26", None)
1512
+ def _make_done_spec(self): # noqa: F811
1513
+ return Composite(
1514
+ {
1515
+ "done": Categorical(
1516
+ 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
1517
+ ),
1518
+ "terminated": Categorical(
1519
+ 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
1520
+ ),
1521
+ "truncated": Categorical(
1522
+ 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
1523
+ ),
1524
+ },
1525
+ shape=self.batch_size,
1526
+ )
1527
+
1528
+ @implement_for("gymnasium", "0.27", None)
1529
+ def _make_done_spec(self): # noqa: F811
1530
+ return Composite(
1531
+ {
1532
+ "done": Categorical(
1533
+ 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
1534
+ ),
1535
+ "terminated": Categorical(
1536
+ 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
1537
+ ),
1538
+ "truncated": Categorical(
1539
+ 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
1540
+ ),
1541
+ },
1542
+ shape=self.batch_size,
1543
+ )
1544
+
1545
+ @implement_for("gym", None, "0.26")
1546
+ def _reset_output_transform(self, reset_data): # noqa: F811
1547
+ if (
1548
+ isinstance(reset_data, tuple)
1549
+ and len(reset_data) == 2
1550
+ and isinstance(reset_data[1], dict)
1551
+ ):
1552
+ return reset_data
1553
+ return reset_data, None
1554
+
1555
+ @implement_for("gym", "0.26", None)
1556
+ def _reset_output_transform(self, reset_data): # noqa: F811
1557
+ return reset_data
1558
+
1559
+ @implement_for("gymnasium", "0.27", None)
1560
+ def _reset_output_transform(self, reset_data): # noqa: F811
1561
+ return reset_data
1562
+
1563
+ @implement_for("gym", None, "0.24")
1564
+ def _output_transform(self, step_outputs_tuple): # noqa: F811
1565
+ observations, reward, done, info = step_outputs_tuple
1566
+ if self._is_batched:
1567
+ # info needs to be flipped
1568
+ info = _flip_info_tuple(info)
1569
+ # The variable naming follows torchrl's convention here.
1570
+ # A done is interpreted the union of terminated and truncated.
1571
+ # (as in earlier versions of gym).
1572
+ truncated = info.pop("TimeLimit.truncated", False)
1573
+ if not isinstance(done, bool) and isinstance(truncated, bool):
1574
+ # if bool is an array, make truncated an array
1575
+ truncated = [truncated] * len(done)
1576
+ truncated = np.array(truncated)
1577
+ elif not isinstance(truncated, bool):
1578
+ # make sure it's a boolean np.array
1579
+ truncated = np.array(truncated, dtype=np.dtype("bool"))
1580
+ terminated = done & ~truncated
1581
+ if not isinstance(terminated, np.ndarray):
1582
+ # if it's not a ndarray, we must return bool
1583
+ # since it's not a bool, we make it so
1584
+ terminated = bool(terminated)
1585
+
1586
+ if isinstance(observations, list) and len(observations) == 1:
1587
+ # Until gym 0.25.2 we had rendered frames returned in lists of length 1
1588
+ observations = observations[0]
1589
+
1590
+ return (observations, reward, terminated, truncated, done, info)
1591
+
1592
+ @implement_for("gym", "0.24", "0.26")
1593
+ def _output_transform(self, step_outputs_tuple): # noqa: F811
1594
+ observations, reward, done, info = step_outputs_tuple
1595
+ # The variable naming follows torchrl's convention here.
1596
+ # A done is interpreted the union of terminated and truncated.
1597
+ # (as in earlier versions of gym).
1598
+ truncated = info.pop("TimeLimit.truncated", False)
1599
+ if not isinstance(done, bool) and isinstance(truncated, bool):
1600
+ # if bool is an array, make truncated an array
1601
+ truncated = [truncated] * len(done)
1602
+ truncated = np.array(truncated)
1603
+ elif not isinstance(truncated, bool):
1604
+ # make sure it's a boolean np.array
1605
+ truncated = np.array(truncated, dtype=np.dtype("bool"))
1606
+ terminated = done & ~truncated
1607
+ if not isinstance(terminated, np.ndarray):
1608
+ # if it's not a ndarray, we must return bool
1609
+ # since it's not a bool, we make it so
1610
+ terminated = bool(terminated)
1611
+
1612
+ if isinstance(observations, list) and len(observations) == 1:
1613
+ # Until gym 0.25.2 we had rendered frames returned in lists of length 1
1614
+ observations = observations[0]
1615
+
1616
+ return (observations, reward, terminated, truncated, done, info)
1617
+
1618
+ @implement_for("gym", "0.26", None)
1619
+ def _output_transform(self, step_outputs_tuple): # noqa: F811
1620
+ # The variable naming follows torchrl's convention here.
1621
+ observations, reward, terminated, truncated, info = step_outputs_tuple
1622
+ return (
1623
+ observations,
1624
+ reward,
1625
+ terminated,
1626
+ truncated,
1627
+ terminated | truncated,
1628
+ info,
1629
+ )
1630
+
1631
+ @implement_for("gymnasium", "0.27", None)
1632
+ def _output_transform(self, step_outputs_tuple): # noqa: F811
1633
+ # The variable naming follows torchrl's convention here.
1634
+ observations, reward, terminated, truncated, info = step_outputs_tuple
1635
+ return (
1636
+ observations,
1637
+ reward,
1638
+ terminated,
1639
+ truncated,
1640
+ terminated | truncated,
1641
+ info,
1642
+ )
1643
+
1644
+ def _init_env(self):
1645
+ pass
1646
+ # init_reset = self.init_reset
1647
+ # if init_reset is None:
1648
+ # warnings.warn(f"init_env is None in the {type(self).__name__} constructor. The current "
1649
+ # f"default behavior is to reset the gym env as soon as it's wrapped in the "
1650
+ # f"class (init_reset=True), but from v0.9 this will be changed to False. "
1651
+ # f"To adapt for these changes, pass init_reset to your constructor.", category=FutureWarning)
1652
+ # init_reset = True
1653
+ # if init_reset:
1654
+ # self._env.reset()
1655
+
1656
+ def __repr__(self) -> str:
1657
+ return (
1658
+ f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})"
1659
+ )
1660
+
1661
+ def rebuild_with_kwargs(self, **new_kwargs):
1662
+ self._constructor_kwargs.update(new_kwargs)
1663
+ self._env = self._build_env(**self._constructor_kwargs)
1664
+ self._make_specs(self._env)
1665
+
1666
+ @implement_for("gym")
1667
+ def _replace_reset(self, reset, kwargs):
1668
+ return kwargs
1669
+
1670
+ @implement_for("gymnasium", None, "1.1.0")
1671
+ def _replace_reset(self, reset, kwargs): # noqa
1672
+ return kwargs
1673
+
1674
+ # From gymnasium 1.1.0, AutoresetMode.DISABLED is like resets in torchrl
1675
+ @implement_for("gymnasium", "1.1.0")
1676
+ def _replace_reset(self, reset, kwargs): # noqa
1677
+ import gymnasium as gym
1678
+
1679
+ if (
1680
+ getattr(self._env, "autoreset_mode", None)
1681
+ == gym.vector.AutoresetMode.DISABLED
1682
+ ):
1683
+ options = {"reset_mask": reset.view(self.batch_size).numpy()}
1684
+ kwargs.setdefault("options", {}).update(options)
1685
+ return kwargs
1686
+
1687
+ def _reset(
1688
+ self, tensordict: TensorDictBase | None = None, **kwargs
1689
+ ) -> TensorDictBase:
1690
+ if self._is_batched:
1691
+ # batched (aka 'vectorized') env reset is a bit special: envs are
1692
+ # automatically reset. What we do here is just to check if _reset
1693
+ # is present. If it is not, we just reset. Otherwise, we just skip.
1694
+ if tensordict is None:
1695
+ return super()._reset(tensordict, **kwargs)
1696
+ reset = tensordict.get("_reset", None)
1697
+ kwargs = self._replace_reset(reset, kwargs)
1698
+ if reset is not None:
1699
+ # we must copy the tensordict because the transform
1700
+ # expects a tuple (tensordict, tensordict_reset) where the
1701
+ # first still carries a _reset
1702
+ tensordict = tensordict.exclude("_reset")
1703
+ if reset is None or reset.all() or "options" in kwargs:
1704
+ result = super()._reset(tensordict, **kwargs)
1705
+ return result
1706
+ else:
1707
+ return tensordict
1708
+ return super()._reset(tensordict, **kwargs)
1709
+
1710
+
1711
+ ACCEPTED_TYPE_ERRORS = {
1712
+ "render_mode": "__init__() got an unexpected keyword argument 'render_mode'",
1713
+ "frame_skip": "unexpected keyword argument 'frameskip'",
1714
+ }
1715
+
1716
+
1717
+ class GymEnv(GymWrapper):
1718
+ """OpenAI Gym environment wrapper constructed by environment ID directly.
1719
+
1720
+ Works across `gymnasium <https://gymnasium.farama.org/>`_ and `OpenAI/gym <https://github.com/openai/gym>`_.
1721
+
1722
+ Args:
1723
+ env_name (str): the environment id registered in `gym.registry`.
1724
+ categorical_action_encoding (bool, optional): if ``True``, categorical
1725
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
1726
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
1727
+ Defaults to ``False``.
1728
+
1729
+ Keyword Args:
1730
+ num_envs (int, optional): the number of envs to run in parallel. Defaults to
1731
+ ``None`` (a single env is to be run). :class:`~gym.vector.AsyncVectorEnv`
1732
+ will be used by default.
1733
+ num_workers (int, optional): number of top-level worker subprocesses used to create/run
1734
+ multiple :class:`GymEnv` instances in parallel (handled by the metaclass
1735
+ :class:`_GymAsyncMeta`). When ``num_workers > 1``, a lazy
1736
+ :class:`~torchrl.envs.ParallelEnv` is returned whose factory preserves the original
1737
+ `GymEnv` kwargs. You can modify the ParallelEnv construction/configuration before
1738
+ it starts by calling :meth:`~torchrl.envs.batched_envs.BatchedEnvBase.configure_parallel`
1739
+ on the returned object (for example: ``env.configure_parallel(use_buffers=True, num_threads=2)``).
1740
+ When both ``num_workers`` and ``num_envs`` are greater than 1, the total number of
1741
+ environments executed in parallel is ``num_workers * num_envs``. Defaults to ``1``.
1742
+ disable_env_checker (bool, optional): for gym > 0.24 only. If ``True`` (default
1743
+ for these versions), the environment checker won't be run.
1744
+ from_pixels (bool, optional): if ``True``, an attempt to return the pixel
1745
+ observations from the env will be performed. By default, these observations
1746
+ will be written under the ``"pixels"`` entry.
1747
+ The method being used varies
1748
+ depending on the gym version and may involve a ``wrappers.pixel_observation.PixelObservationWrapper``.
1749
+ Defaults to ``False``.
1750
+ pixels_only (bool, optional): if ``True``, only the pixel observations will
1751
+ be returned (by default under the ``"pixels"`` entry in the output tensordict).
1752
+ If ``False``, observations (eg, states) and pixels will be returned
1753
+ whenever ``from_pixels=True``. Defaults to ``False``.
1754
+ frame_skip (int, optional): if provided, indicates for how many steps the
1755
+ same action is to be repeated. The observation returned will be the
1756
+ last observation of the sequence, whereas the reward will be the sum
1757
+ of rewards across steps.
1758
+ device (torch.device, optional): if provided, the device on which the data
1759
+ is to be cast. Defaults to ``torch.device("cpu")``.
1760
+ batch_size (torch.Size, optional): the batch size of the environment.
1761
+ Should match the leading dimensions of all observations, done states,
1762
+ rewards, actions and infos.
1763
+ Defaults to ``torch.Size([])``.
1764
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
1765
+ for envs to be ``done`` just after :meth:`reset` is called.
1766
+ Defaults to ``False``.
1767
+
1768
+ Attributes:
1769
+ available_envs (List[str]): the list of envs that can be built.
1770
+
1771
+ .. note::
1772
+ If an attribute cannot be found, this class will attempt to retrieve it from
1773
+ the nested env:
1774
+
1775
+ >>> from torchrl.envs import GymEnv
1776
+ >>> env = GymEnv("Pendulum-v1")
1777
+ >>> print(env.spec.max_episode_steps)
1778
+ 200
1779
+
1780
+
1781
+ If a use-case is not covered by TorchRL, please submit an issue on GitHub.
1782
+
1783
+ Examples:
1784
+ >>> from torchrl.envs import GymEnv
1785
+ >>> env = GymEnv("Pendulum-v1")
1786
+ >>> td = env.rand_step()
1787
+ >>> print(td)
1788
+ TensorDict(
1789
+ fields={
1790
+ action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
1791
+ next: TensorDict(
1792
+ fields={
1793
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
1794
+ observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
1795
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
1796
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
1797
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
1798
+ batch_size=torch.Size([]),
1799
+ device=cpu,
1800
+ is_shared=False)},
1801
+ batch_size=torch.Size([]),
1802
+ device=cpu,
1803
+ is_shared=False)
1804
+ >>> print(env.available_envs)
1805
+ ['ALE/Adventure-ram-v5', 'ALE/Adventure-v5', 'ALE/AirRaid-ram-v5', 'ALE/AirRaid-v5', 'ALE/Alien-ram-v5', 'ALE/Alien-v5',
1806
+
1807
+ To run multiple environments in parallel:
1808
+ >>> from torchrl.envs import GymEnv
1809
+ >>> env = GymEnv("Pendulum-v1", num_workers=4)
1810
+ >>> td_reset = env.reset()
1811
+ >>> td = env.rand_step(td_reset)
1812
+ >>> print(td)
1813
+ TensorDict(
1814
+ fields={
1815
+ action: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
1816
+ done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1817
+ next: TensorDict(
1818
+ fields={
1819
+ done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1820
+ observation: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
1821
+ reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
1822
+ terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1823
+ truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
1824
+ batch_size=torch.Size([4]),
1825
+ device=None,
1826
+ is_shared=False),
1827
+ observation: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
1828
+ terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1829
+ truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
1830
+ batch_size=torch.Size([4]),
1831
+ device=None,
1832
+ is_shared=False)
1833
+
1834
+ .. note::
1835
+ If both `OpenAI/gym` and `gymnasium` are present in the virtual environment,
1836
+ one can swap backend using :func:`~torchrl.envs.libs.gym.set_gym_backend`:
1837
+
1838
+ >>> from torchrl.envs import set_gym_backend, GymEnv
1839
+ >>> with set_gym_backend("gym"):
1840
+ ... env = GymEnv("Pendulum-v1")
1841
+ ... print(env._env)
1842
+ <class 'gym.wrappers.time_limit.TimeLimit'>
1843
+ >>> with set_gym_backend("gymnasium"):
1844
+ ... env = GymEnv("Pendulum-v1")
1845
+ ... print(env._env)
1846
+ <class 'gymnasium.wrappers.time_limit.TimeLimit'>
1847
+
1848
+ .. note::
1849
+ info dictionaries will be read using :class:`~torchrl.envs.gym_like.default_info_dict_reader`
1850
+ if no other reader is provided. To provide another reader, refer to
1851
+ :meth:`set_info_dict_reader`. To automatically register the info_dict
1852
+ content, refer to :meth:`torchrl.envs.GymLikeEnv.auto_register_info_dict`.
1853
+
1854
+ .. note:: Gym spaces are not completely covered.
1855
+ The following spaces are accounted for provided that they can be represented by a torch.Tensor, a nested tensor
1856
+ and/or within a tensordict:
1857
+
1858
+ - spaces.Box
1859
+ - spaces.Sequence
1860
+ - spaces.Tuple
1861
+ - spaces.Discrete
1862
+ - spaces.MultiBinary
1863
+ - spaces.MultiDiscrete
1864
+ - spaces.Dict
1865
+
1866
+ Some considerations should be made when working with gym spaces. For instance, a tuple of spaces
1867
+ can only be supported if the spaces are semantically identical (same dtype and same number of dimensions).
1868
+ Ragged dimension can be supported through :func:`~torch.nested.nested_tensor`, but then there should be only
1869
+ one level of tuple and data should be stacked along the first dimension (as nested_tensors can only be
1870
+ stacked along the first dimension).
1871
+
1872
+ Check the example in examples/envs/gym_conversion_examples.py to know more!
1873
+
1874
+ """
1875
+
1876
+ def __init__(self, env_name, **kwargs):
1877
+ backend = kwargs.pop("backend", None)
1878
+ with set_gym_backend(backend) if backend is not None else nullcontext():
1879
+ kwargs["env_name"] = env_name
1880
+ self._set_gym_args(kwargs)
1881
+ super().__init__(**kwargs)
1882
+
1883
+ @implement_for("gym", None, "0.24.0")
1884
+ def _set_gym_args(self, kwargs) -> None: # noqa: F811
1885
+ disable_env_checker = kwargs.pop("disable_env_checker", None)
1886
+ if disable_env_checker is not None:
1887
+ raise RuntimeError(
1888
+ "disable_env_checker should only be set if gym version is > 0.24"
1889
+ )
1890
+
1891
+ @implement_for("gym", "0.24.0", None)
1892
+ def _set_gym_args( # noqa: F811
1893
+ self,
1894
+ kwargs,
1895
+ ) -> None:
1896
+ kwargs.setdefault("disable_env_checker", True)
1897
+
1898
+ @implement_for("gymnasium", "1.0.0", "1.1.0")
1899
+ def _set_gym_args( # noqa: F811
1900
+ self,
1901
+ kwargs,
1902
+ ) -> None:
1903
+ raise ImportError(GYMNASIUM_1_ERROR)
1904
+
1905
+ @implement_for("gymnasium", None, "1.0.0")
1906
+ def _set_gym_args( # noqa: F811
1907
+ self,
1908
+ kwargs,
1909
+ ) -> None:
1910
+ kwargs.setdefault("disable_env_checker", True)
1911
+
1912
+ @implement_for("gymnasium", "1.1.0")
1913
+ def _set_gym_args( # noqa: F811
1914
+ self,
1915
+ kwargs,
1916
+ ) -> None:
1917
+ kwargs.setdefault("disable_env_checker", True)
1918
+
1919
+ def _async_env(self, *args, **kwargs):
1920
+ return gym_backend("vector").AsyncVectorEnv(*args, **kwargs)
1921
+
1922
+ def _build_env(
1923
+ self,
1924
+ env_name: str,
1925
+ **kwargs,
1926
+ ) -> gym.core.Env: # noqa: F821
1927
+ if not _has_gym:
1928
+ raise RuntimeError(
1929
+ f"gym not found, unable to create {env_name}. "
1930
+ f"Consider downloading and installing gym from"
1931
+ f" {self.git_url}"
1932
+ )
1933
+ from_pixels = kwargs.pop("from_pixels", False)
1934
+ self._set_gym_default(kwargs, from_pixels)
1935
+ pixels_only = kwargs.pop("pixels_only", True)
1936
+ num_envs = kwargs.pop("num_envs", 0)
1937
+ made_env = False
1938
+ kwargs["frameskip"] = self.frame_skip
1939
+ self.wrapper_frame_skip = 1
1940
+ while not made_env:
1941
+ # env.__init__ may not be compatible with all the kwargs that
1942
+ # have been preset. We iterate through the various solutions
1943
+ # to find the config that works.
1944
+ try:
1945
+ with warnings.catch_warnings(record=True) as w:
1946
+ if env_name.startswith("ALE/"):
1947
+ try:
1948
+ import ale_py # noqa: F401
1949
+ except ImportError as err:
1950
+ torchrl_logger.warning(
1951
+ f"ale_py not found, this may cause issues with ALE environments: {err}"
1952
+ )
1953
+ # we catch warnings as they may cause silent bugs
1954
+ env = self.lib.make(env_name, **kwargs)
1955
+ if len(w) and "frameskip" in str(w[-1].message):
1956
+ raise TypeError("unexpected keyword argument 'frameskip'")
1957
+ made_env = True
1958
+ except TypeError as err:
1959
+ if ACCEPTED_TYPE_ERRORS["frame_skip"] in str(err):
1960
+ # we can disable this, not strictly indispensable to know
1961
+ # warn(
1962
+ # "Discarding frameskip arg. This will be taken care of by TorchRL env wrapper."
1963
+ # )
1964
+ self.wrapper_frame_skip = kwargs.pop("frameskip")
1965
+ elif ACCEPTED_TYPE_ERRORS["render_mode"] in str(err):
1966
+ warn("Discarding render_mode from the env constructor.")
1967
+ kwargs.pop("render_mode")
1968
+ else:
1969
+ raise err
1970
+ env = super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels)
1971
+ if num_envs > 0:
1972
+ make_fn = partial(self.lib.make, env_name, **kwargs)
1973
+ env = self._async_env([make_fn] * num_envs)
1974
+ self.batch_size = torch.Size([num_envs, *self.batch_size])
1975
+ return env
1976
+
1977
+ @implement_for("gym", None, "0.25.0")
1978
+ def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811
1979
+ # Do nothing for older gym versions (render_mode was introduced in 0.25.0).
1980
+ pass
1981
+
1982
+ @implement_for("gym", "0.25.0", None)
1983
+ def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811
1984
+ if from_pixels:
1985
+ kwargs.setdefault("render_mode", "rgb_array")
1986
+
1987
+ @implement_for("gymnasium", None, "0.27.0")
1988
+ def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811
1989
+ # gymnasium < 0.27.0 also supports render_mode (forked from gym 0.26+)
1990
+ if from_pixels:
1991
+ kwargs.setdefault("render_mode", "rgb_array")
1992
+
1993
+ @implement_for("gymnasium", "0.27.0", None)
1994
+ def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811
1995
+ if from_pixels:
1996
+ kwargs.setdefault("render_mode", "rgb_array")
1997
+
1998
+ @property
1999
+ def env_name(self):
2000
+ return self._constructor_kwargs["env_name"]
2001
+
2002
+ def _check_kwargs(self, kwargs: dict):
2003
+ if "env_name" not in kwargs:
2004
+ raise TypeError("Expected 'env_name' to be part of kwargs")
2005
+
2006
+ def __repr__(self) -> str:
2007
+ return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})"
2008
+
2009
+
2010
+ class MOGymWrapper(GymWrapper):
2011
+ """FARAMA MO-Gymnasium environment wrapper.
2012
+
2013
+ Examples:
2014
+ >>> import mo_gymnasium as mo_gym
2015
+ >>> env = MOGymWrapper(mo_gym.make('minecart-v0'), frame_skip=4)
2016
+ >>> td = env.rand_step()
2017
+ >>> print(td)
2018
+
2019
+ """
2020
+
2021
+ git_url = "https://github.com/Farama-Foundation/MO-Gymnasium"
2022
+ libname = "mo-gymnasium"
2023
+
2024
+ _make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs)
2025
+
2026
+ @_classproperty
2027
+ def available_envs(cls):
2028
+ if not _has_mo:
2029
+ return []
2030
+ return [
2031
+ "deep-sea-treasure-v0",
2032
+ "deep-sea-treasure-concave-v0",
2033
+ "resource-gathering-v0",
2034
+ "fishwood-v0",
2035
+ "breakable-bottles-v0",
2036
+ "fruit-tree-v0",
2037
+ "water-reservoir-v0",
2038
+ "four-room-v0",
2039
+ "mo-mountaincar-v0",
2040
+ "mo-mountaincarcontinuous-v0",
2041
+ "mo-lunar-lander-v2",
2042
+ "minecart-v0",
2043
+ "mo-highway-v0",
2044
+ "mo-highway-fast-v0",
2045
+ "mo-supermario-v0",
2046
+ "mo-reacher-v4",
2047
+ "mo-hopper-v4",
2048
+ "mo-halfcheetah-v4",
2049
+ ]
2050
+
2051
+
2052
+ class MOGymEnv(GymEnv):
2053
+ """FARAMA MO-Gymnasium environment wrapper.
2054
+
2055
+ Examples:
2056
+ >>> env = MOGymEnv(env_name="minecart-v0", frame_skip=4)
2057
+ >>> td = env.rand_step()
2058
+ >>> print(td)
2059
+ >>> print(env.available_envs)
2060
+
2061
+ """
2062
+
2063
+ git_url = "https://github.com/Farama-Foundation/MO-Gymnasium"
2064
+ libname = "mo-gymnasium"
2065
+
2066
+ available_envs = MOGymWrapper.available_envs
2067
+
2068
+ @property
2069
+ def lib(self) -> ModuleType:
2070
+ if _has_mo:
2071
+ import mo_gymnasium as mo_gym
2072
+
2073
+ return mo_gym
2074
+ else:
2075
+ try:
2076
+ import mo_gymnasium # noqa: F401
2077
+ except ImportError as err:
2078
+ raise ImportError("MO-gymnasium not found, check installation") from err
2079
+
2080
+ _make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs)
2081
+
2082
+
2083
+ class terminal_obs_reader(default_info_dict_reader):
2084
+ """Terminal observation reader for 'vectorized' gym environments.
2085
+
2086
+ When running envs in parallel, Gym(nasium) writes the result of the true call
2087
+ to `step` in `"final_observation"` entry within the `info` dictionary.
2088
+
2089
+ This breaks the natural flow and makes single-processed and multiprocessed envs
2090
+ incompatible.
2091
+
2092
+ This class reads the info obs, removes the `"final_observation"` from
2093
+ the env and writes its content in the data.
2094
+
2095
+ Next, a :class:`torchrl.envs.VecGymEnvTransform` transform will reorganise the
2096
+ data by caching the result of the (implicit) reset and swap the true next
2097
+ observation with the reset one. At reset time, the true reset data will be
2098
+ replaced.
2099
+
2100
+ Args:
2101
+ observation_spec (Composite): The observation spec of the gym env.
2102
+ backend (str, optional): the backend of the env. One of `"sb3"` for
2103
+ stable-baselines3 or `"gym"` for gym/gymnasium.
2104
+
2105
+ .. note:: In general, this class should not be handled directly. It is
2106
+ created whenever a vectorized environment is placed within a :class:`GymWrapper`.
2107
+
2108
+ """
2109
+
2110
+ backend_key = {
2111
+ "sb3": "terminal_observation",
2112
+ "gym": "final_observation",
2113
+ "gymnasium": "final_obs",
2114
+ }
2115
+ backend_info_key = {
2116
+ "sb3": "terminal_info",
2117
+ "gym": "final_info",
2118
+ "gymnasium": "final_info",
2119
+ }
2120
+
2121
+ def __init__(self, observation_spec: Composite, backend, name="final"):
2122
+ super().__init__()
2123
+ self.name = name
2124
+ self._obs_spec = observation_spec.clone()
2125
+ self.backend = backend
2126
+ self._final_validated = False
2127
+
2128
+ @property
2129
+ def info_spec(self):
2130
+ return self._info_spec
2131
+
2132
+ def _read_obs(self, obs, key, tensor, index):
2133
+ if obs is None:
2134
+ return
2135
+ if isinstance(obs, np.ndarray):
2136
+ # Simplest case: there is one observation,
2137
+ # presented as a np.ndarray. The key should be pixels or observation.
2138
+ # We just write that value at its location in the tensor
2139
+ tensor[index] = torch.as_tensor(obs, device=tensor.device)
2140
+ if isinstance(obs, torch.Tensor):
2141
+ # Simplest case: there is one observation,
2142
+ # presented as a np.ndarray. The key should be pixels or observation.
2143
+ # We just write that value at its location in the tensor
2144
+ tensor[index] = obs.to(device=tensor.device)
2145
+ elif isinstance(obs, dict):
2146
+ if key not in obs:
2147
+ raise KeyError(
2148
+ f"The observation {key} could not be found in the final observation dict."
2149
+ )
2150
+ subobs = obs[key]
2151
+ if subobs is not None:
2152
+ # if the obs is a dict, we expect that the key points also to
2153
+ # a value in the obs. We retrieve this value and write it in the
2154
+ # tensor
2155
+ tensor[index] = torch.as_tensor(subobs, device=tensor.device)
2156
+
2157
+ elif isinstance(obs, (list, tuple)):
2158
+ # tuples are stacked along the first dimension when passing gym spaces
2159
+ # to torchrl specs. As such, we can simply stack the tuple and set it
2160
+ # at the relevant index (assuming stacking can be achieved)
2161
+ tensor[index] = torch.as_tensor(obs, device=tensor.device)
2162
+ else:
2163
+ raise NotImplementedError(
2164
+ f"Observations of type {type(obs)} are not supported yet."
2165
+ )
2166
+
2167
+ def __call__(self, info_dict, tensordict):
2168
+ # TODO: This is a tad slow, we iterate over each sub-env and call spec.zero() at each step.
2169
+ # In theory we could spare that whole thing but we need to run it once at the beginning if specs
2170
+ # of the info reader are not passed as we need to observe the data to infer the spec.
2171
+ # We should find a way to avoid this call altogether is no env is resetting.
2172
+ def replace_none(nparray):
2173
+ if not isinstance(nparray, np.ndarray) or nparray.dtype != np.dtype("O"):
2174
+ return nparray
2175
+ is_none = np.array([info is None for info in nparray])
2176
+ if is_none.any():
2177
+ # Then it is a final observation and we delegate the registration to the appropriate reader
2178
+ nz = (~is_none).nonzero()[0][0]
2179
+ zero_like = tree_map(lambda x: np.zeros_like(x), nparray[nz])
2180
+ for idx in is_none.nonzero()[0]:
2181
+ nparray[idx] = zero_like
2182
+ # tree_map with multiple trees was added in PyTorch 2.2
2183
+ if TORCH_VERSION >= version.parse("2.2"):
2184
+ return tree_map(lambda *x: np.stack(x), *nparray)
2185
+ else:
2186
+ # For older PyTorch versions, manually flatten/unflatten
2187
+ flat_lists_specs = [tree_flatten(tree) for tree in nparray]
2188
+ flat_lists = [fl for fl, _ in flat_lists_specs]
2189
+ spec = flat_lists_specs[0][1]
2190
+ stacked = [np.stack(elems) for elems in zip(*flat_lists)]
2191
+ return tree_unflatten(stacked, spec)
2192
+
2193
+ info_dict = tree_map(replace_none, info_dict)
2194
+ # convert info_dict to a tensordict
2195
+ info_dict = TensorDict(info_dict)
2196
+ # get the terminal observation
2197
+ terminal_obs = info_dict.pop(self.backend_key[self.backend], None)
2198
+ # get the terminal info dict
2199
+ terminal_info = info_dict.pop(self.backend_info_key[self.backend], None)
2200
+
2201
+ if terminal_info is None:
2202
+ terminal_info = {}
2203
+
2204
+ super().__call__(info_dict, tensordict)
2205
+ if not self._final_validated:
2206
+ self.info_spec[self.name] = self._obs_spec.update(self.info_spec)
2207
+ self._final_validated = True
2208
+
2209
+ final_info = terminal_info.copy()
2210
+ if terminal_obs is not None:
2211
+ final_info["observation"] = terminal_obs
2212
+
2213
+ for key in self.info_spec[self.name].keys():
2214
+ spec = self.info_spec[self.name, key]
2215
+
2216
+ final_obs_buffer = spec.zero()
2217
+ terminal_obs = final_info.get(key, None)
2218
+ if terminal_obs is not None:
2219
+ for i, obs in enumerate(terminal_obs):
2220
+ # writes final_obs inplace with terminal_obs content
2221
+ self._read_obs(obs, key, final_obs_buffer, index=i)
2222
+ tensordict.set((self.name, key), final_obs_buffer)
2223
+ return tensordict
2224
+
2225
+ def reset(self):
2226
+ super().reset()
2227
+ self._final_validated = False
2228
+
2229
+
2230
+ def _flip_info_tuple(info: tuple[dict]) -> dict[str, tuple]:
2231
+ # In Gym < 0.24, batched envs returned tuples of dict, and not dict of tuples.
2232
+ # We patch this by flipping the tuple -> dict order.
2233
+ info_example = set(info[0])
2234
+ for item in info[1:]:
2235
+ info_example = info_example.union(item)
2236
+ result = {}
2237
+ for key in info_example:
2238
+ result[key] = tuple(_info.get(key, None) for _info in info)
2239
+ return result