torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,185 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from tensordict import TensorDictBase
11
+
12
+
13
+ # TODO: code small architecture ref in Impala paper
14
+
15
+
16
+ class _ResNetBlock(nn.Module):
17
+ def __init__(
18
+ self,
19
+ num_ch,
20
+ ):
21
+ super().__init__()
22
+ resnet_block = []
23
+ resnet_block.append(nn.ReLU(inplace=True))
24
+ resnet_block.append(
25
+ nn.LazyConv2d(
26
+ out_channels=num_ch,
27
+ kernel_size=3,
28
+ stride=1,
29
+ padding=1,
30
+ )
31
+ )
32
+ resnet_block.append(nn.ReLU(inplace=True))
33
+ resnet_block.append(
34
+ nn.Conv2d(
35
+ in_channels=num_ch,
36
+ out_channels=num_ch,
37
+ kernel_size=3,
38
+ stride=1,
39
+ padding=1,
40
+ )
41
+ )
42
+ self.seq = nn.Sequential(*resnet_block)
43
+
44
+ def forward(self, x):
45
+ x += self.seq(x)
46
+ return x
47
+
48
+
49
+ class _ConvNetBlock(nn.Module):
50
+ def __init__(self, num_ch):
51
+ super().__init__()
52
+
53
+ conv = nn.LazyConv2d(
54
+ out_channels=num_ch,
55
+ kernel_size=3,
56
+ stride=1,
57
+ padding=1,
58
+ )
59
+ mp = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
60
+ self.feats_conv = nn.Sequential(conv, mp)
61
+ self.resnet1 = _ResNetBlock(num_ch=num_ch)
62
+ self.resnet2 = _ResNetBlock(num_ch=num_ch)
63
+
64
+ def forward(self, x):
65
+ x = self.feats_conv(x)
66
+ x = self.resnet1(x)
67
+ x = self.resnet1(x)
68
+ return x
69
+
70
+
71
+ class ImpalaNet(nn.Module): # noqa: D101
72
+ def __init__(
73
+ self,
74
+ num_actions,
75
+ channels=(16, 32, 32),
76
+ out_features=256,
77
+ use_lstm=False,
78
+ batch_first=True,
79
+ clamp_reward=True,
80
+ one_hot=False,
81
+ ):
82
+ super().__init__()
83
+ self.batch_first = batch_first
84
+ self.use_lstm = use_lstm
85
+ self.clamp_reward = clamp_reward
86
+ self.one_hot = one_hot
87
+ self.num_actions = num_actions
88
+
89
+ layers = [_ConvNetBlock(num_ch) for num_ch in channels]
90
+ layers += [nn.ReLU(inplace=True)]
91
+ self.convs = nn.Sequential(*layers)
92
+ self.fc = nn.Sequential(nn.LazyLinear(out_features), nn.ReLU(inplace=True))
93
+
94
+ # FC output size + last reward.
95
+ core_output_size = out_features + 1
96
+
97
+ if use_lstm:
98
+ self.core = nn.LSTM(
99
+ core_output_size,
100
+ out_features,
101
+ num_layers=1,
102
+ batch_first=batch_first,
103
+ )
104
+ core_output_size = out_features
105
+
106
+ self.policy = nn.Linear(core_output_size, self.num_actions)
107
+ self.baseline = nn.Linear(core_output_size, 1)
108
+
109
+ def forward(self, x, reward, done, core_state=None, mask=None): # noqa: D102
110
+ if self.batch_first:
111
+ B, T, *x_shape = x.shape
112
+ batch_shape = torch.Size([B, T])
113
+ else:
114
+ T, B, *x_shape = x.shape
115
+ batch_shape = torch.Size([T, B])
116
+ if mask is None:
117
+ x = x.view(-1, *x.shape[-3:])
118
+ else:
119
+ x = x[mask]
120
+ if x.ndimension() != 4:
121
+ raise RuntimeError(
122
+ f"masked input should have 4 dimensions but got {x.ndimension()} instead"
123
+ )
124
+ x = self.convs(x)
125
+ x = x.view(B * T, -1)
126
+ x = self.fc(x)
127
+
128
+ if mask is None:
129
+ if self.batch_first:
130
+ x = x.view(B, T, -1)
131
+ else:
132
+ x = x.view(T, B, -1)
133
+ else:
134
+ x = self._allocate_masked_x(x, mask)
135
+
136
+ if self.clamp_reward:
137
+ reward = torch.clamp(reward, -1, 1)
138
+ reward = reward.unsqueeze(-1)
139
+
140
+ core_input = torch.cat([x, reward], dim=-1)
141
+
142
+ if self.use_lstm:
143
+ core_output, _ = self.core(core_input, core_state)
144
+ else:
145
+ core_output = core_input
146
+
147
+ policy_logits = self.policy(core_output)
148
+ baseline = self.baseline(core_output)
149
+
150
+ softmax_vals = F.softmax(policy_logits, dim=-1)
151
+ action = torch.multinomial(
152
+ softmax_vals.view(-1, softmax_vals.shape[-1]), num_samples=1
153
+ ).view(softmax_vals.shape[:-1])
154
+ if self.one_hot:
155
+ action = F.one_hot(action, policy_logits.shape[-1])
156
+
157
+ if policy_logits.shape[:2] != batch_shape:
158
+ raise RuntimeError("policy_logits and batch-shape mismatch")
159
+ if baseline.shape[:2] != batch_shape:
160
+ raise RuntimeError("baseline and batch-shape mismatch")
161
+ if action.shape[:2] != batch_shape:
162
+ raise RuntimeError("action and batch-shape mismatch")
163
+
164
+ return (action, policy_logits, baseline), core_state
165
+
166
+ def _allocate_masked_x(self, x, mask):
167
+ x_empty = torch.zeros(
168
+ *mask.shape[:2], x.shape[-1], device=x.device, dtype=x.dtype
169
+ )
170
+ x_empty[mask] = x
171
+ return x_empty
172
+
173
+
174
+ class ImpalaNetTensorDict(ImpalaNet): # noqa: D101
175
+ observation_key = "pixels"
176
+
177
+ def forward(self, tensordict: TensorDictBase): # noqa: D102
178
+ x = tensordict.get(self.observation_key)
179
+ done = tensordict.get("done").squeeze(-1)
180
+ reward = tensordict.get("reward").squeeze(-1)
181
+ mask = tensordict.get(("collector", "mask"))
182
+ core_state = (
183
+ tensordict.get("core_state") if "core_state" in tensordict.keys() else None
184
+ )
185
+ return super().forward(x, reward, done, core_state=core_state, mask=mask)
@@ -0,0 +1,162 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import inspect
8
+ import warnings
9
+ from collections.abc import Callable, Sequence
10
+
11
+ import torch
12
+ from torch import nn
13
+ from torchrl.data.utils import DEVICE_TYPING
14
+ from torchrl.modules.models.exploration import NoisyLazyLinear, NoisyLinear
15
+
16
+ LazyMapping = {
17
+ nn.Linear: nn.LazyLinear,
18
+ NoisyLinear: NoisyLazyLinear,
19
+ }
20
+
21
+
22
+ class SqueezeLayer(nn.Module):
23
+ """Squeezing layer.
24
+
25
+ Squeezes some given singleton dimensions of an input tensor.
26
+
27
+ Args:
28
+ dims (iterable): dimensions to be squeezed
29
+ default: (-1,)
30
+
31
+ """
32
+
33
+ def __init__(self, dims: Sequence[int] = (-1,)):
34
+ super().__init__()
35
+ for dim in dims:
36
+ if dim >= 0:
37
+ raise RuntimeError("dims must all be < 0")
38
+ self.dims = dims
39
+
40
+ def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: D102
41
+ for dim in self.dims:
42
+ if input.shape[dim] != 1:
43
+ raise RuntimeError(
44
+ f"Tried to squeeze an input over dims {self.dims} with shape {input.shape}"
45
+ )
46
+ input = input.squeeze(dim)
47
+ return input
48
+
49
+
50
+ class Squeeze2dLayer(SqueezeLayer):
51
+ """Squeezing layer for convolutional neural networks.
52
+
53
+ Squeezes the last two singleton dimensions of an input tensor.
54
+
55
+ """
56
+
57
+ def __init__(self):
58
+ super().__init__((-2, -1))
59
+
60
+
61
+ class SquashDims(nn.Module):
62
+ """A squashing layer.
63
+
64
+ Flattens the N last dimensions of an input tensor.
65
+
66
+ Args:
67
+ ndims_in (int): number of dimensions to be flattened.
68
+ default = 3
69
+
70
+ Examples:
71
+ >>> from torchrl.modules.models.utils import SquashDims
72
+ >>> import torch
73
+ >>> x = torch.randn(1, 2, 3, 4)
74
+ >>> print(SquashDims()(x).shape)
75
+ torch.Size([1, 24])
76
+
77
+ """
78
+
79
+ def __init__(self, ndims_in: int = 3):
80
+ super().__init__()
81
+ self.ndims_in = ndims_in
82
+
83
+ def forward(self, value: torch.Tensor) -> torch.Tensor:
84
+ value = value.flatten(-self.ndims_in, -1)
85
+ return value
86
+
87
+
88
+ def _find_depth(depth: int | None, *list_or_ints: Sequence):
89
+ """Find depth based on a sequence of inputs and a depth indicator.
90
+
91
+ If the depth is None, it is inferred by the length of one (or more) matching
92
+ lists of integers.
93
+ Raises an exception if depth does not match the list lengths or if lists lengths
94
+ do not match.
95
+
96
+ Args:
97
+ depth (int, optional): depth of the network
98
+ *list_or_ints (lists of int or int): if depth is None, at least one of
99
+ these inputs must be a list of ints of the length of the desired
100
+ network.
101
+ """
102
+ if depth is None:
103
+ for item in list_or_ints:
104
+ if isinstance(item, (list, tuple)):
105
+ depth = len(item)
106
+ if depth is None:
107
+ raise ValueError(
108
+ f"depth=None requires one of the input args (kernel_sizes, strides, "
109
+ f"num_cells) to be a a list or tuple. Got {tuple(type(item) for item in list_or_ints)}"
110
+ )
111
+ return depth
112
+
113
+
114
+ def create_on_device(
115
+ module_class: type[nn.Module] | Callable,
116
+ device: DEVICE_TYPING | None,
117
+ *args,
118
+ **kwargs,
119
+ ) -> nn.Module:
120
+ """Create a new instance of :obj:`module_class` on :obj:`device`.
121
+
122
+ The new instance is created directly on the device if its constructor supports this.
123
+
124
+ Args:
125
+ module_class (Type[nn.Module]): the class of module to be created.
126
+ device (DEVICE_TYPING): device to create the module on.
127
+ *args: positional arguments to be passed to the module constructor.
128
+ **kwargs: keyword arguments to be passed to the module constructor.
129
+
130
+ """
131
+ fullargspec = inspect.getfullargspec(module_class.__init__)
132
+ if "device" in fullargspec.args or "device" in fullargspec.kwonlyargs:
133
+ return module_class(*args, device=device, **kwargs)
134
+ else:
135
+ result = module_class(*args, **kwargs)
136
+ if hasattr(result, "to"):
137
+ result = result.to(device)
138
+ return result
139
+
140
+
141
+ def _reset_parameters_recursive(module, warn_if_no_op: bool = True) -> bool:
142
+ """Recursively resets the parameters of a :class:`~torch.nn.Module` in-place.
143
+
144
+ Args:
145
+ module (torch.nn.Module): the module to reset.
146
+ warn_if_no_op (bool, optional): whether to raise a warning in case this is a no-op.
147
+ Defaults to ``True``.
148
+
149
+ Returns: whether any parameter has been reset.
150
+
151
+ """
152
+ any_reset = False
153
+ for layer in module.children():
154
+ if hasattr(layer, "reset_parameters"):
155
+ layer.reset_parameters()
156
+ any_reset |= True
157
+ any_reset |= _reset_parameters_recursive(layer, warn_if_no_op=False)
158
+ if warn_if_no_op and not any_reset:
159
+ warnings.warn(
160
+ "_reset_parameters_recursive was called without the parameters argument and did not find any parameters to reset"
161
+ )
162
+ return any_reset
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .cem import CEMPlanner
7
+ from .common import MPCPlannerBase
8
+ from .mppi import MPPIPlanner
9
+
10
+ __all__ = ["CEMPlanner", "MPCPlannerBase", "MPPIPlanner"]
@@ -0,0 +1,228 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING
8
+
9
+ import torch
10
+ from tensordict import TensorDict, TensorDictBase
11
+ from torchrl.modules.planners.common import MPCPlannerBase
12
+
13
+ if TYPE_CHECKING:
14
+ from torchrl.envs.common import EnvBase
15
+
16
+
17
+ class CEMPlanner(MPCPlannerBase):
18
+ """CEMPlanner Module.
19
+
20
+ Reference: The cross-entropy method for optimization, Botev et al. 2013
21
+
22
+ This module will perform a CEM planning step when given a TensorDict
23
+ containing initial states.
24
+ The CEM planning step is performed by sampling actions from a Gaussian
25
+ distribution with zero mean and unit variance.
26
+ The sampled actions are then used to perform a rollout in the environment.
27
+ The cumulative rewards obtained with the rollout is then
28
+ ranked. We select the top-k episodes and use their actions to update the
29
+ mean and standard deviation of the actions distribution.
30
+ The CEM planning step is repeated for a specified number of steps.
31
+
32
+ A call to the module returns the actions that empirically maximised the
33
+ returns given a planning horizon
34
+
35
+ Args:
36
+ env (EnvBase): The environment to perform the planning step on (can be
37
+ `ModelBasedEnv` or :obj:`EnvBase`).
38
+ planning_horizon (int): The length of the simulated trajectories
39
+ optim_steps (int): The number of optimization steps used by the MPC
40
+ planner
41
+ num_candidates (int): The number of candidates to sample from the
42
+ Gaussian distributions.
43
+ top_k (int): The number of top candidates to use to
44
+ update the mean and standard deviation of the Gaussian distribution.
45
+ reward_key (str, optional): The key in the TensorDict to use to
46
+ retrieve the reward. Defaults to "reward".
47
+ action_key (str, optional): The key in the TensorDict to use to store
48
+ the action. Defaults to "action"
49
+
50
+ Examples:
51
+ >>> from tensordict import TensorDict
52
+ >>> from torchrl.data import Composite, Unbounded
53
+ >>> from torchrl.envs.model_based import ModelBasedEnvBase
54
+ >>> from torchrl.modules import SafeModule
55
+ >>> class MyMBEnv(ModelBasedEnvBase):
56
+ ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None):
57
+ ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size)
58
+ ... self.state_spec = Composite(
59
+ ... hidden_observation=Unbounded((4,))
60
+ ... )
61
+ ... self.observation_spec = Composite(
62
+ ... hidden_observation=Unbounded((4,))
63
+ ... )
64
+ ... self.action_spec = Unbounded((1,))
65
+ ... self.reward_spec = Unbounded((1,))
66
+ ...
67
+ ... def _reset(self, tensordict: TensorDict) -> TensorDict:
68
+ ... tensordict = TensorDict(
69
+ ... {},
70
+ ... batch_size=self.batch_size,
71
+ ... device=self.device,
72
+ ... )
73
+ ... tensordict = tensordict.update(
74
+ ... self.full_state_spec.rand())
75
+ ... tensordict = tensordict.update(
76
+ ... self.full_action_spec.rand())
77
+ ... tensordict = tensordict.update(
78
+ ... self.full_observation_spec.rand())
79
+ ... return tensordict
80
+ ...
81
+ >>> from torchrl.modules import MLP, WorldModelWrapper
82
+ >>> import torch.nn as nn
83
+ >>> world_model = WorldModelWrapper(
84
+ ... SafeModule(
85
+ ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0),
86
+ ... in_keys=["hidden_observation", "action"],
87
+ ... out_keys=["hidden_observation"],
88
+ ... ),
89
+ ... SafeModule(
90
+ ... nn.Linear(4, 1),
91
+ ... in_keys=["hidden_observation"],
92
+ ... out_keys=["reward"],
93
+ ... ),
94
+ ... )
95
+ >>> env = MyMBEnv(world_model)
96
+ >>> # Build a planner and use it as actor
97
+ >>> planner = CEMPlanner(env, 10, 11, 7, 3)
98
+ >>> env.rollout(5, planner)
99
+ TensorDict(
100
+ fields={
101
+ action: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
102
+ done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
103
+ hidden_observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
104
+ next: TensorDict(
105
+ fields={
106
+ done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
107
+ hidden_observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
108
+ reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
109
+ terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
110
+ batch_size=torch.Size([5]),
111
+ device=cpu,
112
+ is_shared=False),
113
+ terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
114
+ batch_size=torch.Size([5]),
115
+ device=cpu,
116
+ is_shared=False)
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ env: EnvBase,
122
+ planning_horizon: int,
123
+ optim_steps: int,
124
+ num_candidates: int,
125
+ top_k: int,
126
+ reward_key: str = ("next", "reward"),
127
+ action_key: str = "action",
128
+ ):
129
+ super().__init__(env=env, action_key=action_key)
130
+ self.planning_horizon = planning_horizon
131
+ self.optim_steps = optim_steps
132
+ self.num_candidates = num_candidates
133
+ self.top_k = top_k
134
+ self.reward_key = reward_key
135
+
136
+ def planning(self, tensordict: TensorDictBase) -> torch.Tensor:
137
+ batch_size = tensordict.batch_size
138
+ action_shape = (
139
+ *batch_size,
140
+ self.num_candidates,
141
+ self.planning_horizon,
142
+ *self.action_spec.shape,
143
+ )
144
+ action_stats_shape = (
145
+ *batch_size,
146
+ 1,
147
+ self.planning_horizon,
148
+ *self.action_spec.shape,
149
+ )
150
+ action_topk_shape = (
151
+ *batch_size,
152
+ self.top_k,
153
+ self.planning_horizon,
154
+ *self.action_spec.shape,
155
+ )
156
+ TIME_DIM = len(self.action_spec.shape) - 3
157
+ K_DIM = len(self.action_spec.shape) - 4
158
+ expanded_original_tensordict = (
159
+ tensordict.unsqueeze(-1)
160
+ .expand(*batch_size, self.num_candidates)
161
+ .to_tensordict()
162
+ )
163
+ _action_means = torch.zeros(
164
+ *action_stats_shape,
165
+ device=tensordict.device,
166
+ dtype=self.env.action_spec.dtype,
167
+ )
168
+ _action_stds = torch.ones_like(_action_means)
169
+ container = TensorDict(
170
+ {
171
+ "tensordict": expanded_original_tensordict,
172
+ "stats": TensorDict(
173
+ {
174
+ "_action_means": _action_means,
175
+ "_action_stds": _action_stds,
176
+ },
177
+ [*batch_size, 1, self.planning_horizon],
178
+ ),
179
+ },
180
+ batch_size,
181
+ )
182
+
183
+ for _ in range(self.optim_steps):
184
+ actions_means = container.get(("stats", "_action_means"))
185
+ actions_stds = container.get(("stats", "_action_stds"))
186
+ actions = actions_means + actions_stds * torch.randn(
187
+ *action_shape,
188
+ device=actions_means.device,
189
+ dtype=actions_means.dtype,
190
+ )
191
+ actions = self.env.action_spec.project(actions)
192
+ optim_tensordict = container.get("tensordict").clone()
193
+ policy = _PrecomputedActionsSequentialSetter(actions)
194
+ optim_tensordict = self.env.rollout(
195
+ max_steps=self.planning_horizon,
196
+ policy=policy,
197
+ auto_reset=False,
198
+ tensordict=optim_tensordict,
199
+ )
200
+
201
+ sum_rewards = optim_tensordict.get(self.reward_key).sum(
202
+ dim=TIME_DIM, keepdim=True
203
+ )
204
+ _, top_k = sum_rewards.topk(self.top_k, dim=K_DIM)
205
+ top_k = top_k.expand(action_topk_shape)
206
+ best_actions = actions.gather(K_DIM, top_k)
207
+ container.set_(
208
+ ("stats", "_action_means"), best_actions.mean(dim=K_DIM, keepdim=True)
209
+ )
210
+ container.set_(
211
+ ("stats", "_action_stds"), best_actions.std(dim=K_DIM, keepdim=True)
212
+ )
213
+ action_means = container.get(("stats", "_action_means"))
214
+ return action_means[..., 0, 0, :]
215
+
216
+
217
+ class _PrecomputedActionsSequentialSetter:
218
+ def __init__(self, actions):
219
+ self.actions = actions
220
+ self.cmpt = 0
221
+
222
+ def __call__(self, tensordict):
223
+ # checks that the step count is lower or equal to the horizon
224
+ if self.cmpt >= self.actions.shape[-2]:
225
+ raise ValueError("Precomputed actions sequence is too short")
226
+ tensordict = tensordict.set("action", self.actions[..., self.cmpt, :])
227
+ self.cmpt += 1
228
+ return tensordict
@@ -0,0 +1,73 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import abc
8
+ from typing import TYPE_CHECKING
9
+
10
+ import torch
11
+ from tensordict import TensorDictBase
12
+
13
+ from torchrl.modules import SafeModule
14
+
15
+ if TYPE_CHECKING:
16
+ from torchrl.envs.common import EnvBase
17
+
18
+
19
+ class MPCPlannerBase(SafeModule, metaclass=abc.ABCMeta):
20
+ """MPCPlannerBase abstract Module.
21
+
22
+ This class inherits from :obj:`SafeModule`. Provided a :obj:`TensorDict`, this module will perform a Model Predictive Control (MPC) planning step.
23
+ At the end of the planning step, the :obj:`MPCPlanner` will return a proposed action.
24
+
25
+ Args:
26
+ env (EnvBase): The environment to perform the planning step on (Can be :obj:`ModelBasedEnvBase` or :obj:`EnvBase`).
27
+ action_key (str, optional): The key that will point to the computed action.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ env: EnvBase,
33
+ action_key: str = "action",
34
+ ):
35
+ # Check if env is stateless
36
+ if env.batch_locked:
37
+ raise ValueError(
38
+ "Environment is batch_locked. MPCPlanners need an environment that accepts batched inputs with any batch size"
39
+ )
40
+ out_keys = [action_key]
41
+ in_keys = list(env.observation_spec.keys(True, True))
42
+ super().__init__(env, in_keys=in_keys, out_keys=out_keys)
43
+ self.env = env
44
+ self.action_spec = env.action_spec
45
+ self.to(env.device)
46
+
47
+ @abc.abstractmethod
48
+ def planning(self, td: TensorDictBase) -> torch.Tensor:
49
+ """Performs the MPC planning step.
50
+
51
+ Args:
52
+ td (TensorDict): The TensorDict to perform the planning step on.
53
+ """
54
+ raise NotImplementedError()
55
+
56
+ def forward(
57
+ self,
58
+ tensordict: TensorDictBase,
59
+ tensordict_out: TensorDictBase | None = None,
60
+ **kwargs,
61
+ ) -> TensorDictBase:
62
+ if "params" in kwargs or "vmap" in kwargs:
63
+ raise ValueError(
64
+ "MPCPlannerBase does not currently support functional programming."
65
+ )
66
+ action = self.planning(tensordict)
67
+ action = self.action_spec.project(action)
68
+ tensordict_out = self._write_to_tensordict(
69
+ tensordict,
70
+ (action,),
71
+ tensordict_out,
72
+ )
73
+ return tensordict_out