torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
torchrl/envs/common.py ADDED
@@ -0,0 +1,4241 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import abc
9
+ import re
10
+ import warnings
11
+ import weakref
12
+ from collections.abc import Callable, Iterator, Sequence
13
+ from copy import deepcopy
14
+ from functools import partial, wraps
15
+ from typing import Any
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ from tensordict import (
21
+ is_tensor_collection,
22
+ LazyStackedTensorDict,
23
+ TensorDictBase,
24
+ unravel_key,
25
+ )
26
+ from tensordict.base import _is_leaf_nontensor, NO_DEFAULT
27
+ from tensordict.utils import is_non_tensor, NestedKey
28
+ from torchrl._utils import (
29
+ _ends_with,
30
+ _make_ordinal_device,
31
+ _replace_last,
32
+ implement_for,
33
+ prod,
34
+ seed_generator,
35
+ )
36
+
37
+ from torchrl.data.tensor_specs import (
38
+ Categorical,
39
+ Composite,
40
+ NonTensor,
41
+ TensorSpec,
42
+ Unbounded,
43
+ )
44
+ from torchrl.data.utils import DEVICE_TYPING
45
+ from torchrl.envs.utils import (
46
+ _make_compatible_policy,
47
+ _repr_by_depth,
48
+ _StepMDP,
49
+ _terminated_or_truncated,
50
+ _update_during_reset,
51
+ check_env_specs as check_env_specs_func,
52
+ get_available_libraries,
53
+ )
54
+
55
+ LIBRARIES = get_available_libraries()
56
+
57
+
58
+ def _tensor_to_np(t):
59
+ return t.detach().cpu().numpy()
60
+
61
+
62
+ dtype_map = {
63
+ torch.float: np.float32,
64
+ torch.double: np.float64,
65
+ torch.bool: bool,
66
+ }
67
+
68
+
69
+ def _maybe_unlock(func):
70
+ @wraps(func)
71
+ def wrapper(self, *args, **kwargs):
72
+ is_locked = self.is_spec_locked
73
+ try:
74
+ if is_locked:
75
+ self.set_spec_lock_(False)
76
+ result = func(self, *args, **kwargs)
77
+ finally:
78
+ if is_locked:
79
+ if not hasattr(self, "_cache"):
80
+ self._cache = {}
81
+ self._cache.clear()
82
+ self.set_spec_lock_(True)
83
+ return result
84
+
85
+ return wrapper
86
+
87
+
88
+ def _cache_value(func):
89
+ """Caches the result of the decorated function in env._cache dictionary."""
90
+ func_name = func.__name__
91
+
92
+ @wraps(func)
93
+ def wrapper(self, *args, **kwargs):
94
+ if not self.is_spec_locked:
95
+ return func(self, *args, **kwargs)
96
+ result = self.__dict__.setdefault("_cache", {}).get(func_name, NO_DEFAULT)
97
+ if result is NO_DEFAULT:
98
+ result = func(self, *args, **kwargs)
99
+ self.__dict__.setdefault("_cache", {})[func_name] = result
100
+ return result
101
+
102
+ return wrapper
103
+
104
+
105
+ def _clear_cache_when_set(func):
106
+ """A decorator for EnvBase methods that should clear the caches when called."""
107
+
108
+ @wraps(func)
109
+ def wrapper(self, *args, **kwargs):
110
+ # if there's no cache we'll just recompute the value
111
+ if "_cache" not in self.__dict__:
112
+ self._cache = {}
113
+ else:
114
+ self._cache.clear()
115
+ result = func(self, *args, **kwargs)
116
+ self._cache.clear()
117
+ return result
118
+
119
+ return wrapper
120
+
121
+
122
+ class EnvMetaData:
123
+ """A class for environment meta-data storage and passing in multiprocessed settings."""
124
+
125
+ def __init__(
126
+ self,
127
+ *,
128
+ tensordict: TensorDictBase,
129
+ specs: Composite,
130
+ batch_size: torch.Size,
131
+ env_str: str,
132
+ device: torch.device,
133
+ batch_locked: bool,
134
+ device_map: dict,
135
+ ):
136
+ self.device = device
137
+ self.tensordict = tensordict
138
+ self.specs = specs
139
+ self.batch_size = batch_size
140
+ self.env_str = env_str
141
+ self.batch_locked = batch_locked
142
+ self.device_map = device_map
143
+ self.has_dynamic_specs = _has_dynamic_specs(specs)
144
+
145
+ @property
146
+ def tensordict(self) -> TensorDictBase:
147
+ td = self._tensordict.copy()
148
+ if td.device != self.device:
149
+ if self.device is None:
150
+ return td.clear_device_()
151
+ else:
152
+ return td.to(self.device)
153
+ return td
154
+
155
+ @property
156
+ def specs(self):
157
+ return self._specs.to(self.device)
158
+
159
+ @tensordict.setter
160
+ def tensordict(self, value: TensorDictBase):
161
+ self._tensordict = value.to("cpu")
162
+
163
+ @specs.setter
164
+ def specs(self, value: Composite):
165
+ self._specs = value.to("cpu")
166
+
167
+ @staticmethod
168
+ def metadata_from_env(env) -> EnvMetaData:
169
+ tensordict = env.fake_tensordict().clone()
170
+
171
+ for done_key in env.done_keys:
172
+ tensordict.set(
173
+ _replace_last(done_key, "_reset"),
174
+ torch.zeros_like(tensordict.get(("next", done_key))),
175
+ )
176
+
177
+ specs = env.specs.to("cpu")
178
+
179
+ batch_size = env.batch_size
180
+ try:
181
+ env_str = str(env)
182
+ except Exception:
183
+ env_str = f"{env.__class__.__name__}()"
184
+ device = env.device
185
+ specs = specs.to("cpu")
186
+ batch_locked = env.batch_locked
187
+ # we need to save the device map, as the tensordict will be placed on cpu
188
+ device_map = {}
189
+
190
+ def fill_device_map(name, val, device_map=device_map):
191
+ device_map[name] = val.device
192
+
193
+ tensordict.named_apply(fill_device_map, nested_keys=True, filter_empty=True)
194
+ return EnvMetaData(
195
+ tensordict=tensordict,
196
+ specs=specs,
197
+ batch_size=batch_size,
198
+ env_str=env_str,
199
+ device=device,
200
+ batch_locked=batch_locked,
201
+ device_map=device_map,
202
+ )
203
+
204
+ def expand(self, *size: int) -> EnvMetaData:
205
+ tensordict = self.tensordict.expand(*size).clone()
206
+ batch_size = torch.Size(list(size))
207
+ return EnvMetaData(
208
+ tensordict=tensordict,
209
+ specs=self.specs.expand(*size),
210
+ batch_size=batch_size,
211
+ env_str=self.env_str,
212
+ device=self.device,
213
+ batch_locked=self.batch_locked,
214
+ device_map=self.device_map,
215
+ )
216
+
217
+ def clone(self):
218
+ return EnvMetaData(
219
+ tensordict=self.tensordict.clone(),
220
+ specs=self.specs.clone(),
221
+ batch_size=torch.Size([*self.batch_size]),
222
+ env_str=deepcopy(self.env_str),
223
+ device=self.device,
224
+ batch_locked=self.batch_locked,
225
+ device_map=self.device_map,
226
+ )
227
+
228
+ def to(self, device: DEVICE_TYPING) -> EnvMetaData:
229
+ if device is not None:
230
+ device = _make_ordinal_device(torch.device(device))
231
+ device_map = {key: device for key in self.device_map}
232
+ tensordict = self.tensordict.contiguous().to(device)
233
+ specs = self.specs.to(device)
234
+ return EnvMetaData(
235
+ tensordict=tensordict,
236
+ specs=specs,
237
+ batch_size=self.batch_size,
238
+ env_str=self.env_str,
239
+ device=device,
240
+ batch_locked=self.batch_locked,
241
+ device_map=device_map,
242
+ )
243
+
244
+ def __getitem__(self, item):
245
+ from tensordict.utils import _getitem_batch_size
246
+
247
+ return EnvMetaData(
248
+ tensordict=self.tensordict[item],
249
+ specs=self.specs[item],
250
+ batch_size=_getitem_batch_size(self.batch_size, item),
251
+ env_str=self.env_str,
252
+ device=self.device,
253
+ batch_locked=self.batch_locked,
254
+ device_map=self.device_map,
255
+ )
256
+
257
+
258
+ class _EnvPostInit(abc.ABCMeta):
259
+ def __call__(cls, *args, **kwargs):
260
+ spec_locked = kwargs.pop("spec_locked", True)
261
+ auto_reset = kwargs.pop("auto_reset", False)
262
+ auto_reset_replace = kwargs.pop("auto_reset_replace", True)
263
+ instance: EnvBase = super().__call__(*args, **kwargs)
264
+ if "_cache" not in instance.__dict__:
265
+ instance._cache = {}
266
+
267
+ if spec_locked:
268
+ instance.input_spec.lock_(recurse=True)
269
+ instance.output_spec.lock_(recurse=True)
270
+ instance._is_spec_locked = spec_locked
271
+
272
+ # we create the done spec by adding a done/terminated entry if one is missing
273
+ instance._create_done_specs()
274
+ # we access lazy attributed to make sure they're built properly.
275
+ # This isn't done in `__init__` because we don't know if super().__init__
276
+ # will be called before or after the specs, batch size etc are set.
277
+ _ = instance.done_spec
278
+ _ = instance.reward_keys
279
+ # _ = instance.action_keys
280
+ _ = instance.state_spec
281
+ if auto_reset:
282
+ from torchrl.envs.transforms.transforms import (
283
+ AutoResetEnv,
284
+ AutoResetTransform,
285
+ )
286
+
287
+ return AutoResetEnv(
288
+ instance, AutoResetTransform(replace=auto_reset_replace)
289
+ )
290
+
291
+ done_keys = set(instance.full_done_spec.keys(True, True))
292
+ obs_keys = set(instance.full_observation_spec.keys(True, True))
293
+ reward_keys = set(instance.full_reward_spec.keys(True, True))
294
+ # state_keys can match obs_keys so we don't test that
295
+ action_keys = set(instance.full_action_spec.keys(True, True))
296
+ state_keys = set(instance.full_state_spec.keys(True, True))
297
+ total_set = set()
298
+ for keyset in (done_keys, obs_keys, reward_keys):
299
+ if total_set.intersection(keyset):
300
+ raise RuntimeError(
301
+ f"The set of keys of one spec collides (culprit: {total_set.intersection(keyset)}) with another."
302
+ )
303
+ total_set = total_set.union(keyset)
304
+ total_set = set()
305
+ for keyset in (state_keys, action_keys):
306
+ if total_set.intersection(keyset):
307
+ raise RuntimeError(
308
+ f"The set of keys of one spec collides (culprit: {total_set.intersection(keyset)}) with another."
309
+ )
310
+ total_set = total_set.union(keyset)
311
+ return instance
312
+
313
+
314
+ class EnvBase(nn.Module, metaclass=_EnvPostInit):
315
+ """Abstract environment parent class.
316
+
317
+ Keyword Args:
318
+ device (torch.device): The device of the environment. Deviceless environments
319
+ are allowed (device=None). If not ``None``, all specs will be cast
320
+ on that device and it is expected that all inputs and outputs will
321
+ live on that device.
322
+ Defaults to ``None``.
323
+ batch_size (torch.Size or equivalent, optional): batch-size of the environment.
324
+ Corresponds to the leading dimension of all the input and output
325
+ tensordicts the environment reads and writes. Defaults to an empty batch-size.
326
+ run_type_checks (bool, optional): If ``True``, type-checks will occur
327
+ at every reset and every step. Defaults to ``False``.
328
+ allow_done_after_reset (bool, optional): if ``True``, an environment can
329
+ be done after a call to :meth:`reset` is made. Defaults to ``False``.
330
+ spec_locked (bool, optional): if ``True``, the specs are locked and can only be
331
+ modified if :meth:`~torchrl.envs.EnvBase.set_spec_lock_` is called.
332
+
333
+ .. note:: The locking is achieved by the `EnvBase` metaclass. It does not appear in the
334
+ `__init__` method and is included in the keyword arguments strictly for type-hinting purpose.
335
+
336
+ .. seealso:: :ref:`Locking environment specs <Environment-lock>`.
337
+
338
+ Defaults to ``True``.
339
+ auto_reset (bool, optional): if ``True``, the env is assumed to reset automatically
340
+ when done. Defaults to ``False``.
341
+
342
+ .. note:: The auto-resetting is achieved by the `EnvBase` metaclass. It does not appear in the
343
+ `__init__` method and is included in the keyword arguments strictly for type-hinting purpose.
344
+
345
+ .. seealso:: The :ref:`auto-resetting environments API <autoresetting_envs>` section in the API
346
+ documentation.
347
+
348
+ Attributes:
349
+ done_spec (Composite): equivalent to ``full_done_spec`` as all
350
+ ``done_specs`` contain at least a ``"done"`` and a ``"terminated"`` entry
351
+ action_spec (TensorSpec): the spec of the action. Links to the spec of the leaf
352
+ action if only one action tensor is to be expected. Otherwise links to
353
+ ``full_action_spec``.
354
+ observation_spec (Composite): equivalent to ``full_observation_spec``.
355
+ reward_spec (TensorSpec): the spec of the reward. Links to the spec of the leaf
356
+ reward if only one reward tensor is to be expected. Otherwise links to
357
+ ``full_reward_spec``.
358
+ state_spec (Composite): equivalent to ``full_state_spec``.
359
+ full_done_spec (Composite): a composite spec such that ``full_done_spec.zero()``
360
+ returns a tensordict containing only the leaves encoding the done status of the
361
+ environment.
362
+ full_action_spec (Composite): a composite spec such that ``full_action_spec.zero()``
363
+ returns a tensordict containing only the leaves encoding the action of the
364
+ environment.
365
+ full_observation_spec (Composite): a composite spec such that ``full_observation_spec.zero()``
366
+ returns a tensordict containing only the leaves encoding the observation of the
367
+ environment.
368
+ full_reward_spec (Composite): a composite spec such that ``full_reward_spec.zero()``
369
+ returns a tensordict containing only the leaves encoding the reward of the
370
+ environment.
371
+ full_state_spec (Composite): a composite spec such that ``full_state_spec.zero()``
372
+ returns a tensordict containing only the leaves encoding the inputs (actions
373
+ excluded) of the environment.
374
+ batch_size (torch.Size): The batch-size of the environment.
375
+ device (torch.device): the device where the input/outputs of the environment
376
+ are to be expected. Can be ``None``.
377
+ is_spec_locked (bool): returns ``True`` if the specs are locked. See the :attr:`spec_locked`
378
+ argument above.
379
+
380
+ Methods:
381
+ step (TensorDictBase -> TensorDictBase): step in the environment
382
+ reset (TensorDictBase, optional -> TensorDictBase): reset the environment
383
+ set_seed (int -> int): sets the seed of the environment
384
+ rand_step (TensorDictBase, optional -> TensorDictBase): random step given the action spec
385
+ rollout (Callable, ... -> TensorDictBase): executes a rollout in the environment with the given policy (or random
386
+ steps if no policy is provided)
387
+
388
+ Examples:
389
+ >>> from torchrl.envs import EnvBase
390
+ >>> class CounterEnv(EnvBase):
391
+ ... def __init__(self, batch_size=(), device=None, **kwargs):
392
+ ... self.observation_spec = Composite(
393
+ ... count=Unbounded(batch_size, device=device, dtype=torch.int64))
394
+ ... self.action_spec = Unbounded(batch_size, device=device, dtype=torch.int8)
395
+ ... # done spec and reward spec are set automatically
396
+ ... def _step(self, tensordict):
397
+ ...
398
+ >>> from torchrl.envs.libs.gym import GymEnv
399
+ >>> env = GymEnv("Pendulum-v1")
400
+ >>> env.batch_size # how many envs are run at once
401
+ torch.Size([])
402
+ >>> env.input_spec
403
+ Composite(
404
+ full_state_spec: None,
405
+ full_action_spec: Composite(
406
+ action: BoundedContinuous(
407
+ shape=torch.Size([1]),
408
+ space=ContinuousBox(
409
+ low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
410
+ high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
411
+ device=cpu,
412
+ dtype=torch.float32,
413
+ domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
414
+ >>> env.action_spec
415
+ BoundedContinuous(
416
+ shape=torch.Size([1]),
417
+ space=ContinuousBox(
418
+ low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
419
+ high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
420
+ device=cpu,
421
+ dtype=torch.float32,
422
+ domain=continuous)
423
+ >>> env.observation_spec
424
+ Composite(
425
+ observation: BoundedContinuous(
426
+ shape=torch.Size([3]),
427
+ space=ContinuousBox(
428
+ low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
429
+ high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
430
+ device=cpu,
431
+ dtype=torch.float32,
432
+ domain=continuous), device=cpu, shape=torch.Size([]))
433
+ >>> env.reward_spec
434
+ UnboundedContinuous(
435
+ shape=torch.Size([1]),
436
+ space=None,
437
+ device=cpu,
438
+ dtype=torch.float32,
439
+ domain=continuous)
440
+ >>> env.done_spec
441
+ Categorical(
442
+ shape=torch.Size([1]),
443
+ space=DiscreteBox(n=2),
444
+ device=cpu,
445
+ dtype=torch.bool,
446
+ domain=discrete)
447
+ >>> # the output_spec contains all the expected outputs
448
+ >>> env.output_spec
449
+ Composite(
450
+ full_reward_spec: Composite(
451
+ reward: UnboundedContinuous(
452
+ shape=torch.Size([1]),
453
+ space=None,
454
+ device=cpu,
455
+ dtype=torch.float32,
456
+ domain=continuous), device=cpu, shape=torch.Size([])),
457
+ full_observation_spec: Composite(
458
+ observation: BoundedContinuous(
459
+ shape=torch.Size([3]),
460
+ space=ContinuousBox(
461
+ low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
462
+ high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
463
+ device=cpu,
464
+ dtype=torch.float32,
465
+ domain=continuous), device=cpu, shape=torch.Size([])),
466
+ full_done_spec: Composite(
467
+ done: Categorical(
468
+ shape=torch.Size([1]),
469
+ space=DiscreteBox(n=2),
470
+ device=cpu,
471
+ dtype=torch.bool,
472
+ domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
473
+
474
+ .. note:: Learn more about dynamic specs and environments :ref:`here <dynamic_envs>`.
475
+ """
476
+
477
+ _batch_size: torch.Size | None
478
+ _device: torch.device | None
479
+ _is_spec_locked: bool = False
480
+
481
+ def __init__(
482
+ self,
483
+ *,
484
+ device: DEVICE_TYPING | None = None,
485
+ batch_size: tuple | torch.Size | None = None,
486
+ run_type_checks: bool = False,
487
+ allow_done_after_reset: bool = False,
488
+ spec_locked: bool = True,
489
+ auto_reset: bool = False,
490
+ ):
491
+ if "_cache" not in self.__dict__:
492
+ self._cache = {}
493
+ super().__init__()
494
+
495
+ self.__dict__.setdefault("_batch_size", None)
496
+ self.__dict__.setdefault("_device", None)
497
+
498
+ if batch_size is not None:
499
+ # we want an error to be raised if we pass batch_size but
500
+ # it's already been set
501
+ batch_size = self.batch_size = torch.Size(batch_size)
502
+ else:
503
+ batch_size = torch.Size(())
504
+
505
+ if device is not None:
506
+ device = self.__dict__["_device"] = _make_ordinal_device(
507
+ torch.device(device)
508
+ )
509
+
510
+ output_spec = self.__dict__.get("_output_spec")
511
+ if output_spec is None:
512
+ output_spec = self.__dict__["_output_spec"] = Composite(
513
+ shape=batch_size, device=device
514
+ )
515
+ elif self._output_spec.device != device and device is not None:
516
+ self.__dict__["_output_spec"] = self.__dict__["_output_spec"].to(
517
+ self.device
518
+ )
519
+ input_spec = self.__dict__.get("_input_spec")
520
+ if input_spec is None:
521
+ input_spec = self.__dict__["_input_spec"] = Composite(
522
+ shape=batch_size, device=device
523
+ )
524
+ elif self._input_spec.device != device and device is not None:
525
+ self.__dict__["_input_spec"] = self.__dict__["_input_spec"].to(self.device)
526
+
527
+ output_spec.unlock_(recurse=True)
528
+ input_spec.unlock_(recurse=True)
529
+ if "full_observation_spec" not in output_spec:
530
+ output_spec["full_observation_spec"] = Composite(batch_size=batch_size)
531
+ if "full_done_spec" not in output_spec:
532
+ output_spec["full_done_spec"] = Composite(batch_size=batch_size)
533
+ if "full_reward_spec" not in output_spec:
534
+ output_spec["full_reward_spec"] = Composite(batch_size=batch_size)
535
+ if "full_state_spec" not in input_spec:
536
+ input_spec["full_state_spec"] = Composite(batch_size=batch_size)
537
+ if "full_action_spec" not in input_spec:
538
+ input_spec["full_action_spec"] = Composite(batch_size=batch_size)
539
+
540
+ if "is_closed" not in self.__dir__():
541
+ self.is_closed = True
542
+ self._run_type_checks = run_type_checks
543
+ self._allow_done_after_reset = allow_done_after_reset
544
+
545
+ _collector: weakref.ReferenceType[
546
+ LLMCollector # noqa: F821 # type: ignore
547
+ ] | None = None
548
+
549
+ def register_collector(self, collector: BaseCollector): # noqa: F821 # type: ignore
550
+ """Registers a collector with the environment.
551
+
552
+ Args:
553
+ collector (BaseCollector): The collector to register.
554
+ """
555
+ self._collector = weakref.ref(collector)
556
+
557
+ @property
558
+ def collector(self) -> BaseCollector | None: # noqa: F821 # type: ignore
559
+ """Returns the collector associated with the container, if it exists."""
560
+ return self._collector() if self._collector is not None else None
561
+
562
+ def set_spec_lock_(self, mode: bool = True) -> EnvBase:
563
+ """Locks or unlocks the environment's specs.
564
+
565
+ Args:
566
+ mode (bool): Whether to lock (`True`) or unlock (`False`) the specs. Defaults to `True`.
567
+
568
+ Returns:
569
+ EnvBase: The environment instance itself.
570
+
571
+ .. seealso:: :ref:`Locking environment specs <Environment-lock>`.
572
+
573
+ """
574
+ output_spec = self.__dict__.get("_output_spec")
575
+ input_spec = self.__dict__.get("_input_spec")
576
+ if mode:
577
+ if output_spec is not None:
578
+ output_spec.lock_(recurse=True)
579
+ if input_spec is not None:
580
+ input_spec.lock_(recurse=True)
581
+ else:
582
+ self._cache.clear()
583
+ if output_spec is not None:
584
+ output_spec.unlock_(recurse=True)
585
+ if input_spec is not None:
586
+ input_spec.unlock_(recurse=True)
587
+ self.__dict__["_is_spec_locked"] = mode
588
+ return self
589
+
590
+ @property
591
+ def is_spec_locked(self):
592
+ """Gets whether the environment's specs are locked.
593
+
594
+ This property can be modified directly.
595
+
596
+ Returns:
597
+ bool: True if the specs are locked, False otherwise.
598
+
599
+ .. seealso:: :ref:`Locking environment specs <Environment-lock>`.
600
+
601
+ """
602
+ return self.__dict__.get("_is_spec_locked", False)
603
+
604
+ @is_spec_locked.setter
605
+ def is_spec_locked(self, value: bool):
606
+ self.set_spec_lock_(value)
607
+
608
+ def auto_specs_(
609
+ self,
610
+ policy: Callable[[TensorDictBase], TensorDictBase],
611
+ *,
612
+ tensordict: TensorDictBase | None = None,
613
+ action_key: NestedKey | list[NestedKey] = "action",
614
+ done_key: NestedKey | list[NestedKey] | None = None,
615
+ observation_key: NestedKey | list[NestedKey] = "observation",
616
+ reward_key: NestedKey | list[NestedKey] = "reward",
617
+ ):
618
+ """Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
619
+
620
+ This method performs a rollout using the provided policy to infer the input and output specifications of the environment.
621
+ It updates the environment's specs for actions, observations, rewards, and done signals based on the data collected
622
+ during the rollout.
623
+
624
+ Args:
625
+ policy (Callable[[TensorDictBase], TensorDictBase]):
626
+ A callable policy that takes a `TensorDictBase` as input and returns a `TensorDictBase` as output.
627
+ This policy is used to perform the rollout and determine the specs.
628
+
629
+ Keyword Args:
630
+ tensordict (TensorDictBase, optional):
631
+ An optional `TensorDictBase` instance to be used as the initial state for the rollout.
632
+ If not provided, the environment's `reset` method will be called to obtain the initial state.
633
+ action_key (NestedKey or List[NestedKey], optional):
634
+ The key(s) used to identify actions in the `TensorDictBase`. Defaults to "action".
635
+ done_key (NestedKey or List[NestedKey], optional):
636
+ The key(s) used to identify done signals in the `TensorDictBase`. Defaults to ``None``, which will
637
+ attempt to use ["done", "terminated", "truncated"] as potential keys.
638
+ observation_key (NestedKey or List[NestedKey], optional):
639
+ The key(s) used to identify observations in the `TensorDictBase`. Defaults to "observation".
640
+ reward_key (NestedKey or List[NestedKey], optional):
641
+ The key(s) used to identify rewards in the `TensorDictBase`. Defaults to "reward".
642
+
643
+ Returns:
644
+ EnvBase: The environment instance with updated specs.
645
+
646
+ Raises:
647
+ RuntimeError: If there are keys in the output specs that are not accounted for in the provided keys.
648
+ """
649
+ if self.batch_locked or tensordict is None:
650
+ batch_size = self.batch_size
651
+ else:
652
+ batch_size = tensordict.batch_size
653
+ if tensordict is None:
654
+ tensordict = self.reset()
655
+
656
+ # Input specs
657
+ tensordict.update(policy(tensordict))
658
+ step_0 = self.step(tensordict.copy())
659
+ tensordict2 = step_0.get("next").copy()
660
+ step_1 = self.step(policy(tensordict2).copy())
661
+ nexts_0: TensorDictBase = step_0.pop("next")
662
+ nexts_1: TensorDictBase = step_1.pop("next")
663
+
664
+ input_spec_stack = {}
665
+ tensordict.apply(
666
+ partial(_tensor_to_spec, stack=input_spec_stack),
667
+ tensordict2,
668
+ named=True,
669
+ nested_keys=True,
670
+ is_leaf=_is_leaf_nontensor,
671
+ )
672
+ input_spec = Composite(input_spec_stack, batch_size=batch_size)
673
+ if not self.batch_locked and batch_size != self.batch_size:
674
+ while input_spec.shape:
675
+ input_spec = input_spec[0]
676
+ if isinstance(action_key, NestedKey):
677
+ action_key = [action_key]
678
+ full_action_spec = input_spec.separates(*action_key, default=None)
679
+
680
+ # Output specs
681
+
682
+ output_spec_stack = {}
683
+ nexts_0.apply(
684
+ partial(_tensor_to_spec, stack=output_spec_stack),
685
+ nexts_1,
686
+ named=True,
687
+ nested_keys=True,
688
+ is_leaf=_is_leaf_nontensor,
689
+ )
690
+
691
+ output_spec = Composite(output_spec_stack, batch_size=batch_size)
692
+ if not self.batch_locked and batch_size != self.batch_size:
693
+ while output_spec.shape:
694
+ output_spec = output_spec[0]
695
+
696
+ if done_key is None:
697
+ done_key = ["done", "terminated", "truncated"]
698
+ full_done_spec = output_spec.separates(*done_key, default=None)
699
+ if full_done_spec is not None:
700
+ self.full_done_spec = full_done_spec
701
+
702
+ if isinstance(reward_key, NestedKey):
703
+ reward_key = [reward_key]
704
+ full_reward_spec = output_spec.separates(*reward_key, default=None)
705
+
706
+ if isinstance(observation_key, NestedKey):
707
+ observation_key = [observation_key]
708
+ full_observation_spec = output_spec.separates(*observation_key, default=None)
709
+ if not output_spec.is_empty(recurse=True):
710
+ raise RuntimeError(
711
+ f"Keys {list(output_spec.keys(True, True))} are unaccounted for. "
712
+ f"Make sure you have passed all the leaf names to the auto_specs_ method."
713
+ )
714
+
715
+ if full_action_spec is not None:
716
+ self.full_action_spec = full_action_spec
717
+ if full_done_spec is not None:
718
+ self.full_done_spec = full_done_spec
719
+ if full_observation_spec is not None:
720
+ self.full_observation_spec = full_observation_spec
721
+ if full_reward_spec is not None:
722
+ self.full_reward_spec = full_reward_spec
723
+ full_state_spec = input_spec
724
+ self.full_state_spec = full_state_spec
725
+
726
+ return self
727
+
728
+ def check_env_specs(self, *args, **kwargs):
729
+ kwargs.setdefault("return_contiguous", not self._has_dynamic_specs)
730
+ return check_env_specs_func(self, *args, **kwargs)
731
+
732
+ check_env_specs.__doc__ = check_env_specs_func.__doc__
733
+
734
+ def cardinality(self, tensordict: TensorDictBase | None = None) -> int:
735
+ """The cardinality of the action space.
736
+
737
+ By default, this is just a wrapper around :meth:`env.action_space.cardinality <~torchrl.data.TensorSpec.cardinality>`.
738
+
739
+ This class is useful when the action spec is variable:
740
+
741
+ - The number of actions can be undefined, e.g., ``Categorical(n=-1)``;
742
+ - The action cardinality may depend on the action mask;
743
+ - The shape can be dynamic, as in ``Unbound(shape=(-1))``.
744
+
745
+ In these cases, the :meth:`cardinality` should be overwritten,
746
+
747
+ Args:
748
+ tensordict (TensorDictBase, optional): a tensordict containing the data required to compute the cardinality.
749
+
750
+ """
751
+ return self.full_action_spec.cardinality()
752
+
753
+ def configure_parallel(
754
+ self,
755
+ *,
756
+ use_buffers: bool | None = None,
757
+ shared_memory: bool | None = None,
758
+ memmap: bool | None = None,
759
+ mp_start_method: str | None = None,
760
+ num_threads: int | None = None,
761
+ num_sub_threads: int | None = None,
762
+ non_blocking: bool | None = None,
763
+ daemon: bool | None = None,
764
+ ) -> EnvBase:
765
+ """Configure parallel execution parameters.
766
+
767
+ This method allows configuring parameters for parallel environment
768
+ execution before the environment is started. It is only effective
769
+ on :class:`~torchrl.envs.BatchedEnvBase` and its subclasses.
770
+
771
+ Args:
772
+ use_buffers (bool, optional): whether communication between workers should
773
+ occur via circular preallocated memory buffers.
774
+ shared_memory (bool, optional): whether the returned tensordict will be
775
+ placed in shared memory.
776
+ memmap (bool, optional): whether the returned tensordict will be placed
777
+ in memory map.
778
+ mp_start_method (str, optional): the multiprocessing start method.
779
+ num_threads (int, optional): number of threads for this process.
780
+ num_sub_threads (int, optional): number of threads of the subprocesses.
781
+ non_blocking (bool, optional): if ``True``, device moves will be done using
782
+ the ``non_blocking=True`` option.
783
+ daemon (bool, optional): whether the processes should be daemonized.
784
+
785
+ Returns:
786
+ self: Returns self for method chaining.
787
+
788
+ Raises:
789
+ NotImplementedError: If called on an environment that does not support
790
+ parallel configuration.
791
+ RuntimeError: If called after the environment has already started.
792
+
793
+ Example:
794
+ >>> env = DMControlEnv("cheetah", "run", num_envs=4)
795
+ >>> env.configure_parallel(use_buffers=True, num_threads=2)
796
+ >>> env.reset() # Environment starts here, configure_parallel no longer effective
797
+
798
+ """
799
+ raise NotImplementedError(
800
+ f"{type(self).__name__} does not support configure_parallel. "
801
+ "This method is only available on BatchedEnvBase and its subclasses."
802
+ )
803
+
804
+ @classmethod
805
+ def make_parallel(
806
+ cls,
807
+ create_env_fn,
808
+ *,
809
+ num_envs: int = 1,
810
+ create_env_kwargs: dict | Sequence[dict] | None = None,
811
+ pin_memory: bool = False,
812
+ share_individual_td: bool | None = None,
813
+ shared_memory: bool = True,
814
+ memmap: bool = False,
815
+ policy_proof: Callable | None = None,
816
+ device: DEVICE_TYPING | None = None,
817
+ allow_step_when_done: bool = False,
818
+ num_threads: int | None = None,
819
+ num_sub_threads: int = 1,
820
+ serial_for_single: bool = False,
821
+ non_blocking: bool = False,
822
+ mp_start_method: str | None = None,
823
+ use_buffers: bool | None = None,
824
+ consolidate: bool = True,
825
+ daemon: bool = False,
826
+ **parallel_kwargs,
827
+ ) -> EnvBase:
828
+ """Factory method to create a ParallelEnv from an environment creator.
829
+
830
+ This method provides a convenient way to create parallel environments
831
+ with the same signature as :class:`~torchrl.envs.ParallelEnv`.
832
+
833
+ Args:
834
+ create_env_fn (callable): A callable that creates an environment instance.
835
+ num_envs (int, optional): Number of parallel environments. Defaults to 1.
836
+ create_env_kwargs (dict or list of dicts, optional): kwargs to be used
837
+ with the environments being created.
838
+ pin_memory (bool, optional): Whether to pin memory. Defaults to False.
839
+ share_individual_td (bool, optional): if ``True``, a different tensordict
840
+ is created for every process/worker and a lazy stack is returned.
841
+ shared_memory (bool, optional): whether the returned tensordict will be
842
+ placed in shared memory. Defaults to True.
843
+ memmap (bool, optional): whether the returned tensordict will be placed
844
+ in memory map. Defaults to False.
845
+ policy_proof (callable, optional): if provided, it'll be used to get
846
+ the list of tensors to return through step() and reset() methods.
847
+ device (str, int, torch.device, optional): The device of the batched
848
+ environment.
849
+ allow_step_when_done (bool, optional): Allow stepping when done.
850
+ Defaults to False.
851
+ num_threads (int, optional): number of threads for this process.
852
+ num_sub_threads (int, optional): number of threads of the subprocesses.
853
+ Defaults to 1.
854
+ serial_for_single (bool, optional): if ``True``, creating a parallel
855
+ environment with a single worker will return a SerialEnv instead.
856
+ Defaults to False.
857
+ non_blocking (bool, optional): if ``True``, device moves will be done
858
+ using the ``non_blocking=True`` option. Defaults to False.
859
+ mp_start_method (str, optional): the multiprocessing start method.
860
+ use_buffers (bool, optional): whether communication between workers
861
+ should occur via circular preallocated memory buffers.
862
+ consolidate (bool, optional): Whether to consolidate tensordicts.
863
+ Defaults to True.
864
+ daemon (bool, optional): whether the processes should be daemonized.
865
+ Defaults to False.
866
+ **parallel_kwargs: Additional keyword arguments passed to ParallelEnv.
867
+
868
+ Returns:
869
+ EnvBase: A ParallelEnv (or SerialEnv if serial_for_single=True and num_envs=1).
870
+
871
+ """
872
+ from torchrl.envs import ParallelEnv
873
+
874
+ return ParallelEnv(
875
+ num_workers=num_envs,
876
+ create_env_fn=create_env_fn,
877
+ create_env_kwargs=create_env_kwargs,
878
+ pin_memory=pin_memory,
879
+ share_individual_td=share_individual_td,
880
+ shared_memory=shared_memory,
881
+ memmap=memmap,
882
+ policy_proof=policy_proof,
883
+ device=device,
884
+ allow_step_when_done=allow_step_when_done,
885
+ num_threads=num_threads,
886
+ num_sub_threads=num_sub_threads,
887
+ serial_for_single=serial_for_single,
888
+ non_blocking=non_blocking,
889
+ mp_start_method=mp_start_method,
890
+ use_buffers=use_buffers,
891
+ consolidate=consolidate,
892
+ daemon=daemon,
893
+ **parallel_kwargs,
894
+ )
895
+
896
+ @classmethod
897
+ def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs):
898
+ # inplace update will write tensors in-place on the provided tensordict.
899
+ # This is risky, especially if gradients need to be passed (in-place copy
900
+ # for tensors that are part of computational graphs will result in an error).
901
+ # It can also lead to inconsistencies when calling rollout.
902
+ cls._inplace_update = _inplace_update
903
+ cls._batch_locked = _batch_locked
904
+ cls._device = None
905
+ # cached in_keys to be excluded from update when calling step
906
+ cls._cache_in_keys = None
907
+
908
+ # We may assign _input_spec to the cls, but it must be assigned to the instance
909
+ # we pull it off, and place it back where it belongs
910
+ _input_spec = None
911
+ if hasattr(cls, "_input_spec"):
912
+ _input_spec = cls._input_spec.clone()
913
+ delattr(cls, "_input_spec")
914
+ _output_spec = None
915
+ if hasattr(cls, "_output_spec"):
916
+ _output_spec = cls._output_spec.clone()
917
+ delattr(cls, "_output_spec")
918
+ env = super().__new__(cls)
919
+ if _input_spec is not None:
920
+ env.__dict__["_input_spec"] = _input_spec
921
+ if _output_spec is not None:
922
+ env.__dict__["_output_spec"] = _output_spec
923
+ return env
924
+
925
+ return super().__new__(cls)
926
+
927
+ def __setattr__(self, key, value):
928
+ if key in (
929
+ "_input_spec",
930
+ "_observation_spec",
931
+ "_action_spec",
932
+ "_reward_spec",
933
+ "_output_spec",
934
+ "_state_spec",
935
+ "_done_spec",
936
+ ):
937
+ raise AttributeError(
938
+ "To set an environment spec, please use `env.observation_spec = obs_spec` (without the leading"
939
+ " underscore)."
940
+ )
941
+ super().__setattr__(key, value)
942
+
943
+ @property
944
+ def batch_locked(self) -> bool:
945
+ """Whether the environment can be used with a batch size different from the one it was initialized with or not.
946
+
947
+ If True, the env needs to be used with a tensordict having the same batch size as the env.
948
+ batch_locked is an immutable property.
949
+ """
950
+ return self._batch_locked
951
+
952
+ @batch_locked.setter
953
+ def batch_locked(self, value: bool) -> None:
954
+ raise RuntimeError("batch_locked is a read-only property")
955
+
956
+ @property
957
+ def run_type_checks(self) -> bool:
958
+ return self._run_type_checks
959
+
960
+ @run_type_checks.setter
961
+ def run_type_checks(self, run_type_checks: bool) -> None:
962
+ self._run_type_checks = run_type_checks
963
+
964
+ @property
965
+ def batch_size(self) -> torch.Size:
966
+ """Number of envs batched in this environment instance organised in a `torch.Size()` object.
967
+
968
+ Environment may be similar or different but it is assumed that they have little if
969
+ not no interactions between them (e.g., multi-task or batched execution
970
+ in parallel).
971
+
972
+ """
973
+ _batch_size = self.__dict__.get("_batch_size")
974
+ if _batch_size is None:
975
+ _batch_size = self._batch_size = torch.Size([])
976
+ return _batch_size
977
+
978
+ @batch_size.setter
979
+ @_maybe_unlock
980
+ def batch_size(self, value: torch.Size) -> None:
981
+ self._batch_size = torch.Size(value)
982
+ if (
983
+ hasattr(self, "output_spec")
984
+ and self.output_spec.shape[: len(value)] != value
985
+ ):
986
+ self.output_spec.shape = value
987
+ if hasattr(self, "input_spec") and self.input_spec.shape[: len(value)] != value:
988
+ self.input_spec.shape = value
989
+
990
+ @property
991
+ def shape(self):
992
+ """Equivalent to :attr:`~.batch_size`."""
993
+ return self.batch_size
994
+
995
+ @property
996
+ def device(self) -> torch.device:
997
+ device = self.__dict__.get("_device")
998
+ return device
999
+
1000
+ @device.setter
1001
+ def device(self, value: torch.device) -> None:
1002
+ device = self.__dict__.get("_device")
1003
+ if device is None:
1004
+ self.__dict__["_device"] = value
1005
+ return
1006
+ raise RuntimeError("device cannot be set. Call env.to(device) instead.")
1007
+
1008
+ def ndimension(self):
1009
+ return len(self.batch_size)
1010
+
1011
+ @property
1012
+ def ndim(self):
1013
+ return self.ndimension()
1014
+
1015
+ def append_transform(
1016
+ self,
1017
+ transform: Transform | Callable[[TensorDictBase], TensorDictBase], # noqa: F821
1018
+ ) -> torchrl.envs.TransformedEnv: # noqa
1019
+ """Returns a transformed environment where the callable/transform passed is applied.
1020
+
1021
+ Args:
1022
+ transform (Transform or Callable[[TensorDictBase], TensorDictBase]): the transform to apply
1023
+ to the environment.
1024
+
1025
+ Examples:
1026
+ >>> from torchrl.envs import GymEnv
1027
+ >>> import torch
1028
+ >>> env = GymEnv("CartPole-v1")
1029
+ >>> loc = 0.5
1030
+ >>> scale = 1.0
1031
+ >>> transform = lambda data: data.set("observation", (data.get("observation") - loc)/scale)
1032
+ >>> env = env.append_transform(transform=transform)
1033
+ >>> print(env)
1034
+ TransformedEnv(
1035
+ env=GymEnv(env=CartPole-v1, batch_size=torch.Size([]), device=cpu),
1036
+ transform=_CallableTransform(keys=[]))
1037
+
1038
+ """
1039
+ from torchrl.envs.transforms.transforms import TransformedEnv
1040
+
1041
+ return TransformedEnv(self, transform)
1042
+
1043
+ # Parent specs: input and output spec.
1044
+ @property
1045
+ def input_spec(self) -> TensorSpec:
1046
+ """Input spec.
1047
+
1048
+ The composite spec containing all specs for data input to the environments.
1049
+
1050
+ It contains:
1051
+
1052
+ - "full_action_spec": the spec of the input actions
1053
+ - "full_state_spec": the spec of all other environment inputs
1054
+
1055
+ This attribute is locked and should be read-only.
1056
+ Instead, to set the specs contained in it, use the respective properties.
1057
+
1058
+ Examples:
1059
+ >>> from torchrl.envs.libs.gym import GymEnv
1060
+ >>> env = GymEnv("Pendulum-v1")
1061
+ >>> env.input_spec
1062
+ Composite(
1063
+ full_state_spec: None,
1064
+ full_action_spec: Composite(
1065
+ action: BoundedContinuous(
1066
+ shape=torch.Size([1]),
1067
+ space=ContinuousBox(
1068
+ low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
1069
+ high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
1070
+ device=cpu,
1071
+ dtype=torch.float32,
1072
+ domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
1073
+
1074
+
1075
+ """
1076
+ input_spec = self.__dict__.get("_input_spec")
1077
+ if input_spec is None:
1078
+ is_locked = self.is_spec_locked
1079
+ if is_locked:
1080
+ self.set_spec_lock_(False)
1081
+ input_spec = Composite(
1082
+ full_state_spec=None,
1083
+ shape=self.batch_size,
1084
+ device=self.device,
1085
+ )
1086
+ self.__dict__["_input_spec"] = input_spec
1087
+ if is_locked:
1088
+ self.set_spec_lock_(True)
1089
+ return input_spec
1090
+
1091
+ @input_spec.setter
1092
+ def input_spec(self, value: TensorSpec) -> None:
1093
+ raise RuntimeError("input_spec is protected.")
1094
+
1095
+ @property
1096
+ def output_spec(self) -> TensorSpec:
1097
+ """Output spec.
1098
+
1099
+ The composite spec containing all specs for data output from the environments.
1100
+
1101
+ It contains:
1102
+
1103
+ - "full_reward_spec": the spec of reward
1104
+ - "full_done_spec": the spec of done
1105
+ - "full_observation_spec": the spec of all other environment outputs
1106
+
1107
+ This attribute is locked and should be read-only.
1108
+ Instead, to set the specs contained in it, use the respective properties.
1109
+
1110
+ Examples:
1111
+ >>> from torchrl.envs.libs.gym import GymEnv
1112
+ >>> env = GymEnv("Pendulum-v1")
1113
+ >>> env.output_spec
1114
+ Composite(
1115
+ full_reward_spec: Composite(
1116
+ reward: UnboundedContinuous(
1117
+ shape=torch.Size([1]),
1118
+ space=None,
1119
+ device=cpu,
1120
+ dtype=torch.float32,
1121
+ domain=continuous), device=cpu, shape=torch.Size([])),
1122
+ full_observation_spec: Composite(
1123
+ observation: BoundedContinuous(
1124
+ shape=torch.Size([3]),
1125
+ space=ContinuousBox(
1126
+ low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
1127
+ high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
1128
+ device=cpu,
1129
+ dtype=torch.float32,
1130
+ domain=continuous), device=cpu, shape=torch.Size([])),
1131
+ full_done_spec: Composite(
1132
+ done: Categorical(
1133
+ shape=torch.Size([1]),
1134
+ space=DiscreteBox(n=2),
1135
+ device=cpu,
1136
+ dtype=torch.bool,
1137
+ domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
1138
+
1139
+
1140
+ """
1141
+ output_spec = self.__dict__.get("_output_spec")
1142
+ if output_spec is None:
1143
+ is_locked = self.is_spec_locked
1144
+ if is_locked:
1145
+ self.set_spec_lock_(False)
1146
+ output_spec = Composite(
1147
+ shape=self.batch_size,
1148
+ device=self.device,
1149
+ )
1150
+ self.__dict__["_output_spec"] = output_spec
1151
+ if is_locked:
1152
+ self.set_spec_lock_(True)
1153
+ return output_spec
1154
+
1155
+ @output_spec.setter
1156
+ def output_spec(self, value: TensorSpec) -> None:
1157
+ raise RuntimeError("output_spec is protected.")
1158
+
1159
+ @property
1160
+ @_cache_value
1161
+ def action_keys(self) -> list[NestedKey]:
1162
+ """The action keys of an environment.
1163
+
1164
+ By default, there will only be one key named "action".
1165
+
1166
+ Keys are sorted by depth in the data tree.
1167
+ """
1168
+ keys = self.full_action_spec.keys(True, True)
1169
+ keys = sorted(keys, key=_repr_by_depth)
1170
+ return keys
1171
+
1172
+ @property
1173
+ @_cache_value
1174
+ def state_keys(self) -> list[NestedKey]:
1175
+ """The state keys of an environment.
1176
+
1177
+ By default, there will only be one key named "state".
1178
+
1179
+ Keys are sorted by depth in the data tree.
1180
+ """
1181
+ state_keys = self.__dict__.get("_state_keys")
1182
+ if state_keys is not None:
1183
+ return state_keys
1184
+ keys = self.input_spec["full_state_spec"].keys(True, True)
1185
+ keys = sorted(keys, key=_repr_by_depth)
1186
+ self.__dict__["_state_keys"] = keys
1187
+ return keys
1188
+
1189
+ @property
1190
+ def action_key(self) -> NestedKey:
1191
+ """The action key of an environment.
1192
+
1193
+ By default, this will be "action".
1194
+
1195
+ If there is more than one action key in the environment, this function will raise an exception.
1196
+ """
1197
+ if len(self.action_keys) > 1:
1198
+ raise KeyError(
1199
+ "action_key requested but more than one key present in the environment"
1200
+ )
1201
+ return self.action_keys[0]
1202
+
1203
+ # Action spec: action specs belong to input_spec
1204
+ @property
1205
+ def action_spec(self) -> TensorSpec:
1206
+ """The ``action`` spec.
1207
+
1208
+ The ``action_spec`` is always stored as a composite spec.
1209
+
1210
+ If the action spec is provided as a simple spec, this will be returned.
1211
+
1212
+ >>> env.action_spec = Unbounded(1)
1213
+ >>> env.action_spec
1214
+ UnboundedContinuous(
1215
+ shape=torch.Size([1]),
1216
+ space=ContinuousBox(
1217
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
1218
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
1219
+ device=cpu,
1220
+ dtype=torch.float32,
1221
+ domain=continuous)
1222
+
1223
+ If the action spec is provided as a composite spec and contains only one leaf,
1224
+ this function will return just the leaf.
1225
+
1226
+ >>> env.action_spec = Composite({"nested": {"action": Unbounded(1)}})
1227
+ >>> env.action_spec
1228
+ UnboundedContinuous(
1229
+ shape=torch.Size([1]),
1230
+ space=ContinuousBox(
1231
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
1232
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
1233
+ device=cpu,
1234
+ dtype=torch.float32,
1235
+ domain=continuous)
1236
+
1237
+ If the action spec is provided as a composite spec and has more than one leaf,
1238
+ this function will return the whole spec.
1239
+
1240
+ >>> env.action_spec = Composite({"nested": {"action": Unbounded(1), "another_action": Categorical(1)}})
1241
+ >>> env.action_spec
1242
+ Composite(
1243
+ nested: Composite(
1244
+ action: UnboundedContinuous(
1245
+ shape=torch.Size([1]),
1246
+ space=ContinuousBox(
1247
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
1248
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
1249
+ device=cpu,
1250
+ dtype=torch.float32,
1251
+ domain=continuous),
1252
+ another_action: Categorical(
1253
+ shape=torch.Size([]),
1254
+ space=DiscreteBox(n=1),
1255
+ device=cpu,
1256
+ dtype=torch.int64,
1257
+ domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
1258
+
1259
+ To retrieve the full spec passed, use:
1260
+
1261
+ >>> env.input_spec["full_action_spec"]
1262
+
1263
+ This property is mutable.
1264
+
1265
+ Examples:
1266
+ >>> from torchrl.envs.libs.gym import GymEnv
1267
+ >>> env = GymEnv("Pendulum-v1")
1268
+ >>> env.action_spec
1269
+ BoundedContinuous(
1270
+ shape=torch.Size([1]),
1271
+ space=ContinuousBox(
1272
+ low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
1273
+ high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
1274
+ device=cpu,
1275
+ dtype=torch.float32,
1276
+ domain=continuous)
1277
+ """
1278
+ try:
1279
+ action_spec = self.input_spec["full_action_spec"]
1280
+ except (KeyError, AttributeError):
1281
+ raise KeyError("Failed to find the action_spec.")
1282
+
1283
+ if len(self.action_keys) > 1:
1284
+ return action_spec
1285
+ else:
1286
+ if len(self.action_keys) == 1 and self.action_keys[0] != "action":
1287
+ return action_spec
1288
+ try:
1289
+ return action_spec[self.action_key]
1290
+ except KeyError:
1291
+ # the key may have changed
1292
+ raise KeyError(
1293
+ "The action_key attribute seems to have changed. "
1294
+ "This occurs when a action_spec is updated without "
1295
+ "calling `env.action_spec = new_spec`. "
1296
+ "Make sure you rely on this type of command "
1297
+ "to set the action and other specs."
1298
+ )
1299
+
1300
+ @action_spec.setter
1301
+ @_maybe_unlock
1302
+ def action_spec(self, value: TensorSpec) -> None:
1303
+ device = self.input_spec._device
1304
+ if not hasattr(value, "shape"):
1305
+ raise TypeError(
1306
+ f"action_spec of type {type(value)} do not have a shape attribute."
1307
+ )
1308
+ if value.shape[: len(self.batch_size)] != self.batch_size:
1309
+ raise ValueError(
1310
+ f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). "
1311
+ "Please use `env.action_spec_unbatched = value` to set unbatched versions instead."
1312
+ )
1313
+
1314
+ if not isinstance(value, Composite):
1315
+ value = Composite(
1316
+ action=value.to(device), shape=self.batch_size, device=device
1317
+ )
1318
+
1319
+ self.input_spec["full_action_spec"] = value.to(device)
1320
+
1321
+ @property
1322
+ def full_action_spec(self) -> Composite:
1323
+ """The full action spec.
1324
+
1325
+ ``full_action_spec`` is a :class:`~torchrl.data.Composite`` instance
1326
+ that contains all the action entries.
1327
+
1328
+ Examples:
1329
+ >>> from torchrl.envs import BraxEnv
1330
+ >>> for envname in BraxEnv.available_envs:
1331
+ ... break
1332
+ >>> env = BraxEnv(envname)
1333
+ >>> env.full_action_spec
1334
+ Composite(
1335
+ action: BoundedContinuous(
1336
+ shape=torch.Size([8]),
1337
+ space=ContinuousBox(
1338
+ low=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True),
1339
+ high=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True)),
1340
+ device=cpu,
1341
+ dtype=torch.float32,
1342
+ domain=continuous), device=cpu, shape=torch.Size([]))
1343
+
1344
+ """
1345
+ full_action_spec = self.input_spec.get("full_action_spec", None)
1346
+ if full_action_spec is None:
1347
+ is_locked = self.is_spec_locked
1348
+ if is_locked:
1349
+ self.set_spec_lock_(False)
1350
+ full_action_spec = Composite(shape=self.batch_size, device=self.device)
1351
+ self.input_spec["full_action_spec"] = full_action_spec
1352
+ if is_locked:
1353
+ self.set_spec_lock_(True)
1354
+ return full_action_spec
1355
+
1356
+ @full_action_spec.setter
1357
+ def full_action_spec(self, spec: Composite) -> None:
1358
+ self.action_spec = spec
1359
+
1360
+ # Reward spec
1361
+ @property
1362
+ @_cache_value
1363
+ def reward_keys(self) -> list[NestedKey]:
1364
+ """The reward keys of an environment.
1365
+
1366
+ By default, there will only be one key named "reward".
1367
+
1368
+ Keys are sorted by depth in the data tree.
1369
+ """
1370
+ reward_keys = sorted(self.full_reward_spec.keys(True, True), key=_repr_by_depth)
1371
+ return reward_keys
1372
+
1373
+ @property
1374
+ @_cache_value
1375
+ def observation_keys(self) -> list[NestedKey]:
1376
+ """The observation keys of an environment.
1377
+
1378
+ By default, there will only be one key named "observation".
1379
+
1380
+ Keys are sorted by depth in the data tree.
1381
+ """
1382
+ observation_keys = sorted(
1383
+ self.full_observation_spec.keys(True, True), key=_repr_by_depth
1384
+ )
1385
+ return observation_keys
1386
+
1387
+ @property
1388
+ @_cache_value
1389
+ def _observation_keys_step_mdp(self) -> list[NestedKey]:
1390
+ """The observation keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
1391
+ observation_keys_leaves = sorted(
1392
+ self.full_observation_spec.keys(True, True, step_mdp_static_only=True),
1393
+ key=_repr_by_depth,
1394
+ )
1395
+ return observation_keys_leaves
1396
+
1397
+ @property
1398
+ @_cache_value
1399
+ def _state_keys_step_mdp(self) -> list[NestedKey]:
1400
+ """The state keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
1401
+ state_keys_leaves = sorted(
1402
+ self.full_state_spec.keys(True, True, step_mdp_static_only=True),
1403
+ key=_repr_by_depth,
1404
+ )
1405
+ return state_keys_leaves
1406
+
1407
+ @property
1408
+ @_cache_value
1409
+ def _action_keys_step_mdp(self) -> list[NestedKey]:
1410
+ """The action keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
1411
+ action_keys_leaves = sorted(
1412
+ self.full_action_spec.keys(True, True, step_mdp_static_only=True),
1413
+ key=_repr_by_depth,
1414
+ )
1415
+ return action_keys_leaves
1416
+
1417
+ @property
1418
+ @_cache_value
1419
+ def _done_keys_step_mdp(self) -> list[NestedKey]:
1420
+ """The done keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
1421
+ done_keys_leaves = sorted(
1422
+ self.full_done_spec.keys(True, True, step_mdp_static_only=True),
1423
+ key=_repr_by_depth,
1424
+ )
1425
+ return done_keys_leaves
1426
+
1427
+ @property
1428
+ @_cache_value
1429
+ def _reward_keys_step_mdp(self) -> list[NestedKey]:
1430
+ """The reward keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
1431
+ reward_keys_leaves = sorted(
1432
+ self.full_reward_spec.keys(True, True, step_mdp_static_only=True),
1433
+ key=_repr_by_depth,
1434
+ )
1435
+ return reward_keys_leaves
1436
+
1437
+ @property
1438
+ def reward_key(self):
1439
+ """The reward key of an environment.
1440
+
1441
+ By default, this will be "reward".
1442
+
1443
+ If there is more than one reward key in the environment, this function will raise an exception.
1444
+ """
1445
+ if len(self.reward_keys) > 1:
1446
+ raise KeyError(
1447
+ "reward_key requested but more than one key present in the environment"
1448
+ )
1449
+ return self.reward_keys[0]
1450
+
1451
+ # Reward spec: reward specs belong to output_spec
1452
+ @property
1453
+ def reward_spec(self) -> TensorSpec:
1454
+ """The ``reward`` spec.
1455
+
1456
+ The ``reward_spec`` is always stored as a composite spec.
1457
+
1458
+ If the reward spec is provided as a simple spec, this will be returned.
1459
+
1460
+ >>> env.reward_spec = Unbounded(1)
1461
+ >>> env.reward_spec
1462
+ UnboundedContinuous(
1463
+ shape=torch.Size([1]),
1464
+ space=ContinuousBox(
1465
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
1466
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
1467
+ device=cpu,
1468
+ dtype=torch.float32,
1469
+ domain=continuous)
1470
+
1471
+ If the reward spec is provided as a composite spec and contains only one leaf,
1472
+ this function will return just the leaf.
1473
+
1474
+ >>> env.reward_spec = Composite({"nested": {"reward": Unbounded(1)}})
1475
+ >>> env.reward_spec
1476
+ UnboundedContinuous(
1477
+ shape=torch.Size([1]),
1478
+ space=ContinuousBox(
1479
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
1480
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
1481
+ device=cpu,
1482
+ dtype=torch.float32,
1483
+ domain=continuous)
1484
+
1485
+ If the reward spec is provided as a composite spec and has more than one leaf,
1486
+ this function will return the whole spec.
1487
+
1488
+ >>> env.reward_spec = Composite({"nested": {"reward": Unbounded(1), "another_reward": Categorical(1)}})
1489
+ >>> env.reward_spec
1490
+ Composite(
1491
+ nested: Composite(
1492
+ reward: UnboundedContinuous(
1493
+ shape=torch.Size([1]),
1494
+ space=ContinuousBox(
1495
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
1496
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
1497
+ device=cpu,
1498
+ dtype=torch.float32,
1499
+ domain=continuous),
1500
+ another_reward: Categorical(
1501
+ shape=torch.Size([]),
1502
+ space=DiscreteBox(n=1),
1503
+ device=cpu,
1504
+ dtype=torch.int64,
1505
+ domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
1506
+
1507
+ To retrieve the full spec passed, use:
1508
+
1509
+ >>> env.output_spec["full_reward_spec"]
1510
+
1511
+ This property is mutable.
1512
+
1513
+ Examples:
1514
+ >>> from torchrl.envs.libs.gym import GymEnv
1515
+ >>> env = GymEnv("Pendulum-v1")
1516
+ >>> env.reward_spec
1517
+ UnboundedContinuous(
1518
+ shape=torch.Size([1]),
1519
+ space=None,
1520
+ device=cpu,
1521
+ dtype=torch.float32,
1522
+ domain=continuous)
1523
+ """
1524
+ try:
1525
+ reward_spec = self.output_spec["full_reward_spec"]
1526
+ except (KeyError, AttributeError):
1527
+ # populate the "reward" entry
1528
+ # this will be raised if there is not full_reward_spec (unlikely) or no reward_key
1529
+ # Since output_spec is lazily populated with an empty composite spec for
1530
+ # reward_spec, the second case is much more likely to occur.
1531
+ self.reward_spec = Unbounded(
1532
+ shape=(*self.batch_size, 1),
1533
+ device=self.device,
1534
+ )
1535
+ reward_spec = self.output_spec["full_reward_spec"]
1536
+
1537
+ reward_keys = self.reward_keys
1538
+ if len(reward_keys) > 1 or not len(reward_keys):
1539
+ return reward_spec
1540
+ else:
1541
+ if len(self.reward_keys) == 1 and self.reward_keys[0] != "reward":
1542
+ return reward_spec
1543
+ return reward_spec[self.reward_keys[0]]
1544
+
1545
+ @reward_spec.setter
1546
+ @_maybe_unlock
1547
+ def reward_spec(self, value: TensorSpec) -> None:
1548
+ device = self.output_spec._device
1549
+ if not hasattr(value, "shape"):
1550
+ raise TypeError(
1551
+ f"reward_spec of type {type(value)} do not have a shape " f"attribute."
1552
+ )
1553
+ if value.shape[: len(self.batch_size)] != self.batch_size:
1554
+ raise ValueError(
1555
+ f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). "
1556
+ "Please use `env.reward_spec_unbatched = value` to set unbatched versions instead."
1557
+ )
1558
+ if not isinstance(value, Composite):
1559
+ value = Composite(
1560
+ reward=value.to(device), shape=self.batch_size, device=device
1561
+ )
1562
+ for leaf in value.values(True, True):
1563
+ if len(leaf.shape) == 0:
1564
+ raise RuntimeError(
1565
+ "the reward_spec's leaves shape cannot be empty (this error"
1566
+ " usually comes from trying to set a reward_spec"
1567
+ " with a null number of dimensions. Try using a multidimensional"
1568
+ " spec instead, for instance with a singleton dimension at the tail)."
1569
+ )
1570
+ self.output_spec["full_reward_spec"] = value.to(device)
1571
+
1572
+ @property
1573
+ def full_reward_spec(self) -> Composite:
1574
+ """The full reward spec.
1575
+
1576
+ ``full_reward_spec`` is a :class:`~torchrl.data.Composite`` instance
1577
+ that contains all the reward entries.
1578
+
1579
+ Examples:
1580
+ >>> import gymnasium
1581
+ >>> from torchrl.envs import GymWrapper, TransformedEnv, RenameTransform
1582
+ >>> base_env = GymWrapper(gymnasium.make("Pendulum-v1"))
1583
+ >>> env = TransformedEnv(base_env, RenameTransform("reward", ("nested", "reward")))
1584
+ >>> env.full_reward_spec
1585
+ Composite(
1586
+ nested: Composite(
1587
+ reward: UnboundedContinuous(
1588
+ shape=torch.Size([1]),
1589
+ space=ContinuousBox(
1590
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
1591
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
1592
+ device=cpu,
1593
+ dtype=torch.float32,
1594
+ domain=continuous), device=None, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
1595
+
1596
+ """
1597
+ try:
1598
+ return self.output_spec["full_reward_spec"]
1599
+ except KeyError:
1600
+ # populate the "reward" entry
1601
+ # this will be raised if there is not full_reward_spec (unlikely) or no reward_key
1602
+ # Since output_spec is lazily populated with an empty composite spec for
1603
+ # reward_spec, the second case is much more likely to occur.
1604
+ self.reward_spec = Unbounded(
1605
+ shape=(*self.batch_size, 1),
1606
+ device=self.device,
1607
+ )
1608
+ return self.output_spec["full_reward_spec"]
1609
+
1610
+ @full_reward_spec.setter
1611
+ @_maybe_unlock
1612
+ def full_reward_spec(self, spec: Composite) -> None:
1613
+ self.reward_spec = spec.to(self.device) if self.device is not None else spec
1614
+
1615
+ # done spec
1616
+ @property
1617
+ @_cache_value
1618
+ def done_keys(self) -> list[NestedKey]:
1619
+ """The done keys of an environment.
1620
+
1621
+ By default, there will only be one key named "done".
1622
+
1623
+ Keys are sorted by depth in the data tree.
1624
+ """
1625
+ done_keys = sorted(self.full_done_spec.keys(True, True), key=_repr_by_depth)
1626
+ return done_keys
1627
+
1628
+ @property
1629
+ def done_key(self):
1630
+ """The done key of an environment.
1631
+
1632
+ By default, this will be "done".
1633
+
1634
+ If there is more than one done key in the environment, this function will raise an exception.
1635
+ """
1636
+ done_keys = self.done_keys
1637
+ if len(done_keys) > 1:
1638
+ raise KeyError(
1639
+ "done_key requested but more than one key present in the environment"
1640
+ )
1641
+ return done_keys[0]
1642
+
1643
+ @property
1644
+ def full_done_spec(self) -> Composite:
1645
+ """The full done spec.
1646
+
1647
+ ``full_done_spec`` is a :class:`~torchrl.data.Composite`` instance
1648
+ that contains all the done entries.
1649
+ It can be used to generate fake data with a structure that mimics the
1650
+ one obtained at runtime.
1651
+
1652
+ Examples:
1653
+ >>> import gymnasium
1654
+ >>> from torchrl.envs import GymWrapper
1655
+ >>> env = GymWrapper(gymnasium.make("Pendulum-v1"))
1656
+ >>> env.full_done_spec
1657
+ Composite(
1658
+ done: Categorical(
1659
+ shape=torch.Size([1]),
1660
+ space=DiscreteBox(n=2),
1661
+ device=cpu,
1662
+ dtype=torch.bool,
1663
+ domain=discrete),
1664
+ truncated: Categorical(
1665
+ shape=torch.Size([1]),
1666
+ space=DiscreteBox(n=2),
1667
+ device=cpu,
1668
+ dtype=torch.bool,
1669
+ domain=discrete), device=cpu, shape=torch.Size([]))
1670
+
1671
+ """
1672
+ return self.output_spec["full_done_spec"]
1673
+
1674
+ @full_done_spec.setter
1675
+ @_maybe_unlock
1676
+ def full_done_spec(self, spec: Composite) -> None:
1677
+ self.done_spec = spec.to(self.device) if self.device is not None else spec
1678
+
1679
+ # Done spec: done specs belong to output_spec
1680
+ @property
1681
+ def done_spec(self) -> TensorSpec:
1682
+ """The ``done`` spec.
1683
+
1684
+ The ``done_spec`` is always stored as a composite spec.
1685
+
1686
+ If the done spec is provided as a simple spec, this will be returned.
1687
+
1688
+ >>> env.done_spec = Categorical(2, dtype=torch.bool)
1689
+ >>> env.done_spec
1690
+ Categorical(
1691
+ shape=torch.Size([]),
1692
+ space=DiscreteBox(n=2),
1693
+ device=cpu,
1694
+ dtype=torch.bool,
1695
+ domain=discrete)
1696
+
1697
+ If the done spec is provided as a composite spec and contains only one leaf,
1698
+ this function will return just the leaf.
1699
+
1700
+ >>> env.done_spec = Composite({"nested": {"done": Categorical(2, dtype=torch.bool)}})
1701
+ >>> env.done_spec
1702
+ Categorical(
1703
+ shape=torch.Size([]),
1704
+ space=DiscreteBox(n=2),
1705
+ device=cpu,
1706
+ dtype=torch.bool,
1707
+ domain=discrete)
1708
+
1709
+ If the done spec is provided as a composite spec and has more than one leaf,
1710
+ this function will return the whole spec.
1711
+
1712
+ >>> env.done_spec = Composite({"nested": {"done": Categorical(2, dtype=torch.bool), "another_done": Categorical(2, dtype=torch.bool)}})
1713
+ >>> env.done_spec
1714
+ Composite(
1715
+ nested: Composite(
1716
+ done: Categorical(
1717
+ shape=torch.Size([]),
1718
+ space=DiscreteBox(n=2),
1719
+ device=cpu,
1720
+ dtype=torch.bool,
1721
+ domain=discrete),
1722
+ another_done: Categorical(
1723
+ shape=torch.Size([]),
1724
+ space=DiscreteBox(n=2),
1725
+ device=cpu,
1726
+ dtype=torch.bool,
1727
+ domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
1728
+
1729
+ To always retrieve the full spec passed, use:
1730
+
1731
+ >>> env.output_spec["full_done_spec"]
1732
+
1733
+ This property is mutable.
1734
+
1735
+ Examples:
1736
+ >>> from torchrl.envs.libs.gym import GymEnv
1737
+ >>> env = GymEnv("Pendulum-v1")
1738
+ >>> env.done_spec
1739
+ Categorical(
1740
+ shape=torch.Size([1]),
1741
+ space=DiscreteBox(n=2),
1742
+ device=cpu,
1743
+ dtype=torch.bool,
1744
+ domain=discrete)
1745
+ """
1746
+ done_spec = self.output_spec["full_done_spec"]
1747
+ return done_spec
1748
+
1749
+ @_maybe_unlock
1750
+ def _create_done_specs(self):
1751
+ """Reads through the done specs and makes it so that it's complete.
1752
+
1753
+ If the done_specs contain only a ``"done"`` entry, a similar ``"terminated"`` entry is created.
1754
+ Same goes if only ``"terminated"`` key is present.
1755
+
1756
+ If none of ``"done"`` and ``"terminated"`` can be found and the spec is not
1757
+ empty, nothing is changed.
1758
+
1759
+ """
1760
+ try:
1761
+ full_done_spec = self.output_spec["full_done_spec"]
1762
+ except KeyError:
1763
+ full_done_spec = Composite(
1764
+ shape=self.output_spec.shape, device=self.output_spec.device
1765
+ )
1766
+ full_done_spec["done"] = Categorical(
1767
+ n=2,
1768
+ shape=(*full_done_spec.shape, 1),
1769
+ dtype=torch.bool,
1770
+ device=self.device,
1771
+ )
1772
+ full_done_spec["terminated"] = Categorical(
1773
+ n=2,
1774
+ shape=(*full_done_spec.shape, 1),
1775
+ dtype=torch.bool,
1776
+ device=self.device,
1777
+ )
1778
+ self.output_spec["full_done_spec"] = full_done_spec
1779
+ return
1780
+
1781
+ def check_local_done(spec):
1782
+ shape = None
1783
+ for key, item in list(
1784
+ spec.items()
1785
+ ): # list to avoid error due to in-loop changes
1786
+ # in the case where the spec is non-empty and there is no done and no terminated, we do nothing
1787
+ if key == "done" and "terminated" not in spec.keys():
1788
+ spec["terminated"] = item.clone()
1789
+ elif key == "terminated" and "done" not in spec.keys():
1790
+ spec["done"] = item.clone()
1791
+ elif isinstance(item, Composite):
1792
+ check_local_done(item)
1793
+ else:
1794
+ if shape is None:
1795
+ shape = item.shape
1796
+ continue
1797
+ # checks that all shape match
1798
+ if shape != item.shape:
1799
+ raise ValueError(
1800
+ f"All shapes should match in done_spec {spec} (shape={shape}, key={key})."
1801
+ )
1802
+
1803
+ # if the spec is empty, we need to add a done and terminated manually
1804
+ if spec.is_empty():
1805
+ spec["done"] = Categorical(
1806
+ n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device
1807
+ )
1808
+ spec["terminated"] = Categorical(
1809
+ n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device
1810
+ )
1811
+
1812
+ if_locked = self.is_spec_locked
1813
+ if if_locked:
1814
+ self.is_spec_locked = False
1815
+ check_local_done(full_done_spec)
1816
+ self.output_spec["full_done_spec"] = full_done_spec
1817
+ if if_locked:
1818
+ self.is_spec_locked = True
1819
+ return
1820
+
1821
+ @done_spec.setter
1822
+ @_maybe_unlock
1823
+ def done_spec(self, value: TensorSpec) -> None:
1824
+ device = self.output_spec.device
1825
+ if not hasattr(value, "shape"):
1826
+ raise TypeError(
1827
+ f"done_spec of type {type(value)} do not have a shape " f"attribute."
1828
+ )
1829
+ if value.shape[: len(self.batch_size)] != self.batch_size:
1830
+ raise ValueError(
1831
+ f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
1832
+ )
1833
+ if not isinstance(value, Composite):
1834
+ value = Composite(
1835
+ done=value.to(device),
1836
+ terminated=value.to(device),
1837
+ shape=self.batch_size,
1838
+ device=device,
1839
+ )
1840
+ for leaf in value.values(True, True):
1841
+ if len(leaf.shape) == 0:
1842
+ raise RuntimeError(
1843
+ "the done_spec's leaves shape cannot be empty (this error"
1844
+ " usually comes from trying to set a reward_spec"
1845
+ " with a null number of dimensions. Try using a multidimensional"
1846
+ " spec instead, for instance with a singleton dimension at the tail)."
1847
+ )
1848
+ self.output_spec["full_done_spec"] = value.to(device)
1849
+ self._create_done_specs()
1850
+
1851
+ # observation spec: observation specs belong to output_spec
1852
+ @property
1853
+ def observation_spec(self) -> Composite:
1854
+ """Observation spec.
1855
+
1856
+ Must be a :class:`torchrl.data.Composite` instance.
1857
+ The keys listed in the spec are directly accessible after reset and step.
1858
+
1859
+ In TorchRL, even though they are not properly speaking "observations"
1860
+ all info, states, results of transforms etc. outputs from the environment are stored in the
1861
+ ``observation_spec``.
1862
+
1863
+ Therefore, ``"observation_spec"`` should be thought as
1864
+ a generic data container for environment outputs that are not done or reward data.
1865
+
1866
+ Examples:
1867
+ >>> from torchrl.envs.libs.gym import GymEnv
1868
+ >>> env = GymEnv("Pendulum-v1")
1869
+ >>> env.observation_spec
1870
+ Composite(
1871
+ observation: BoundedContinuous(
1872
+ shape=torch.Size([3]),
1873
+ space=ContinuousBox(
1874
+ low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
1875
+ high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
1876
+ device=cpu,
1877
+ dtype=torch.float32,
1878
+ domain=continuous), device=cpu, shape=torch.Size([]))
1879
+
1880
+ """
1881
+ observation_spec = self.output_spec.get("full_observation_spec", default=None)
1882
+ if observation_spec is None:
1883
+ is_locked = self.is_spec_locked
1884
+ if is_locked:
1885
+ self.set_spec_lock_(False)
1886
+ observation_spec = Composite(shape=self.batch_size, device=self.device)
1887
+ self.output_spec["full_observation_spec"] = observation_spec
1888
+ if is_locked:
1889
+ self.set_spec_lock_(True)
1890
+
1891
+ return observation_spec
1892
+
1893
+ @observation_spec.setter
1894
+ @_maybe_unlock
1895
+ def observation_spec(self, value: TensorSpec) -> None:
1896
+ if not isinstance(value, Composite):
1897
+ value = Composite(
1898
+ observation=value,
1899
+ device=self.device,
1900
+ batch_size=self.output_spec.batch_size,
1901
+ )
1902
+ elif value.shape[: len(self.batch_size)] != self.batch_size:
1903
+ raise ValueError(
1904
+ f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
1905
+ )
1906
+ if value.shape[: len(self.batch_size)] != self.batch_size:
1907
+ raise ValueError(
1908
+ f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
1909
+ )
1910
+ device = self.output_spec._device
1911
+ self.output_spec["full_observation_spec"] = (
1912
+ value.to(device) if device is not None else value
1913
+ )
1914
+
1915
+ @property
1916
+ def full_observation_spec(self) -> Composite:
1917
+ return self.observation_spec
1918
+
1919
+ @full_observation_spec.setter
1920
+ @_maybe_unlock
1921
+ def full_observation_spec(self, spec: Composite):
1922
+ self.observation_spec = spec
1923
+
1924
+ # state spec: state specs belong to input_spec
1925
+ @property
1926
+ def state_spec(self) -> Composite:
1927
+ """State spec.
1928
+
1929
+ Must be a :class:`torchrl.data.Composite` instance.
1930
+ The keys listed here should be provided as input alongside actions to the environment.
1931
+
1932
+ In TorchRL, even though they are not properly speaking "state"
1933
+ all inputs to the environment that are not actions are stored in the
1934
+ ``state_spec``.
1935
+
1936
+ Therefore, ``"state_spec"`` should be thought as
1937
+ a generic data container for environment inputs that are not action data.
1938
+
1939
+ Examples:
1940
+ >>> from torchrl.envs import BraxEnv
1941
+ >>> for envname in BraxEnv.available_envs:
1942
+ ... break
1943
+ >>> env = BraxEnv(envname)
1944
+ >>> env.state_spec
1945
+ Composite(
1946
+ state: Composite(
1947
+ pipeline_state: Composite(
1948
+ q: UnboundedContinuous(
1949
+ shape=torch.Size([15]),
1950
+ space=None,
1951
+ device=cpu,
1952
+ dtype=torch.float32,
1953
+ domain=continuous),
1954
+ [...], device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
1955
+
1956
+
1957
+ """
1958
+ state_spec = self.input_spec["full_state_spec"]
1959
+ if state_spec is None:
1960
+ is_locked = self.is_spec_locked
1961
+ if is_locked:
1962
+ self.set_spec_lock_(False)
1963
+ state_spec = Composite(shape=self.batch_size, device=self.device)
1964
+ self.input_spec["full_state_spec"] = state_spec
1965
+ if is_locked:
1966
+ self.set_spec_lock_(True)
1967
+ return state_spec
1968
+
1969
+ @state_spec.setter
1970
+ @_maybe_unlock
1971
+ def state_spec(self, value: Composite) -> None:
1972
+ if value is None:
1973
+ self.input_spec["full_state_spec"] = Composite(
1974
+ device=self.device, shape=self.batch_size
1975
+ )
1976
+ else:
1977
+ device = self.input_spec.device
1978
+ if not isinstance(value, Composite):
1979
+ raise TypeError("The type of an state_spec must be Composite.")
1980
+ elif value.shape[: len(self.batch_size)] != self.batch_size:
1981
+ raise ValueError(
1982
+ f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
1983
+ )
1984
+ if value.shape[: len(self.batch_size)] != self.batch_size:
1985
+ raise ValueError(
1986
+ f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
1987
+ )
1988
+ self.input_spec["full_state_spec"] = (
1989
+ value.to(device) if device is not None else value
1990
+ )
1991
+
1992
+ @property
1993
+ def full_state_spec(self) -> Composite:
1994
+ """The full state spec.
1995
+
1996
+ ``full_state_spec`` is a :class:`~torchrl.data.Composite`` instance
1997
+ that contains all the state entries (ie, the input data that is not action).
1998
+
1999
+ Examples:
2000
+ >>> from torchrl.envs import BraxEnv
2001
+ >>> for envname in BraxEnv.available_envs:
2002
+ ... break
2003
+ >>> env = BraxEnv(envname)
2004
+ >>> env.full_state_spec
2005
+ Composite(
2006
+ state: Composite(
2007
+ pipeline_state: Composite(
2008
+ q: UnboundedContinuous(
2009
+ shape=torch.Size([15]),
2010
+ space=None,
2011
+ device=cpu,
2012
+ dtype=torch.float32,
2013
+ domain=continuous),
2014
+ [...], device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
2015
+
2016
+ """
2017
+ return self.state_spec
2018
+
2019
+ @full_state_spec.setter
2020
+ @_maybe_unlock
2021
+ def full_state_spec(self, spec: Composite) -> None:
2022
+ self.state_spec = spec
2023
+
2024
+ # Single-env specs can be used to remove the batch size from the spec
2025
+ @property
2026
+ def batch_dims(self) -> int:
2027
+ """Number of batch dimensions of the env."""
2028
+ return len(self.batch_size)
2029
+
2030
+ def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec:
2031
+ if not self.batch_dims:
2032
+ return spec
2033
+ idx = tuple(0 for _ in range(self.batch_dims))
2034
+ return spec[idx]
2035
+
2036
+ @property
2037
+ def full_action_spec_unbatched(self) -> Composite:
2038
+ """Returns the action spec of the env as if it had no batch dimensions."""
2039
+ return self._make_single_env_spec(self.full_action_spec)
2040
+
2041
+ @full_action_spec_unbatched.setter
2042
+ @_maybe_unlock
2043
+ def full_action_spec_unbatched(self, spec: Composite):
2044
+ spec = spec.expand(self.batch_size + spec.shape)
2045
+ self.full_action_spec = spec
2046
+
2047
+ @property
2048
+ def action_spec_unbatched(self) -> TensorSpec:
2049
+ """Returns the action spec of the env as if it had no batch dimensions."""
2050
+ return self._make_single_env_spec(self.action_spec)
2051
+
2052
+ @action_spec_unbatched.setter
2053
+ @_maybe_unlock
2054
+ def action_spec_unbatched(self, spec: Composite):
2055
+ spec = spec.expand(self.batch_size + spec.shape)
2056
+ self.action_spec = spec
2057
+
2058
+ @property
2059
+ def full_observation_spec_unbatched(self) -> Composite:
2060
+ """Returns the observation spec of the env as if it had no batch dimensions."""
2061
+ return self._make_single_env_spec(self.full_observation_spec)
2062
+
2063
+ @full_observation_spec_unbatched.setter
2064
+ @_maybe_unlock
2065
+ def full_observation_spec_unbatched(self, spec: Composite):
2066
+ spec = spec.expand(self.batch_size + spec.shape)
2067
+ self.full_observation_spec = spec
2068
+
2069
+ @property
2070
+ def observation_spec_unbatched(self) -> Composite:
2071
+ """Returns the observation spec of the env as if it had no batch dimensions."""
2072
+ return self._make_single_env_spec(self.observation_spec)
2073
+
2074
+ @observation_spec_unbatched.setter
2075
+ @_maybe_unlock
2076
+ def observation_spec_unbatched(self, spec: Composite):
2077
+ spec = spec.expand(self.batch_size + spec.shape)
2078
+ self.observation_spec = spec
2079
+
2080
+ @property
2081
+ def full_reward_spec_unbatched(self) -> Composite:
2082
+ """Returns the reward spec of the env as if it had no batch dimensions."""
2083
+ return self._make_single_env_spec(self.full_reward_spec)
2084
+
2085
+ @full_reward_spec_unbatched.setter
2086
+ @_maybe_unlock
2087
+ def full_reward_spec_unbatched(self, spec: Composite):
2088
+ spec = spec.expand(self.batch_size + spec.shape)
2089
+ self.full_reward_spec = spec
2090
+
2091
+ @property
2092
+ def reward_spec_unbatched(self) -> TensorSpec:
2093
+ """Returns the reward spec of the env as if it had no batch dimensions."""
2094
+ return self._make_single_env_spec(self.reward_spec)
2095
+
2096
+ @reward_spec_unbatched.setter
2097
+ @_maybe_unlock
2098
+ def reward_spec_unbatched(self, spec: Composite):
2099
+ spec = spec.expand(self.batch_size + spec.shape)
2100
+ self.reward_spec = spec
2101
+
2102
+ @property
2103
+ def full_done_spec_unbatched(self) -> Composite:
2104
+ """Returns the done spec of the env as if it had no batch dimensions."""
2105
+ return self._make_single_env_spec(self.full_done_spec)
2106
+
2107
+ @full_done_spec_unbatched.setter
2108
+ @_maybe_unlock
2109
+ def full_done_spec_unbatched(self, spec: Composite):
2110
+ spec = spec.expand(self.batch_size + spec.shape)
2111
+ self.full_done_spec = spec
2112
+
2113
+ @property
2114
+ def done_spec_unbatched(self) -> TensorSpec:
2115
+ """Returns the done spec of the env as if it had no batch dimensions."""
2116
+ return self._make_single_env_spec(self.done_spec)
2117
+
2118
+ @done_spec_unbatched.setter
2119
+ @_maybe_unlock
2120
+ def done_spec_unbatched(self, spec: Composite):
2121
+ spec = spec.expand(self.batch_size + spec.shape)
2122
+ self.done_spec = spec
2123
+
2124
+ @property
2125
+ def output_spec_unbatched(self) -> Composite:
2126
+ """Returns the output spec of the env as if it had no batch dimensions."""
2127
+ return self._make_single_env_spec(self.output_spec)
2128
+
2129
+ @output_spec_unbatched.setter
2130
+ @_maybe_unlock
2131
+ def output_spec_unbatched(self, spec: Composite):
2132
+ spec = spec.expand(self.batch_size + spec.shape)
2133
+ self.output_spec = spec
2134
+
2135
+ @property
2136
+ def input_spec_unbatched(self) -> Composite:
2137
+ """Returns the input spec of the env as if it had no batch dimensions."""
2138
+ return self._make_single_env_spec(self.input_spec)
2139
+
2140
+ @input_spec_unbatched.setter
2141
+ @_maybe_unlock
2142
+ def input_spec_unbatched(self, spec: Composite):
2143
+ spec = spec.expand(self.batch_size + spec.shape)
2144
+ self.input_spec = spec
2145
+
2146
+ @property
2147
+ def full_state_spec_unbatched(self) -> Composite:
2148
+ """Returns the state spec of the env as if it had no batch dimensions."""
2149
+ return self._make_single_env_spec(self.full_state_spec)
2150
+
2151
+ @full_state_spec_unbatched.setter
2152
+ @_maybe_unlock
2153
+ def full_state_spec_unbatched(self, spec: Composite):
2154
+ spec = spec.expand(self.batch_size + spec.shape)
2155
+ self.full_state_spec = spec
2156
+
2157
+ @property
2158
+ def state_spec_unbatched(self) -> TensorSpec:
2159
+ """Returns the state spec of the env as if it had no batch dimensions."""
2160
+ return self._make_single_env_spec(self.state_spec)
2161
+
2162
+ @state_spec_unbatched.setter
2163
+ @_maybe_unlock
2164
+ def state_spec_unbatched(self, spec: Composite):
2165
+ spec = spec.expand(self.batch_size + spec.shape)
2166
+ self.state_spec = spec
2167
+
2168
+ def _skip_tensordict(self, tensordict: TensorDictBase) -> TensorDictBase:
2169
+ # Creates a "skip" tensordict, ie a placeholder for when a step is skipped
2170
+ next_tensordict = self.full_done_spec.zero()
2171
+ next_tensordict.update(self.full_observation_spec.zero())
2172
+ next_tensordict.update(self.full_reward_spec.zero())
2173
+
2174
+ # Copy the data from tensordict in `next`
2175
+ keys = set()
2176
+
2177
+ def select_and_clone(name, x, y):
2178
+ keys.add(name)
2179
+ if y is not None:
2180
+ if y.device == x.device:
2181
+ return x.clone()
2182
+ return x.to(y.device)
2183
+
2184
+ result = tensordict._fast_apply(
2185
+ select_and_clone,
2186
+ next_tensordict,
2187
+ device=self.device,
2188
+ default=None,
2189
+ filter_empty=True,
2190
+ is_leaf=_is_leaf_nontensor,
2191
+ named=True,
2192
+ nested_keys=True,
2193
+ )
2194
+ result.update(next_tensordict.exclude(*keys).filter_empty_())
2195
+ return result
2196
+
2197
+ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
2198
+ """Makes a step in the environment.
2199
+
2200
+ Step accepts a single argument, tensordict, which usually carries an 'action' key which indicates the action
2201
+ to be taken.
2202
+ Step will call an out-place private method, _step, which is the method to be re-written by EnvBase subclasses.
2203
+
2204
+ Args:
2205
+ tensordict (TensorDictBase): Tensordict containing the action to be taken.
2206
+ If the input tensordict contains a ``"next"`` entry, the values contained in it
2207
+ will prevail over the newly computed values. This gives a mechanism
2208
+ to override the underlying computations.
2209
+
2210
+ Returns:
2211
+ the input tensordict, modified in place with the resulting observations, done state and reward
2212
+ (+ others if needed).
2213
+
2214
+ """
2215
+ # sanity check
2216
+ self._assert_tensordict_shape(tensordict)
2217
+ partial_steps = tensordict.pop("_step", None)
2218
+
2219
+ next_tensordict = None
2220
+
2221
+ if partial_steps is not None:
2222
+ tensordict_batch_size = None
2223
+ if not self.batch_locked:
2224
+ # Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here
2225
+ if partial_steps.all():
2226
+ partial_steps = None
2227
+ else:
2228
+ tensordict_batch_size = tensordict.batch_size
2229
+ partial_steps = partial_steps.view(tensordict_batch_size)
2230
+ tensordict = tensordict[partial_steps]
2231
+ else:
2232
+ if not partial_steps.any():
2233
+ next_tensordict = self._skip_tensordic(tensordict)
2234
+ else:
2235
+ # trust that the _step can handle this!
2236
+ tensordict.set("_step", partial_steps)
2237
+ if tensordict_batch_size is None:
2238
+ tensordict_batch_size = self.batch_size
2239
+
2240
+ next_preset = tensordict.get("next", None)
2241
+
2242
+ if next_tensordict is None:
2243
+ next_tensordict = self._step(tensordict)
2244
+ next_tensordict = self._step_proc_data(next_tensordict)
2245
+ if next_preset is not None:
2246
+ # tensordict could already have a "next" key
2247
+ # this could be done more efficiently by not excluding but just passing
2248
+ # the necessary keys
2249
+ next_tensordict.update(
2250
+ next_preset.exclude(*next_tensordict.keys(True, True))
2251
+ )
2252
+ tensordict.set("next", next_tensordict)
2253
+ if partial_steps is not None and tensordict_batch_size != self.batch_size:
2254
+ result = tensordict.new_zeros(tensordict_batch_size)
2255
+
2256
+ if tensordict_batch_size == tensordict.batch_size:
2257
+
2258
+ def select_and_clone(x, y):
2259
+ if y is not None:
2260
+ if x.device == y.device:
2261
+ return x.clone()
2262
+ return x.to(y.device)
2263
+
2264
+ result.update(
2265
+ tensordict._fast_apply(
2266
+ select_and_clone,
2267
+ result,
2268
+ device=result.device,
2269
+ filter_empty=True,
2270
+ default=None,
2271
+ batch_size=result.batch_size,
2272
+ is_leaf=_is_leaf_nontensor,
2273
+ )
2274
+ )
2275
+ if partial_steps.any():
2276
+ result[partial_steps] = tensordict
2277
+ return result
2278
+ return tensordict
2279
+
2280
+ @classmethod
2281
+ def _complete_done(
2282
+ cls, done_spec: Composite, data: TensorDictBase
2283
+ ) -> TensorDictBase:
2284
+ """Completes the data structure at step time to put missing done keys."""
2285
+ # by default, if a done key is missing, it is assumed that it is False
2286
+ # except in 2 cases: (1) there is a "done" but no "terminated" or (2)
2287
+ # there is a "terminated" but no "done".
2288
+ if done_spec.ndim:
2289
+ leading_dim = data.shape[: -done_spec.ndim]
2290
+ else:
2291
+ leading_dim = data.shape
2292
+ vals = {}
2293
+ i = -1
2294
+ for i, (key, item) in enumerate(done_spec.items()): # noqa: B007
2295
+ val = data.get(key, None)
2296
+ if isinstance(item, Composite):
2297
+ if val is not None:
2298
+ cls._complete_done(item, val)
2299
+ continue
2300
+ shape = (*leading_dim, *item.shape)
2301
+ if val is not None:
2302
+ if val.shape != shape:
2303
+ val = val.reshape(shape)
2304
+ data.set(key, val)
2305
+ vals[key] = val
2306
+
2307
+ if len(vals) < i + 1:
2308
+ # complete missing dones: we only want to do that if we don't have enough done values
2309
+ data_keys = set(data.keys())
2310
+ done_spec_keys = set(done_spec.keys())
2311
+ for key, item in done_spec.items(False, True):
2312
+ val = vals.get(key, None)
2313
+ if (
2314
+ key == "done"
2315
+ and val is not None
2316
+ and "terminated" in done_spec_keys
2317
+ and "terminated" not in data_keys
2318
+ ):
2319
+ if "truncated" in data_keys:
2320
+ raise RuntimeError(
2321
+ "Cannot infer the value of terminated when only done and truncated are present."
2322
+ )
2323
+ data.set("terminated", val)
2324
+ data_keys.add("terminated")
2325
+ elif (
2326
+ key == "terminated"
2327
+ and val is not None
2328
+ and "done" in done_spec_keys
2329
+ and "done" not in data_keys
2330
+ ):
2331
+ if "truncated" in data_keys:
2332
+ val = val | data.get("truncated")
2333
+ data.set("done", val)
2334
+ data_keys.add("done")
2335
+ elif val is None and key not in data_keys:
2336
+ # we must keep this here: we only want to fill with 0s if we're sure
2337
+ # done should not be copied to terminated or terminated to done
2338
+ # in this case, just fill with 0s
2339
+ data.set(key, item.zero(leading_dim))
2340
+ return data
2341
+
2342
+ def _step_proc_data(self, next_tensordict_out):
2343
+ batch_size = self.batch_size
2344
+ dims = len(batch_size)
2345
+ leading_batch_size = (
2346
+ next_tensordict_out.batch_size[:-dims]
2347
+ if dims
2348
+ else next_tensordict_out.shape
2349
+ )
2350
+ for reward_key in self.reward_keys:
2351
+ expected_reward_shape = torch.Size(
2352
+ [
2353
+ *leading_batch_size,
2354
+ *self.output_spec["full_reward_spec"][reward_key].shape,
2355
+ ]
2356
+ )
2357
+ # If the reward has a variable shape, we don't want to perform this check
2358
+ if all(s > 0 for s in expected_reward_shape):
2359
+ reward = next_tensordict_out.get(reward_key)
2360
+ actual_reward_shape = reward.shape
2361
+ if actual_reward_shape != expected_reward_shape:
2362
+ reward = reward.view(expected_reward_shape)
2363
+ next_tensordict_out.set(reward_key, reward)
2364
+
2365
+ self._complete_done(self.full_done_spec, next_tensordict_out)
2366
+
2367
+ if self.run_type_checks:
2368
+ for key, spec in self.observation_spec.items():
2369
+ obs = next_tensordict_out.get(key)
2370
+ spec.type_check(obs)
2371
+
2372
+ for reward_key in self.reward_keys:
2373
+ if (
2374
+ next_tensordict_out.get(reward_key).dtype
2375
+ is not self.output_spec[
2376
+ unravel_key(("full_reward_spec", reward_key))
2377
+ ].dtype
2378
+ ):
2379
+ raise TypeError(
2380
+ f"expected reward.dtype to be {self.output_spec[unravel_key(('full_reward_spec',reward_key))]} "
2381
+ f"but got {next_tensordict_out.get(reward_key).dtype}"
2382
+ )
2383
+
2384
+ for done_key in self.done_keys:
2385
+ if (
2386
+ next_tensordict_out.get(done_key).dtype
2387
+ is not self.output_spec["full_done_spec", done_key].dtype
2388
+ ):
2389
+ raise TypeError(
2390
+ f"expected done.dtype to be {self.output_spec['full_done_spec', done_key].dtype} but got {next_tensordict_out.get(done_key).dtype}"
2391
+ )
2392
+ return next_tensordict_out
2393
+
2394
+ def _get_in_keys_to_exclude(self, tensordict):
2395
+ if self._cache_in_keys is None:
2396
+ self._cache_in_keys = list(
2397
+ set(self.input_spec.keys(True)).intersection(
2398
+ tensordict.keys(True, True)
2399
+ )
2400
+ )
2401
+ return self._cache_in_keys
2402
+
2403
+ @classmethod
2404
+ def register_gym(
2405
+ cls,
2406
+ id: str,
2407
+ *,
2408
+ entry_point: Callable | None = None,
2409
+ transform: Transform | None = None, # noqa: F821
2410
+ info_keys: list[NestedKey] | None = None,
2411
+ backend: str | None = None,
2412
+ to_numpy: bool = False,
2413
+ reward_threshold: float | None = None,
2414
+ nondeterministic: bool = False,
2415
+ max_episode_steps: int | None = None,
2416
+ order_enforce: bool = True,
2417
+ autoreset: bool | None = None,
2418
+ disable_env_checker: bool = False,
2419
+ apply_api_compatibility: bool = False,
2420
+ **kwargs,
2421
+ ):
2422
+ """Registers an environment in gym(nasium).
2423
+
2424
+ This method is designed with the following scopes in mind:
2425
+
2426
+ - Incorporate a TorchRL-first environment in a framework that uses Gym;
2427
+ - Incorporate another environment (eg, DeepMind Control, Brax, Jumanji, ...)
2428
+ in a framework that uses Gym.
2429
+
2430
+ Args:
2431
+ id (str): the name of the environment. Should follow the
2432
+ `gym naming convention <https://www.gymlibrary.dev/content/environment_creation/#registering-envs>`_.
2433
+
2434
+ Keyword Args:
2435
+ entry_point (callable, optional): the entry point to build the environment.
2436
+ If none is passed, the parent class will be used as entry point.
2437
+ Typically, this is used to register an environment that does not
2438
+ necessarily inherit from the base being used:
2439
+
2440
+ >>> from torchrl.envs import DMControlEnv
2441
+ >>> DMControlEnv.register_gym("DMC-cheetah-v0", env_name="cheetah", task="run")
2442
+ >>> # equivalently
2443
+ >>> EnvBase.register_gym("DMC-cheetah-v0", entry_point=DMControlEnv, env_name="cheetah", task="run")
2444
+
2445
+ transform (torchrl.envs.Transform): a transform (or list of transforms
2446
+ within a :class:`torchrl.envs.Compose` instance) to be used with the env.
2447
+ This arg can be passed during a call to :func:`~gym.make` (see
2448
+ example below).
2449
+ info_keys (List[NestedKey], optional): if provided, these keys will
2450
+ be used to build the info dictionary and will be excluded from
2451
+ the observation keys.
2452
+ This arg can be passed during a call to :func:`~gym.make` (see
2453
+ example below).
2454
+
2455
+ .. warning::
2456
+ It may be the case that using ``info_keys`` makes a spec empty
2457
+ because the content has been moved to the info dictionary.
2458
+ Gym does not like empty ``Dict`` in the specs, so this empty
2459
+ content should be removed with :class:`~torchrl.envs.transforms.RemoveEmptySpecs`.
2460
+
2461
+ backend (str, optional): the backend. Can be either `"gym"` or `"gymnasium"`
2462
+ or any other backend compatible with :class:`~torchrl.envs.libs.gym.set_gym_backend`.
2463
+ to_numpy (bool, optional): if ``True``, the result of calls to `step` and
2464
+ `reset` will be mapped to numpy arrays. Defaults to ``False``
2465
+ (results are tensors).
2466
+ This arg can be passed during a call to :func:`~gym.make` (see
2467
+ example below).
2468
+ reward_threshold (:obj:`float`, optional): [Gym kwarg] The reward threshold
2469
+ considered to have learnt an environment.
2470
+ nondeterministic (bool, optional): [Gym kwarg If the environment is nondeterministic
2471
+ (even with knowledge of the initial seed and all actions). Defaults to
2472
+ ``False``.
2473
+ max_episode_steps (int, optional): [Gym kwarg] The maximum number
2474
+ of episodes steps before truncation. Used by the Time Limit wrapper.
2475
+ order_enforce (bool, optional): [Gym >= 0.14] Whether the order
2476
+ enforcer wrapper should be applied to ensure users run functions
2477
+ in the correct order.
2478
+ Defaults to ``True``.
2479
+ autoreset (bool, optional): [Gym >= 0.14 and <1.0.0] Whether the autoreset wrapper
2480
+ should be added such that reset does not need to be called.
2481
+ Defaults to ``False``.
2482
+ disable_env_checker: [Gym >= 0.14] Whether the environment
2483
+ checker should be disabled for the environment. Defaults to ``False``.
2484
+ apply_api_compatibility: [Gym >= 0.26 and <1.0.0] If to apply the `StepAPICompatibility` wrapper.
2485
+ Defaults to ``False``.
2486
+ **kwargs: arbitrary keyword arguments which are passed to the environment constructor.
2487
+
2488
+ .. note::
2489
+ TorchRL's environment do not have the concept of an ``"info"`` dictionary,
2490
+ as ``TensorDict`` offers all the storage requirements deemed necessary
2491
+ in most training settings. Still, you can use the ``info_keys`` argument to
2492
+ have a fine grained control over what is deemed to be considered
2493
+ as an observation and what should be seen as info.
2494
+
2495
+ Examples:
2496
+ >>> # Register the "cheetah" env from DMControl with the "run" task
2497
+ >>> from torchrl.envs import DMControlEnv
2498
+ >>> import torch
2499
+ >>> DMControlEnv.register_gym("DMC-cheetah-v0", to_numpy=False, backend="gym", env_name="cheetah", task_name="run")
2500
+ >>> import gym
2501
+ >>> envgym = gym.make("DMC-cheetah-v0")
2502
+ >>> envgym.seed(0)
2503
+ >>> torch.manual_seed(0)
2504
+ >>> envgym.reset()
2505
+ ({'position': tensor([-0.0855, 0.0215, -0.0881, -0.0412, -0.1101, 0.0080, 0.0254, 0.0424],
2506
+ dtype=torch.float64), 'velocity': tensor([ 1.9609e-02, -1.9776e-04, -1.6347e-03, 3.3842e-02, 2.5338e-02,
2507
+ 3.3064e-02, 1.0381e-04, 7.6656e-05, 1.0204e-02],
2508
+ dtype=torch.float64)}, {})
2509
+ >>> envgym.step(envgym.action_space.sample())
2510
+ ({'position': tensor([-0.0833, 0.0275, -0.0612, -0.0770, -0.1256, 0.0082, 0.0186, 0.0476],
2511
+ dtype=torch.float64), 'velocity': tensor([ 0.2221, 0.2256, 0.5930, 2.6937, -3.5865, -1.5479, 0.0187, -0.6825,
2512
+ 0.5224], dtype=torch.float64)}, tensor([0.0018], dtype=torch.float64), tensor([False]), tensor([False]), {})
2513
+ >>> # same environment with observation stacked
2514
+ >>> from torchrl.envs import CatTensors
2515
+ >>> envgym = gym.make("DMC-cheetah-v0", transform=CatTensors(in_keys=["position", "velocity"], out_key="observation"))
2516
+ >>> envgym.reset()
2517
+ ({'observation': tensor([-0.1005, 0.0335, -0.0268, 0.0133, -0.0627, 0.0074, -0.0488, -0.0353,
2518
+ -0.0075, -0.0069, 0.0098, -0.0058, 0.0033, -0.0157, -0.0004, -0.0381,
2519
+ -0.0452], dtype=torch.float64)}, {})
2520
+ >>> # same environment with numpy observations
2521
+ >>> envgym = gym.make("DMC-cheetah-v0", transform=CatTensors(in_keys=["position", "velocity"], out_key="observation"), to_numpy=True)
2522
+ >>> envgym.reset()
2523
+ ({'observation': array([-0.11355747, 0.04257728, 0.00408397, 0.04155852, -0.0389733 ,
2524
+ -0.01409826, -0.0978704 , -0.08808327, 0.03970837, 0.00535434,
2525
+ -0.02353762, 0.05116226, 0.02788907, 0.06848346, 0.05154399,
2526
+ 0.0371798 , 0.05128025])}, {})
2527
+ >>> # If gymnasium is installed, we can register the environment there too.
2528
+ >>> DMControlEnv.register_gym("DMC-cheetah-v0", to_numpy=False, backend="gymnasium", env_name="cheetah", task_name="run")
2529
+ >>> import gymnasium
2530
+ >>> envgym = gymnasium.make("DMC-cheetah-v0")
2531
+ >>> envgym.seed(0)
2532
+ >>> torch.manual_seed(0)
2533
+ >>> envgym.reset()
2534
+ ({'position': tensor([-0.0855, 0.0215, -0.0881, -0.0412, -0.1101, 0.0080, 0.0254, 0.0424],
2535
+ dtype=torch.float64), 'velocity': tensor([ 1.9609e-02, -1.9776e-04, -1.6347e-03, 3.3842e-02, 2.5338e-02,
2536
+ 3.3064e-02, 1.0381e-04, 7.6656e-05, 1.0204e-02],
2537
+ dtype=torch.float64)}, {})
2538
+
2539
+ .. note::
2540
+ This feature also works for stateless environments (eg, :class:`~torchrl.envs.BraxEnv`).
2541
+
2542
+ >>> import gymnasium
2543
+ >>> import torch
2544
+ >>> from tensordict import TensorDict
2545
+ >>> from torchrl.envs import BraxEnv, SelectTransform
2546
+ >>>
2547
+ >>> # get action for dydactic purposes
2548
+ >>> env = BraxEnv("ant", batch_size=[2])
2549
+ >>> env.set_seed(0)
2550
+ >>> torch.manual_seed(0)
2551
+ >>> td = env.rollout(10)
2552
+ >>>
2553
+ >>> actions = td.get("action")
2554
+ >>>
2555
+ >>> # register env
2556
+ >>> env.register_gym("Brax-Ant-v0", env_name="ant", batch_size=[2], info_keys=["state"])
2557
+ >>> gym_env = gymnasium.make("Brax-Ant-v0")
2558
+ >>> gym_env.seed(0)
2559
+ >>> torch.manual_seed(0)
2560
+ >>>
2561
+ >>> gym_env.reset()
2562
+ >>> obs = []
2563
+ >>> for i in range(10):
2564
+ ... obs, reward, terminated, truncated, info = gym_env.step(td[..., i].get("action"))
2565
+
2566
+
2567
+ """
2568
+ from torchrl.envs.libs.gym import gym_backend, set_gym_backend
2569
+
2570
+ if backend is None:
2571
+ backend = gym_backend()
2572
+
2573
+ with set_gym_backend(backend):
2574
+ return cls._register_gym(
2575
+ id=id,
2576
+ entry_point=entry_point,
2577
+ transform=transform,
2578
+ info_keys=info_keys,
2579
+ to_numpy=to_numpy,
2580
+ reward_threshold=reward_threshold,
2581
+ nondeterministic=nondeterministic,
2582
+ max_episode_steps=max_episode_steps,
2583
+ order_enforce=order_enforce,
2584
+ autoreset=autoreset,
2585
+ disable_env_checker=disable_env_checker,
2586
+ apply_api_compatibility=apply_api_compatibility,
2587
+ **kwargs,
2588
+ )
2589
+
2590
+ _GYM_UNRECOGNIZED_KWARG = (
2591
+ "The keyword argument {} is not compatible with gym version {}"
2592
+ )
2593
+
2594
+ @implement_for("gym", "0.26", None, class_method=True)
2595
+ def _register_gym(
2596
+ cls,
2597
+ id,
2598
+ entry_point: Callable | None = None,
2599
+ transform: Transform | None = None, # noqa: F821
2600
+ info_keys: list[NestedKey] | None = None,
2601
+ to_numpy: bool = False,
2602
+ reward_threshold: float | None = None,
2603
+ nondeterministic: bool = False,
2604
+ max_episode_steps: int | None = None,
2605
+ order_enforce: bool = True,
2606
+ autoreset: bool | None = None,
2607
+ disable_env_checker: bool = False,
2608
+ apply_api_compatibility: bool = False,
2609
+ **kwargs,
2610
+ ):
2611
+ import gym
2612
+ from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
2613
+
2614
+ if entry_point is None:
2615
+ entry_point = cls
2616
+ entry_point = partial(
2617
+ _TorchRLGymWrapper,
2618
+ entry_point=entry_point,
2619
+ info_keys=info_keys,
2620
+ to_numpy=to_numpy,
2621
+ transform=transform,
2622
+ **kwargs,
2623
+ )
2624
+ return gym.register(
2625
+ id=id,
2626
+ entry_point=entry_point,
2627
+ reward_threshold=reward_threshold,
2628
+ nondeterministic=nondeterministic,
2629
+ max_episode_steps=max_episode_steps,
2630
+ order_enforce=order_enforce,
2631
+ autoreset=bool(autoreset),
2632
+ disable_env_checker=disable_env_checker,
2633
+ apply_api_compatibility=apply_api_compatibility,
2634
+ )
2635
+
2636
+ @implement_for("gym", "0.25", "0.26", class_method=True)
2637
+ def _register_gym( # noqa: F811
2638
+ cls,
2639
+ id,
2640
+ entry_point: Callable | None = None,
2641
+ transform: Transform | None = None, # noqa: F821
2642
+ info_keys: list[NestedKey] | None = None,
2643
+ to_numpy: bool = False,
2644
+ reward_threshold: float | None = None,
2645
+ nondeterministic: bool = False,
2646
+ max_episode_steps: int | None = None,
2647
+ order_enforce: bool = True,
2648
+ autoreset: bool | None = None,
2649
+ disable_env_checker: bool = False,
2650
+ apply_api_compatibility: bool = False,
2651
+ **kwargs,
2652
+ ):
2653
+ import gym
2654
+
2655
+ if apply_api_compatibility is not False:
2656
+ raise TypeError(
2657
+ cls._GYM_UNRECOGNIZED_KWARG.format(
2658
+ "apply_api_compatibility", gym.__version__
2659
+ )
2660
+ )
2661
+ from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
2662
+
2663
+ if entry_point is None:
2664
+ entry_point = cls
2665
+ entry_point = partial(
2666
+ _TorchRLGymWrapper,
2667
+ entry_point=entry_point,
2668
+ info_keys=info_keys,
2669
+ to_numpy=to_numpy,
2670
+ transform=transform,
2671
+ **kwargs,
2672
+ )
2673
+ return gym.register(
2674
+ id=id,
2675
+ entry_point=entry_point,
2676
+ reward_threshold=reward_threshold,
2677
+ nondeterministic=nondeterministic,
2678
+ max_episode_steps=max_episode_steps,
2679
+ order_enforce=order_enforce,
2680
+ autoreset=bool(autoreset),
2681
+ disable_env_checker=disable_env_checker,
2682
+ )
2683
+
2684
+ @implement_for("gym", "0.24", "0.25", class_method=True)
2685
+ def _register_gym( # noqa: F811
2686
+ cls,
2687
+ id,
2688
+ entry_point: Callable | None = None,
2689
+ transform: Transform | None = None, # noqa: F821
2690
+ info_keys: list[NestedKey] | None = None,
2691
+ to_numpy: bool = False,
2692
+ reward_threshold: float | None = None,
2693
+ nondeterministic: bool = False,
2694
+ max_episode_steps: int | None = None,
2695
+ order_enforce: bool = True,
2696
+ autoreset: bool | None = None,
2697
+ disable_env_checker: bool = False,
2698
+ apply_api_compatibility: bool = False,
2699
+ **kwargs,
2700
+ ):
2701
+ import gym
2702
+
2703
+ if apply_api_compatibility is not False:
2704
+ raise TypeError(
2705
+ cls._GYM_UNRECOGNIZED_KWARG.format(
2706
+ "apply_api_compatibility", gym.__version__
2707
+ )
2708
+ )
2709
+ if disable_env_checker is not False:
2710
+ raise TypeError(
2711
+ cls._GYM_UNRECOGNIZED_KWARG.format(
2712
+ "disable_env_checker", gym.__version__
2713
+ )
2714
+ )
2715
+ from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
2716
+
2717
+ if entry_point is None:
2718
+ entry_point = cls
2719
+ entry_point = partial(
2720
+ _TorchRLGymWrapper,
2721
+ entry_point=entry_point,
2722
+ info_keys=info_keys,
2723
+ to_numpy=to_numpy,
2724
+ transform=transform,
2725
+ **kwargs,
2726
+ )
2727
+ return gym.register(
2728
+ id=id,
2729
+ entry_point=entry_point,
2730
+ reward_threshold=reward_threshold,
2731
+ nondeterministic=nondeterministic,
2732
+ max_episode_steps=max_episode_steps,
2733
+ order_enforce=order_enforce,
2734
+ autoreset=bool(autoreset),
2735
+ )
2736
+
2737
+ @implement_for("gym", "0.21", "0.24", class_method=True)
2738
+ def _register_gym( # noqa: F811
2739
+ cls,
2740
+ id,
2741
+ entry_point: Callable | None = None,
2742
+ transform: Transform | None = None, # noqa: F821
2743
+ info_keys: list[NestedKey] | None = None,
2744
+ to_numpy: bool = False,
2745
+ reward_threshold: float | None = None,
2746
+ nondeterministic: bool = False,
2747
+ max_episode_steps: int | None = None,
2748
+ order_enforce: bool = True,
2749
+ autoreset: bool | None = None,
2750
+ disable_env_checker: bool = False,
2751
+ apply_api_compatibility: bool = False,
2752
+ **kwargs,
2753
+ ):
2754
+ import gym
2755
+
2756
+ if apply_api_compatibility is not False:
2757
+ raise TypeError(
2758
+ cls._GYM_UNRECOGNIZED_KWARG.format(
2759
+ "apply_api_compatibility", gym.__version__
2760
+ )
2761
+ )
2762
+ if disable_env_checker is not False:
2763
+ raise TypeError(
2764
+ cls._GYM_UNRECOGNIZED_KWARG.format(
2765
+ "disable_env_checker", gym.__version__
2766
+ )
2767
+ )
2768
+ if autoreset is not None:
2769
+ raise TypeError(
2770
+ cls._GYM_UNRECOGNIZED_KWARG.format("autoreset", gym.__version__)
2771
+ )
2772
+ from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
2773
+
2774
+ if entry_point is None:
2775
+ entry_point = cls
2776
+ entry_point = partial(
2777
+ _TorchRLGymWrapper,
2778
+ entry_point=entry_point,
2779
+ info_keys=info_keys,
2780
+ to_numpy=to_numpy,
2781
+ transform=transform,
2782
+ **kwargs,
2783
+ )
2784
+ return gym.register(
2785
+ id=id,
2786
+ entry_point=entry_point,
2787
+ reward_threshold=reward_threshold,
2788
+ nondeterministic=nondeterministic,
2789
+ max_episode_steps=max_episode_steps,
2790
+ order_enforce=order_enforce,
2791
+ autoreset=bool(autoreset),
2792
+ )
2793
+
2794
+ @implement_for("gym", None, "0.21", class_method=True)
2795
+ def _register_gym( # noqa: F811
2796
+ cls,
2797
+ id,
2798
+ entry_point: Callable | None = None,
2799
+ transform: Transform | None = None, # noqa: F821
2800
+ info_keys: list[NestedKey] | None = None,
2801
+ to_numpy: bool = False,
2802
+ reward_threshold: float | None = None,
2803
+ nondeterministic: bool = False,
2804
+ max_episode_steps: int | None = None,
2805
+ order_enforce: bool = True,
2806
+ autoreset: bool | None = None,
2807
+ disable_env_checker: bool = False,
2808
+ apply_api_compatibility: bool = False,
2809
+ **kwargs,
2810
+ ):
2811
+ import gym
2812
+ from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
2813
+
2814
+ if order_enforce is not True:
2815
+ raise TypeError(
2816
+ cls._GYM_UNRECOGNIZED_KWARG.format("order_enforce", gym.__version__)
2817
+ )
2818
+ if disable_env_checker is not False:
2819
+ raise TypeError(
2820
+ cls._GYM_UNRECOGNIZED_KWARG.format(
2821
+ "disable_env_checker", gym.__version__
2822
+ )
2823
+ )
2824
+ if autoreset is not None:
2825
+ raise TypeError(
2826
+ cls._GYM_UNRECOGNIZED_KWARG.format("autoreset", gym.__version__)
2827
+ )
2828
+ if apply_api_compatibility is not False:
2829
+ raise TypeError(
2830
+ cls._GYM_UNRECOGNIZED_KWARG.format(
2831
+ "apply_api_compatibility", gym.__version__
2832
+ )
2833
+ )
2834
+ if entry_point is None:
2835
+ entry_point = cls
2836
+ entry_point = partial(
2837
+ _TorchRLGymWrapper,
2838
+ entry_point=entry_point,
2839
+ info_keys=info_keys,
2840
+ to_numpy=to_numpy,
2841
+ transform=transform,
2842
+ **kwargs,
2843
+ )
2844
+ return gym.register(
2845
+ id=id,
2846
+ entry_point=entry_point,
2847
+ reward_threshold=reward_threshold,
2848
+ nondeterministic=nondeterministic,
2849
+ max_episode_steps=max_episode_steps,
2850
+ )
2851
+
2852
+ @implement_for("gymnasium", None, "1.0.0", class_method=True)
2853
+ def _register_gym( # noqa: F811
2854
+ cls,
2855
+ id,
2856
+ entry_point: Callable | None = None,
2857
+ transform: Transform | None = None, # noqa: F821
2858
+ info_keys: list[NestedKey] | None = None,
2859
+ to_numpy: bool = False,
2860
+ reward_threshold: float | None = None,
2861
+ nondeterministic: bool = False,
2862
+ max_episode_steps: int | None = None,
2863
+ order_enforce: bool = True,
2864
+ autoreset: bool | None = None,
2865
+ disable_env_checker: bool = False,
2866
+ apply_api_compatibility: bool = False,
2867
+ **kwargs,
2868
+ ):
2869
+ import gymnasium
2870
+ from torchrl.envs.libs._gym_utils import _TorchRLGymnasiumWrapper
2871
+
2872
+ if entry_point is None:
2873
+ entry_point = cls
2874
+
2875
+ entry_point = partial(
2876
+ _TorchRLGymnasiumWrapper,
2877
+ entry_point=entry_point,
2878
+ info_keys=info_keys,
2879
+ to_numpy=to_numpy,
2880
+ transform=transform,
2881
+ **kwargs,
2882
+ )
2883
+ return gymnasium.register(
2884
+ id=id,
2885
+ entry_point=entry_point,
2886
+ reward_threshold=reward_threshold,
2887
+ nondeterministic=nondeterministic,
2888
+ max_episode_steps=max_episode_steps,
2889
+ order_enforce=order_enforce,
2890
+ autoreset=bool(autoreset),
2891
+ disable_env_checker=disable_env_checker,
2892
+ apply_api_compatibility=apply_api_compatibility,
2893
+ )
2894
+
2895
+ @implement_for("gymnasium", "1.1.0", class_method=True)
2896
+ def _register_gym( # noqa: F811
2897
+ cls,
2898
+ id,
2899
+ entry_point: Callable | None = None,
2900
+ transform: Transform | None = None, # noqa: F821
2901
+ info_keys: list[NestedKey] | None = None,
2902
+ to_numpy: bool = False,
2903
+ reward_threshold: float | None = None,
2904
+ nondeterministic: bool = False,
2905
+ max_episode_steps: int | None = None,
2906
+ order_enforce: bool = True,
2907
+ autoreset: bool | None = None,
2908
+ disable_env_checker: bool = False,
2909
+ apply_api_compatibility: bool = False,
2910
+ **kwargs,
2911
+ ):
2912
+ import gymnasium
2913
+ from torchrl.envs.libs._gym_utils import _TorchRLGymnasiumWrapper
2914
+
2915
+ if autoreset is not None:
2916
+ raise TypeError(
2917
+ f"the autoreset argument is deprecated in gymnasium>=1.0. Got autoreset={autoreset}"
2918
+ )
2919
+ if entry_point is None:
2920
+ entry_point = cls
2921
+
2922
+ entry_point = partial(
2923
+ _TorchRLGymnasiumWrapper,
2924
+ entry_point=entry_point,
2925
+ info_keys=info_keys,
2926
+ to_numpy=to_numpy,
2927
+ transform=transform,
2928
+ **kwargs,
2929
+ )
2930
+ if apply_api_compatibility is not False:
2931
+ raise TypeError(
2932
+ cls._GYM_UNRECOGNIZED_KWARG.format(
2933
+ "apply_api_compatibility", gymnasium.__version__
2934
+ )
2935
+ )
2936
+ return gymnasium.register(
2937
+ id=id,
2938
+ entry_point=entry_point,
2939
+ reward_threshold=reward_threshold,
2940
+ nondeterministic=nondeterministic,
2941
+ max_episode_steps=max_episode_steps,
2942
+ order_enforce=order_enforce,
2943
+ disable_env_checker=disable_env_checker,
2944
+ )
2945
+
2946
+ def forward(self, *args, **kwargs):
2947
+ raise NotImplementedError(
2948
+ "EnvBase.forward is not implemented. If you ended here during a call to `ParallelEnv(...)`, please use "
2949
+ "a constructor such as `ParallelEnv(num_env, lambda env=env: env)` instead. "
2950
+ "Batched envs require constructors because environment instances may not always be serializable."
2951
+ )
2952
+
2953
+ @abc.abstractmethod
2954
+ def _step(
2955
+ self,
2956
+ tensordict: TensorDictBase,
2957
+ ) -> TensorDictBase:
2958
+ raise NotImplementedError
2959
+
2960
+ @abc.abstractmethod
2961
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2962
+ raise NotImplementedError
2963
+
2964
+ def reset(
2965
+ self,
2966
+ tensordict: TensorDictBase | None = None,
2967
+ **kwargs,
2968
+ ) -> TensorDictBase:
2969
+ """Resets the environment.
2970
+
2971
+ As for step and _step, only the private method :obj:`_reset` should be overwritten by EnvBase subclasses.
2972
+
2973
+ Args:
2974
+ tensordict (TensorDictBase, optional): tensordict to be used to contain the resulting new observation.
2975
+ In some cases, this input can also be used to pass argument to the reset function.
2976
+ kwargs (optional): other arguments to be passed to the native
2977
+ reset function.
2978
+
2979
+ Returns:
2980
+ a tensordict (or the input tensordict, if any), modified in place with the resulting observations.
2981
+
2982
+ .. note:: `reset` should not be overwritten by :class:`~torchrl.envs.EnvBase` subclasses. The method to
2983
+ modify is :meth:`~torchrl.envs.EnvBase._reset`.
2984
+
2985
+ """
2986
+ if tensordict is not None:
2987
+ self._assert_tensordict_shape(tensordict)
2988
+
2989
+ select_reset_only = kwargs.pop("select_reset_only", False)
2990
+ if select_reset_only and tensordict is not None:
2991
+ # When making rollouts with step_and_maybe_reset, it can happen that a tensordict has
2992
+ # keys that are used by reset to optionally set the reset state (eg, the fen in chess). If that's the
2993
+ # case and we don't throw them away here, reset will just be a no-op (put the env in the state reached
2994
+ # during the previous step).
2995
+ # Therefore, maybe_reset tells reset to temporarily hide the non-reset keys.
2996
+ # To make step_and_maybe_reset handle custom reset states, some version of TensorDictPrimer should be used.
2997
+ tensordict_reset = self._reset(
2998
+ tensordict.select(*self.reset_keys, strict=False), **kwargs
2999
+ )
3000
+ else:
3001
+ tensordict_reset = self._reset(tensordict, **kwargs)
3002
+ # We assume that this is done properly
3003
+ # if reset.device != self.device:
3004
+ # reset = reset.to(self.device, non_blocking=True)
3005
+ if tensordict_reset is tensordict:
3006
+ raise RuntimeError(
3007
+ "EnvBase._reset should return outplace changes to the input "
3008
+ "tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty()) "
3009
+ "inside _reset before writing new tensors onto this new instance."
3010
+ )
3011
+ if not isinstance(tensordict_reset, TensorDictBase):
3012
+ raise RuntimeError(
3013
+ f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected."
3014
+ )
3015
+ return self._reset_proc_data(tensordict, tensordict_reset)
3016
+
3017
+ def _reset_proc_data(self, tensordict, tensordict_reset):
3018
+ self._complete_done(self.full_done_spec, tensordict_reset)
3019
+ self._reset_check_done(tensordict, tensordict_reset)
3020
+ if tensordict is not None:
3021
+ return _update_during_reset(tensordict_reset, tensordict, self.reset_keys)
3022
+ return tensordict_reset
3023
+
3024
+ def _reset_check_done(self, tensordict, tensordict_reset):
3025
+ """Checks the done status after reset.
3026
+
3027
+ If _reset signals were passed, we check that the env is not done for these
3028
+ indices.
3029
+
3030
+ We also check that the input tensordict contained ``"done"``s if the
3031
+ reset is partial and incomplete.
3032
+
3033
+ """
3034
+ # we iterate over (reset_key, (done_key, truncated_key)) and check that all
3035
+ # values where reset was true now have a done set to False.
3036
+ # If no reset was present, all done and truncated must be False
3037
+
3038
+ # Once we checked a root, we don't check its leaves - so keep track of the roots. Fortunately, we sort the done
3039
+ # keys in the done_keys_group from root to leaf
3040
+ prefix_complete = set()
3041
+ for reset_key, done_key_group in zip(self.reset_keys, self.done_keys_groups):
3042
+ skip = False
3043
+ if isinstance(reset_key, tuple):
3044
+ for i in range(len(reset_key) - 1):
3045
+ if reset_key[:i] in prefix_complete:
3046
+ skip = True
3047
+ break
3048
+ if skip:
3049
+ continue
3050
+ reset_value = (
3051
+ tensordict.get(reset_key, default=None)
3052
+ if tensordict is not None
3053
+ else None
3054
+ )
3055
+ prefix_complete.add(() if isinstance(reset_key, str) else reset_key[:-1])
3056
+ if reset_value is not None:
3057
+ for done_key in done_key_group:
3058
+ done_val = tensordict_reset.get(done_key)
3059
+ if (
3060
+ done_val.any()
3061
+ and done_val[reset_value].any()
3062
+ and not self._allow_done_after_reset
3063
+ ):
3064
+ raise RuntimeError(
3065
+ f"Env done entry '{done_key}' was (partially) True after reset on specified '_reset' dimensions. This is not allowed."
3066
+ )
3067
+ if (
3068
+ done_key not in tensordict.keys(True)
3069
+ and done_val[~reset_value].any()
3070
+ ):
3071
+ warnings.warn(
3072
+ f"A partial `'_reset'` key has been passed to `reset` ({reset_key}), "
3073
+ f"but the corresponding done_key ({done_key}) wasn't present in the input "
3074
+ f"tensordict. "
3075
+ f"This is discouraged, since the input tensordict should contain "
3076
+ f"all the data not being reset."
3077
+ )
3078
+ # we set the done val to tensordict, to make sure that
3079
+ # _update_during_reset does not pad the value
3080
+ tensordict.set(done_key, done_val)
3081
+ elif not self._allow_done_after_reset:
3082
+ for done_key in done_key_group:
3083
+ if tensordict_reset.get(done_key).any():
3084
+ raise RuntimeError(
3085
+ f"The done entry '{done_key}' was (partially) True after a call to reset() in env {self}."
3086
+ )
3087
+
3088
+ def numel(self) -> int:
3089
+ return prod(self.batch_size)
3090
+
3091
+ def set_seed(
3092
+ self, seed: int | None = None, static_seed: bool = False
3093
+ ) -> int | None:
3094
+ """Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
3095
+
3096
+ Args:
3097
+ seed (int): seed to be set. The seed is set only locally in the environment. To handle the global seed,
3098
+ see :func:`~torch.manual_seed`.
3099
+ static_seed (bool, optional): if ``True``, the seed is not incremented.
3100
+ Defaults to False
3101
+
3102
+ Returns:
3103
+ integer representing the "next seed": i.e. the seed that should be
3104
+ used for another environment if created concomitantly to this environment.
3105
+
3106
+ """
3107
+ self._set_seed(seed)
3108
+ if seed is not None and not static_seed:
3109
+ new_seed = seed_generator(seed)
3110
+ seed = new_seed
3111
+ return seed
3112
+
3113
+ @abc.abstractmethod
3114
+ def _set_seed(self, seed: int | None) -> None:
3115
+ raise NotImplementedError
3116
+
3117
+ def set_state(self):
3118
+ raise NotImplementedError
3119
+
3120
+ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
3121
+ if (
3122
+ self.batch_locked or self.batch_size != ()
3123
+ ) and tensordict.batch_size != self.batch_size:
3124
+ raise RuntimeError(
3125
+ f"Expected a tensordict with shape==env.batch_size, "
3126
+ f"got {tensordict.batch_size} and {self.batch_size}"
3127
+ )
3128
+
3129
+ def all_actions(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
3130
+ """Generates all possible actions from the action spec.
3131
+
3132
+ This only works in environments with fully discrete actions.
3133
+
3134
+ Args:
3135
+ tensordict (TensorDictBase, optional): If given, :meth:`~.reset`
3136
+ is called with this tensordict.
3137
+
3138
+ Returns:
3139
+ a tensordict object with the "action" entry updated with a batch of
3140
+ all possible actions. The actions are stacked together in the
3141
+ leading dimension.
3142
+ """
3143
+ if tensordict is not None:
3144
+ self.reset(tensordict)
3145
+
3146
+ return self.full_action_spec.enumerate(use_mask=True)
3147
+
3148
+ def rand_action(self, tensordict: TensorDictBase | None = None):
3149
+ """Performs a random action given the action_spec attribute.
3150
+
3151
+ Args:
3152
+ tensordict (TensorDictBase, optional): tensordict where the resulting action should be written.
3153
+
3154
+ Returns:
3155
+ a tensordict object with the "action" entry updated with a random
3156
+ sample from the action-spec.
3157
+
3158
+ """
3159
+ shape = torch.Size([])
3160
+ if not self.batch_locked:
3161
+ if not self.batch_size and tensordict is not None:
3162
+ # if we can't infer the batch-size from the env, take it from tensordict
3163
+ shape = tensordict.shape
3164
+ elif not self.batch_size:
3165
+ # if tensordict wasn't provided, we assume empty batch size
3166
+ shape = torch.Size([])
3167
+ elif tensordict.shape != self.batch_size:
3168
+ # if tensordict is not None and the env has a batch size, their shape must match
3169
+ raise RuntimeError(
3170
+ "The input tensordict and the env have a different batch size: "
3171
+ f"env.batch_size={self.batch_size} and tensordict.batch_size={tensordict.shape}. "
3172
+ f"Non batch-locked environment require the env batch-size to be either empty or to"
3173
+ f" match the tensordict one."
3174
+ )
3175
+ # We generate the action from the full_action_spec
3176
+ r = self.input_spec["full_action_spec"].rand(shape)
3177
+ if tensordict is None:
3178
+ return r
3179
+ tensordict.update(r)
3180
+ return tensordict
3181
+
3182
+ def rand_step(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
3183
+ """Performs a random step in the environment given the action_spec attribute.
3184
+
3185
+ Args:
3186
+ tensordict (TensorDictBase, optional): tensordict where the resulting info should be written.
3187
+
3188
+ Returns:
3189
+ a tensordict object with the new observation after a random step in the environment. The action will
3190
+ be stored with the "action" key.
3191
+
3192
+ """
3193
+ tensordict = self.rand_action(tensordict)
3194
+ return self.step(tensordict)
3195
+
3196
+ @property
3197
+ def specs(self) -> Composite:
3198
+ """Returns a Composite container where all the environment are present.
3199
+
3200
+ This feature allows one to create an environment, retrieve all of the specs in a single data container and then
3201
+ erase the environment from the workspace.
3202
+
3203
+ """
3204
+ return Composite(
3205
+ output_spec=self.output_spec,
3206
+ input_spec=self.input_spec,
3207
+ shape=self.batch_size,
3208
+ )
3209
+
3210
+ @property
3211
+ @_cache_value
3212
+ def _has_dynamic_specs(self) -> bool:
3213
+ return _has_dynamic_specs(self.specs)
3214
+
3215
+ def rollout(
3216
+ self,
3217
+ max_steps: int,
3218
+ policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
3219
+ callback: Callable[[TensorDictBase, ...], Any] | None = None,
3220
+ *,
3221
+ auto_reset: bool = True,
3222
+ auto_cast_to_device: bool = False,
3223
+ break_when_any_done: bool | None = None,
3224
+ break_when_all_done: bool | None = None,
3225
+ return_contiguous: bool | None = False,
3226
+ tensordict: TensorDictBase | None = None,
3227
+ set_truncated: bool = False,
3228
+ out=None,
3229
+ trust_policy: bool = False,
3230
+ storing_device: DEVICE_TYPING | None = None,
3231
+ ) -> TensorDictBase:
3232
+ """Executes a rollout in the environment.
3233
+
3234
+ The function will return as soon as any of the contained environments
3235
+ reaches any of the done states.
3236
+
3237
+ Args:
3238
+ max_steps (int): maximum number of steps to be executed. The actual number of steps can be smaller if
3239
+ the environment reaches a done state before max_steps have been executed.
3240
+ policy (callable, optional): callable to be called to compute the desired action.
3241
+ If no policy is provided, actions will be called using :obj:`env.rand_step()`.
3242
+ The policy can be any callable that reads either a tensordict or
3243
+ the entire sequence of observation entries __sorted as__ the ``env.observation_spec.keys()``.
3244
+ Defaults to `None`.
3245
+ callback (Callable[[TensorDict], Any], optional): function to be called at each iteration with the given
3246
+ TensorDict. Defaults to ``None``. The output of ``callback`` will not be collected, it is the user
3247
+ responsibility to save any result within the callback call if data needs to be carried over beyond
3248
+ the call to ``rollout``.
3249
+
3250
+ Keyword Args:
3251
+ auto_reset (bool, optional): if ``True``, the contained environments will be reset before starting the
3252
+ rollout. If ``False``, then the rollout will continue from a previous state, which requires the
3253
+ ``tensordict`` argument to be passed with the previous rollout. Default is ``True``.
3254
+ auto_cast_to_device (bool, optional): if ``True``, the device of the tensordict is automatically cast to the
3255
+ policy device before the policy is used. Default is ``False``.
3256
+ break_when_any_done (bool): if ``True``, break when any of the contained environments reaches any of the
3257
+ done states. If ``False``, then the done environments are reset automatically. Default is ``True``.
3258
+
3259
+ .. seealso:: The :ref:`Partial resets <ref_partial_resets>` of the documentation gives more
3260
+ information about partial resets.
3261
+
3262
+ break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any
3263
+ of the done states. If ``False``, break if at least one environment reaches any of the done states.
3264
+ Default is ``False``.
3265
+
3266
+ .. seealso:: The :ref:`Partial steps <ref_partial_steps>` of the documentation gives more
3267
+ information about partial resets.
3268
+
3269
+ return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is `True` if
3270
+ the env does not have dynamic specs, otherwise `False`.
3271
+ tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial
3272
+ tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
3273
+ environment in those dimensions (if needed).
3274
+ This normally should not occur if ``tensordict`` is the output of a reset, but can occur
3275
+ if ``tensordict`` is the last step of a previous rollout.
3276
+ A ``tensordict`` can also be provided when ``auto_reset=True`` if metadata need to be passed
3277
+ to the ``reset`` method, such as a batch-size or a device for stateless environments.
3278
+ set_truncated (bool, optional): if ``True``, ``"truncated"`` and ``"done"`` keys will be set to
3279
+ ``True`` after completion of the rollout. If no ``"truncated"`` is found within the
3280
+ ``done_spec``, an exception is raised.
3281
+ Truncated keys can be set through ``env.add_truncated_keys``.
3282
+ Defaults to ``False``.
3283
+ trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
3284
+ assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
3285
+ and ``False`` otherwise.
3286
+ storing_device (Device, optional): if provided, the tensordict will be stored on this device.
3287
+ Defaults to ``None``.
3288
+
3289
+ Returns:
3290
+ TensorDict object containing the resulting trajectory.
3291
+
3292
+ The data returned will be marked with a "time" dimension name for the last
3293
+ dimension of the tensordict (at the ``env.ndim`` index).
3294
+
3295
+ ``rollout`` is quite handy to display what the data structure of the
3296
+ environment looks like.
3297
+
3298
+ Examples:
3299
+ >>> # Using rollout without a policy
3300
+ >>> from torchrl.envs.libs.gym import GymEnv
3301
+ >>> from torchrl.envs.transforms import TransformedEnv, StepCounter
3302
+ >>> env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20))
3303
+ >>> rollout = env.rollout(max_steps=1000)
3304
+ >>> print(rollout)
3305
+ TensorDict(
3306
+ fields={
3307
+ action: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
3308
+ done: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3309
+ next: TensorDict(
3310
+ fields={
3311
+ done: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3312
+ observation: Tensor(shape=torch.Size([20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
3313
+ reward: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
3314
+ step_count: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
3315
+ truncated: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
3316
+ batch_size=torch.Size([20]),
3317
+ device=cpu,
3318
+ is_shared=False),
3319
+ observation: Tensor(shape=torch.Size([20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
3320
+ step_count: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
3321
+ truncated: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
3322
+ batch_size=torch.Size([20]),
3323
+ device=cpu,
3324
+ is_shared=False)
3325
+ >>> print(rollout.names)
3326
+ ['time']
3327
+ >>> # with envs that contain more dimensions
3328
+ >>> from torchrl.envs import SerialEnv
3329
+ >>> env = SerialEnv(3, lambda: TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20)))
3330
+ >>> rollout = env.rollout(max_steps=1000)
3331
+ >>> print(rollout)
3332
+ TensorDict(
3333
+ fields={
3334
+ action: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
3335
+ done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3336
+ next: TensorDict(
3337
+ fields={
3338
+ done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3339
+ observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
3340
+ reward: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
3341
+ step_count: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
3342
+ truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
3343
+ batch_size=torch.Size([3, 20]),
3344
+ device=cpu,
3345
+ is_shared=False),
3346
+ observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
3347
+ step_count: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
3348
+ truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
3349
+ batch_size=torch.Size([3, 20]),
3350
+ device=cpu,
3351
+ is_shared=False)
3352
+ >>> print(rollout.names)
3353
+ [None, 'time']
3354
+
3355
+ Using a policy (a regular :class:`~torch.nn.Module` or a :class:`~tensordict.nn.TensorDictModule`)
3356
+ is also easy:
3357
+
3358
+ Examples:
3359
+ >>> from torch import nn
3360
+ >>> env = GymEnv("CartPole-v1", categorical_action_encoding=True)
3361
+ >>> class ArgMaxModule(nn.Module):
3362
+ ... def forward(self, values):
3363
+ ... return values.argmax(-1)
3364
+ >>> n_obs = env.observation_spec["observation"].shape[-1]
3365
+ >>> n_act = env.action_spec.n
3366
+ >>> # A deterministic policy
3367
+ >>> policy = nn.Sequential(
3368
+ ... nn.Linear(n_obs, n_act),
3369
+ ... ArgMaxModule())
3370
+ >>> env.rollout(max_steps=10, policy=policy)
3371
+ TensorDict(
3372
+ fields={
3373
+ action: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
3374
+ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3375
+ next: TensorDict(
3376
+ fields={
3377
+ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3378
+ observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
3379
+ reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
3380
+ terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3381
+ truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
3382
+ batch_size=torch.Size([10]),
3383
+ device=cpu,
3384
+ is_shared=False),
3385
+ observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
3386
+ terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3387
+ truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
3388
+ batch_size=torch.Size([10]),
3389
+ device=cpu,
3390
+ is_shared=False)
3391
+ >>> # Under the hood, rollout will wrap the policy in a TensorDictModule
3392
+ >>> # To speed things up we can do that ourselves
3393
+ >>> from tensordict.nn import TensorDictModule
3394
+ >>> policy = TensorDictModule(policy, in_keys=list(env.observation_spec.keys()), out_keys=["action"])
3395
+ >>> env.rollout(max_steps=10, policy=policy)
3396
+ TensorDict(
3397
+ fields={
3398
+ action: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
3399
+ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3400
+ next: TensorDict(
3401
+ fields={
3402
+ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3403
+ observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
3404
+ reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
3405
+ terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3406
+ truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
3407
+ batch_size=torch.Size([10]),
3408
+ device=cpu,
3409
+ is_shared=False),
3410
+ observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
3411
+ terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3412
+ truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
3413
+ batch_size=torch.Size([10]),
3414
+ device=cpu,
3415
+ is_shared=False)
3416
+
3417
+
3418
+ In some instances, contiguous tensordict cannot be obtained because
3419
+ they cannot be stacked. This can happen when the data returned at
3420
+ each step may have a different shape, or when different environments
3421
+ are executed together. In that case, ``return_contiguous=False``
3422
+ will cause the returned tensordict to be a lazy stack of tensordicts:
3423
+
3424
+ Examples of non-contiguous rollout:
3425
+ >>> rollout = env.rollout(4, return_contiguous=False)
3426
+ >>> print(rollout)
3427
+ LazyStackedTensorDict(
3428
+ fields={
3429
+ action: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
3430
+ done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3431
+ next: LazyStackedTensorDict(
3432
+ fields={
3433
+ done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3434
+ observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
3435
+ reward: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
3436
+ step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
3437
+ truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
3438
+ batch_size=torch.Size([3, 4]),
3439
+ device=cpu,
3440
+ is_shared=False),
3441
+ observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
3442
+ step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
3443
+ truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
3444
+ batch_size=torch.Size([3, 4]),
3445
+ device=cpu,
3446
+ is_shared=False)
3447
+ >>> print(rollout.names)
3448
+ [None, 'time']
3449
+
3450
+ Rollouts can be used in a loop to emulate data collection.
3451
+ To do so, you need to pass as input the last tensordict coming from the previous rollout after calling
3452
+ :func:`~torchrl.envs.utils.step_mdp` on it.
3453
+
3454
+ Examples of data collection rollouts:
3455
+ >>> from torchrl.envs import GymEnv, step_mdp
3456
+ >>> env = GymEnv("CartPole-v1")
3457
+ >>> epochs = 10
3458
+ >>> input_td = env.reset()
3459
+ >>> for i in range(epochs):
3460
+ ... rollout_td = env.rollout(
3461
+ ... max_steps=100,
3462
+ ... break_when_any_done=False,
3463
+ ... auto_reset=False,
3464
+ ... tensordict=input_td,
3465
+ ... )
3466
+ ... input_td = step_mdp(
3467
+ ... rollout_td[..., -1],
3468
+ ... )
3469
+
3470
+ """
3471
+ if break_when_any_done is None: # True by default
3472
+ if break_when_all_done: # all overrides
3473
+ break_when_any_done = False
3474
+ else:
3475
+ break_when_any_done = True
3476
+ if break_when_all_done is None:
3477
+ # There is no case where break_when_all_done is True by default
3478
+ break_when_all_done = False
3479
+ if break_when_all_done and break_when_any_done:
3480
+ raise TypeError(
3481
+ "Cannot have both break_when_all_done and break_when_any_done True at the same time."
3482
+ )
3483
+ if return_contiguous is None:
3484
+ return_contiguous = not self._has_dynamic_specs
3485
+ if policy is not None:
3486
+ policy = _make_compatible_policy(
3487
+ policy,
3488
+ self.observation_spec,
3489
+ env=self,
3490
+ fast_wrap=True,
3491
+ trust_policy=trust_policy,
3492
+ )
3493
+ if auto_cast_to_device:
3494
+ try:
3495
+ policy_device = next(policy.parameters()).device
3496
+ except (StopIteration, AttributeError):
3497
+ policy_device = None
3498
+ else:
3499
+ policy_device = None
3500
+ else:
3501
+ policy = self.rand_action
3502
+ policy_device = None
3503
+
3504
+ env_device = self.device
3505
+
3506
+ if auto_reset:
3507
+ tensordict = self.reset(tensordict)
3508
+ elif tensordict is None:
3509
+ raise RuntimeError("tensordict must be provided when auto_reset is False")
3510
+ else:
3511
+ tensordict = self.maybe_reset(tensordict)
3512
+
3513
+ kwargs = {
3514
+ "tensordict": tensordict,
3515
+ "auto_cast_to_device": auto_cast_to_device,
3516
+ "max_steps": max_steps,
3517
+ "policy": policy,
3518
+ "policy_device": policy_device,
3519
+ "env_device": env_device,
3520
+ "storing_device": None
3521
+ if storing_device is None
3522
+ else torch.device(storing_device),
3523
+ "callback": callback,
3524
+ }
3525
+ if break_when_any_done or break_when_all_done:
3526
+ tensordicts = self._rollout_stop_early(
3527
+ break_when_all_done=break_when_all_done,
3528
+ break_when_any_done=break_when_any_done,
3529
+ **kwargs,
3530
+ )
3531
+ else:
3532
+ tensordicts = self._rollout_nonstop(**kwargs)
3533
+ batch_size = self.batch_size if tensordict is None else tensordict.batch_size
3534
+ if return_contiguous:
3535
+ try:
3536
+ out_td = torch.stack(tensordicts, len(batch_size), out=out)
3537
+ except RuntimeError as err:
3538
+ if (
3539
+ re.match(
3540
+ "The shapes of the tensors to stack is incompatible", str(err)
3541
+ )
3542
+ and self._has_dynamic_specs
3543
+ ):
3544
+ raise RuntimeError(
3545
+ "The environment specs are dynamic. Call rollout with return_contiguous=False."
3546
+ )
3547
+ if re.match(
3548
+ "The sets of keys in the tensordicts to stack are exclusive",
3549
+ str(err),
3550
+ ):
3551
+ for reward_key in self.reward_keys:
3552
+ if any(reward_key in td for td in tensordicts):
3553
+ raise RuntimeError(
3554
+ "The reward key was present in the root tensordict of at least one of the tensordicts to stack. "
3555
+ "The likely cause is that your environment returns a reward during a call to `reset`, which is not allowed. "
3556
+ "To fix this, you should return the reward in the `step` method but not in during `reset`. If you need a reward "
3557
+ "to be returned during `reset`, submit an issue on github."
3558
+ )
3559
+ raise
3560
+ else:
3561
+ out_td = LazyStackedTensorDict.maybe_dense_stack(
3562
+ tensordicts, len(batch_size), out=out
3563
+ )
3564
+ if set_truncated:
3565
+ found_truncated = False
3566
+ for key in self.done_keys:
3567
+ if _ends_with(key, "truncated"):
3568
+ val = out_td.get(("next", key))
3569
+ done = out_td.get(("next", _replace_last(key, "done")))
3570
+ val[(slice(None),) * (out_td.ndim - 1) + (-1,)] = True
3571
+ out_td.set(("next", key), val)
3572
+ out_td.set(("next", _replace_last(key, "done")), val | done)
3573
+ found_truncated = True
3574
+ if not found_truncated:
3575
+ raise RuntimeError(
3576
+ "set_truncated was set to True but no truncated key could be found. "
3577
+ "Make sure a 'truncated' entry was set in the environment "
3578
+ "full_done_keys using `env.add_truncated_keys()`."
3579
+ )
3580
+
3581
+ out_td.refine_names(..., "time")
3582
+ return out_td
3583
+
3584
+ @_maybe_unlock
3585
+ def add_truncated_keys(self) -> EnvBase:
3586
+ """Adds truncated keys to the environment."""
3587
+ i = 0
3588
+ for key in self.done_keys:
3589
+ i += 1
3590
+ truncated_key = _replace_last(key, "truncated")
3591
+ self.full_done_spec[truncated_key] = self.full_done_spec[key].clone()
3592
+ if i == 0:
3593
+ raise KeyError(f"Couldn't find done keys. done_spec={self.full_done_specs}")
3594
+
3595
+ return self
3596
+
3597
+ def step_mdp(self, next_tensordict: TensorDictBase) -> TensorDictBase:
3598
+ """Advances the environment state by one step using the provided `next_tensordict`.
3599
+
3600
+ This method updates the environment's state by transitioning from the current
3601
+ state to the next, as defined by the `next_tensordict`. The resulting tensordict
3602
+ includes updated observations and any other relevant state information, with
3603
+ keys managed according to the environment's specifications.
3604
+
3605
+ Internally, this method utilizes a precomputed :class:`~torchrl.envs.utils._StepMDP` instance to efficiently
3606
+ handle the transition of state, observation, action, reward, and done keys. The
3607
+ :class:`~torchrl.envs.utils._StepMDP` class optimizes the process by precomputing the keys to include and
3608
+ exclude, reducing runtime overhead during repeated calls. The :class:`~torchrl.envs.utils._StepMDP` instance
3609
+ is created with `exclude_action=False`, meaning that action keys are retained in
3610
+ the root tensordict.
3611
+
3612
+ Args:
3613
+ next_tensordict (TensorDictBase): A tensordict containing the state of the
3614
+ environment at the next time step. This tensordict should include keys
3615
+ for observations, actions, rewards, and done flags, as defined by the
3616
+ environment's specifications.
3617
+
3618
+ Returns:
3619
+ TensorDictBase: A new tensordict representing the environment state after
3620
+ advancing by one step.
3621
+
3622
+ .. note:: The method ensures that the environment's key specifications are validated
3623
+ against the provided `next_tensordict`, issuing warnings if discrepancies
3624
+ are found.
3625
+
3626
+ .. note:: This method is designed to work efficiently with environments that have
3627
+ consistent key specifications, leveraging the `_StepMDP` class to minimize
3628
+ overhead.
3629
+
3630
+ Example:
3631
+ >>> from torchrl.envs import GymEnv
3632
+ >>> env = GymEnv("Pendulum-1")
3633
+ >>> data = env.reset()
3634
+ >>> for i in range(10):
3635
+ ... # compute action
3636
+ ... env.rand_action(data)
3637
+ ... # Perform action
3638
+ ... next_data = env.step(reset_data)
3639
+ ... data = env.step_mdp(next_data)
3640
+ """
3641
+ return self._step_mdp(next_tensordict)
3642
+
3643
+ @property
3644
+ @_cache_value
3645
+ def _step_mdp(self) -> Callable[[TensorDictBase], TensorDictBase]:
3646
+ return _StepMDP(self, exclude_action=False)
3647
+
3648
+ def _rollout_stop_early(
3649
+ self,
3650
+ *,
3651
+ break_when_any_done,
3652
+ break_when_all_done,
3653
+ tensordict,
3654
+ auto_cast_to_device,
3655
+ max_steps,
3656
+ policy,
3657
+ policy_device,
3658
+ env_device,
3659
+ storing_device,
3660
+ callback,
3661
+ ):
3662
+ # Get the sync func
3663
+ if auto_cast_to_device:
3664
+ sync_func = _get_sync_func(policy_device, env_device)
3665
+ tensordicts = []
3666
+ partial_steps = True
3667
+ for i in range(max_steps):
3668
+ if auto_cast_to_device:
3669
+ if policy_device is not None:
3670
+ tensordict = tensordict.to(policy_device, non_blocking=True)
3671
+ sync_func()
3672
+ else:
3673
+ tensordict.clear_device_()
3674
+ # In case policy(..) does not modify in-place - no-op for TensorDict and related
3675
+ tensordict.update(policy(tensordict))
3676
+ if auto_cast_to_device:
3677
+ if env_device is not None:
3678
+ tensordict = tensordict.to(env_device, non_blocking=True)
3679
+ sync_func()
3680
+ else:
3681
+ tensordict.clear_device_()
3682
+ tensordict = self.step(tensordict)
3683
+ if storing_device is None or tensordict.device == storing_device:
3684
+ td_append = tensordict.copy()
3685
+ else:
3686
+ td_append = tensordict.to(storing_device)
3687
+ if break_when_all_done:
3688
+ if partial_steps is not True and not partial_steps.all():
3689
+ # At least one step is partial
3690
+ td_append.pop("_step", None)
3691
+ td_append = torch.where(
3692
+ partial_steps.view(td_append.shape), td_append, tensordicts[-1]
3693
+ )
3694
+
3695
+ tensordicts.append(td_append)
3696
+
3697
+ if i == max_steps - 1:
3698
+ # we don't truncate as one could potentially continue the run
3699
+ break
3700
+ tensordict = self._step_mdp(tensordict)
3701
+
3702
+ if break_when_any_done:
3703
+ # done and truncated are in done_keys
3704
+ # We read if any key is done.
3705
+ any_done = _terminated_or_truncated(
3706
+ tensordict,
3707
+ full_done_spec=self.output_spec["full_done_spec"],
3708
+ key=None,
3709
+ )
3710
+ if any_done:
3711
+ break
3712
+ else:
3713
+ # Write the '_step' entry, indicating which step is to be undertaken
3714
+ _terminated_or_truncated(
3715
+ tensordict,
3716
+ full_done_spec=self.output_spec["full_done_spec"],
3717
+ key="_neg_step",
3718
+ write_full_false=False,
3719
+ )
3720
+ # This is what differentiates _step and _reset: we need to flip _step False -> True
3721
+ partial_step_curr = tensordict.pop("_neg_step", None)
3722
+ if partial_step_curr is not None:
3723
+ partial_step_curr = ~partial_step_curr
3724
+ partial_steps = partial_steps & partial_step_curr
3725
+ if partial_steps is not True:
3726
+ if not partial_steps.any():
3727
+ break
3728
+ # Write the final _step entry
3729
+ tensordict.set("_step", partial_steps)
3730
+
3731
+ if callback is not None:
3732
+ callback(self, tensordict)
3733
+ return tensordicts
3734
+
3735
+ def _rollout_nonstop(
3736
+ self,
3737
+ *,
3738
+ tensordict,
3739
+ auto_cast_to_device,
3740
+ max_steps,
3741
+ policy,
3742
+ policy_device,
3743
+ env_device,
3744
+ storing_device,
3745
+ callback,
3746
+ ):
3747
+ if auto_cast_to_device:
3748
+ sync_func = _get_sync_func(policy_device, env_device)
3749
+ tensordicts = []
3750
+ tensordict_ = tensordict
3751
+ for i in range(max_steps):
3752
+ if auto_cast_to_device:
3753
+ if policy_device is not None:
3754
+ tensordict_ = tensordict_.to(policy_device, non_blocking=True)
3755
+ sync_func()
3756
+ else:
3757
+ tensordict_.clear_device_()
3758
+ # In case policy(..) does not modify in-place - no-op for TensorDict and related
3759
+ tensordict_.update(policy(tensordict_))
3760
+ if auto_cast_to_device:
3761
+ if env_device is not None:
3762
+ tensordict_ = tensordict_.to(env_device, non_blocking=True)
3763
+ sync_func()
3764
+ else:
3765
+ tensordict_.clear_device_()
3766
+ if i == max_steps - 1:
3767
+ tensordict = self.step(tensordict_)
3768
+ else:
3769
+ tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
3770
+ if storing_device is None or tensordict.device == storing_device:
3771
+ tensordicts.append(tensordict)
3772
+ else:
3773
+ tensordicts.append(tensordict.to(storing_device))
3774
+ if i == max_steps - 1:
3775
+ # we don't truncate as one could potentially continue the run
3776
+ break
3777
+ if callback is not None:
3778
+ callback(self, tensordict)
3779
+
3780
+ return tensordicts
3781
+
3782
+ def step_and_maybe_reset(
3783
+ self, tensordict: TensorDictBase
3784
+ ) -> tuple[TensorDictBase, TensorDictBase]:
3785
+ """Runs a step in the environment and (partially) resets it if needed.
3786
+
3787
+ Args:
3788
+ tensordict (TensorDictBase): an input data structure for the :meth:`step`
3789
+ method.
3790
+
3791
+ This method allows to easily code non-stopping rollout functions.
3792
+
3793
+ Examples:
3794
+ >>> from torchrl.envs import ParallelEnv, GymEnv
3795
+ >>> def rollout(env, n):
3796
+ ... data_ = env.reset()
3797
+ ... result = []
3798
+ ... for i in range(n):
3799
+ ... data, data_ = env.step_and_maybe_reset(data_)
3800
+ ... result.append(data)
3801
+ ... return torch.stack(result)
3802
+ >>> env = ParallelEnv(2, lambda: GymEnv("CartPole-v1"))
3803
+ >>> print(rollout(env, 2))
3804
+ TensorDict(
3805
+ fields={
3806
+ done: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3807
+ next: TensorDict(
3808
+ fields={
3809
+ done: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3810
+ observation: Tensor(shape=torch.Size([2, 2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
3811
+ reward: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
3812
+ terminated: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3813
+ truncated: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
3814
+ batch_size=torch.Size([2, 2]),
3815
+ device=cpu,
3816
+ is_shared=False),
3817
+ observation: Tensor(shape=torch.Size([2, 2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
3818
+ terminated: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
3819
+ truncated: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
3820
+ batch_size=torch.Size([2, 2]),
3821
+ device=cpu,
3822
+ is_shared=False)
3823
+ """
3824
+ if tensordict.device != self.device:
3825
+ tensordict = tensordict.to(self.device)
3826
+ tensordict = self.step(tensordict)
3827
+ # done and truncated are in done_keys
3828
+ # We read if any key is done.
3829
+ tensordict_ = self._step_mdp(tensordict)
3830
+ # if self._post_step_mdp_hooks is not None:
3831
+ # tensordict_ = self._post_step_mdp_hooks(tensordict_)
3832
+ tensordict_ = self.maybe_reset(tensordict_)
3833
+ return tensordict, tensordict_
3834
+
3835
+ # _post_step_mdp_hooks: Callable[[TensorDictBase], TensorDictBase] | None = None
3836
+
3837
+ @property
3838
+ @_cache_value
3839
+ def _simple_done(self):
3840
+ key_set = set(self.full_done_spec.keys())
3841
+
3842
+ _simple_done = "done" in key_set and "terminated" in key_set
3843
+ return _simple_done
3844
+
3845
+ def any_done(self, tensordict: TensorDictBase) -> bool:
3846
+ """Checks if the tensordict is in a "done" state (or if an element of the batch is).
3847
+
3848
+ Writes the result under the `"_reset"` entry.
3849
+
3850
+ Returns: a bool indicating whether there is an element in the tensordict that is marked
3851
+ as done.
3852
+
3853
+ .. note:: The tensordict passed should be a `"next"` tensordict or equivalent -- i.e., it should not
3854
+ contain a `"next"` value.
3855
+
3856
+ """
3857
+ if self._simple_done:
3858
+ done = tensordict._get_str("done", default=None)
3859
+ if done is not None:
3860
+ any_done = done.any()
3861
+ else:
3862
+ any_done = False
3863
+ if any_done:
3864
+ tensordict._set_str(
3865
+ "_reset",
3866
+ done.clone(),
3867
+ validated=True,
3868
+ inplace=False,
3869
+ non_blocking=False,
3870
+ )
3871
+ else:
3872
+ any_done = _terminated_or_truncated(
3873
+ tensordict,
3874
+ full_done_spec=self.output_spec["full_done_spec"],
3875
+ key="_reset",
3876
+ )
3877
+ return any_done
3878
+
3879
+ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
3880
+ """Checks the done keys of the input tensordict and, if needed, resets the environment where it is done.
3881
+
3882
+ Args:
3883
+ tensordict (TensorDictBase): a tensordict coming from the output of :func:`~torchrl.envs.utils.step_mdp`.
3884
+
3885
+ Returns:
3886
+ A tensordict that is identical to the input where the environment was
3887
+ not reset and contains the new reset data where the environment was reset.
3888
+
3889
+ """
3890
+ any_done = self.any_done(tensordict)
3891
+ if any_done:
3892
+ tensordict = self.reset(tensordict, select_reset_only=True)
3893
+ return tensordict
3894
+
3895
+ def empty_cache(self):
3896
+ """Erases all the cached values.
3897
+
3898
+ For regular envs, the key lists (reward, done etc) are cached, but in some cases
3899
+ they may change during the execution of the code (eg, when adding a transform).
3900
+
3901
+ """
3902
+ self._cache.clear()
3903
+
3904
+ @property
3905
+ @_cache_value
3906
+ def reset_keys(self) -> list[NestedKey]:
3907
+ """Returns a list of reset keys.
3908
+
3909
+ Reset keys are keys that indicate partial reset, in batched, multitask or multiagent
3910
+ settings. They are structured as ``(*prefix, "_reset")`` where ``prefix`` is
3911
+ a (possibly empty) tuple of strings pointing to a tensordict location
3912
+ where a done state can be found.
3913
+
3914
+ Keys are sorted by depth in the data tree.
3915
+ """
3916
+ reset_keys = sorted(
3917
+ (
3918
+ _replace_last(done_key, "_reset")
3919
+ for (done_key, *_) in self.done_keys_groups
3920
+ ),
3921
+ key=_repr_by_depth,
3922
+ )
3923
+ return reset_keys
3924
+
3925
+ @property
3926
+ def _filtered_reset_keys(self):
3927
+ """Returns only the effective reset keys, discarding nested resets if they're not being used."""
3928
+ reset_keys = self.reset_keys
3929
+ result = []
3930
+
3931
+ def _root(key):
3932
+ if isinstance(key, str):
3933
+ return ()
3934
+ return key[:-1]
3935
+
3936
+ roots = []
3937
+ for reset_key in reset_keys:
3938
+ cur_root = _root(reset_key)
3939
+ for root in roots:
3940
+ if cur_root[: len(root)] == root:
3941
+ break
3942
+ else:
3943
+ roots.append(cur_root)
3944
+ result.append(reset_key)
3945
+ return result
3946
+
3947
+ @property
3948
+ @_cache_value
3949
+ def done_keys_groups(self):
3950
+ """A list of done keys, grouped as the reset keys.
3951
+
3952
+ This is a list of lists. The outer list has the length of reset keys, the
3953
+ inner lists contain the done keys (eg, done and truncated) that can
3954
+ be read to determine a reset when it is absent.
3955
+ """
3956
+ # done keys, sorted as reset keys
3957
+ done_keys_group = []
3958
+ roots = set()
3959
+ fds = self.full_done_spec
3960
+ for done_key in self.done_keys:
3961
+ root_name = done_key[:-1] if isinstance(done_key, tuple) else ()
3962
+ root = fds[root_name] if root_name else fds
3963
+ n = len(roots)
3964
+ roots.add(root_name)
3965
+ if len(roots) - n:
3966
+ done_keys_group.append(
3967
+ [
3968
+ unravel_key(root_name + (key,))
3969
+ for key in root.keys(include_nested=False, leaves_only=True)
3970
+ ]
3971
+ )
3972
+ return done_keys_group
3973
+
3974
+ def _select_observation_keys(self, tensordict: TensorDictBase) -> Iterator[str]:
3975
+ for key in tensordict.keys():
3976
+ if key.rfind("observation") >= 0:
3977
+ yield key
3978
+
3979
+ def close(self, *, raise_if_closed: bool = True):
3980
+ self.is_closed = True
3981
+
3982
+ def __del__(self):
3983
+ # if del occurs before env has been set up, we don't want a recursion
3984
+ # error
3985
+ if "is_closed" in self.__dict__ and not self.is_closed:
3986
+ try:
3987
+ self.close()
3988
+ except Exception:
3989
+ # a TypeError will typically be raised if the env is deleted when the program ends.
3990
+ # In the future, insignificant changes to the close method may change the error type.
3991
+ # We excplicitely assume that any error raised during closure in
3992
+ # __del__ will not affect the program.
3993
+ pass
3994
+
3995
+ @_maybe_unlock
3996
+ def to(self, device: DEVICE_TYPING) -> EnvBase:
3997
+ device = _make_ordinal_device(torch.device(device))
3998
+ if device == self.device:
3999
+ return self
4000
+ self.__dict__["_input_spec"] = self.input_spec.to(device)
4001
+ self.__dict__["_output_spec"] = self.output_spec.to(device)
4002
+ self._device = device
4003
+ return super().to(device)
4004
+
4005
+ def fake_tensordict(self) -> TensorDictBase:
4006
+ """Returns a fake tensordict with key-value pairs that match in shape, device and dtype what can be expected during an environment rollout."""
4007
+ state_spec = self.state_spec
4008
+ observation_spec = self.observation_spec
4009
+ action_spec = self.input_spec["full_action_spec"]
4010
+ # instantiates reward_spec if needed
4011
+ _ = self.full_reward_spec
4012
+ reward_spec = self.output_spec["full_reward_spec"]
4013
+ full_done_spec = self.output_spec["full_done_spec"]
4014
+
4015
+ fake_obs = observation_spec.zero()
4016
+ fake_reward = reward_spec.zero()
4017
+ fake_done = full_done_spec.zero()
4018
+ fake_state = state_spec.zero()
4019
+ fake_action = action_spec.zero()
4020
+
4021
+ if any(
4022
+ isinstance(val, LazyStackedTensorDict) for val in fake_action.values(True)
4023
+ ):
4024
+ fake_input = fake_action.update(fake_state)
4025
+ else:
4026
+ fake_input = fake_state.update(fake_action)
4027
+
4028
+ # the input and output key may match, but the output prevails
4029
+ # Hence we generate the input, and override using the output
4030
+ fake_in_out = fake_input.update(fake_obs)
4031
+
4032
+ next_output = fake_obs.clone()
4033
+ next_output.update(fake_reward)
4034
+ next_output.update(fake_done)
4035
+ fake_in_out.update(fake_done.clone())
4036
+ if "next" not in fake_in_out.keys():
4037
+ fake_in_out.set("next", next_output)
4038
+ else:
4039
+ fake_in_out.get("next").update(next_output)
4040
+
4041
+ fake_in_out.batch_size = self.batch_size
4042
+ fake_in_out = fake_in_out.to(self.device)
4043
+ return fake_in_out
4044
+
4045
+
4046
+ class _EnvWrapper(EnvBase):
4047
+ """Abstract environment wrapper class.
4048
+
4049
+ Unlike EnvBase, _EnvWrapper comes with a :obj:`_build_env` private method that will be called upon instantiation.
4050
+ Interfaces with other libraries should be coded using _EnvWrapper.
4051
+
4052
+ It is possible to directly query attributed from the nested environment it its name does not conflict with
4053
+ an attribute of the wrapper:
4054
+ >>> env = SomeWrapper(...)
4055
+ >>> custom_attribute0 = env._env.custom_attribute
4056
+ >>> custom_attribute1 = env.custom_attribute
4057
+ >>> assert custom_attribute0 is custom_attribute1 # should return True
4058
+
4059
+ """
4060
+
4061
+ git_url: str = ""
4062
+ available_envs: dict[str, Any] = {}
4063
+ libname: str = ""
4064
+
4065
+ def __init__(
4066
+ self,
4067
+ *args,
4068
+ device: DEVICE_TYPING = None,
4069
+ batch_size: torch.Size | None = None,
4070
+ allow_done_after_reset: bool = False,
4071
+ spec_locked: bool = True,
4072
+ **kwargs,
4073
+ ):
4074
+ super().__init__(
4075
+ device=device,
4076
+ batch_size=batch_size,
4077
+ allow_done_after_reset=allow_done_after_reset,
4078
+ spec_locked=spec_locked,
4079
+ )
4080
+ if len(args):
4081
+ raise ValueError(
4082
+ "`_EnvWrapper.__init__` received a non-empty args list of arguments. "
4083
+ "Make sure only keywords arguments are used when calling `super().__init__`."
4084
+ )
4085
+
4086
+ frame_skip = kwargs.pop("frame_skip", 1)
4087
+ if not isinstance(frame_skip, int):
4088
+ raise ValueError(f"frame_skip must be an integer, got {frame_skip}")
4089
+ self.frame_skip = frame_skip
4090
+ # this value can be changed if frame_skip is passed during env construction
4091
+ self.wrapper_frame_skip = frame_skip
4092
+
4093
+ self._constructor_kwargs = kwargs
4094
+ self._check_kwargs(kwargs)
4095
+ self._convert_actions_to_numpy = kwargs.pop("convert_actions_to_numpy", True)
4096
+ self._env = self._build_env(**kwargs) # writes the self._env attribute
4097
+ self._make_specs(self._env) # writes the self._env attribute
4098
+ self.is_closed = False
4099
+ self._init_env() # runs all the steps to have a ready-to-use env
4100
+
4101
+ def _sync_device(self):
4102
+ sync_func = self.__dict__.get("_sync_device_val")
4103
+ if sync_func is None:
4104
+ device = self.device
4105
+ if device.type != "cuda":
4106
+ if torch.cuda.is_available():
4107
+ self._sync_device_val = torch.cuda.synchronize
4108
+ elif torch.backends.mps.is_available():
4109
+ self._sync_device_val = torch.cuda.synchronize
4110
+ elif device.type == "cpu":
4111
+ self._sync_device_val = _do_nothing
4112
+ else:
4113
+ self._sync_device_val = _do_nothing
4114
+ return self._sync_device
4115
+ return sync_func
4116
+
4117
+ @abc.abstractmethod
4118
+ def _check_kwargs(self, kwargs: dict):
4119
+ raise NotImplementedError
4120
+
4121
+ def __getattr__(self, attr: str) -> Any:
4122
+ if attr in self.__dir__():
4123
+ return self.__getattribute__(
4124
+ attr
4125
+ ) # make sure that appropriate exceptions are raised
4126
+
4127
+ elif attr.startswith("__"):
4128
+ raise AttributeError(
4129
+ "passing built-in private methods is "
4130
+ f"not permitted with type {type(self)}. "
4131
+ f"Got attribute {attr}."
4132
+ )
4133
+
4134
+ elif "_env" in self.__dir__():
4135
+ env = self.__getattribute__("_env")
4136
+ return getattr(env, attr)
4137
+ super().__getattr__(attr)
4138
+
4139
+ raise AttributeError(
4140
+ f"The env wasn't set in {self.__class__.__name__}, cannot access {attr}"
4141
+ )
4142
+
4143
+ @abc.abstractmethod
4144
+ def _init_env(self) -> int | None:
4145
+ """Runs all the necessary steps such that the environment is ready to use.
4146
+
4147
+ This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment
4148
+ is reset (if needed). For instance, DMControl envs require the env to be reset before being used, but Gym envs
4149
+ don't.
4150
+
4151
+ Returns:
4152
+ the resulting seed
4153
+
4154
+ """
4155
+ raise NotImplementedError
4156
+
4157
+ @abc.abstractmethod
4158
+ def _build_env(self, **kwargs) -> gym.Env: # noqa: F821
4159
+ """Creates an environment from the target library and stores it with the `_env` attribute.
4160
+
4161
+ When overwritten, this function should pass all the required kwargs to the env instantiation method.
4162
+
4163
+ """
4164
+ raise NotImplementedError
4165
+
4166
+ @abc.abstractmethod
4167
+ def _make_specs(self, env: gym.Env) -> None: # noqa: F821
4168
+ raise NotImplementedError
4169
+
4170
+ def close(self, *, raise_if_closed: bool = True) -> None:
4171
+ """Closes the contained environment if possible."""
4172
+ self.is_closed = True
4173
+ try:
4174
+ self._env.close()
4175
+ except AttributeError:
4176
+ pass
4177
+
4178
+
4179
+ def make_tensordict(
4180
+ env: _EnvWrapper,
4181
+ policy: Callable[[TensorDictBase, ...], TensorDictBase] | None = None,
4182
+ ) -> TensorDictBase:
4183
+ """Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.
4184
+
4185
+ Args:
4186
+ env (_EnvWrapper): environment defining the observation, action and reward space;
4187
+ policy (Callable, optional): policy corresponding to the environment.
4188
+
4189
+ """
4190
+ with torch.no_grad():
4191
+ tensordict = env.reset()
4192
+ if policy is not None:
4193
+ tensordict.update(policy(tensordict))
4194
+ else:
4195
+ tensordict.set("action", env.action_spec.rand(), inplace=False)
4196
+ tensordict = env.step(tensordict)
4197
+ return tensordict.zero_()
4198
+
4199
+
4200
+ def _get_sync_func(policy_device, env_device):
4201
+ if torch.cuda.is_available():
4202
+ # Look for a specific device
4203
+ if policy_device is not None and policy_device.type == "cuda":
4204
+ if env_device is None or env_device.type == "cuda":
4205
+ return torch.cuda.synchronize
4206
+ return partial(torch.cuda.synchronize, device=policy_device)
4207
+ if env_device is not None and env_device.type == "cuda":
4208
+ if policy_device is None:
4209
+ return torch.cuda.synchronize
4210
+ return partial(torch.cuda.synchronize, device=env_device)
4211
+ return torch.cuda.synchronize
4212
+ if torch.backends.mps.is_available():
4213
+ return torch.mps.synchronize
4214
+ return _do_nothing
4215
+
4216
+
4217
+ def _do_nothing():
4218
+ return
4219
+
4220
+
4221
+ def _has_dynamic_specs(spec: Composite):
4222
+ from tensordict.base import _NESTED_TENSORS_AS_LISTS
4223
+
4224
+ return any(
4225
+ any(s == -1 for s in spec.shape)
4226
+ for spec in spec.values(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)
4227
+ )
4228
+
4229
+
4230
+ def _tensor_to_spec(name, leaf, leaf_compare=None, *, stack):
4231
+ if not (isinstance(leaf, torch.Tensor) or is_tensor_collection(leaf)):
4232
+ stack[name] = NonTensor(shape=())
4233
+ return
4234
+ elif is_non_tensor(leaf):
4235
+ stack[name] = NonTensor(shape=leaf.shape)
4236
+ return
4237
+ shape = leaf.shape
4238
+ if leaf_compare is not None:
4239
+ shape_compare = leaf_compare.shape
4240
+ shape = [s0 if s0 == s1 else -1 for s0, s1 in zip(shape, shape_compare)]
4241
+ stack[name] = Unbounded(shape, device=leaf.device, dtype=leaf.dtype)