torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,570 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass, field
9
+ from functools import partial
10
+ from typing import Any
11
+
12
+ import torch
13
+
14
+ from omegaconf import MISSING
15
+
16
+ from torchrl.trainers.algorithms.configs.common import ConfigBase
17
+
18
+
19
+ @dataclass
20
+ class ActivationConfig(ConfigBase):
21
+ """A class to configure an activation function.
22
+
23
+ Defaults to :class:`torch.nn.Tanh`.
24
+
25
+ .. seealso:: :class:`torch.nn.Tanh`
26
+ """
27
+
28
+ _target_: str = "torch.nn.Tanh"
29
+ _partial_: bool = False
30
+
31
+ def __post_init__(self) -> None:
32
+ """Post-initialization hook for activation configurations."""
33
+
34
+
35
+ @dataclass
36
+ class LayerConfig(ConfigBase):
37
+ """A class to configure a layer.
38
+
39
+ Defaults to :class:`torch.nn.Linear`.
40
+
41
+ .. seealso:: :class:`torch.nn.Linear`
42
+ """
43
+
44
+ _target_: str = "torch.nn.Linear"
45
+ _partial_: bool = False
46
+
47
+ def __post_init__(self) -> None:
48
+ """Post-initialization hook for layer configurations."""
49
+
50
+
51
+ @dataclass
52
+ class NetworkConfig(ConfigBase):
53
+ """Parent class to configure a network."""
54
+
55
+ _partial_: bool = False
56
+
57
+ def __post_init__(self) -> None:
58
+ """Post-initialization hook for network configurations."""
59
+
60
+
61
+ @dataclass
62
+ class MLPConfig(NetworkConfig):
63
+ """A class to configure a multi-layer perceptron.
64
+
65
+ Example:
66
+ >>> cfg = MLPConfig(in_features=10, out_features=5, depth=2, num_cells=32)
67
+ >>> net = instantiate(cfg)
68
+ >>> y = net(torch.randn(1, 10))
69
+ >>> assert y.shape == (1, 5)
70
+
71
+ .. seealso:: :class:`torchrl.modules.MLP`
72
+ """
73
+
74
+ in_features: int | None = None
75
+ out_features: Any = None
76
+ depth: int | None = None
77
+ num_cells: Any = None
78
+ activation_class: ActivationConfig = field(
79
+ default_factory=partial(
80
+ ActivationConfig, _target_="torch.nn.Tanh", _partial_=True
81
+ )
82
+ )
83
+ activation_kwargs: Any = None
84
+ norm_class: Any = None
85
+ norm_kwargs: Any = None
86
+ dropout: float | None = None
87
+ bias_last_layer: bool = True
88
+ single_bias_last_layer: bool = False
89
+ layer_class: LayerConfig = field(
90
+ default_factory=partial(LayerConfig, _target_="torch.nn.Linear", _partial_=True)
91
+ )
92
+ layer_kwargs: dict | None = None
93
+ activate_last_layer: bool = False
94
+ device: Any = None
95
+ _target_: str = "torchrl.modules.MLP"
96
+
97
+ def __post_init__(self):
98
+ if isinstance(self.activation_class, str):
99
+ self.activation_class = ActivationConfig(
100
+ _target_=self.activation_class, _partial_=True
101
+ )
102
+ if isinstance(self.layer_class, str):
103
+ self.layer_class = LayerConfig(_target_=self.layer_class, _partial_=True)
104
+
105
+
106
+ @dataclass
107
+ class NormConfig(ConfigBase):
108
+ """A class to configure a normalization layer.
109
+
110
+ Defaults to :class:`torch.nn.BatchNorm1d`.
111
+
112
+ .. seealso:: :class:`torch.nn.BatchNorm1d`
113
+ """
114
+
115
+ _target_: str = "torch.nn.BatchNorm1d"
116
+ _partial_: bool = False
117
+
118
+ def __post_init__(self) -> None:
119
+ """Post-initialization hook for normalization configurations."""
120
+
121
+
122
+ @dataclass
123
+ class AggregatorConfig(ConfigBase):
124
+ """A class to configure an aggregator layer.
125
+
126
+ Defaults to :class:`torchrl.modules.models.utils.SquashDims`.
127
+
128
+ .. seealso:: :class:`torchrl.modules.models.utils.SquashDims`
129
+ """
130
+
131
+ _target_: str = "torchrl.modules.models.utils.SquashDims"
132
+ _partial_: bool = False
133
+
134
+ def __post_init__(self) -> None:
135
+ """Post-initialization hook for aggregator configurations."""
136
+
137
+
138
+ @dataclass
139
+ class ConvNetConfig(NetworkConfig):
140
+ """A class to configure a convolutional network.
141
+
142
+ Defaults to :class:`torchrl.modules.ConvNet`.
143
+
144
+ Example:
145
+ >>> cfg = ConvNetConfig(in_features=3, depth=2, num_cells=[32, 64], kernel_sizes=[3, 5], strides=[1, 2], paddings=[1, 2])
146
+ >>> net = instantiate(cfg)
147
+ >>> y = net(torch.randn(1, 3, 32, 32))
148
+ >>> assert y.shape == (1, 64)
149
+
150
+ .. seealso:: :class:`torchrl.modules.ConvNet`
151
+ """
152
+
153
+ in_features: int | None = None
154
+ depth: int | None = None
155
+ num_cells: Any = None
156
+ kernel_sizes: Any = 3
157
+ strides: Any = 1
158
+ paddings: Any = 0
159
+ activation_class: ActivationConfig = field(
160
+ default_factory=partial(
161
+ ActivationConfig, _target_="torch.nn.ELU", _partial_=True
162
+ )
163
+ )
164
+ activation_kwargs: Any = None
165
+ norm_class: NormConfig | None = None
166
+ norm_kwargs: Any = None
167
+ bias_last_layer: bool = True
168
+ aggregator_class: AggregatorConfig = field(
169
+ default_factory=partial(
170
+ AggregatorConfig,
171
+ _target_="torchrl.modules.models.utils.SquashDims",
172
+ _partial_=True,
173
+ )
174
+ )
175
+ aggregator_kwargs: dict | None = None
176
+ squeeze_output: bool = False
177
+ device: Any = None
178
+ _target_: str = "torchrl.modules.ConvNet"
179
+
180
+ def __post_init__(self):
181
+ if self.activation_class is None and isinstance(self.activation_class, str):
182
+ self.activation_class = ActivationConfig(
183
+ _target_=self.activation_class, _partial_=True
184
+ )
185
+ if self.norm_class is None and isinstance(self.norm_class, str):
186
+ self.norm_class = NormConfig(_target_=self.norm_class, _partial_=True)
187
+ if self.aggregator_class is None and isinstance(self.aggregator_class, str):
188
+ self.aggregator_class = AggregatorConfig(
189
+ _target_=self.aggregator_class, _partial_=True
190
+ )
191
+
192
+
193
+ @dataclass
194
+ class ModelConfig(ConfigBase):
195
+ """Parent class to configure a model.
196
+
197
+ A model can be made of several networks. It is always a :class:`~tensordict.nn.TensorDictModuleBase` instance.
198
+
199
+ .. seealso:: :class:`TanhNormalModelConfig`, :class:`ValueModelConfig`
200
+ """
201
+
202
+ _partial_: bool = False
203
+ in_keys: Any = None
204
+ out_keys: Any = None
205
+ shared: bool = False
206
+
207
+ def __post_init__(self) -> None:
208
+ """Post-initialization hook for model configurations."""
209
+
210
+
211
+ @dataclass
212
+ class TensorDictModuleConfig(ModelConfig):
213
+ """A class to configure a TensorDictModule.
214
+
215
+ Example:
216
+ >>> cfg = TensorDictModuleConfig(module=MLPConfig(in_features=10, out_features=10, depth=2, num_cells=32), in_keys=["observation"], out_keys=["action"])
217
+ >>> module = instantiate(cfg)
218
+ >>> assert isinstance(module, TensorDictModule)
219
+ >>> assert module(observation=torch.randn(10, 10)).shape == (10, 10)
220
+
221
+ .. seealso:: :class:`tensordict.nn.TensorDictModule`
222
+ """
223
+
224
+ module: MLPConfig = MISSING
225
+ _target_: str = (
226
+ "torchrl.trainers.algorithms.configs.modules._make_tensordict_module"
227
+ )
228
+ _partial_: bool = False
229
+
230
+ def __post_init__(self) -> None:
231
+ """Post-initialization hook for TensorDict module configurations."""
232
+ return super().__post_init__()
233
+
234
+
235
+ @dataclass
236
+ class TensorDictSequentialConfig(ModelConfig):
237
+ """A class to configure a TensorDictSequential.
238
+
239
+ Example:
240
+ >>> cfg = TensorDictSequentialConfig(
241
+ ... modules=[
242
+ ... TensorDictModuleConfig(module=MLPConfig(in_features=10, out_features=10, depth=2, num_cells=32), in_keys=["observation"], out_keys=["hidden"]),
243
+ ... TensorDictModuleConfig(module=MLPConfig(in_features=10, out_features=5, depth=2, num_cells=32), in_keys=["hidden"], out_keys=["action"])
244
+ ... ]
245
+ ... )
246
+ >>> seq = instantiate(cfg)
247
+ >>> assert isinstance(seq, TensorDictSequential)
248
+
249
+ .. seealso:: :class:`tensordict.nn.TensorDictSequential`
250
+ """
251
+
252
+ modules: Any | None = None
253
+ partial_tolerant: bool = False
254
+ selected_out_keys: Any | None = None
255
+ inplace: bool | str | None = None
256
+ _target_: str = (
257
+ "torchrl.trainers.algorithms.configs.modules._make_tensordict_sequential"
258
+ )
259
+ _partial_: bool = False
260
+
261
+ def __post_init__(self) -> None:
262
+ return super().__post_init__()
263
+
264
+
265
+ @dataclass
266
+ class TanhNormalModelConfig(ModelConfig):
267
+ """A class to configure a TanhNormal model.
268
+
269
+ Example:
270
+ >>> cfg = TanhNormalModelConfig(network=MLPConfig(in_features=10, out_features=5, depth=2, num_cells=32))
271
+ >>> net = instantiate(cfg)
272
+ >>> y = net(torch.randn(1, 10))
273
+ >>> assert y.shape == (1, 5)
274
+
275
+ .. seealso:: :class:`torchrl.modules.TanhNormal`
276
+ """
277
+
278
+ network: MLPConfig = MISSING
279
+ eval_mode: bool = False
280
+
281
+ extract_normal_params: bool = True
282
+ scale_mapping: str = "biased_softplus_1.0"
283
+ scale_lb: float = 1e-4
284
+
285
+ param_keys: Any = None
286
+
287
+ exploration_type: Any = "RANDOM"
288
+
289
+ return_log_prob: bool = False
290
+
291
+ _target_: str = (
292
+ "torchrl.trainers.algorithms.configs.modules._make_tanh_normal_model"
293
+ )
294
+
295
+ def __post_init__(self):
296
+ """Post-initialization hook for TanhNormal model configurations."""
297
+ super().__post_init__()
298
+ if self.in_keys is None:
299
+ self.in_keys = ["observation"]
300
+ if self.param_keys is None:
301
+ self.param_keys = ["loc", "scale"]
302
+ if self.out_keys is None:
303
+ self.out_keys = ["action"]
304
+
305
+
306
+ @dataclass
307
+ class ValueModelConfig(ModelConfig):
308
+ """A class to configure a Value model.
309
+
310
+ Example:
311
+ >>> cfg = ValueModelConfig(network=MLPConfig(in_features=10, out_features=5, depth=2, num_cells=32))
312
+ >>> net = instantiate(cfg)
313
+ >>> y = net(torch.randn(1, 10))
314
+ >>> assert y.shape == (1, 5)
315
+
316
+ .. seealso:: :class:`torchrl.modules.ValueOperator`
317
+ """
318
+
319
+ _target_: str = "torchrl.trainers.algorithms.configs.modules._make_value_model"
320
+ network: NetworkConfig = MISSING
321
+
322
+ def __post_init__(self) -> None:
323
+ """Post-initialization hook for value model configurations."""
324
+ super().__post_init__()
325
+
326
+
327
+ @dataclass
328
+ class TanhModuleConfig(ModelConfig):
329
+ """A class to configure a TanhModule.
330
+
331
+ Example:
332
+ >>> cfg = TanhModuleConfig(in_keys=["action"], out_keys=["action"], low=-1.0, high=1.0)
333
+ >>> module = instantiate(cfg)
334
+ >>> assert isinstance(module, TanhModule)
335
+
336
+ .. seealso:: :class:`torchrl.modules.TanhModule`
337
+ """
338
+
339
+ spec: Any = None
340
+ low: Any = None
341
+ high: Any = None
342
+ clamp: bool = False
343
+ _target_: str = "torchrl.trainers.algorithms.configs.modules._make_tanh_module"
344
+
345
+ def __post_init__(self) -> None:
346
+ """Post-initialization hook for TanhModule configurations."""
347
+ super().__post_init__()
348
+
349
+
350
+ @dataclass
351
+ class AdditiveGaussianModuleConfig(ModelConfig):
352
+ """A class to configure an AdditiveGaussianModule.
353
+
354
+ Example:
355
+ >>> cfg = AdditiveGaussianModuleConfig(
356
+ ... spec=None,
357
+ ... sigma_init=1.0,
358
+ ... sigma_end=0.1,
359
+ ... mean=0.0,
360
+ ... std=1.0,
361
+ ... action_key="action",
362
+ ... )
363
+ >>> module = instantiate(cfg)
364
+ >>> assert isinstance(module, AdditiveGaussianModule)
365
+
366
+ .. seealso:: :class:`torchrl.modules.AdditiveGaussianModule`
367
+ """
368
+
369
+ spec: Any = None
370
+ sigma_init: float = 1.0
371
+ sigma_end: float = 0.1
372
+ annealing_num_steps: int = 1000
373
+ mean: float = 0.0
374
+ std: float = 1.0
375
+ action_key: Any = "action"
376
+ safe: bool = False
377
+ device: Any = None
378
+ _target_: str = (
379
+ "torchrl.trainers.algorithms.configs.modules._make_additive_gaussian_module"
380
+ )
381
+ _partial_: bool = False
382
+
383
+ def __post_init__(self) -> None:
384
+ super().__post_init__()
385
+
386
+
387
+ def _make_tensordict_module(*args, **kwargs):
388
+ """Helper function to create a TensorDictModule."""
389
+ from hydra.utils import instantiate
390
+ from tensordict.nn import TensorDictModule
391
+
392
+ module = kwargs.pop("module")
393
+ shared = kwargs.pop("shared", False)
394
+
395
+ # Instantiate the module if it's a config
396
+ if hasattr(module, "_target_"):
397
+ module = instantiate(module)
398
+ elif callable(module) and hasattr(module, "func"): # partial function
399
+ module = module()
400
+
401
+ # Create the TensorDictModule
402
+ tensordict_module = TensorDictModule(module, **kwargs)
403
+
404
+ # Apply share_memory if needed
405
+ if shared:
406
+ tensordict_module = tensordict_module.share_memory()
407
+
408
+ return tensordict_module
409
+
410
+
411
+ def _make_tensordict_sequential(*args, **kwargs):
412
+ """Helper function to create a TensorDictSequential."""
413
+ from hydra.utils import instantiate
414
+ from omegaconf import DictConfig, ListConfig
415
+ from tensordict.nn import TensorDictSequential
416
+
417
+ modules = kwargs.pop("modules")
418
+ shared = kwargs.pop("shared", False)
419
+ partial_tolerant = kwargs.pop("partial_tolerant", False)
420
+ selected_out_keys = kwargs.pop("selected_out_keys", None)
421
+ inplace = kwargs.pop("inplace", None)
422
+
423
+ def _instantiate_module(module):
424
+ if hasattr(module, "_target_"):
425
+ return instantiate(module)
426
+ elif callable(module) and hasattr(module, "func"):
427
+ return module()
428
+ else:
429
+ return module
430
+
431
+ if isinstance(modules, (dict, DictConfig)):
432
+ instantiated_modules = {
433
+ key: _instantiate_module(module) for key, module in modules.items()
434
+ }
435
+ elif isinstance(modules, (list, ListConfig)):
436
+ instantiated_modules = [_instantiate_module(module) for module in modules]
437
+ else:
438
+ raise ValueError(
439
+ f"modules must be a dict or list, got {type(modules).__name__}"
440
+ )
441
+
442
+ tensordict_sequential = TensorDictSequential(
443
+ instantiated_modules,
444
+ partial_tolerant=partial_tolerant,
445
+ selected_out_keys=selected_out_keys,
446
+ inplace=inplace,
447
+ )
448
+
449
+ if shared:
450
+ tensordict_sequential = tensordict_sequential.share_memory()
451
+
452
+ return tensordict_sequential
453
+
454
+
455
+ def _make_tanh_normal_model(*args, **kwargs):
456
+ """Helper function to create a TanhNormal model with ProbabilisticTensorDictSequential."""
457
+ from hydra.utils import instantiate
458
+ from tensordict.nn import (
459
+ ProbabilisticTensorDictModule,
460
+ ProbabilisticTensorDictSequential,
461
+ TensorDictModule,
462
+ )
463
+ from torchrl.modules import NormalParamExtractor, TanhNormal
464
+
465
+ # Extract parameters
466
+ network = kwargs.pop("network")
467
+ in_keys = list(kwargs.pop("in_keys", ["observation"]))
468
+ param_keys = list(kwargs.pop("param_keys", ["loc", "scale"]))
469
+ out_keys = list(kwargs.pop("out_keys", ["action"]))
470
+ extract_normal_params = kwargs.pop("extract_normal_params", True)
471
+ scale_mapping = kwargs.pop("scale_mapping", "biased_softplus_1.0")
472
+ scale_lb = kwargs.pop("scale_lb", 1e-4)
473
+ return_log_prob = kwargs.pop("return_log_prob", False)
474
+ eval_mode = kwargs.pop("eval_mode", False)
475
+ exploration_type = kwargs.pop("exploration_type", "RANDOM")
476
+ shared = kwargs.pop("shared", False)
477
+
478
+ # Now instantiate the network
479
+ if hasattr(network, "_target_"):
480
+ network = instantiate(network)
481
+ elif callable(network) and hasattr(network, "func"): # partial function
482
+ network = network()
483
+
484
+ # Create the sequential
485
+ if extract_normal_params:
486
+ # Add NormalParamExtractor to split the output
487
+ network = torch.nn.Sequential(
488
+ network,
489
+ NormalParamExtractor(scale_mapping=scale_mapping, scale_lb=scale_lb),
490
+ )
491
+
492
+ module = TensorDictModule(network, in_keys=in_keys, out_keys=param_keys)
493
+ if shared:
494
+ module = module.share_memory()
495
+
496
+ # Create ProbabilisticTensorDictModule
497
+ prob_module = ProbabilisticTensorDictModule(
498
+ in_keys=param_keys,
499
+ out_keys=out_keys,
500
+ distribution_class=TanhNormal,
501
+ return_log_prob=return_log_prob,
502
+ default_interaction_type=exploration_type,
503
+ **kwargs,
504
+ )
505
+
506
+ result = ProbabilisticTensorDictSequential(module, prob_module)
507
+ if eval_mode:
508
+ result.eval()
509
+ return result
510
+
511
+
512
+ def _make_value_model(*args, **kwargs):
513
+ """Helper function to create a ValueOperator with the given network."""
514
+ from hydra.utils import instantiate
515
+
516
+ from torchrl.modules import ValueOperator
517
+
518
+ network = kwargs.pop("network")
519
+ shared = kwargs.pop("shared", False)
520
+
521
+ # Instantiate the network if it's a config
522
+ if hasattr(network, "_target_"):
523
+ network = instantiate(network)
524
+ elif callable(network) and hasattr(network, "func"): # partial function
525
+ network = network()
526
+
527
+ # Create the ValueOperator
528
+ value_operator = ValueOperator(network, **kwargs)
529
+
530
+ # Apply share_memory if needed
531
+ if shared:
532
+ value_operator = value_operator.share_memory()
533
+
534
+ return value_operator
535
+
536
+
537
+ def _make_tanh_module(*args, **kwargs):
538
+ """Helper function to create a TanhModule."""
539
+ from omegaconf import ListConfig
540
+
541
+ from torchrl.modules import TanhModule
542
+
543
+ kwargs.pop("shared", False)
544
+
545
+ if "in_keys" in kwargs and isinstance(kwargs["in_keys"], ListConfig):
546
+ kwargs["in_keys"] = list(kwargs["in_keys"])
547
+ if "out_keys" in kwargs and isinstance(kwargs["out_keys"], ListConfig):
548
+ kwargs["out_keys"] = list(kwargs["out_keys"])
549
+
550
+ return TanhModule(**kwargs)
551
+
552
+
553
+ def _make_additive_gaussian_module(*args, **kwargs):
554
+ """Helper function to create an AdditiveGaussianModule."""
555
+ from omegaconf import ListConfig
556
+
557
+ from torchrl.modules.tensordict_module.exploration import AdditiveGaussianModule
558
+
559
+ kwargs.pop("shared", False)
560
+ kwargs.pop("in_keys", None)
561
+ kwargs.pop("out_keys", None)
562
+
563
+ if "action_key" in kwargs and isinstance(kwargs["action_key"], ListConfig):
564
+ action_key_list = list(kwargs["action_key"])
565
+ if len(action_key_list) == 1:
566
+ kwargs["action_key"] = action_key_list[0]
567
+ else:
568
+ kwargs["action_key"] = tuple(action_key_list)
569
+
570
+ return AdditiveGaussianModule(**kwargs)