torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,32 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .vllm_double_buffer import (
7
+ VLLMDoubleBufferSyncScheme,
8
+ VLLMDoubleBufferTransport,
9
+ VLLMDoubleBufferWeightReceiver,
10
+ VLLMDoubleBufferWeightSender,
11
+ )
12
+ from .vllm_nccl import (
13
+ get_model_metadata,
14
+ VLLMCollectiveTransport,
15
+ VLLMWeightReceiver,
16
+ VLLMWeightSender,
17
+ VLLMWeightSyncScheme,
18
+ )
19
+
20
+ __all__ = [
21
+ # vLLM NCCL-based weight sync
22
+ "VLLMWeightSyncScheme",
23
+ "VLLMWeightSender",
24
+ "VLLMWeightReceiver",
25
+ "VLLMCollectiveTransport",
26
+ "get_model_metadata",
27
+ # vLLM double-buffer weight sync
28
+ "VLLMDoubleBufferSyncScheme",
29
+ "VLLMDoubleBufferWeightSender",
30
+ "VLLMDoubleBufferWeightReceiver",
31
+ "VLLMDoubleBufferTransport",
32
+ ]
@@ -0,0 +1,370 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """vLLM weight synchronization using double-buffered shared memory.
7
+
8
+ This module provides weight synchronization for vLLM engines using a double-buffer
9
+ approach with memory-mapped TensorDict storage.
10
+
11
+ **Architecture Overview**
12
+
13
+ The double-buffer synchronization uses a simpler architecture compared to NCCL:
14
+
15
+ 1. **Sender (Trainer)**
16
+ - Extracts weights from the training model
17
+ - Writes weights to shared directory using TensorDict.memmap
18
+ - No coordination needed - receiver pulls when ready
19
+
20
+ 2. **Receiver (vLLM Worker)**
21
+ - Uses RPC to tell all vLLM workers to load from shared directory
22
+ - Each worker reads weights and calls model.load_weights()
23
+ - Can trigger at any time (pull-based)
24
+
25
+ **Key Differences from NCCL**
26
+
27
+ - **Async vs Sync**: Double-buffer is asynchronous (no coordination required)
28
+ - **Push vs Pull**: Sender writes, receiver pulls when ready via RPC
29
+ - **Simplicity**: No NCCL collectives, uses file I/O
30
+ - **Storage**: Uses shared filesystem instead of GPU-GPU transfer
31
+
32
+ **RPC Pattern**
33
+
34
+ Like the NCCL implementation, this uses RPC to coordinate workers:
35
+ - RPC tells workers: "load weights from this directory"
36
+ - Workers read from shared storage independently
37
+ - Each worker calls `model_runner.model.load_weights()`
38
+
39
+ **Usage Example**
40
+
41
+ .. code-block:: python
42
+
43
+ # Create scheme with shared directory
44
+ scheme = VLLMDoubleBufferSyncScheme(
45
+ remote_addr="/shared/weights",
46
+ num_threads=4
47
+ )
48
+
49
+ # Sender side (trainer)
50
+ sender = scheme.create_sender()
51
+ sender.register_model(policy_model)
52
+ sender.update_weights() # Writes to /shared/weights
53
+
54
+ # Receiver side (vLLM worker - AsyncVLLM)
55
+ receiver = scheme.create_receiver(vllm_engine)
56
+ receiver.poll_and_apply() # RPC to workers -> load from /shared/weights
57
+
58
+ **Node-to-Node Transfer**
59
+
60
+ For distributed setups, you can use different addresses:
61
+ - Sender writes to local path
62
+ - Use NFS, rsync, or other file sync mechanisms
63
+ - Receiver reads from its local mount point
64
+ """
65
+
66
+ from __future__ import annotations
67
+
68
+ from typing import Any, Literal
69
+
70
+ from tensordict import TensorDict, TensorDictBase
71
+ from torchrl._utils import logger
72
+ from torchrl.weight_update.weight_sync_schemes import WeightStrategy, WeightSyncScheme
73
+
74
+
75
+ class VLLMDoubleBufferTransport:
76
+ """Transport for vLLM using double-buffered memory-mapped storage.
77
+
78
+ This transport writes weights to a shared directory and reads them back
79
+ using TensorDict's memory-mapping capabilities.
80
+
81
+ Args:
82
+ remote_addr: Directory path where sender writes weights.
83
+ local_addr: Directory path where receiver reads weights.
84
+ If None, uses same path as remote_addr (for local testing).
85
+ num_threads: Number of threads for memmap operations.
86
+ """
87
+
88
+ def __init__(
89
+ self, remote_addr: str, local_addr: str | None = None, num_threads: int = 1
90
+ ):
91
+ if local_addr is None:
92
+ local_addr = remote_addr
93
+ self.remote_addr = remote_addr
94
+ self.local_addr = local_addr
95
+ self.num_threads = num_threads
96
+
97
+ def send_weights(self, model_id: str, weights: Any) -> None:
98
+ """Writes the weights to a shared directory.
99
+
100
+ Args:
101
+ model_id: Identifier for the model (used for logging).
102
+ weights: TensorDict or dict of weights to write.
103
+ """
104
+ if isinstance(weights, dict):
105
+ weights = TensorDict(weights, batch_size=[])
106
+ elif isinstance(weights, TensorDictBase):
107
+ # Ensure it has a batch_size
108
+ if weights.batch_size == ():
109
+ weights = weights.clone()
110
+
111
+ logger.info(f"Writing weights for model '{model_id}' to {self.remote_addr}")
112
+ weights.memmap(self.remote_addr, num_threads=self.num_threads)
113
+ logger.info(f"Weights written successfully to {self.remote_addr}")
114
+
115
+ def receive_weights(
116
+ self,
117
+ timeout: float | None = None,
118
+ *,
119
+ weights: Any = None,
120
+ model: Any = None,
121
+ strategy: Any = None,
122
+ ) -> Any | None:
123
+ """Reads the weights from the shared directory.
124
+
125
+ Args:
126
+ timeout: Ignored (file-based transport is instant).
127
+ weights: Ignored.
128
+ model: Ignored.
129
+ strategy: Ignored.
130
+
131
+ Returns:
132
+ TensorDict with flattened keys containing the weights.
133
+ """
134
+ # Timeout is ignored since file-based transport doesn't involve waiting
135
+ logger.info(f"Reading weights from {self.local_addr}")
136
+ received_weights = TensorDict.load_memmap(self.local_addr)
137
+ received_weights = received_weights.flatten_keys(".")
138
+ logger.info(f"Weights read successfully from {self.local_addr}")
139
+ return received_weights
140
+
141
+ def check_connection(self) -> bool:
142
+ """Check if the transport is ready.
143
+
144
+ For file-based transport, always returns True.
145
+ """
146
+ return True
147
+
148
+
149
+ class VLLMDoubleBufferSyncScheme(WeightSyncScheme):
150
+ """Weight synchronization scheme for vLLM using double-buffered storage.
151
+
152
+ This scheme uses memory-mapped TensorDict storage to transfer weights from
153
+ a trainer to vLLM inference workers. It's simpler than NCCL-based approaches
154
+ and doesn't require process group coordination.
155
+
156
+ Args:
157
+ remote_addr: Directory path where sender writes weights.
158
+ local_addr: Directory path where receiver reads weights.
159
+ If None, uses same path as remote_addr (for local testing).
160
+ num_threads: Number of threads for memmap operations. Defaults to 1.
161
+ strategy: Weight extraction strategy ("tensordict" or "state_dict").
162
+
163
+ Example:
164
+ >>> # Local testing (same machine)
165
+ >>> scheme = VLLMDoubleBufferSyncScheme(
166
+ ... remote_addr="/tmp/weights",
167
+ ... strategy="tensordict"
168
+ ... )
169
+ >>>
170
+ >>> # Distributed setup (different machines)
171
+ >>> # On trainer node:
172
+ >>> scheme = VLLMDoubleBufferSyncScheme(
173
+ ... remote_addr="/mnt/shared/weights", # NFS mount
174
+ ... num_threads=4
175
+ ... )
176
+ >>>
177
+ >>> # On vLLM worker node:
178
+ >>> scheme = VLLMDoubleBufferSyncScheme(
179
+ ... remote_addr="/mnt/shared/weights", # Same NFS mount
180
+ ... num_threads=4
181
+ ... )
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ remote_addr: str,
187
+ local_addr: str | None = None,
188
+ num_threads: int = 1,
189
+ strategy: Literal["tensordict", "state_dict"] = "tensordict",
190
+ ):
191
+ self.remote_addr = remote_addr
192
+ self.local_addr = local_addr if local_addr is not None else remote_addr
193
+ self.num_threads = num_threads
194
+ self.strategy_name = strategy
195
+
196
+ def create_transport(self, **kwargs) -> VLLMDoubleBufferTransport:
197
+ """Create transport for double-buffered storage.
198
+
199
+ Args:
200
+ **kwargs: Not used for file-based transport (kept for API compatibility).
201
+
202
+ Returns:
203
+ A VLLMDoubleBufferTransport instance.
204
+ """
205
+ return VLLMDoubleBufferTransport(
206
+ remote_addr=self.remote_addr,
207
+ local_addr=self.local_addr,
208
+ num_threads=self.num_threads,
209
+ )
210
+
211
+ def create_sender(self) -> VLLMDoubleBufferWeightSender:
212
+ """Create a weight sender for the trainer process."""
213
+ return VLLMDoubleBufferWeightSender(self)
214
+
215
+ def create_receiver(self, vllm_engine) -> VLLMDoubleBufferWeightReceiver:
216
+ """Create a weight receiver for a vLLM worker process.
217
+
218
+ Args:
219
+ vllm_engine: The vLLM engine instance (must have .llm_engine.model_executor attribute).
220
+ """
221
+ return VLLMDoubleBufferWeightReceiver(self, vllm_engine)
222
+
223
+
224
+ class VLLMDoubleBufferWeightSender:
225
+ """Sends weights to vLLM workers using double-buffered storage.
226
+
227
+ This sender extracts weights from a training model and writes them to
228
+ a shared directory using TensorDict.memmap.
229
+
230
+ Example:
231
+ >>> sender = scheme.create_sender()
232
+ >>> sender.register_model(policy_model)
233
+ >>>
234
+ >>> # During training loop
235
+ >>> sender.update_weights() # Writes current weights to shared storage
236
+ """
237
+
238
+ def __init__(self, scheme: VLLMDoubleBufferSyncScheme):
239
+ self._scheme = scheme
240
+ self._strategy = WeightStrategy(extract_as=scheme.strategy_name)
241
+ self._model_ref = None
242
+ self._transport = None
243
+
244
+ def register_model(self, model: Any) -> None:
245
+ """Register the model to extract weights from.
246
+
247
+ Args:
248
+ model: The model to extract weights from (e.g., TransformersWrapper).
249
+ """
250
+ import weakref
251
+
252
+ self._model_ref = weakref.ref(model)
253
+
254
+ # Create transport on registration
255
+ self._transport = self._scheme.create_transport()
256
+ logger.info(
257
+ f"Registered model for double-buffer weight sync to {self._scheme.remote_addr}"
258
+ )
259
+
260
+ def update_weights(self, weights: Any | None = None) -> None:
261
+ """Extract and write weights to shared storage.
262
+
263
+ Args:
264
+ weights: Optional weights to send. If None, extracts from registered model.
265
+ """
266
+ if self._transport is None:
267
+ raise RuntimeError("Transport not initialized. Call register_model first.")
268
+
269
+ # Extract weights if not provided
270
+ if weights is None:
271
+ model = self._model_ref()
272
+ if model is None:
273
+ raise RuntimeError("Model reference is dead")
274
+ weights = self._strategy.extract_weights(model)
275
+ else:
276
+ # Ensure weights are in the right format
277
+ if hasattr(weights, "state_dict"):
278
+ # It's a module, extract
279
+ weights = self._strategy.extract_weights(weights)
280
+
281
+ # Send via transport
282
+ self._transport.send_weights("vllm_model", weights)
283
+
284
+
285
+ class VLLMDoubleBufferWeightReceiver:
286
+ """Receives weights in a vLLM worker using double-buffered storage.
287
+
288
+ This receiver reads weights from a shared directory and loads them into
289
+ the vLLM engine using the engine's load_weights interface.
290
+
291
+ Example:
292
+ >>> receiver = scheme.create_receiver(vllm_engine)
293
+ >>>
294
+ >>> # Poll for new weights
295
+ >>> if receiver.poll_and_apply():
296
+ ... print("Weights updated!")
297
+ """
298
+
299
+ def __init__(self, scheme: VLLMDoubleBufferSyncScheme, vllm_engine):
300
+ self._scheme = scheme
301
+ self._strategy = WeightStrategy(extract_as=scheme.strategy_name)
302
+ self._vllm_engine = vllm_engine
303
+ self._transport = scheme.create_transport()
304
+ logger.info(
305
+ f"Initialized double-buffer receiver reading from {self._scheme.local_addr}"
306
+ )
307
+
308
+ def apply_weights(self, weights: TensorDict, inplace: bool = True) -> None:
309
+ """Apply weights to vLLM engine using RPC.
310
+
311
+ This method uses RPC to tell all vLLM workers to load weights from
312
+ the shared storage directory. Similar to how AsyncVLLM._update_weights_with_nccl_broadcast_simple
313
+ uses collective_rpc to coordinate workers.
314
+
315
+ Args:
316
+ weights: TensorDict with flattened keys containing weights.
317
+ inplace: Whether to apply weights in place. Default is `True`.
318
+ """
319
+ if not inplace:
320
+ raise ValueError("Cannot apply weights out of place for vLLM double-buffer")
321
+ logger.info("Applying weights to vLLM engine via RPC")
322
+
323
+ # Convert TensorDict to list of (name, tensor) tuples
324
+ weights_list = list(weights.items())
325
+
326
+ # Check if this is an AsyncVLLM instance (uses RPC to coordinate workers)
327
+ if hasattr(self._vllm_engine, "collective_rpc"):
328
+ # AsyncVLLM path: use RPC to tell all workers to load weights
329
+ logger.info(
330
+ f"Using RPC to load {len(weights_list)} weights across all replicas"
331
+ )
332
+
333
+ # Call collective_rpc to tell workers to load from shared storage
334
+ # The method 'load_weights_from_storage' will be called on each worker
335
+ futures = self._vllm_engine.collective_rpc(
336
+ method="load_weights_from_storage",
337
+ args=(str(self._scheme.local_addr), self._transport.num_threads),
338
+ )
339
+
340
+ # Wait for all workers to complete
341
+ import ray
342
+
343
+ ray.get(futures)
344
+ logger.info("Weights loaded successfully via RPC")
345
+ else:
346
+ # Direct path for local LLM (non-AsyncVLLM)
347
+ logger.info("Using direct load for local LLM")
348
+ engine = (
349
+ self._vllm_engine.llm_engine
350
+ if hasattr(self._vllm_engine, "llm_engine")
351
+ else self._vllm_engine
352
+ )
353
+ worker = engine.model_executor.driver_worker
354
+ model = worker.model_runner.model
355
+ model.load_weights(weights_list)
356
+ logger.info("Weights loaded successfully")
357
+
358
+ def poll_and_apply(self, timeout: float = 180.0) -> bool:
359
+ """Poll for and apply weights from shared storage.
360
+
361
+ Args:
362
+ timeout: Not used for file-based transport (kept for API compatibility).
363
+
364
+ Returns:
365
+ True if weights were successfully read and applied, False otherwise.
366
+ """
367
+ # timeout is not used by file-based transport but kept for API compatibility
368
+ weights = self._transport.receive_weights()
369
+ self.apply_weights(weights)
370
+ return True