torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,214 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Callable
8
+
9
+ import torch
10
+
11
+ from tensordict import NestedKey, set_list_to_stack, TensorDict, TensorDictBase
12
+ from tensordict.tensorclass import NonTensorData, NonTensorStack
13
+
14
+ from torchrl.data.map.hash import SipHash
15
+ from torchrl.data.tensor_specs import (
16
+ Categorical as CategoricalSpec,
17
+ Composite,
18
+ NonTensor,
19
+ Unbounded,
20
+ )
21
+ from torchrl.envs import EnvBase
22
+ from torchrl.envs.utils import _StepMDP
23
+
24
+
25
+ class LLMHashingEnv(EnvBase):
26
+ """A text generation environment that uses a hashing module to identify unique observations.
27
+
28
+ The primary goal of this environment is to identify token chains using a hashing function.
29
+ This allows the data to be stored in a :class:`~torchrl.data.MCTSForest` using nothing but hashes as node
30
+ identifiers, or easily prune repeated token chains in a data structure.
31
+
32
+ .. The following figure gives an overview of this workflow:
33
+ .. .. figure:: /_static/img/rollout-llm.png
34
+ .. :alt: Data collection loop with our LLM environment.
35
+
36
+ Args:
37
+ vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed.
38
+
39
+ Keyword Args:
40
+ hashing_module (Callable[[torch.Tensor], torch.Tensor], optional):
41
+ A hashing function that takes a tensor as input and returns a hashed tensor.
42
+ Defaults to :class:`~torchrl.data.SipHash` if not provided.
43
+ observation_key (NestedKey, optional): The key for the observation in the TensorDict.
44
+ Defaults to "observation".
45
+ text_output (bool, optional): Whether to include the text output in the observation.
46
+ Defaults to `True`.
47
+ tokenizer (transformers.Tokenizer | None, optional):
48
+ A tokenizer function that converts text to tensors.
49
+ Only used when `text_output` is `True`.
50
+ Must implement the following methods: `decode` and `batch_decode`.
51
+ Defaults to ``None``.
52
+ text_key (NestedKey | None, optional): The key for the text output in the TensorDict.
53
+ Defaults to "text".
54
+
55
+ Examples:
56
+ >>> from tensordict import TensorDict
57
+ >>> from torchrl.envs import LLMHashingEnv
58
+ >>> from transformers import GPT2Tokenizer
59
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
60
+ >>> x = tokenizer(["Check out TorchRL!"])["input_ids"]
61
+ >>> env = LLMHashingEnv(tokenizer=tokenizer)
62
+ >>> td = TensorDict(observation=x, batch_size=[1])
63
+ >>> td = env.reset(td)
64
+ >>> print(td)
65
+ TensorDict(
66
+ fields={
67
+ done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
68
+ hash: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
69
+ observation: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.int64, is_shared=False),
70
+ terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
71
+ text: NonTensorStack(
72
+ ['Check out TorchRL!'],
73
+ batch_size=torch.Size([1]),
74
+ device=None)},
75
+ batch_size=torch.Size([1]),
76
+ device=None,
77
+ is_shared=False)
78
+
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ vocab_size: int | None = None,
84
+ *,
85
+ hashing_module: Callable[[torch.Tensor], torch.Tensor] = None,
86
+ observation_key: NestedKey = "observation",
87
+ text_output: bool = True,
88
+ tokenizer: Callable[[str | list[str]], torch.Tensor] | None = None,
89
+ text_key: NestedKey | None = "text",
90
+ ):
91
+ super().__init__()
92
+ if vocab_size is None:
93
+ if tokenizer is None:
94
+ raise TypeError(
95
+ "You must provide a vocab_size integer if tokenizer is `None`."
96
+ )
97
+ vocab_size = tokenizer.vocab_size
98
+ self._batch_locked = False
99
+ if hashing_module is None:
100
+ hashing_module = SipHash()
101
+
102
+ self._hashing_module = hashing_module
103
+ self._tokenizer = tokenizer
104
+ self.observation_key = observation_key
105
+ observation_spec = {
106
+ observation_key: CategoricalSpec(n=vocab_size, shape=(-1,)),
107
+ "hashing": Unbounded(shape=(1,), dtype=torch.int64),
108
+ }
109
+ self.text_output = text_output
110
+ if not text_output:
111
+ text_key = None
112
+ elif text_key is None:
113
+ text_key = "text"
114
+ if text_key is not None:
115
+ observation_spec[text_key] = NonTensor(shape=())
116
+ self.text_key = text_key
117
+ self.observation_spec = Composite(observation_spec)
118
+ self.action_spec = Composite(action=CategoricalSpec(vocab_size, shape=(1,)))
119
+ _StepMDP(self)
120
+
121
+ @set_list_to_stack(True)
122
+ def make_tensordict(self, input: str | list[str]) -> TensorDict:
123
+ """Converts a string or list of strings in a TensorDict with appropriate shape and device."""
124
+ list_len = len(input) if isinstance(input, list) else 0
125
+ tensordict = TensorDict(
126
+ {self.observation_key: self._tokenizer(input)}, device=self.device
127
+ )
128
+ if list_len:
129
+ tensordict.batch_size = [list_len]
130
+ return self.reset(tensordict)
131
+
132
+ def _reset(self, tensordict: TensorDictBase):
133
+ """Initializes the environment with a given observation.
134
+
135
+ Args:
136
+ tensordict (TensorDictBase): A TensorDict containing the initial observation.
137
+
138
+ Returns:
139
+ A TensorDict containing the initial observation, its hash, and other relevant information.
140
+
141
+ """
142
+ out = tensordict.empty()
143
+ obs = tensordict.get(self.observation_key, None)
144
+ if obs is None:
145
+ raise RuntimeError(
146
+ f"Resetting the {type(self).__name__} environment requires a prompt."
147
+ )
148
+ if self.text_output:
149
+ if obs.ndim > 1:
150
+ text = self._tokenizer.batch_decode(obs)
151
+ text = NonTensorStack.from_list(text)
152
+ else:
153
+ text = self._tokenizer.decode(obs)
154
+ text = NonTensorData(text)
155
+ out.set(self.text_key, text)
156
+
157
+ if obs.ndim > 1:
158
+ out.set("hashing", self._hashing_module(obs).unsqueeze(-1))
159
+ else:
160
+ out.set("hashing", self._hashing_module(obs.unsqueeze(0)).transpose(0, -1))
161
+
162
+ if not self.full_done_spec.is_empty():
163
+ out.update(self.full_done_spec.zero(tensordict.shape))
164
+ else:
165
+ out.set("done", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool))
166
+ out.set(
167
+ "terminated", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool)
168
+ )
169
+ return out
170
+
171
+ def _step(self, tensordict):
172
+ """Takes an action (i.e., the next token to generate) and returns the next observation and reward.
173
+
174
+ Args:
175
+ tensordict: A TensorDict containing the current observation and action.
176
+
177
+ Returns:
178
+ A TensorDict containing the next observation, its hash, and other relevant information.
179
+ """
180
+ out = tensordict.empty()
181
+ action = tensordict.get("action")
182
+ obs = torch.cat([tensordict.get(self.observation_key), action], -1)
183
+ kwargs = {self.observation_key: obs}
184
+
185
+ catval = torch.cat([tensordict.get("hashing"), action], -1)
186
+ if obs.ndim > 1:
187
+ new_hash = self._hashing_module(catval).unsqueeze(-1)
188
+ else:
189
+ new_hash = self._hashing_module(catval.unsqueeze(0)).transpose(0, -1)
190
+
191
+ if self.text_output:
192
+ if obs.ndim > 1:
193
+ text = self._tokenizer.batch_decode(obs)
194
+ text = NonTensorStack.from_list(text)
195
+ else:
196
+ text = self._tokenizer.decode(obs)
197
+ text = NonTensorData(text)
198
+ kwargs[self.text_key] = text
199
+ kwargs.update(
200
+ {
201
+ "hashing": new_hash,
202
+ "done": torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool),
203
+ "terminated": torch.zeros(
204
+ (*tensordict.batch_size, 1), dtype=torch.bool
205
+ ),
206
+ }
207
+ )
208
+ return out.update(kwargs)
209
+
210
+ def _set_seed(self, *args) -> None:
211
+ """Sets the seed for the environment's randomness.
212
+
213
+ .. note:: This environment has no randomness, so this method does nothing.
214
+ """
@@ -0,0 +1,401 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import numpy as np
8
+
9
+ import torch
10
+ from tensordict import TensorDict, TensorDictBase
11
+ from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
12
+ from torchrl.envs.common import EnvBase
13
+ from torchrl.envs.utils import make_composite_from_td
14
+
15
+
16
+ class PendulumEnv(EnvBase):
17
+ """A stateless Pendulum environment.
18
+
19
+ See the Pendulum tutorial for more details: :ref:`tutorial <pendulum_tuto>`.
20
+
21
+ Specs:
22
+ >>> env = PendulumEnv()
23
+ >>> env.specs
24
+ Composite(
25
+ output_spec: Composite(
26
+ full_observation_spec: Composite(
27
+ th: BoundedContinuous(
28
+ shape=torch.Size([]),
29
+ space=ContinuousBox(
30
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
31
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
32
+ device=cpu,
33
+ dtype=torch.float32,
34
+ domain=continuous),
35
+ thdot: BoundedContinuous(
36
+ shape=torch.Size([]),
37
+ space=ContinuousBox(
38
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
39
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
40
+ device=cpu,
41
+ dtype=torch.float32,
42
+ domain=continuous),
43
+ params: Composite(
44
+ max_speed: UnboundedDiscrete(
45
+ shape=torch.Size([]),
46
+ space=ContinuousBox(
47
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True),
48
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True)),
49
+ device=cpu,
50
+ dtype=torch.int64,
51
+ domain=discrete),
52
+ max_torque: UnboundedContinuous(
53
+ shape=torch.Size([]),
54
+ space=ContinuousBox(
55
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
56
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
57
+ device=cpu,
58
+ dtype=torch.float32,
59
+ domain=continuous),
60
+ dt: UnboundedContinuous(
61
+ shape=torch.Size([]),
62
+ space=ContinuousBox(
63
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
64
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
65
+ device=cpu,
66
+ dtype=torch.float32,
67
+ domain=continuous),
68
+ g: UnboundedContinuous(
69
+ shape=torch.Size([]),
70
+ space=ContinuousBox(
71
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
72
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
73
+ device=cpu,
74
+ dtype=torch.float32,
75
+ domain=continuous),
76
+ m: UnboundedContinuous(
77
+ shape=torch.Size([]),
78
+ space=ContinuousBox(
79
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
80
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
81
+ device=cpu,
82
+ dtype=torch.float32,
83
+ domain=continuous),
84
+ l: UnboundedContinuous(
85
+ shape=torch.Size([]),
86
+ space=ContinuousBox(
87
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
88
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
89
+ device=cpu,
90
+ dtype=torch.float32,
91
+ domain=continuous),
92
+ device=None,
93
+ shape=torch.Size([])),
94
+ device=None,
95
+ shape=torch.Size([])),
96
+ full_reward_spec: Composite(
97
+ reward: UnboundedContinuous(
98
+ shape=torch.Size([1]),
99
+ space=ContinuousBox(
100
+ low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
101
+ high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
102
+ device=cpu,
103
+ dtype=torch.float32,
104
+ domain=continuous),
105
+ device=None,
106
+ shape=torch.Size([])),
107
+ full_done_spec: Composite(
108
+ done: Categorical(
109
+ shape=torch.Size([1]),
110
+ space=CategoricalBox(n=2),
111
+ device=cpu,
112
+ dtype=torch.bool,
113
+ domain=discrete),
114
+ terminated: Categorical(
115
+ shape=torch.Size([1]),
116
+ space=CategoricalBox(n=2),
117
+ device=cpu,
118
+ dtype=torch.bool,
119
+ domain=discrete),
120
+ device=None,
121
+ shape=torch.Size([])),
122
+ device=None,
123
+ shape=torch.Size([])),
124
+ input_spec: Composite(
125
+ full_state_spec: Composite(
126
+ th: BoundedContinuous(
127
+ shape=torch.Size([]),
128
+ space=ContinuousBox(
129
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
130
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
131
+ device=cpu,
132
+ dtype=torch.float32,
133
+ domain=continuous),
134
+ thdot: BoundedContinuous(
135
+ shape=torch.Size([]),
136
+ space=ContinuousBox(
137
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
138
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
139
+ device=cpu,
140
+ dtype=torch.float32,
141
+ domain=continuous),
142
+ params: Composite(
143
+ max_speed: UnboundedDiscrete(
144
+ shape=torch.Size([]),
145
+ space=ContinuousBox(
146
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True),
147
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True)),
148
+ device=cpu,
149
+ dtype=torch.int64,
150
+ domain=discrete),
151
+ max_torque: UnboundedContinuous(
152
+ shape=torch.Size([]),
153
+ space=ContinuousBox(
154
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
155
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
156
+ device=cpu,
157
+ dtype=torch.float32,
158
+ domain=continuous),
159
+ dt: UnboundedContinuous(
160
+ shape=torch.Size([]),
161
+ space=ContinuousBox(
162
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
163
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
164
+ device=cpu,
165
+ dtype=torch.float32,
166
+ domain=continuous),
167
+ g: UnboundedContinuous(
168
+ shape=torch.Size([]),
169
+ space=ContinuousBox(
170
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
171
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
172
+ device=cpu,
173
+ dtype=torch.float32,
174
+ domain=continuous),
175
+ m: UnboundedContinuous(
176
+ shape=torch.Size([]),
177
+ space=ContinuousBox(
178
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
179
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
180
+ device=cpu,
181
+ dtype=torch.float32,
182
+ domain=continuous),
183
+ l: UnboundedContinuous(
184
+ shape=torch.Size([]),
185
+ space=ContinuousBox(
186
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
187
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
188
+ device=cpu,
189
+ dtype=torch.float32,
190
+ domain=continuous),
191
+ device=None,
192
+ shape=torch.Size([])),
193
+ device=None,
194
+ shape=torch.Size([])),
195
+ full_action_spec: Composite(
196
+ action: BoundedContinuous(
197
+ shape=torch.Size([1]),
198
+ space=ContinuousBox(
199
+ low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
200
+ high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
201
+ device=cpu,
202
+ dtype=torch.float32,
203
+ domain=continuous),
204
+ device=None,
205
+ shape=torch.Size([])),
206
+ device=None,
207
+ shape=torch.Size([])),
208
+ device=None,
209
+ shape=torch.Size([]))
210
+
211
+ """
212
+
213
+ DEFAULT_X = np.pi
214
+ DEFAULT_Y = 1.0
215
+
216
+ metadata = {
217
+ "render_modes": ["human", "rgb_array"],
218
+ "render_fps": 30,
219
+ }
220
+ batch_locked = False
221
+ rng = None
222
+
223
+ def __init__(self, td_params=None, seed=None, device=None):
224
+ if td_params is None:
225
+ td_params = self.gen_params(device=self.device)
226
+
227
+ super().__init__(device=device)
228
+ self._make_spec(td_params)
229
+ if seed is None:
230
+ seed = torch.empty((), dtype=torch.int64).random_(generator=self.rng).item()
231
+ self.set_seed(seed)
232
+
233
+ @classmethod
234
+ def _step(cls, tensordict):
235
+ th, thdot = tensordict["th"], tensordict["thdot"] # th := theta
236
+
237
+ g_force = tensordict["params", "g"]
238
+ mass = tensordict["params", "m"]
239
+ length = tensordict["params", "l"]
240
+ dt = tensordict["params", "dt"]
241
+ u = tensordict["action"].squeeze(-1)
242
+ u = u.clamp(
243
+ -tensordict["params", "max_torque"], tensordict["params", "max_torque"]
244
+ )
245
+ costs = cls.angle_normalize(th) ** 2 + 0.1 * thdot**2 + 0.001 * (u**2)
246
+
247
+ new_thdot = (
248
+ thdot
249
+ + (3 * g_force / (2 * length) * th.sin() + 3.0 / (mass * length**2) * u)
250
+ * dt
251
+ )
252
+ new_thdot = new_thdot.clamp(
253
+ -tensordict["params", "max_speed"], tensordict["params", "max_speed"]
254
+ )
255
+ new_th = th + new_thdot * dt
256
+ reward = -costs.view(*tensordict.shape, 1)
257
+ done = torch.zeros_like(reward, dtype=torch.bool)
258
+ out = TensorDict(
259
+ {
260
+ "th": new_th,
261
+ "thdot": new_thdot,
262
+ "params": tensordict["params"],
263
+ "reward": reward,
264
+ "done": done,
265
+ },
266
+ tensordict.shape,
267
+ )
268
+ return out
269
+
270
+ def _reset(self, tensordict):
271
+ batch_size = (
272
+ tensordict.batch_size if tensordict is not None else self.batch_size
273
+ )
274
+ if tensordict is None or "params" not in tensordict:
275
+ # if no ``tensordict`` is passed, we generate a single set of hyperparameters
276
+ # Otherwise, we assume that the input ``tensordict`` contains all the relevant
277
+ # parameters to get started.
278
+ tensordict = self.gen_params(batch_size=batch_size, device=self.device)
279
+ elif "th" in tensordict and "thdot" in tensordict:
280
+ # we can hard-reset the env too
281
+ return tensordict
282
+ out = self._reset_random_data(
283
+ tensordict.shape, batch_size, tensordict["params"]
284
+ )
285
+ return out
286
+
287
+ def _reset_random_data(self, shape, batch_size, params):
288
+
289
+ high_th = torch.tensor(self.DEFAULT_X, device=self.device)
290
+ high_thdot = torch.tensor(self.DEFAULT_Y, device=self.device)
291
+ low_th = -high_th
292
+ low_thdot = -high_thdot
293
+
294
+ # for non batch-locked environments, the input ``tensordict`` shape dictates the number
295
+ # of simulators run simultaneously. In other contexts, the initial
296
+ # random state's shape will depend upon the environment batch-size instead.
297
+ th = (
298
+ torch.rand(shape, generator=self.rng, device=self.device)
299
+ * (high_th - low_th)
300
+ + low_th
301
+ )
302
+ thdot = (
303
+ torch.rand(shape, generator=self.rng, device=self.device)
304
+ * (high_thdot - low_thdot)
305
+ + low_thdot
306
+ )
307
+ out = TensorDict(
308
+ {
309
+ "th": th,
310
+ "thdot": thdot,
311
+ "params": params,
312
+ },
313
+ batch_size=batch_size,
314
+ )
315
+ return out
316
+
317
+ def _make_spec(self, td_params):
318
+ # Under the hood, this will populate self.output_spec["observation"]
319
+ self.observation_spec = Composite(
320
+ th=Bounded(
321
+ low=-torch.pi,
322
+ high=torch.pi,
323
+ shape=(),
324
+ dtype=torch.float32,
325
+ ),
326
+ thdot=Bounded(
327
+ low=-td_params["params", "max_speed"],
328
+ high=td_params["params", "max_speed"],
329
+ shape=(),
330
+ dtype=torch.float32,
331
+ ),
332
+ # we need to add the ``params`` to the observation specs, as we want
333
+ # to pass it at each step during a rollout
334
+ params=make_composite_from_td(
335
+ td_params["params"], unsqueeze_null_shapes=False
336
+ ),
337
+ shape=(),
338
+ )
339
+ # since the environment is stateless, we expect the previous output as input.
340
+ # For this, ``EnvBase`` expects some state_spec to be available
341
+ self.state_spec = self.observation_spec.clone()
342
+ # action-spec will be automatically wrapped in input_spec when
343
+ # `self.action_spec = spec` will be called supported
344
+ self.action_spec = Bounded(
345
+ low=-td_params["params", "max_torque"],
346
+ high=td_params["params", "max_torque"],
347
+ shape=(1,),
348
+ dtype=torch.float32,
349
+ )
350
+ self.reward_spec = Unbounded(shape=(*td_params.shape, 1))
351
+
352
+ def make_composite_from_td(td):
353
+ # custom function to convert a ``tensordict`` in a similar spec structure
354
+ # of unbounded values.
355
+ composite = Composite(
356
+ {
357
+ key: make_composite_from_td(tensor)
358
+ if isinstance(tensor, TensorDictBase)
359
+ else Unbounded(
360
+ dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
361
+ )
362
+ for key, tensor in td.items()
363
+ },
364
+ shape=td.shape,
365
+ )
366
+ return composite
367
+
368
+ def _set_seed(self, seed: int) -> None:
369
+ rng = torch.Generator(device=self.device)
370
+ rng.manual_seed(seed)
371
+ self.rng = rng
372
+
373
+ @staticmethod
374
+ def gen_params(g=10.0, batch_size=None, device=None) -> TensorDictBase:
375
+ """Returns a ``tensordict`` containing the physical parameters such as gravitational force and torque or speed limits."""
376
+ if batch_size is None:
377
+ batch_size = []
378
+ td = TensorDict(
379
+ {
380
+ "params": TensorDict(
381
+ {
382
+ "max_speed": 8,
383
+ "max_torque": 2.0,
384
+ "dt": 0.05,
385
+ "g": g,
386
+ "m": 1.0,
387
+ "l": 1.0,
388
+ },
389
+ [],
390
+ )
391
+ },
392
+ [],
393
+ device=device,
394
+ )
395
+ if batch_size:
396
+ td = td.expand(batch_size).contiguous()
397
+ return td
398
+
399
+ @staticmethod
400
+ def angle_normalize(x):
401
+ return ((x + torch.pi) % (2 * torch.pi)) - torch.pi