torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,265 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING
8
+
9
+ import torch
10
+ from tensordict import TensorDict, TensorDictBase
11
+ from torch import nn
12
+
13
+ from torchrl.modules.planners.common import MPCPlannerBase
14
+
15
+ if TYPE_CHECKING:
16
+ from torchrl.envs.common import EnvBase
17
+
18
+
19
+ class MPPIPlanner(MPCPlannerBase):
20
+ """MPPI Planner Module.
21
+
22
+ Reference:
23
+
24
+ - Model predictive path integral control using covariance variable importance
25
+ sampling. (Williams, G., Aldrich, A., and Theodorou, E. A.) https://arxiv.org/abs/1509.01149
26
+ - Temporal Difference Learning for Model Predictive Control
27
+ (Hansen N., Wang X., Su H.) https://arxiv.org/abs/2203.04955
28
+
29
+ This module will perform a MPPI planning step when given a TensorDict
30
+ containing initial states.
31
+
32
+ A call to the module returns the actions that empirically maximised the
33
+ returns given a planning horizon
34
+
35
+ Args:
36
+ env (EnvBase): The environment to perform the planning step on (can be
37
+ `ModelBasedEnv` or :obj:`EnvBase`).
38
+ planning_horizon (int): The length of the simulated trajectories
39
+ optim_steps (int): The number of optimization steps used by the MPC
40
+ planner
41
+ num_candidates (int): The number of candidates to sample from the
42
+ Gaussian distributions.
43
+ top_k (int): The number of top candidates to use to
44
+ update the mean and standard deviation of the Gaussian distribution.
45
+ reward_key (str, optional): The key in the TensorDict to use to
46
+ retrieve the reward. Defaults to "reward".
47
+ action_key (str, optional): The key in the TensorDict to use to store
48
+ the action. Defaults to "action"
49
+
50
+ Examples:
51
+ >>> from tensordict import TensorDict
52
+ >>> from torchrl.data import Composite, Unbounded
53
+ >>> from torchrl.envs.model_based import ModelBasedEnvBase
54
+ >>> from tensordict.nn import TensorDictModule
55
+ >>> from torchrl.modules import ValueOperator
56
+ >>> from torchrl.objectives.value import TDLambdaEstimator
57
+ >>> class MyMBEnv(ModelBasedEnvBase):
58
+ ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None):
59
+ ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size)
60
+ ... self.state_spec = Composite(
61
+ ... hidden_observation=Unbounded((4,))
62
+ ... )
63
+ ... self.observation_spec = Composite(
64
+ ... hidden_observation=Unbounded((4,))
65
+ ... )
66
+ ... self.action_spec = Unbounded((1,))
67
+ ... self.reward_spec = Unbounded((1,))
68
+ ...
69
+ ... def _reset(self, tensordict: TensorDict) -> TensorDict:
70
+ ... tensordict = TensorDict(
71
+ ... {},
72
+ ... batch_size=self.batch_size,
73
+ ... device=self.device,
74
+ ... )
75
+ ... tensordict = tensordict.update(
76
+ ... self.full_state_spec.rand())
77
+ ... tensordict = tensordict.update(
78
+ ... self.full_action_spec.rand())
79
+ ... tensordict = tensordict.update(
80
+ ... self.full_observation_spec.rand())
81
+ ... return tensordict
82
+ ...
83
+ >>> from torchrl.modules import MLP, WorldModelWrapper
84
+ >>> import torch.nn as nn
85
+ >>> world_model = WorldModelWrapper(
86
+ ... TensorDictModule(
87
+ ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0),
88
+ ... in_keys=["hidden_observation", "action"],
89
+ ... out_keys=["hidden_observation"],
90
+ ... ),
91
+ ... TensorDictModule(
92
+ ... nn.Linear(4, 1),
93
+ ... in_keys=["hidden_observation"],
94
+ ... out_keys=["reward"],
95
+ ... ),
96
+ ... )
97
+ >>> env = MyMBEnv(world_model)
98
+ >>> value_net = nn.Linear(4, 1)
99
+ >>> value_net = ValueOperator(value_net, in_keys=["hidden_observation"])
100
+ >>> adv = TDLambdaEstimator(
101
+ ... gamma=0.99,
102
+ ... lmbda=0.95,
103
+ ... value_network=value_net,
104
+ ... )
105
+ >>> # Build a planner and use it as actor
106
+ >>> planner = MPPIPlanner(
107
+ ... env,
108
+ ... adv,
109
+ ... temperature=1.0,
110
+ ... planning_horizon=10,
111
+ ... optim_steps=11,
112
+ ... num_candidates=7,
113
+ ... top_k=3)
114
+ >>> env.rollout(5, planner)
115
+ TensorDict(
116
+ fields={
117
+ action: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
118
+ done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
119
+ hidden_observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
120
+ next: TensorDict(
121
+ fields={
122
+ done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
123
+ hidden_observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
124
+ reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
125
+ terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
126
+ batch_size=torch.Size([5]),
127
+ device=cpu,
128
+ is_shared=False),
129
+ terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
130
+ batch_size=torch.Size([5]),
131
+ device=cpu,
132
+ is_shared=False)
133
+ """
134
+
135
+ def __init__(
136
+ self,
137
+ env: EnvBase,
138
+ advantage_module: nn.Module,
139
+ temperature: float,
140
+ planning_horizon: int,
141
+ optim_steps: int,
142
+ num_candidates: int,
143
+ top_k: int,
144
+ reward_key: str = ("next", "reward"),
145
+ action_key: str = "action",
146
+ ):
147
+ super().__init__(env=env, action_key=action_key)
148
+ self.advantage_module = advantage_module
149
+ self.planning_horizon = planning_horizon
150
+ self.optim_steps = optim_steps
151
+ self.num_candidates = num_candidates
152
+ self.top_k = top_k
153
+ self.reward_key = reward_key
154
+ self.register_buffer("temperature", torch.as_tensor(temperature))
155
+
156
+ def planning(self, tensordict: TensorDictBase) -> torch.Tensor:
157
+ batch_size = tensordict.batch_size
158
+ action_shape = (
159
+ *batch_size,
160
+ self.num_candidates,
161
+ self.planning_horizon,
162
+ *self.action_spec.shape,
163
+ )
164
+ action_stats_shape = (
165
+ *batch_size,
166
+ 1,
167
+ self.planning_horizon,
168
+ *self.action_spec.shape,
169
+ )
170
+ action_topk_shape = (
171
+ *batch_size,
172
+ self.top_k,
173
+ self.planning_horizon,
174
+ *self.action_spec.shape,
175
+ )
176
+ adv_topk_shape = (
177
+ *batch_size,
178
+ self.top_k,
179
+ 1,
180
+ 1,
181
+ )
182
+ K_DIM = len(self.action_spec.shape) - 4
183
+ expanded_original_tensordict = (
184
+ tensordict.unsqueeze(-1)
185
+ .expand(*batch_size, self.num_candidates)
186
+ .to_tensordict()
187
+ )
188
+ _action_means = torch.zeros(
189
+ *action_stats_shape,
190
+ device=tensordict.device,
191
+ dtype=self.env.action_spec.dtype,
192
+ )
193
+ _action_stds = torch.ones_like(_action_means)
194
+ container = TensorDict(
195
+ {
196
+ "tensordict": expanded_original_tensordict,
197
+ "stats": TensorDict(
198
+ {
199
+ "_action_means": _action_means,
200
+ "_action_stds": _action_stds,
201
+ },
202
+ [*batch_size, 1, self.planning_horizon],
203
+ ),
204
+ },
205
+ batch_size,
206
+ )
207
+
208
+ for _ in range(self.optim_steps):
209
+ actions_means = container.get(("stats", "_action_means"))
210
+ actions_stds = container.get(("stats", "_action_stds"))
211
+ actions = actions_means + actions_stds * torch.randn(
212
+ *action_shape,
213
+ device=actions_means.device,
214
+ dtype=actions_means.dtype,
215
+ )
216
+ actions = self.env.action_spec.project(actions)
217
+ optim_tensordict = container.get("tensordict").clone()
218
+ policy = _PrecomputedActionsSequentialSetter(actions)
219
+ optim_tensordict = self.env.rollout(
220
+ max_steps=self.planning_horizon,
221
+ policy=policy,
222
+ auto_reset=False,
223
+ tensordict=optim_tensordict,
224
+ )
225
+ # compute advantage
226
+ self.advantage_module(optim_tensordict)
227
+ # get advantage of the current state
228
+ advantage = optim_tensordict["advantage"][..., :1, :]
229
+ # get top-k trajectories
230
+ _, top_k = advantage.topk(self.top_k, dim=K_DIM)
231
+ # get omega weights for each top-k trajectory
232
+ vals = advantage.gather(K_DIM, top_k.expand(adv_topk_shape))
233
+ Omegas = (self.temperature * vals).exp()
234
+
235
+ # gather best actions
236
+ best_actions = actions.gather(K_DIM, top_k.expand(action_topk_shape))
237
+
238
+ # compute weighted average
239
+ _action_means = (Omegas * best_actions).sum(
240
+ dim=K_DIM, keepdim=True
241
+ ) / Omegas.sum(K_DIM, True)
242
+ _action_stds = (
243
+ (Omegas * (best_actions - _action_means).pow(2)).sum(
244
+ dim=K_DIM, keepdim=True
245
+ )
246
+ / Omegas.sum(K_DIM, True)
247
+ ).sqrt()
248
+ container.set_(("stats", "_action_means"), _action_means)
249
+ container.set_(("stats", "_action_stds"), _action_stds)
250
+ action_means = container.get(("stats", "_action_means"))
251
+ return action_means[..., 0, 0, :]
252
+
253
+
254
+ class _PrecomputedActionsSequentialSetter:
255
+ def __init__(self, actions):
256
+ self.actions = actions
257
+ self.cmpt = 0
258
+
259
+ def __call__(self, tensordict):
260
+ # checks that the step count is lower or equal to the horizon
261
+ if self.cmpt >= self.actions.shape[-2]:
262
+ raise ValueError("Precomputed actions sequence is too short")
263
+ tensordict = tensordict.set("action", self.actions[..., self.cmpt, :])
264
+ self.cmpt += 1
265
+ return tensordict
@@ -0,0 +1,89 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torchrl.modules.tensordict_module.actors import (
7
+ Actor,
8
+ ActorCriticOperator,
9
+ ActorCriticWrapper,
10
+ ActorValueOperator,
11
+ DecisionTransformerInferenceWrapper,
12
+ DistributionalQValueActor,
13
+ DistributionalQValueHook,
14
+ DistributionalQValueModule,
15
+ LMHeadActorValueOperator,
16
+ MultiStepActorWrapper,
17
+ ProbabilisticActor,
18
+ QValueActor,
19
+ QValueHook,
20
+ QValueModule,
21
+ TanhModule,
22
+ ValueOperator,
23
+ )
24
+ from torchrl.modules.tensordict_module.common import SafeModule, VmapModule
25
+ from torchrl.modules.tensordict_module.exploration import (
26
+ AdditiveGaussianModule,
27
+ AdditiveGaussianWrapper,
28
+ EGreedyModule,
29
+ EGreedyWrapper,
30
+ OrnsteinUhlenbeckProcessModule,
31
+ OrnsteinUhlenbeckProcessWrapper,
32
+ RandomPolicy,
33
+ )
34
+ from torchrl.modules.tensordict_module.probabilistic import (
35
+ SafeProbabilisticModule,
36
+ SafeProbabilisticTensorDictSequential,
37
+ )
38
+ from torchrl.modules.tensordict_module.rnn import (
39
+ GRU,
40
+ GRUCell,
41
+ GRUModule,
42
+ LSTM,
43
+ LSTMCell,
44
+ LSTMModule,
45
+ recurrent_mode,
46
+ set_recurrent_mode,
47
+ )
48
+ from torchrl.modules.tensordict_module.sequence import SafeSequential
49
+ from torchrl.modules.tensordict_module.world_models import WorldModelWrapper
50
+
51
+ __all__ = [
52
+ "Actor",
53
+ "ActorCriticOperator",
54
+ "ActorCriticWrapper",
55
+ "ActorValueOperator",
56
+ "DecisionTransformerInferenceWrapper",
57
+ "DistributionalQValueActor",
58
+ "DistributionalQValueHook",
59
+ "DistributionalQValueModule",
60
+ "LMHeadActorValueOperator",
61
+ "MultiStepActorWrapper",
62
+ "ProbabilisticActor",
63
+ "QValueActor",
64
+ "QValueHook",
65
+ "QValueModule",
66
+ "TanhModule",
67
+ "ValueOperator",
68
+ "SafeModule",
69
+ "VmapModule",
70
+ "AdditiveGaussianModule",
71
+ "AdditiveGaussianWrapper",
72
+ "EGreedyModule",
73
+ "EGreedyWrapper",
74
+ "RandomPolicy",
75
+ "OrnsteinUhlenbeckProcessModule",
76
+ "OrnsteinUhlenbeckProcessWrapper",
77
+ "SafeProbabilisticModule",
78
+ "SafeProbabilisticTensorDictSequential",
79
+ "GRU",
80
+ "GRUCell",
81
+ "GRUModule",
82
+ "LSTM",
83
+ "LSTMCell",
84
+ "LSTMModule",
85
+ "recurrent_mode",
86
+ "set_recurrent_mode",
87
+ "SafeSequential",
88
+ "WorldModelWrapper",
89
+ ]