torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,752 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import abc
9
+ import functools
10
+ import re
11
+ import warnings
12
+ from collections.abc import Callable, Mapping, Sequence
13
+ from typing import Any, TypeVar
14
+
15
+ import numpy as np
16
+ import torch
17
+ from tensordict import NonTensorData, TensorDict, TensorDictBase
18
+
19
+ from torchrl._utils import logger as torchrl_logger
20
+ from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded
21
+ from torchrl.envs.common import _EnvWrapper, _maybe_unlock, EnvBase
22
+
23
+ T = TypeVar("T", bound=EnvBase)
24
+
25
+
26
+ class BaseInfoDictReader(metaclass=abc.ABCMeta):
27
+ """Base class for info-readers."""
28
+
29
+ @abc.abstractmethod
30
+ def __call__(
31
+ self, info_dict: dict[str, Any], tensordict: TensorDictBase
32
+ ) -> TensorDictBase:
33
+ raise NotImplementedError
34
+
35
+ @property
36
+ @abc.abstractmethod
37
+ def info_spec(self) -> dict[str, TensorSpec]:
38
+ raise NotImplementedError
39
+
40
+
41
+ class default_info_dict_reader(BaseInfoDictReader):
42
+ """Default info-key reader.
43
+
44
+ Args:
45
+ keys (list of keys, optional): If provided, the list of keys to get from
46
+ the info dictionary. Defaults to all keys.
47
+ spec (List[TensorSpec], Dict[str, TensorSpec] or Composite, optional):
48
+ If a list of specs is provided, each spec will be matched to its
49
+ correspondent key to form a :class:`torchrl.data.Composite`.
50
+ If not provided, a composite spec with :class:`~torchrl.data.Unbounded`
51
+ specs will lazyly be created.
52
+ ignore_private (bool, optional): If ``True``, private infos (starting with
53
+ an underscore) will be ignored. Defaults to ``True``.
54
+
55
+ In cases where keys can be directly written to a tensordict (mostly if they abide to the
56
+ tensordict shape), one simply needs to indicate the keys to be registered during
57
+ instantiation.
58
+
59
+ Examples:
60
+ >>> from torchrl.envs.libs.gym import GymWrapper
61
+ >>> from torchrl.envs import default_info_dict_reader
62
+ >>> reader = default_info_dict_reader(["my_info_key"])
63
+ >>> # assuming "some_env-v0" returns a dict with a key "my_info_key"
64
+ >>> env = GymWrapper(gym.make("some_env-v0"))
65
+ >>> env.set_info_dict_reader(info_dict_reader=reader)
66
+ >>> tensordict = env.reset()
67
+ >>> tensordict = env.rand_step(tensordict)
68
+ >>> assert "my_info_key" in tensordict.keys()
69
+
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ keys: list[str] | None = None,
75
+ spec: Sequence[TensorSpec] | dict[str, TensorSpec] | Composite | None = None,
76
+ ignore_private: bool = True,
77
+ ):
78
+ self.ignore_private = ignore_private
79
+ self._lazy = False
80
+ if keys is None:
81
+ self._lazy = True
82
+ self.keys = keys
83
+
84
+ if spec is None and keys is None:
85
+ _info_spec = None
86
+ elif spec is None:
87
+ _info_spec = Composite({key: Unbounded(()) for key in keys}, shape=[])
88
+ elif not isinstance(spec, Composite):
89
+ if self.keys is not None and len(spec) != len(self.keys):
90
+ raise ValueError(
91
+ "If specifying specs for info keys with a sequence, the "
92
+ "length of the sequence must match the number of keys"
93
+ )
94
+ if isinstance(spec, dict):
95
+ _info_spec = Composite(spec, shape=[])
96
+ else:
97
+ _info_spec = Composite(
98
+ {key: spec for key, spec in zip(keys, spec)}, shape=[]
99
+ )
100
+ else:
101
+ _info_spec = spec.clone()
102
+ self._info_spec = _info_spec
103
+
104
+ def __call__(
105
+ self, info_dict: dict[str, Any], tensordict: TensorDictBase
106
+ ) -> TensorDictBase:
107
+ if not isinstance(info_dict, (dict, TensorDictBase)) and len(self.keys):
108
+ warnings.warn(
109
+ f"Found an info_dict of type {type(info_dict)} "
110
+ f"but expected type or subtype `dict`."
111
+ )
112
+ keys = self.keys
113
+ if keys is None:
114
+ keys = info_dict.keys()
115
+ if self.ignore_private:
116
+ keys = [key for key in keys if not key.startswith("_")]
117
+ self.keys = keys
118
+ # create an info_spec only if there is none
119
+ info_spec = None if self.info_spec is not None else Composite()
120
+ for key in keys:
121
+ if key in info_dict:
122
+ val = info_dict[key]
123
+ if val.dtype == np.dtype("O"):
124
+ val = np.stack(val)
125
+ tensordict.set(key, val)
126
+ if info_spec is not None:
127
+ val = tensordict.get(key)
128
+ info_spec[key] = Unbounded(
129
+ val.shape, device=val.device, dtype=val.dtype
130
+ )
131
+ elif self.info_spec is not None:
132
+ if key in self.info_spec:
133
+ # Fill missing with 0s
134
+ tensordict.set(key, self.info_spec[key].zero())
135
+ else:
136
+ raise KeyError(f"The key {key} could not be found or inferred.")
137
+ # set the info spec if there wasn't any - this should occur only once in this class
138
+ if info_spec is not None:
139
+ if tensordict.device is not None:
140
+ info_spec = info_spec.to(tensordict.device)
141
+ self._info_spec = info_spec
142
+ return tensordict
143
+
144
+ def reset(self):
145
+ self.keys = None
146
+ self._info_spec = None
147
+
148
+ @property
149
+ def info_spec(self) -> dict[str, TensorSpec]:
150
+ return self._info_spec
151
+
152
+
153
+ class GymLikeEnv(_EnvWrapper):
154
+ """A gym-like env is an environment.
155
+
156
+ Its behavior is similar to gym environments in what common methods (specifically reset and step) are expected to do.
157
+
158
+ A :obj:`GymLikeEnv` has a :obj:`.step()` method with the following signature:
159
+
160
+ ``env.step(action: np.ndarray) -> Tuple[Union[np.ndarray, dict], double, bool, *info]``
161
+
162
+ where the outputs are the observation, reward and done state respectively.
163
+ In this implementation, the info output is discarded (but specific keys can be read
164
+ by updating info_dict_reader, see :meth:`set_info_dict_reader` method).
165
+
166
+ By default, the first output is written at the "observation" key-value pair in the output tensordict, unless
167
+ the first output is a dictionary. In that case, each observation output will be put at the corresponding
168
+ :obj:`f"{key}"` location for each :obj:`f"{key}"` of the dictionary.
169
+
170
+ It is also expected that env.reset() returns an observation similar to the one observed after a step is completed.
171
+ """
172
+
173
+ _info_dict_reader: list[BaseInfoDictReader]
174
+
175
+ @classmethod
176
+ def __new__(cls, *args, **kwargs):
177
+ self = super().__new__(cls, *args, _batch_locked=True, **kwargs)
178
+ self._info_dict_reader = []
179
+
180
+ return self
181
+
182
+ def fast_encoding(self, mode: bool = True) -> T:
183
+ """Skips several checks during encoding of the environment output to accelerate the execution of the environment.
184
+
185
+ Args:
186
+ mode (bool, optional): the memoization mode. If ``True``, input checks will be executed only once and then
187
+ the encoding pipeline will be pre-recorded.
188
+
189
+ .. seealso:: :meth:`~torchrl.data.TensorSpec.memoize_cache`.
190
+
191
+ Example:
192
+ >>> from torchrl.envs import GymEnv
193
+ >>> from torch.utils.benchmark import Timer
194
+ >>>
195
+ >>> env = GymEnv("Pendulum-v1")
196
+ >>> t = Timer("env.rollout(1000, break_when_any_done=False)", globals=globals(), num_threads=32).adaptive_autorange()
197
+ >>> m = t.median
198
+ >>> print(f"Speed without memoizing: {1000/t.median: 4.4f}fps")
199
+ Speed without memoizing: 10141.5742fps
200
+ >>>
201
+ >>> env.fast_encoding()
202
+ >>> t = Timer("env.rollout(1000, break_when_any_done=False)", globals=globals(), num_threads=32).adaptive_autorange()
203
+ >>> m = t.median
204
+ >>> print(f"Speed with memoizing: {1000/t.median: 4.4f}fps")
205
+ Speed with memoizing: 10576.8388fps
206
+
207
+ """
208
+ self.specs.memoize_encode(mode=mode)
209
+ if mode:
210
+ if type(self).read_obs is not GymLikeEnv.read_obs:
211
+ raise RuntimeError(
212
+ "Cannot use fast_encoding as the read_obs method has been overwritten."
213
+ )
214
+ if type(self).read_reward is not GymLikeEnv.read_reward:
215
+ raise RuntimeError(
216
+ "Cannot use fast_encoding as the read_reward method has been overwritten."
217
+ )
218
+
219
+ if mode:
220
+ self.read_reward = self._read_reward_memo
221
+ self.read_obs = self._read_obs_memo
222
+ else:
223
+ self.read_reward = self._read_reward_eager
224
+ self.read_obs = self._read_obs_eager
225
+
226
+ def read_action(self, action):
227
+ """Reads the action obtained from the input TensorDict and transforms it in the format expected by the contained environment.
228
+
229
+ Args:
230
+ action (Tensor or TensorDict): an action to be taken in the environment
231
+
232
+ Returns: an action in a format compatible with the contained environment.
233
+
234
+ """
235
+ action_spec = self.full_action_spec
236
+ action_keys = self.action_keys
237
+ if len(action_keys) == 1:
238
+ action_spec = action_spec[action_keys[0]]
239
+ return action_spec.to_numpy(action, safe=False)
240
+
241
+ def read_done(
242
+ self,
243
+ terminated: bool | None = None,
244
+ truncated: bool | None = None,
245
+ done: bool | None = None,
246
+ ) -> tuple[bool | np.ndarray, bool | np.ndarray, bool | np.ndarray, bool]:
247
+ """Done state reader.
248
+
249
+ In torchrl, a `"done"` signal means that a trajectory has reach its end,
250
+ either because it has been interrupted or because it is terminated.
251
+ Truncated means the episode has been interrupted early.
252
+ Terminated means the task is finished, the episode is completed.
253
+
254
+ Args:
255
+ terminated (np.ndarray, boolean or other format): completion state
256
+ obtained from the environment.
257
+ ``"terminated"`` equates to ``"termination"`` in gymnasium:
258
+ the signal that the environment has reached the end of the
259
+ episode, any data coming after this should be considered as nonsensical.
260
+ Defaults to ``None``.
261
+ truncated (bool or None): early truncation signal.
262
+ Defaults to ``None``.
263
+ done (bool or None): end-of-trajectory signal.
264
+ This should be the fallback value of envs which do not specify
265
+ if the ``"done"`` entry points to a ``"terminated"`` or
266
+ ``"truncated"``.
267
+ Defaults to ``None``.
268
+
269
+ Returns: a tuple with 4 boolean / tensor values,
270
+
271
+ - a terminated state,
272
+ - a truncated state,
273
+ - a done state,
274
+ - a boolean value indicating whether the frame_skip loop should be broken.
275
+
276
+ """
277
+ if truncated is not None and done is None:
278
+ done = truncated | terminated
279
+ elif truncated is None and done is None:
280
+ done = terminated
281
+ do_break = done.any() if not isinstance(done, bool) else done
282
+ if isinstance(done, bool):
283
+ done = [done]
284
+ if terminated is not None:
285
+ terminated = [terminated]
286
+ if truncated is not None:
287
+ truncated = [truncated]
288
+ return (
289
+ torch.as_tensor(terminated),
290
+ torch.as_tensor(truncated),
291
+ torch.as_tensor(done),
292
+ do_break.any() if not isinstance(do_break, bool) else do_break,
293
+ )
294
+
295
+ _read_reward: Callable[[Any], Any] | None = None
296
+
297
+ def read_reward(self, reward):
298
+ """Reads the reward and maps it to the reward space.
299
+
300
+ Args:
301
+ reward (torch.Tensor or TensorDict): reward to be mapped.
302
+
303
+ """
304
+ return self._read_reward_eager(reward)
305
+
306
+ def _read_reward_eager(self, reward):
307
+ if isinstance(reward, int) and reward == 0:
308
+ return self.reward_spec.zero()
309
+ reward = self.reward_spec.encode(reward, ignore_device=True)
310
+
311
+ if reward is None:
312
+ reward = torch.tensor(np.nan).expand(self.reward_spec.shape)
313
+
314
+ return reward
315
+
316
+ def _read_reward_memo(self, reward):
317
+ func = self._read_reward
318
+ if func is not None:
319
+ return func(reward)
320
+ funcs = []
321
+ if isinstance(reward, int) and reward == 0:
322
+
323
+ def process_zero(reward):
324
+ return self.reward_spec.zero()
325
+
326
+ funcs.append(process_zero)
327
+ else:
328
+
329
+ def encode_reward(reward):
330
+ return self.reward_spec.encode(reward, ignore_device=True)
331
+
332
+ funcs.append(encode_reward)
333
+
334
+ if reward is None:
335
+
336
+ def check_none(reward):
337
+ return torch.tensor(np.nan).expand(self.reward_spec.shape)
338
+
339
+ funcs.append(check_none)
340
+
341
+ if len(funcs) == 1:
342
+ self._read_reward = funcs[0]
343
+ else:
344
+ self._read_reward = functools.partial(
345
+ functools.reduce, lambda x, f: f(x), funcs
346
+ )
347
+ return self._read_reward(reward)
348
+
349
+ def read_obs(
350
+ self, observations: dict[str, Any] | torch.Tensor | np.ndarray
351
+ ) -> dict[str, Any]:
352
+ """Reads an observation from the environment and returns an observation compatible with the output TensorDict.
353
+
354
+ Args:
355
+ observations (observation under a format dictated by the inner env): observation to be read.
356
+
357
+ """
358
+ return self._read_obs_eager(observations)
359
+
360
+ def _read_obs_eager(
361
+ self, observations: dict[str, Any] | torch.Tensor | np.ndarray
362
+ ) -> dict[str, Any]:
363
+ if isinstance(observations, dict):
364
+ if "state" in observations and "observation" not in observations:
365
+ # we rename "state" in "observation" as "observation" is the conventional name
366
+ # for single observation in torchrl.
367
+ # naming it 'state' will result in envs that have a different name for the state vector
368
+ # when queried with and without pixels
369
+ observations["observation"] = observations.pop("state")
370
+ if not isinstance(observations, Mapping):
371
+ for key, spec in self.observation_spec.items(True, True):
372
+ observations_dict = {}
373
+ observations_dict[key] = spec.encode(observations, ignore_device=True)
374
+ # we don't check that there is only one spec because obs spec also
375
+ # contains the data spec of the info dict.
376
+ break
377
+ else:
378
+ raise RuntimeError("Could not find any element in observation_spec.")
379
+ observations = observations_dict
380
+ else:
381
+ for key, val in observations.items():
382
+ if isinstance(self.observation_spec[key], NonTensor):
383
+ observations[key] = NonTensorData(val)
384
+ else:
385
+ observations[key] = self.observation_spec[key].encode(
386
+ val, ignore_device=True
387
+ )
388
+ return observations
389
+
390
+ _read_obs: Callable[[Any], Any] | None = None
391
+
392
+ def _read_obs_memo(
393
+ self, observations: dict[str, Any] | torch.Tensor | np.ndarray
394
+ ) -> dict[str, Any]:
395
+ func = self._read_obs
396
+ if func is not None:
397
+ return func(observations)
398
+ funcs = []
399
+ if isinstance(observations, (dict, Mapping)):
400
+ if "state" in observations and "observation" not in observations:
401
+
402
+ def process_dict_pop(observations):
403
+ observations["observation"] = observations.pop("state")
404
+ return observations
405
+
406
+ funcs.append(process_dict_pop)
407
+ for key in observations.keys():
408
+ if isinstance(self.observation_spec[key], NonTensor):
409
+
410
+ def process_dict(observations, key=key):
411
+ observations[key] = NonTensorData(observations[key])
412
+ return observations
413
+
414
+ else:
415
+
416
+ def process_dict(observations, key=key):
417
+ observations[key] = self.observation_spec[key].encode(
418
+ observations[key], ignore_device=True
419
+ )
420
+ return observations
421
+
422
+ funcs.append(process_dict)
423
+ else:
424
+ key = next(iter(self.observation_spec.keys(True, True)), None)
425
+ if key is None:
426
+ raise RuntimeError("Could not find any element in observation_spec.")
427
+ spec = self.observation_spec[key]
428
+
429
+ def process_non_dict(observations, spec=spec):
430
+ return {key: spec.encode(observations, ignore_device=True)}
431
+
432
+ funcs.append(process_non_dict)
433
+ if len(funcs) == 1:
434
+ self._read_obs = funcs[0]
435
+ else:
436
+ self._read_obs = functools.partial(
437
+ functools.reduce, lambda x, f: f(x), funcs
438
+ )
439
+ return self._read_obs(observations)
440
+
441
+ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
442
+ if len(self.action_keys) == 1:
443
+ # Use brackets to get non-tensor data
444
+ action = tensordict[self.action_key]
445
+ else:
446
+ action = tensordict.select(*self.action_keys).to_dict()
447
+ if self._convert_actions_to_numpy:
448
+ action = self.read_action(action)
449
+
450
+ reward = 0
451
+ for _ in range(self.wrapper_frame_skip):
452
+ step_result = self._env.step(action)
453
+ (
454
+ obs,
455
+ _reward,
456
+ terminated,
457
+ truncated,
458
+ done,
459
+ info_dict,
460
+ ) = self._output_transform(step_result)
461
+
462
+ if _reward is not None:
463
+ reward = reward + _reward
464
+ terminated, truncated, done, do_break = self.read_done(
465
+ terminated=terminated, truncated=truncated, done=done
466
+ )
467
+ if do_break:
468
+ break
469
+
470
+ reward = self.read_reward(reward)
471
+ obs_dict = self.read_obs(obs)
472
+ obs_dict[self.reward_key] = reward
473
+
474
+ # if truncated/terminated is not in the keys, we just don't pass it even if it
475
+ # is defined.
476
+ if terminated is None:
477
+ terminated = done.clone()
478
+ if truncated is not None:
479
+ obs_dict["truncated"] = truncated
480
+ obs_dict["done"] = done
481
+ obs_dict["terminated"] = terminated
482
+ validated = self.validated
483
+ if not validated:
484
+ tensordict_out = TensorDict(obs_dict, batch_size=tensordict.batch_size)
485
+ if validated is None:
486
+ # check if any value has to be recast to something else. If not, we can safely
487
+ # build the tensordict without running checks
488
+ self.validated = all(
489
+ val is tensordict_out.get(key)
490
+ for key, val in TensorDict(obs_dict, []).items(True, True)
491
+ )
492
+ else:
493
+ tensordict_out = TensorDict._new_unsafe(
494
+ obs_dict,
495
+ batch_size=tensordict.batch_size,
496
+ )
497
+ if self.device is not None:
498
+ tensordict_out = tensordict_out.to(self.device)
499
+
500
+ if self.info_dict_reader and info_dict is not None:
501
+ if not isinstance(info_dict, dict):
502
+ warnings.warn(
503
+ f"Expected info to be a dictionary but got a {type(info_dict)} with values {str(info_dict)[:100]}."
504
+ )
505
+ else:
506
+ for info_dict_reader in self.info_dict_reader:
507
+ out = info_dict_reader(info_dict, tensordict_out)
508
+ if out is not None:
509
+ tensordict_out = out
510
+ return tensordict_out
511
+
512
+ @property
513
+ def validated(self):
514
+ return self.__dict__.get("_validated", None)
515
+
516
+ @validated.setter
517
+ def validated(self, value):
518
+ self.__dict__["_validated"] = value
519
+
520
+ def _reset(
521
+ self, tensordict: TensorDictBase | None = None, **kwargs
522
+ ) -> TensorDictBase:
523
+ if (
524
+ tensordict is not None
525
+ and "_reset" in tensordict
526
+ and not tensordict["_reset"].all()
527
+ ):
528
+ raise RuntimeError("Partial resets are not handled at this level.")
529
+ obs, info = self._reset_output_transform(self._env.reset(**kwargs))
530
+
531
+ source = self.read_obs(obs)
532
+
533
+ # _new_unsafe cannot be used because it won't wrap non-tensor correctly
534
+ tensordict_out = TensorDict(
535
+ source=source,
536
+ batch_size=self.batch_size,
537
+ )
538
+ if self.info_dict_reader and info is not None:
539
+ for info_dict_reader in self.info_dict_reader:
540
+ out = info_dict_reader(info, tensordict_out)
541
+ if out is not None:
542
+ tensordict_out = out
543
+ elif info is None and self.info_dict_reader:
544
+ # populate the reset with the items we have not seen from info
545
+ for key, item in self.observation_spec.items(True, True):
546
+ if key not in tensordict_out.keys(True, True):
547
+ tensordict_out[key] = item.zero()
548
+ if self.device is not None:
549
+ tensordict_out = tensordict_out.to(self.device)
550
+ return tensordict_out
551
+
552
+ @abc.abstractmethod
553
+ def _output_transform(
554
+ self, step_outputs_tuple: tuple
555
+ ) -> tuple[
556
+ Any,
557
+ float | np.ndarray,
558
+ bool | np.ndarray | None,
559
+ bool | np.ndarray | None,
560
+ bool | np.ndarray | None,
561
+ dict,
562
+ ]:
563
+ """A method to read the output of the env step.
564
+
565
+ Must return a tuple: (obs, reward, terminated, truncated, done, info).
566
+ If only one end-of-trajectory is passed, it is interpreted as ``"truncated"``.
567
+ An attempt to retrieve ``"truncated"`` from the info dict is also undertaken.
568
+ If 2 are passed (like in gymnasium), we interpret them as ``"terminated",
569
+ "truncated"`` (``"truncated"`` meaning that the trajectory has been
570
+ interrupted early), and ``"done"`` is the union of the two,
571
+ ie. the unspecified end-of-trajectory signal.
572
+
573
+ These three concepts have different usage:
574
+
575
+ - ``"terminated"`` indicated the final stage of a Markov Decision
576
+ Process. It means that one should not pay attention to the
577
+ upcoming observations (eg., in value functions) as they should be
578
+ regarded as not valid.
579
+ - ``"truncated"`` means that the environment has reached a stage where
580
+ we decided to stop the collection for some reason but the next
581
+ observation should not be discarded. If it were not for this
582
+ arbitrary decision, the collection could have proceeded further.
583
+ - ``"done"`` is either one or the other. It is to be interpreted as
584
+ "a reset should be called before the next step is undertaken".
585
+
586
+ """
587
+ ...
588
+
589
+ @abc.abstractmethod
590
+ def _reset_output_transform(self, reset_outputs_tuple: tuple) -> tuple:
591
+ ...
592
+
593
+ @_maybe_unlock
594
+ def set_info_dict_reader(
595
+ self,
596
+ info_dict_reader: BaseInfoDictReader | None = None,
597
+ ignore_private: bool = True,
598
+ ) -> GymLikeEnv:
599
+ """Sets an info_dict_reader function.
600
+
601
+ This function should take as input an
602
+ info_dict dictionary and the tensordict returned by the step function, and
603
+ write values in an ad-hoc manner from one to the other.
604
+
605
+ Args:
606
+ info_dict_reader (Callable[[Dict], TensorDict], optional): a callable
607
+ taking a input dictionary and output tensordict as arguments.
608
+ This function should modify the tensordict in-place. If none is
609
+ provided, :class:`~torchrl.envs.gym_like.default_info_dict_reader`
610
+ will be used.
611
+ ignore_private (bool, optional): If ``True``, private infos (starting with
612
+ an underscore) will be ignored. Defaults to ``True``.
613
+
614
+ Returns: the same environment with the dict_reader registered.
615
+
616
+ .. note::
617
+ Automatically registering an info_dict reader should be done via
618
+ :meth:`auto_register_info_dict`, which will ensure that the env
619
+ specs are properly constructed.
620
+
621
+ Examples:
622
+ >>> from torchrl.envs import default_info_dict_reader
623
+ >>> from torchrl.envs.libs.gym import GymWrapper
624
+ >>> reader = default_info_dict_reader(["my_info_key"])
625
+ >>> # assuming "some_env-v0" returns a dict with a key "my_info_key"
626
+ >>> env = GymWrapper(gym.make("some_env-v0")).set_info_dict_reader(info_dict_reader=reader)
627
+ >>> tensordict = env.reset()
628
+ >>> tensordict = env.rand_step(tensordict)
629
+ >>> assert "my_info_key" in tensordict.keys()
630
+
631
+ """
632
+ if info_dict_reader is None:
633
+ info_dict_reader = default_info_dict_reader(ignore_private=ignore_private)
634
+ self.info_dict_reader.append(info_dict_reader)
635
+ if isinstance(info_dict_reader, BaseInfoDictReader):
636
+ # if we have a BaseInfoDictReader, we know what the specs will be
637
+ # In other cases (eg, RoboHive) we will need to figure it out empirically.
638
+ if (
639
+ isinstance(info_dict_reader, default_info_dict_reader)
640
+ and info_dict_reader.info_spec is None
641
+ ):
642
+ torchrl_logger.info(
643
+ "The info_dict_reader does not have specs. The only way to palliate to this issue automatically "
644
+ "is to run a dummy rollout and gather the specs automatically. "
645
+ "To silence this message, provide the specs directly to your spec reader."
646
+ )
647
+ # Gym does not guarantee that reset passes all info
648
+ self.reset()
649
+ info_dict_reader.reset()
650
+ self.rand_step()
651
+ self.reset()
652
+
653
+ self.observation_spec.update(info_dict_reader.info_spec)
654
+
655
+ return self
656
+
657
+ def auto_register_info_dict(
658
+ self,
659
+ ignore_private: bool = True,
660
+ *,
661
+ info_dict_reader: BaseInfoDictReader = None,
662
+ ) -> EnvBase:
663
+ """Automatically registers the info dict and appends :class:`~torch.envs.transforms.TensorDictPrimer` instances if needed.
664
+
665
+ If no info_dict_reader is provided, it is assumed that all the information contained in the info dict can
666
+ be registered as numerical values within the tensordict.
667
+
668
+ This method returns a (possibly transformed) environment where we make sure that
669
+ the :func:`torchrl.envs.utils.check_env_specs` succeeds, whether
670
+ the info is filled at reset time.
671
+
672
+ .. note:: This method requires running a few iterations in the environment to
673
+ manually check that the behavior matches expectations.
674
+
675
+ Args:
676
+ ignore_private (bool, optional): If ``True``, private infos (starting with
677
+ an underscore) will be ignored. Defaults to ``True``.
678
+
679
+ Keyword Args:
680
+ info_dict_reader (BaseInfoDictReader, optional): the info_dict_reader, if it is known in advance.
681
+ Unlike :meth:`set_info_dict_reader`, this method will create the primers necessary to get
682
+ :func:`~torchrl.envs.utils.check_env_specs` to run.
683
+
684
+ Examples:
685
+ >>> from torchrl.envs import GymEnv
686
+ >>> env = GymEnv("HalfCheetah-v4")
687
+ >>> # registers the info dict reader
688
+ >>> env.auto_register_info_dict()
689
+ GymEnv(env=HalfCheetah-v4, batch_size=torch.Size([]), device=cpu)
690
+ >>> env.rollout(3)
691
+ TensorDict(
692
+ fields={
693
+ action: Tensor(shape=torch.Size([3, 6]), device=cpu, dtype=torch.float32, is_shared=False),
694
+ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
695
+ next: TensorDict(
696
+ fields={
697
+ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
698
+ observation: Tensor(shape=torch.Size([3, 17]), device=cpu, dtype=torch.float64, is_shared=False),
699
+ reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
700
+ reward_ctrl: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
701
+ reward_run: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
702
+ terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
703
+ truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
704
+ x_position: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
705
+ x_velocity: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False)},
706
+ batch_size=torch.Size([3]),
707
+ device=cpu,
708
+ is_shared=False),
709
+ observation: Tensor(shape=torch.Size([3, 17]), device=cpu, dtype=torch.float64, is_shared=False),
710
+ reward_ctrl: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
711
+ reward_run: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
712
+ terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
713
+ truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
714
+ x_position: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False),
715
+ x_velocity: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float64, is_shared=False)},
716
+ batch_size=torch.Size([3]),
717
+ device=cpu,
718
+ is_shared=False)
719
+
720
+ """
721
+ from torchrl.envs import check_env_specs, TensorDictPrimer, TransformedEnv
722
+
723
+ if self.info_dict_reader:
724
+ raise RuntimeError("The environment already has an info-dict reader.")
725
+ self.set_info_dict_reader(
726
+ ignore_private=ignore_private, info_dict_reader=info_dict_reader
727
+ )
728
+ try:
729
+ check_env_specs(self)
730
+ return self
731
+ except (AssertionError, RuntimeError) as err:
732
+ patterns = [
733
+ "The keys of the specs and data do not match",
734
+ "The sets of keys in the tensordicts to stack are exclusive",
735
+ ]
736
+ for pattern in patterns:
737
+ if re.search(pattern, str(err)):
738
+ result = TransformedEnv(
739
+ self, TensorDictPrimer(self.info_dict_reader[0].info_spec)
740
+ )
741
+ check_env_specs(result)
742
+ return result
743
+ raise err
744
+
745
+ def __repr__(self) -> str:
746
+ return (
747
+ f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})"
748
+ )
749
+
750
+ @property
751
+ def info_dict_reader(self):
752
+ return self._info_dict_reader