torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,252 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+
10
+ from torchrl.trainers.algorithms.configs.common import ConfigBase
11
+
12
+
13
+ @dataclass
14
+ class AdamConfig(ConfigBase):
15
+ """Configuration for Adam optimizer."""
16
+
17
+ lr: float = 1e-3
18
+ betas: tuple[float, float] = (0.9, 0.999)
19
+ eps: float = 1e-4
20
+ weight_decay: float = 0.0
21
+ amsgrad: bool = False
22
+ _target_: str = "torch.optim.Adam"
23
+ _partial_: bool = True
24
+
25
+ def __post_init__(self) -> None:
26
+ """Post-initialization hook for Adam optimizer configurations."""
27
+
28
+
29
+ @dataclass
30
+ class AdamWConfig(ConfigBase):
31
+ """Configuration for AdamW optimizer."""
32
+
33
+ lr: float = 1e-3
34
+ betas: tuple[float, float] = (0.9, 0.999)
35
+ eps: float = 1e-8
36
+ weight_decay: float = 1e-2
37
+ amsgrad: bool = False
38
+ maximize: bool = False
39
+ foreach: bool | None = None
40
+ capturable: bool = False
41
+ differentiable: bool = False
42
+ fused: bool | None = None
43
+ _target_: str = "torch.optim.AdamW"
44
+ _partial_: bool = True
45
+
46
+ def __post_init__(self) -> None:
47
+ """Post-initialization hook for AdamW optimizer configurations."""
48
+
49
+
50
+ @dataclass
51
+ class AdamaxConfig(ConfigBase):
52
+ """Configuration for Adamax optimizer."""
53
+
54
+ lr: float = 2e-3
55
+ betas: tuple[float, float] = (0.9, 0.999)
56
+ eps: float = 1e-8
57
+ weight_decay: float = 0.0
58
+ _target_: str = "torch.optim.Adamax"
59
+ _partial_: bool = True
60
+
61
+ def __post_init__(self) -> None:
62
+ """Post-initialization hook for Adamax optimizer configurations."""
63
+
64
+
65
+ @dataclass
66
+ class SGDConfig(ConfigBase):
67
+ """Configuration for SGD optimizer."""
68
+
69
+ lr: float = 1e-3
70
+ momentum: float = 0.0
71
+ dampening: float = 0.0
72
+ weight_decay: float = 0.0
73
+ nesterov: bool = False
74
+ maximize: bool = False
75
+ foreach: bool | None = None
76
+ differentiable: bool = False
77
+ _target_: str = "torch.optim.SGD"
78
+ _partial_: bool = True
79
+
80
+ def __post_init__(self) -> None:
81
+ """Post-initialization hook for SGD optimizer configurations."""
82
+
83
+
84
+ @dataclass
85
+ class RMSpropConfig(ConfigBase):
86
+ """Configuration for RMSprop optimizer."""
87
+
88
+ lr: float = 1e-2
89
+ alpha: float = 0.99
90
+ eps: float = 1e-8
91
+ weight_decay: float = 0.0
92
+ momentum: float = 0.0
93
+ centered: bool = False
94
+ maximize: bool = False
95
+ foreach: bool | None = None
96
+ differentiable: bool = False
97
+ _target_: str = "torch.optim.RMSprop"
98
+ _partial_: bool = True
99
+
100
+ def __post_init__(self) -> None:
101
+ """Post-initialization hook for RMSprop optimizer configurations."""
102
+
103
+
104
+ @dataclass
105
+ class AdagradConfig(ConfigBase):
106
+ """Configuration for Adagrad optimizer."""
107
+
108
+ lr: float = 1e-2
109
+ lr_decay: float = 0.0
110
+ weight_decay: float = 0.0
111
+ initial_accumulator_value: float = 0.0
112
+ eps: float = 1e-10
113
+ maximize: bool = False
114
+ foreach: bool | None = None
115
+ differentiable: bool = False
116
+ _target_: str = "torch.optim.Adagrad"
117
+ _partial_: bool = True
118
+
119
+ def __post_init__(self) -> None:
120
+ """Post-initialization hook for Adagrad optimizer configurations."""
121
+
122
+
123
+ @dataclass
124
+ class AdadeltaConfig(ConfigBase):
125
+ """Configuration for Adadelta optimizer."""
126
+
127
+ lr: float = 1.0
128
+ rho: float = 0.9
129
+ eps: float = 1e-6
130
+ weight_decay: float = 0.0
131
+ foreach: bool | None = None
132
+ maximize: bool = False
133
+ differentiable: bool = False
134
+ _target_: str = "torch.optim.Adadelta"
135
+ _partial_: bool = True
136
+
137
+ def __post_init__(self) -> None:
138
+ """Post-initialization hook for Adadelta optimizer configurations."""
139
+
140
+
141
+ @dataclass
142
+ class RpropConfig(ConfigBase):
143
+ """Configuration for Rprop optimizer."""
144
+
145
+ lr: float = 1e-2
146
+ etas: tuple[float, float] = (0.5, 1.2)
147
+ step_sizes: tuple[float, float] = (1e-6, 50.0)
148
+ foreach: bool | None = None
149
+ maximize: bool = False
150
+ differentiable: bool = False
151
+ _target_: str = "torch.optim.Rprop"
152
+ _partial_: bool = True
153
+
154
+ def __post_init__(self) -> None:
155
+ """Post-initialization hook for Rprop optimizer configurations."""
156
+
157
+
158
+ @dataclass
159
+ class ASGDConfig(ConfigBase):
160
+ """Configuration for ASGD optimizer."""
161
+
162
+ lr: float = 1e-2
163
+ lambd: float = 1e-4
164
+ alpha: float = 0.75
165
+ t0: float = 1e6
166
+ weight_decay: float = 0.0
167
+ foreach: bool | None = None
168
+ maximize: bool = False
169
+ differentiable: bool = False
170
+ _target_: str = "torch.optim.ASGD"
171
+ _partial_: bool = True
172
+
173
+ def __post_init__(self) -> None:
174
+ """Post-initialization hook for ASGD optimizer configurations."""
175
+
176
+
177
+ @dataclass
178
+ class LBFGSConfig(ConfigBase):
179
+ """Configuration for LBFGS optimizer."""
180
+
181
+ lr: float = 1.0
182
+ max_iter: int = 20
183
+ max_eval: int | None = None
184
+ tolerance_grad: float = 1e-7
185
+ tolerance_change: float = 1e-9
186
+ history_size: int = 100
187
+ line_search_fn: str | None = None
188
+ _target_: str = "torch.optim.LBFGS"
189
+ _partial_: bool = True
190
+
191
+ def __post_init__(self) -> None:
192
+ """Post-initialization hook for LBFGS optimizer configurations."""
193
+
194
+
195
+ @dataclass
196
+ class RAdamConfig(ConfigBase):
197
+ """Configuration for RAdam optimizer."""
198
+
199
+ lr: float = 1e-3
200
+ betas: tuple[float, float] = (0.9, 0.999)
201
+ eps: float = 1e-8
202
+ weight_decay: float = 0.0
203
+ _target_: str = "torch.optim.RAdam"
204
+ _partial_: bool = True
205
+
206
+ def __post_init__(self) -> None:
207
+ """Post-initialization hook for RAdam optimizer configurations."""
208
+
209
+
210
+ @dataclass
211
+ class NAdamConfig(ConfigBase):
212
+ """Configuration for NAdam optimizer."""
213
+
214
+ lr: float = 2e-3
215
+ betas: tuple[float, float] = (0.9, 0.999)
216
+ eps: float = 1e-8
217
+ weight_decay: float = 0.0
218
+ momentum_decay: float = 4e-3
219
+ foreach: bool | None = None
220
+ _target_: str = "torch.optim.NAdam"
221
+ _partial_: bool = True
222
+
223
+ def __post_init__(self) -> None:
224
+ """Post-initialization hook for NAdam optimizer configurations."""
225
+
226
+
227
+ @dataclass
228
+ class SparseAdamConfig(ConfigBase):
229
+ """Configuration for SparseAdam optimizer."""
230
+
231
+ lr: float = 1e-3
232
+ betas: tuple[float, float] = (0.9, 0.999)
233
+ eps: float = 1e-8
234
+ _target_: str = "torch.optim.SparseAdam"
235
+ _partial_: bool = True
236
+
237
+ def __post_init__(self) -> None:
238
+ """Post-initialization hook for SparseAdam optimizer configurations."""
239
+
240
+
241
+ @dataclass
242
+ class LionConfig(ConfigBase):
243
+ """Configuration for Lion optimizer."""
244
+
245
+ lr: float = 1e-4
246
+ betas: tuple[float, float] = (0.9, 0.99)
247
+ weight_decay: float = 0.0
248
+ _target_: str = "torch.optim.Lion"
249
+ _partial_: bool = True
250
+
251
+ def __post_init__(self) -> None:
252
+ """Post-initialization hook for Lion optimizer configurations."""
@@ -0,0 +1,191 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Any
10
+
11
+ from torchrl.trainers.algorithms.configs.common import ConfigBase
12
+
13
+
14
+ @dataclass
15
+ class WeightSyncSchemeConfig(ConfigBase):
16
+ """Base configuration for weight synchronization schemes."""
17
+
18
+ _target_: str = "torchrl.weight_update.WeightSyncScheme"
19
+ _partial_: bool = False
20
+
21
+ # Common argument for all schemes
22
+ strategy: str = "tensordict" # "tensordict" or "state_dict"
23
+
24
+ def __post_init__(self) -> None:
25
+ """Post-initialization hook for weight sync scheme configurations."""
26
+
27
+
28
+ @dataclass
29
+ class MultiProcessWeightSyncSchemeConfig(ConfigBase):
30
+ """Configuration for MultiProcessWeightSyncScheme.
31
+
32
+ Weight synchronization for multiprocess operations using pipes.
33
+ This scheme creates transports that communicate via multiprocessing pipes.
34
+ """
35
+
36
+ _target_: str = "torchrl.weight_update.MultiProcessWeightSyncScheme"
37
+ _partial_: bool = False
38
+
39
+ strategy: str = "tensordict" # "tensordict" or "state_dict"
40
+
41
+ def __post_init__(self) -> None:
42
+ """Post-initialization hook for multiprocess weight sync scheme configurations."""
43
+
44
+
45
+ @dataclass
46
+ class SharedMemWeightSyncSchemeConfig(ConfigBase):
47
+ """Configuration for SharedMemWeightSyncScheme.
48
+
49
+ Weight synchronization using shared memory for in-place weight updates.
50
+ Workers automatically see weight updates without explicit message passing.
51
+ """
52
+
53
+ _target_: str = "torchrl.weight_update.SharedMemWeightSyncScheme"
54
+ _partial_: bool = False
55
+
56
+ strategy: str = "tensordict" # "tensordict" or "state_dict"
57
+
58
+ def __post_init__(self) -> None:
59
+ """Post-initialization hook for shared memory weight sync scheme configurations."""
60
+
61
+
62
+ @dataclass
63
+ class NoWeightSyncSchemeConfig(ConfigBase):
64
+ """Configuration for NoWeightSyncScheme.
65
+
66
+ No-op weight synchronization scheme that disables weight synchronization entirely.
67
+ """
68
+
69
+ _target_: str = "torchrl.weight_update.NoWeightSyncScheme"
70
+ _partial_: bool = False
71
+
72
+ strategy: str = "tensordict" # Not really used, but kept for consistency
73
+
74
+ def __post_init__(self) -> None:
75
+ """Post-initialization hook for no weight sync scheme configurations."""
76
+
77
+
78
+ @dataclass
79
+ class RayWeightSyncSchemeConfig(ConfigBase):
80
+ """Configuration for RayWeightSyncScheme.
81
+
82
+ Weight synchronization for Ray distributed computing. Uses Ray's object store
83
+ and remote calls to synchronize weights across distributed workers (Ray actors).
84
+ """
85
+
86
+ _target_: str = "torchrl.weight_update.RayWeightSyncScheme"
87
+ _partial_: bool = False
88
+
89
+ strategy: str = "tensordict" # "tensordict" or "state_dict"
90
+
91
+ def __post_init__(self) -> None:
92
+ """Post-initialization hook for Ray weight sync scheme configurations."""
93
+
94
+
95
+ @dataclass
96
+ class RayModuleTransformSchemeConfig(ConfigBase):
97
+ """Configuration for RayModuleTransformScheme.
98
+
99
+ Weight synchronization for RayModuleTransform actors. This scheme is designed
100
+ specifically for updating models hosted within Ray actors.
101
+ """
102
+
103
+ _target_: str = "torchrl.weight_update.RayModuleTransformScheme"
104
+ _partial_: bool = False
105
+
106
+ strategy: str = "tensordict" # "tensordict" or "state_dict"
107
+
108
+ def __post_init__(self) -> None:
109
+ """Post-initialization hook for Ray module transform scheme configurations."""
110
+
111
+
112
+ @dataclass
113
+ class RPCWeightSyncSchemeConfig(ConfigBase):
114
+ """Configuration for RPCWeightSyncScheme.
115
+
116
+ Weight synchronization for torch.distributed.rpc. Uses RPC calls to synchronize
117
+ weights across distributed workers.
118
+ """
119
+
120
+ _target_: str = "torchrl.weight_update.RPCWeightSyncScheme"
121
+ _partial_: bool = False
122
+
123
+ strategy: str = "tensordict" # "tensordict" or "state_dict"
124
+
125
+ def __post_init__(self) -> None:
126
+ """Post-initialization hook for RPC weight sync scheme configurations."""
127
+
128
+
129
+ @dataclass
130
+ class DistributedWeightSyncSchemeConfig(ConfigBase):
131
+ """Configuration for DistributedWeightSyncScheme.
132
+
133
+ Weight synchronization for torch.distributed. Uses torch.distributed primitives
134
+ (send/recv) to synchronize weights across distributed workers.
135
+ """
136
+
137
+ _target_: str = "torchrl.weight_update.DistributedWeightSyncScheme"
138
+ _partial_: bool = False
139
+
140
+ backend: str = "gloo" # "gloo", "nccl", etc.
141
+ sync: bool = True
142
+ strategy: str = "tensordict" # "tensordict" or "state_dict"
143
+
144
+ def __post_init__(self) -> None:
145
+ """Post-initialization hook for distributed weight sync scheme configurations."""
146
+
147
+
148
+ @dataclass
149
+ class VLLMWeightSyncSchemeConfig(ConfigBase):
150
+ """Configuration for VLLMWeightSyncScheme.
151
+
152
+ Weight synchronization scheme for vLLM engines using collective communication (NCCL).
153
+ Broadcasts weights from a trainer to vLLM inference workers with parallelism support.
154
+ """
155
+
156
+ _target_: str = "torchrl.weight_update.llm.VLLMWeightSyncScheme"
157
+ _partial_: bool = False
158
+
159
+ master_address: str | None = None # Defaults to "localhost"
160
+ master_port: int | None = None # Auto-assigned if None
161
+ gpus_per_replica: int = 1 # tp_size × dp_size × pp_size
162
+ num_replicas: int = 1
163
+ strategy: str = "tensordict" # "tensordict" or "state_dict"
164
+ device: Any = 0 # torch.device | str | int
165
+
166
+ def __post_init__(self) -> None:
167
+ """Post-initialization hook for vLLM weight sync scheme configurations."""
168
+
169
+
170
+ @dataclass
171
+ class VLLMDoubleBufferSyncSchemeConfig(ConfigBase):
172
+ """Configuration for VLLMDoubleBufferSyncScheme.
173
+
174
+ Weight synchronization scheme for vLLM using double-buffered memory-mapped storage.
175
+ Uses TensorDict's memory-mapping capabilities to transfer weights via filesystem.
176
+ """
177
+
178
+ _target_: str = "torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme"
179
+ _partial_: bool = False
180
+
181
+ remote_addr: str | None = None # Directory path where sender writes weights
182
+ local_addr: str | None = None # Directory path where receiver reads weights
183
+ num_threads: int = 1 # Number of threads for memmap operations
184
+ strategy: str = "tensordict" # "tensordict" or "state_dict"
185
+
186
+ def __post_init__(self) -> None:
187
+ """Post-initialization hook for vLLM double buffer sync scheme configurations."""
188
+ if self.remote_addr is None:
189
+ raise ValueError("remote_addr is required for VLLMDoubleBufferSyncScheme")
190
+ if self.local_addr is None:
191
+ self.local_addr = self.remote_addr
@@ -0,0 +1,159 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Any
10
+
11
+ from torchrl.trainers.algorithms.configs.common import ConfigBase
12
+
13
+
14
+ @dataclass
15
+ class WeightUpdaterConfig(ConfigBase):
16
+ """Base configuration for weight updaters."""
17
+
18
+ _target_: str = "torchrl.collectors.WeightUpdaterBase"
19
+ _partial_: bool = True
20
+
21
+ def __post_init__(self) -> None:
22
+ """Post-initialization hook for weight updater configurations."""
23
+
24
+
25
+ @dataclass
26
+ class VanillaWeightUpdaterConfig(ConfigBase):
27
+ """Configuration for VanillaWeightUpdater.
28
+
29
+ A simple implementation for updating local policy weights by directly
30
+ fetching them from a specified source.
31
+ """
32
+
33
+ _target_: str = "torchrl.collectors.VanillaWeightUpdater"
34
+ _partial_: bool = True
35
+
36
+ # Constructor arguments
37
+ weight_getter: Any = None # Callable[[], TensorDictBase] | None
38
+ policy_weights: Any = None # TensorDictBase
39
+
40
+ def __post_init__(self) -> None:
41
+ """Post-initialization hook for vanilla weight updater configurations."""
42
+
43
+
44
+ @dataclass
45
+ class MultiProcessedWeightUpdaterConfig(ConfigBase):
46
+ """Configuration for MultiProcessedWeightUpdater.
47
+
48
+ A remote weight updater for synchronizing policy weights across multiple
49
+ processes or devices in a multiprocessed environment.
50
+ """
51
+
52
+ _target_: str = "torchrl.collectors.MultiProcessedWeightUpdater"
53
+ _partial_: bool = True
54
+
55
+ # Constructor arguments
56
+ get_server_weights: Any = None # Callable[[], TensorDictBase] | None
57
+ policy_weights: Any = None # dict[torch.device, TensorDictBase]
58
+
59
+ def __post_init__(self) -> None:
60
+ """Post-initialization hook for multiprocessed weight updater configurations."""
61
+
62
+
63
+ @dataclass
64
+ class RayWeightUpdaterConfig(ConfigBase):
65
+ """Configuration for RayWeightUpdater.
66
+
67
+ A remote weight updater for synchronizing policy weights across remote
68
+ workers using Ray's distributed computing capabilities.
69
+ """
70
+
71
+ _target_: str = "torchrl.collectors.RayWeightUpdater"
72
+ _partial_: bool = True
73
+
74
+ # Constructor arguments
75
+ policy_weights: Any = None # TensorDictBase
76
+ remote_collectors: Any = None # list
77
+ max_interval: int = 0 # int
78
+
79
+ def __post_init__(self) -> None:
80
+ """Post-initialization hook for Ray weight updater configurations."""
81
+
82
+
83
+ @dataclass
84
+ class RPCWeightUpdaterConfig(ConfigBase):
85
+ """Configuration for RPCWeightUpdater.
86
+
87
+ A remote weight updater for synchronizing policy weights across remote
88
+ workers using RPC communication.
89
+ """
90
+
91
+ _target_: str = "torchrl.collectors.distributed.RPCWeightUpdater"
92
+ _partial_: bool = True
93
+
94
+ # Constructor arguments
95
+ collector_infos: Any = None
96
+ collector_class: Any = None
97
+ collector_rrefs: Any = None
98
+ policy_weights: Any = None # TensorDictBase
99
+ num_workers: int = 0
100
+
101
+ def __post_init__(self) -> None:
102
+ """Post-initialization hook for RPC weight updater configurations."""
103
+
104
+
105
+ @dataclass
106
+ class DistributedWeightUpdaterConfig(ConfigBase):
107
+ """Configuration for DistributedWeightUpdater.
108
+
109
+ A remote weight updater for synchronizing policy weights across distributed
110
+ workers using a dictionary-like store for communication.
111
+ """
112
+
113
+ _target_: str = "torchrl.collectors.distributed.DistributedWeightUpdater"
114
+ _partial_: bool = True
115
+
116
+ # Constructor arguments
117
+ store: Any = None # dict[str, str]
118
+ policy_weights: Any = None # TensorDictBase
119
+ num_workers: int = 0
120
+ sync: bool = True
121
+
122
+ def __post_init__(self) -> None:
123
+ """Post-initialization hook for distributed weight updater configurations."""
124
+
125
+
126
+ @dataclass
127
+ class RemoteModuleWeightUpdaterConfig(ConfigBase):
128
+ """Configuration for RemoteModuleWeightUpdater.
129
+
130
+ A weight updater for remote nn.Modules that requires explicit weight passing.
131
+ Used when the master collector doesn't have direct access to worker weights.
132
+ """
133
+
134
+ _target_: str = "torchrl.collectors.RemoteModuleWeightUpdater"
135
+ _partial_: bool = True
136
+
137
+ def __post_init__(self) -> None:
138
+ """Post-initialization hook for remote module weight updater configurations."""
139
+
140
+
141
+ @dataclass
142
+ class vLLMUpdaterConfig(ConfigBase):
143
+ """Configuration for vLLMUpdater.
144
+
145
+ A weight updater that sends weights to vLLM workers, supporting both local
146
+ vLLM instances and remote Ray actors for LLM inference.
147
+ """
148
+
149
+ _target_: str = "torchrl.collectors.llm.vLLMUpdater"
150
+ _partial_: bool = True
151
+
152
+ # Constructor arguments
153
+ master_address: str | None = None
154
+ master_port: int | None = None
155
+ model_metadata: Any = None # dict[str, tuple[torch.dtype, torch.Size]] | None
156
+ vllm_tp_size: int | None = None
157
+
158
+ def __post_init__(self) -> None:
159
+ """Post-initialization hook for vLLM updater configurations."""