torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,710 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """vLLM weight synchronization for the v2 API.
7
+
8
+ This module provides weight synchronization for vLLM engines using a two-layer
9
+ architecture:
10
+
11
+ **Architecture Overview**
12
+
13
+ The weight synchronization uses two separate layers:
14
+
15
+ 1. **RPC Layer** (Coordination)
16
+ - Signals workers when a collective operation will begin
17
+ - Can be implemented with different backends (Ray, torch.distributed.rpc, etc.)
18
+ - Tells vLLM workers: "prepare to receive weights via collective"
19
+ - Currently supports Ray as the RPC backend
20
+
21
+ 2. **Collective Layer** (Data Transfer)
22
+ - Performs the actual weight broadcast using NCCL
23
+ - High-bandwidth GPU-to-GPU communication
24
+ - All ranks participate simultaneously in the collective
25
+
26
+ **Why Two Layers?**
27
+
28
+ Separating RPC and collectives provides:
29
+ - **Flexibility**: Swap RPC backends (Ray, RPC, gRPC) without changing collectives
30
+ - **Clarity**: Coordination logic separate from data transfer
31
+ - **Performance**: Use optimal transport for each (RPC for signals, NCCL for data)
32
+
33
+ **Flow Example (Ray Backend)**
34
+
35
+ .. code-block::
36
+
37
+ Trainer (rank 0) vLLM Workers (ranks 1+)
38
+ ================ =======================
39
+
40
+ # 1. RPC: Signal collective start
41
+ trainer.update_weights() ---------> [Ray RPC] --------> receiver.init_all_workers_group()
42
+ "I'm ready for collective"
43
+
44
+ # 2. Collective: Broadcast weights
45
+ NCCL broadcast -------------------- [GPU-GPU] ---------> NCCL receive
46
+ (high bandwidth) (parallel)
47
+
48
+ # 3. RPC: Confirmation (optional)
49
+ "broadcast done" <----------------- [Ray RPC] --------- "weights applied"
50
+
51
+ **Extending to Other Backends**
52
+
53
+ To add a new RPC backend (e.g., torch.distributed.rpc):
54
+
55
+ 1. Implement an RPC coordinator in the sender/receiver
56
+ 2. Replace Ray remote calls with your RPC mechanism
57
+ 3. Keep the collective layer unchanged (it's backend-agnostic)
58
+
59
+ .. rubric:: Example
60
+
61
+ .. code-block:: python
62
+
63
+ class TorchRPCVLLMReceiver(VLLMWeightReceiver):
64
+ def init_all_workers_group(self, metadata):
65
+ # Use torch.distributed.rpc instead of Ray
66
+ torch.distributed.rpc.rpc_sync(
67
+ "trainer",
68
+ lambda: "ready",
69
+ )
70
+ super().init_all_workers_group(metadata) # Collective init
71
+
72
+ **Current Implementation (Ray Backend)**
73
+
74
+ .. code-block:: python
75
+
76
+ # Trainer actor (provides RPC endpoint)
77
+ trainer = RayWorkerTransformer.as_remote().options(
78
+ name="Trainer" # Named for discovery
79
+ ).remote(scheme_config)
80
+
81
+ # Receiver actor (uses RPC to coordinate)
82
+ receiver = RayWorkerVLLM.as_remote().remote(
83
+ scheme_config, trainer_actor_name="Trainer"
84
+ )
85
+
86
+ # RPC Layer: Both actors call init() via Ray remote calls
87
+ # This coordinates the collective handshake
88
+ ray.get([trainer.init.remote(), receiver.init.remote()])
89
+
90
+ # RPC Layer: Trigger update via Ray remote call
91
+ # Collective Layer: NCCL broadcast happens automatically
92
+ ray.get(trainer.update_weights.remote(modify_weights=True))
93
+
94
+ In this setup:
95
+ - **Ray provides RPC**: Named actors, ``remote()`` calls, ``ray.get()``
96
+ - **NCCL provides collectives**: GPU-GPU weight broadcast
97
+ - **Loose coupling**: Can replace Ray with any RPC mechanism
98
+ """
99
+
100
+ from __future__ import annotations
101
+
102
+ import time
103
+
104
+ from typing import Any, Literal
105
+
106
+ import torch
107
+ import torch.distributed
108
+ from tensordict import TensorDictBase
109
+
110
+ from torchrl._utils import logger as torchrl_logger
111
+ from torchrl.modules.llm.backends import stateless_init_process_group
112
+ from torchrl.weight_update.weight_sync_schemes import WeightStrategy, WeightSyncScheme
113
+
114
+ # ============================================================================
115
+ # vLLM Transport using Collective Communication
116
+ # ============================================================================
117
+
118
+
119
+ class VLLMCollectiveTransport:
120
+ """Transport for vLLM using collective communication (NCCL).
121
+
122
+ **COLLECTIVE LAYER ONLY** - This class handles the data transfer layer.
123
+ RPC coordination is handled separately by the caller (sender/receiver).
124
+
125
+ This transport uses PyTorch distributed collectives to broadcast weights
126
+ from a trainer (rank 0) to vLLM workers (ranks 1+).
127
+
128
+ **Separation of Concerns:**
129
+ - This class: NCCL collective operations (GPU-GPU data transfer)
130
+ - Caller (sender/receiver): RPC coordination (when to start collective)
131
+
132
+ Args:
133
+ master_address: Address of the master node for distributed init.
134
+ master_port: Port of the master node for distributed init.
135
+ rank: Rank of this process (0 for trainer, 1+ for vLLM workers).
136
+ world_size: Total number of processes (1 + num_replicas * gpus_per_replica).
137
+ device: Device to use for communication (typically cuda:0).
138
+ vllm_engine: Optional vLLM engine reference (for receiver side).
139
+
140
+ Note:
141
+ The RPC layer (e.g., Ray remote calls) must ensure all ranks call
142
+ init_all_workers_group() simultaneously before any collective operations.
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ master_address: str,
148
+ master_port: int,
149
+ rank: int | None,
150
+ world_size: int,
151
+ device: torch.device | str | int | None = None,
152
+ vllm_engine: Any | None = None,
153
+ ):
154
+ self.master_address = master_address
155
+ self.master_port = master_port
156
+ self.rank = rank
157
+ self.world_size = world_size
158
+ self.vllm_engine = vllm_engine
159
+ self._comm_group = None
160
+ self._model_metadata = None
161
+
162
+ # Ray sets CUDA_VISIBLE_DEVICES, so each actor sees only device 0
163
+ # PyNcclCommunicator expects an integer device index
164
+ if device is None:
165
+ self.device = 0 # Default to device 0 (Ray convention)
166
+ elif isinstance(device, str):
167
+ # Extract device index from "cuda:X"
168
+ self.device = int(device.split(":")[-1]) if ":" in device else 0
169
+ elif isinstance(device, torch.device):
170
+ # Extract index from torch.device
171
+ self.device = device.index if device.index is not None else 0
172
+ else:
173
+ self.device = device
174
+
175
+ def init_all_workers_group(
176
+ self, model_metadata: dict[str, tuple[torch.dtype, torch.Size]]
177
+ ):
178
+ """Initialize the collective communication group.
179
+
180
+ Args:
181
+ model_metadata: Dict mapping param names to (dtype, shape) tuples.
182
+ """
183
+ self._model_metadata = model_metadata
184
+
185
+ if self.rank == 0:
186
+ # Trainer side - initialize process group
187
+ torchrl_logger.debug(
188
+ f"Initializing trainer collective group: rank={self.rank}, world_size={self.world_size}, device={self.device}"
189
+ )
190
+ # Ray sets CUDA_VISIBLE_DEVICES, so we always use device 0
191
+ # Set CUDA device before initializing NCCL to avoid segfaults
192
+ torch.cuda.set_device(self.device)
193
+ torchrl_logger.debug(f"Set CUDA device to {self.device}")
194
+
195
+ self._comm_group = stateless_init_process_group(
196
+ self.master_address,
197
+ self.master_port,
198
+ self.rank,
199
+ self.world_size,
200
+ device=self.device,
201
+ )
202
+ torchrl_logger.debug("Trainer collective group initialized successfully")
203
+ else:
204
+ # vLLM worker side - initialize through engine
205
+ if self.vllm_engine is None:
206
+ raise ValueError("vllm_engine must be provided for worker ranks")
207
+
208
+ torchrl_logger.debug(
209
+ "Initializing vLLM worker collective group through engine"
210
+ )
211
+ # Call vLLM engine's init method - it returns futures for all workers
212
+ # Workers will start NCCL init in background threads and return immediately
213
+ refs = self.vllm_engine.init_weight_update_group(
214
+ master_address=self.master_address,
215
+ master_port=self.master_port,
216
+ )
217
+
218
+ # Wait for RPCs to complete - ensures workers have dispatched their NCCL init threads
219
+ import ray
220
+
221
+ ray.get(refs)
222
+ torchrl_logger.debug(
223
+ f"All {len(refs)} vLLM workers have dispatched NCCL init RPCs"
224
+ )
225
+
226
+ # Small delay to ensure worker background threads have entered the NCCL collective
227
+ # This prevents a race where the trainer starts NCCL before workers are ready
228
+
229
+ time.sleep(0.2)
230
+
231
+ self._comm_group = True # Mark as initialized
232
+ torchrl_logger.debug(
233
+ "vLLM workers should now be blocked in NCCL collective, ready for trainer"
234
+ )
235
+
236
+ def send_weights(self, model_id: str, weights: Any) -> None:
237
+ """Broadcast weights to all workers using NCCL.
238
+
239
+ This method follows AsyncVLLM's periodic-mono pattern:
240
+ For each weight: RPC → NCCL broadcast → Wait for RPC completion
241
+
242
+ This should only be called from rank 0 (trainer).
243
+
244
+ Args:
245
+ model_id: ID of the model (used for logging).
246
+ weights: TensorDict or dict of weights to broadcast.
247
+ """
248
+ # This code is a duplicate from AsyncVLLM
249
+ # We are waiting for vLLM server to accept tokens endpoints, at which point we will be
250
+ # able to remove all dependencies on Ray for vllm distributed features.
251
+ # This will allow a more natural integration with the sender/receiver API.
252
+
253
+ import ray
254
+
255
+ if self.rank != 0:
256
+ raise RuntimeError("send_weights should only be called from rank 0")
257
+
258
+ if self._comm_group is None:
259
+ raise RuntimeError(
260
+ "Communication group not initialized. Call init_all_workers_group first."
261
+ )
262
+
263
+ if self._model_metadata is None:
264
+ raise RuntimeError("Model metadata not set")
265
+
266
+ if self.vllm_engine is None:
267
+ raise RuntimeError(
268
+ "vllm_engine must be provided to sender for RPC coordination"
269
+ )
270
+
271
+ # Set CUDA device for this operation
272
+ torch.cuda.set_device(self.device)
273
+
274
+ # Convert to dict if needed
275
+ if isinstance(weights, TensorDictBase):
276
+ weights_dict = weights.to_dict()
277
+ else:
278
+ weights_dict = weights
279
+
280
+ torchrl_logger.debug(
281
+ f"Broadcasting {len(weights_dict)} weights for model '{model_id}'"
282
+ )
283
+
284
+ # Broadcast each weight using periodic-mono pattern (like AsyncVLLM)
285
+ for name, (dtype, shape) in self._model_metadata.items():
286
+ if name not in weights_dict:
287
+ raise ValueError(
288
+ f"Weight '{name}' not found in weights. Weights keys: {list(weights_dict.keys())[:10]}..."
289
+ )
290
+
291
+ tensor = weights_dict[name].to(f"cuda:{self.device}")
292
+ dtype_name = str(dtype).split(".")[-1] # "torch.float16" -> "float16"
293
+
294
+ # Step 1: Send RPC to workers for this weight
295
+ futures = self.vllm_engine.collective_rpc(
296
+ "update_weight", args=(name, dtype_name, tuple(shape))
297
+ )
298
+
299
+ # Step 2: Immediately broadcast this weight
300
+ self._comm_group.broadcast(
301
+ tensor,
302
+ src=0,
303
+ stream=torch.cuda.current_stream(),
304
+ )
305
+
306
+ # Step 3: Wait for workers to complete this weight
307
+ ray.get(futures)
308
+ del tensor
309
+
310
+ torch.cuda.synchronize()
311
+ torchrl_logger.debug(f"Broadcast complete for model '{model_id}'")
312
+
313
+ def receive_weights(
314
+ self,
315
+ timeout: float | None = None,
316
+ *,
317
+ weights: Any = None,
318
+ model: Any = None,
319
+ strategy: Any = None,
320
+ ) -> Any | None:
321
+ """Receive weights from broadcaster.
322
+
323
+ This should only be called from worker ranks (rank > 0).
324
+ This method is called by vLLM engine internally through collective operations.
325
+
326
+ Args:
327
+ timeout: Ignored (vLLM handles synchronization internally).
328
+ weights: Ignored.
329
+ model: Ignored.
330
+ strategy: Ignored.
331
+
332
+ Returns:
333
+ None - vLLM handles weight application internally via collectives.
334
+ """
335
+ # vLLM handles this through its own collective operations
336
+ # The weights are received and applied by the engine during broadcast
337
+ return None
338
+
339
+ def check_connection(self) -> bool:
340
+ """Check if the communication group is initialized."""
341
+ return self._comm_group is not None
342
+
343
+
344
+ # ============================================================================
345
+ # vLLM Weight Synchronization Components
346
+ # ============================================================================
347
+
348
+
349
+ class VLLMWeightSyncScheme(WeightSyncScheme):
350
+ """Weight synchronization scheme for vLLM engines.
351
+
352
+ This scheme uses collective communication (NCCL) to broadcast weights from
353
+ a trainer to vLLM inference workers with parallelism support.
354
+
355
+ Args:
356
+ master_address: Address of the master node. Defaults to "localhost".
357
+ master_port: Port of the master node. If None, will auto-assign.
358
+ gpus_per_replica: Number of GPUs per replica (tp_size × dp_size × pp_size).
359
+ num_replicas: Number of vLLM engine replicas. Defaults to 1.
360
+ strategy: Weight extraction strategy ("tensordict" or "state_dict").
361
+ device: Device index to use for communication. Defaults to 0.
362
+ Note: When using Ray, each actor sees only its assigned GPU as device 0
363
+ due to CUDA_VISIBLE_DEVICES isolation. You should typically use 0.
364
+
365
+ .. warning::
366
+ Collective communication requires ALL ranks to participate simultaneously.
367
+ Both the sender (trainer, rank 0) and all receivers (vLLM workers, ranks 1+)
368
+ must call ``init_all_workers_group()`` at approximately the same time for the collective
369
+ handshake to succeed. Do NOT wait for one init to complete before starting
370
+ the other - start both and wait for both together.
371
+
372
+ Note:
373
+ The world_size for NCCL will be: 1 (trainer) + num_replicas × gpus_per_replica (vLLM workers)
374
+
375
+ Example:
376
+ >>> # Single replica with 2 GPUs (e.g., tp_size=2)
377
+ >>> scheme = VLLMWeightSyncScheme(
378
+ ... master_port=12345,
379
+ ... gpus_per_replica=2,
380
+ ... num_replicas=1,
381
+ ... strategy="tensordict"
382
+ ... ) # world_size = 1 + 1*2 = 3
383
+ >>>
384
+ >>> # Multiple replicas with 1 GPU each
385
+ >>> scheme = VLLMWeightSyncScheme(
386
+ ... master_port=12345,
387
+ ... gpus_per_replica=1,
388
+ ... num_replicas=2,
389
+ ... strategy="tensordict"
390
+ ... ) # world_size = 1 + 2*1 = 3
391
+ >>>
392
+ >>> # Multiple replicas with tp_size=2, dp_size=1, pp_size=1
393
+ >>> scheme = VLLMWeightSyncScheme(
394
+ ... master_port=12345,
395
+ ... gpus_per_replica=2, # 2*1*1
396
+ ... num_replicas=3,
397
+ ... strategy="tensordict"
398
+ ... ) # world_size = 1 + 3*2 = 7
399
+ >>>
400
+ >>> # In trainer process (rank 0)
401
+ >>> sender = VLLMWeightSender(scheme)
402
+ >>> sender.register_model(policy)
403
+ >>>
404
+ >>> # In vLLM worker process (rank 1+)
405
+ >>> receiver = VLLMWeightReceiver(scheme, vllm_engine)
406
+ >>>
407
+ >>> # IMPORTANT: Both must init simultaneously for collective handshake
408
+ >>> # With Ray:
409
+ >>> init_sender = sender_actor.init_all_workers_group.remote(metadata)
410
+ >>> init_receiver = receiver_actor.init_all_workers_group.remote(metadata)
411
+ >>> ray.get([init_sender, init_receiver]) # Wait for both together
412
+ >>>
413
+ >>> # After init, updates work normally
414
+ >>> sender.update_weights()
415
+ >>> # Weights are received automatically via collectives
416
+ """
417
+
418
+ def __init__(
419
+ self,
420
+ master_address: str | None = None,
421
+ master_port: int | None = None,
422
+ gpus_per_replica: int = 1,
423
+ num_replicas: int = 1,
424
+ strategy: Literal["tensordict", "state_dict"] = "tensordict",
425
+ device: torch.device | str | int = 0,
426
+ ):
427
+ self.master_address = (
428
+ master_address if master_address is not None else "localhost"
429
+ )
430
+ self.master_port = master_port
431
+ self.gpus_per_replica = gpus_per_replica
432
+ self.num_replicas = num_replicas
433
+ self.strategy_name = strategy
434
+ # Ray sets CUDA_VISIBLE_DEVICES for each actor, so device 0 is typical
435
+ self.device = device
436
+
437
+ # Auto-assign port if not provided
438
+ if self.master_port is None:
439
+ try:
440
+ from vllm.utils import get_open_port
441
+
442
+ self.master_port = get_open_port()
443
+ except ImportError:
444
+ # Fallback if vLLM not available
445
+ import socket
446
+
447
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
448
+ s.bind(("", 0))
449
+ self.master_port = s.getsockname()[1]
450
+
451
+ def create_transport(self, **kwargs) -> VLLMCollectiveTransport:
452
+ """Create transport for collective communication.
453
+
454
+ For vLLM, this creates a transport but requires additional setup via init_all_workers_group().
455
+ This method is required by the base class but transport creation for vLLM
456
+ is more complex and typically handled by sender/receiver initialization.
457
+
458
+ Args:
459
+ **kwargs: Not used for vLLM (kept for API compatibility).
460
+
461
+ Returns:
462
+ A VLLMCollectiveTransport instance (needs init_all_workers_group() to be called).
463
+ """
464
+ # Return a transport with default rank 0 (trainer)
465
+ # Actual initialization happens in sender/receiver
466
+ world_size = 1 + self.num_replicas * self.gpus_per_replica
467
+ return VLLMCollectiveTransport(
468
+ master_address=self.master_address,
469
+ master_port=self.master_port,
470
+ rank=0,
471
+ world_size=world_size,
472
+ device=self.device,
473
+ )
474
+
475
+ def create_sender(self) -> VLLMWeightSender:
476
+ """Create a weight sender for the trainer process."""
477
+ return VLLMWeightSender(self)
478
+
479
+ def create_receiver(self, vllm_engine) -> VLLMWeightReceiver:
480
+ """Create a weight receiver for a vLLM worker process.
481
+
482
+ Args:
483
+ vllm_engine: The vLLM engine instance (must implement RLvLLMEngine interface).
484
+ """
485
+ return VLLMWeightReceiver(self, vllm_engine)
486
+
487
+
488
+ class VLLMWeightSender:
489
+ """Sends weights to vLLM workers using collective communication.
490
+
491
+ **RPC + Collective Implementation**
492
+
493
+ This class implements both layers:
494
+
495
+ 1. **RPC Layer**: Currently uses Ray remote calls (implicit in test setup)
496
+ - Can be extended to other RPC backends (torch.distributed.rpc, gRPC)
497
+ - In the test, Ray actors provide the RPC mechanism
498
+
499
+ 2. **Collective Layer**: Uses VLLMCollectiveTransport for NCCL broadcast
500
+ - Broadcasts weights from trainer (rank 0) to workers (ranks 1+)
501
+ - High-bandwidth GPU-to-GPU transfer
502
+
503
+ **Extending RPC Backends**
504
+
505
+ To use a different RPC backend, subclass and override coordination:
506
+
507
+ .. code-block:: python
508
+
509
+ class TorchRPCVLLMSender(VLLMWeightSender):
510
+ def update_weights(self, weights=None):
511
+ # Custom RPC: Signal workers to prepare
512
+ for worker in self.workers:
513
+ torch.distributed.rpc.rpc_async(worker, "prepare_receive")
514
+
515
+ # Then do collective (unchanged)
516
+ super().update_weights(weights)
517
+ """
518
+
519
+ def __init__(self, scheme: VLLMWeightSyncScheme):
520
+ self._scheme = scheme
521
+ self._strategy = WeightStrategy(extract_as=scheme.strategy_name)
522
+ self._model_ref = None
523
+ self._transport = None
524
+ self._model_metadata = None
525
+
526
+ def register_model(self, model: Any) -> None:
527
+ """Register the model to extract weights from."""
528
+ import weakref
529
+
530
+ self._model_ref = weakref.ref(model)
531
+
532
+ def init_all_workers_group(
533
+ self,
534
+ model_metadata: dict[str, tuple[torch.dtype, torch.Size]],
535
+ vllm_engine: Any | None = None,
536
+ ):
537
+ """Initialize the collective communication group.
538
+
539
+ Args:
540
+ model_metadata: Dict mapping param names to (dtype, shape) tuples.
541
+ vllm_engine: Optional vLLM engine for RPC coordination. Required for NCCL broadcasts.
542
+ """
543
+ self._model_metadata = model_metadata
544
+ self._vllm_engine = vllm_engine
545
+
546
+ # Create transport for trainer (rank 0)
547
+ world_size = 1 + self._scheme.num_replicas * self._scheme.gpus_per_replica
548
+ self._transport = VLLMCollectiveTransport(
549
+ master_address=self._scheme.master_address,
550
+ master_port=self._scheme.master_port,
551
+ rank=0, # Trainer is always rank 0
552
+ world_size=world_size,
553
+ device=self._scheme.device,
554
+ vllm_engine=vllm_engine,
555
+ )
556
+ torchrl_logger.debug(
557
+ f"Initializing transport from sender with world_size={world_size}"
558
+ )
559
+ self._transport.init_all_workers_group(model_metadata)
560
+
561
+ def update_weights(self, weights: Any | None = None) -> None:
562
+ """Extract and broadcast weights to vLLM workers.
563
+
564
+ Args:
565
+ weights: Optional weights to send. If None, extracts from registered model.
566
+ """
567
+ if self._transport is None:
568
+ raise RuntimeError(
569
+ "Transport not initialized. Call init_all_workers_group first."
570
+ )
571
+
572
+ # Extract weights if not provided
573
+ if weights is None:
574
+ model = self._model_ref()
575
+ if model is None:
576
+ raise RuntimeError("Model reference is dead")
577
+ weights = self._strategy.extract_weights(model)
578
+ else:
579
+ # Ensure weights are in the right format
580
+ if hasattr(weights, "state_dict"):
581
+ # It's a module, extract
582
+ weights = self._strategy.extract_weights(weights)
583
+
584
+ # Send via transport
585
+ self._transport.send_weights("vllm_model", weights)
586
+
587
+
588
+ class VLLMWeightReceiver:
589
+ """Receives weights in a vLLM worker using collective communication.
590
+
591
+ **RPC + Collective Implementation**
592
+
593
+ This class implements both layers:
594
+
595
+ 1. **RPC Layer**: Currently uses Ray for coordination
596
+ - `init()` in test uses Ray `ray.get_actor()` to find trainer
597
+ - Fetches metadata via Ray remote call
598
+ - Signals readiness to participate in collective
599
+
600
+ 2. **Collective Layer**: Participates in NCCL broadcast
601
+ - Receives weights via collective operations
602
+ - vLLM engine applies weights internally during broadcast
603
+
604
+ **Extending RPC Backends**
605
+
606
+ To use a different RPC backend:
607
+
608
+ .. code-block:: python
609
+
610
+ class TorchRPCVLLMReceiver(VLLMWeightReceiver):
611
+ def init(self):
612
+ # Custom RPC: Get metadata from trainer
613
+ metadata = torch.distributed.rpc.rpc_sync(
614
+ "trainer",
615
+ lambda: get_metadata()
616
+ )
617
+
618
+ # Then init collective (unchanged)
619
+ self.receiver.init_all_workers_group(metadata)
620
+
621
+ Note:
622
+ The RPC and collective layers are loosely coupled. The RPC layer
623
+ ensures all ranks are ready before the collective starts, but the
624
+ actual data transfer is independent of the RPC mechanism.
625
+ """
626
+
627
+ def __init__(self, scheme: VLLMWeightSyncScheme, vllm_engine):
628
+ self._scheme = scheme
629
+ self._strategy = WeightStrategy(extract_as=scheme.strategy_name)
630
+ self._vllm_engine = vllm_engine
631
+ self._transport = None
632
+
633
+ def init_all_workers_group(
634
+ self, model_metadata: dict[str, tuple[torch.dtype, torch.Size]]
635
+ ):
636
+ """Initialize the collective communication group.
637
+
638
+ Args:
639
+ model_metadata: Dict mapping param names to (dtype, shape) tuples.
640
+ """
641
+ # For vLLM receiver, we use rank=1 as a placeholder
642
+ # The engine handles actual rank assignment internally for all workers
643
+ world_size = 1 + self._scheme.num_replicas * self._scheme.gpus_per_replica
644
+ self._transport = VLLMCollectiveTransport(
645
+ master_address=self._scheme.master_address,
646
+ master_port=self._scheme.master_port,
647
+ rank=None, # Placeholder - engine assigns actual ranks
648
+ world_size=world_size,
649
+ device=self._scheme.device,
650
+ vllm_engine=self._vllm_engine,
651
+ )
652
+ torchrl_logger.debug(
653
+ f"Initializing transport from receiver with world_size={world_size}."
654
+ )
655
+ self._transport.init_all_workers_group(model_metadata)
656
+
657
+ def apply_weights(self, weights: Any, inplace: bool = True) -> None:
658
+ """Apply weights to vLLM engine.
659
+
660
+ Args:
661
+ weights: The weights to apply.
662
+ inplace: Whether to apply weights in place. Default is `True`.
663
+
664
+ Note: For vLLM, weights are applied automatically during the collective
665
+ broadcast operation. This method is a no-op but kept for API consistency.
666
+ """
667
+ # vLLM handles weight application through its collective operations
668
+ # The weights are already applied by the time broadcast completes
669
+
670
+ def poll_and_apply(self, timeout: float = 0.1) -> bool:
671
+ """Poll for and apply weights.
672
+
673
+ Returns:
674
+ False - vLLM uses push-based updates via collectives, not polling.
675
+ """
676
+ # vLLM uses collective broadcasts (push), not polling
677
+ # This is handled by the engine's collective operations
678
+ return False
679
+
680
+
681
+ # ============================================================================
682
+ # Helper Functions
683
+ # ============================================================================
684
+
685
+
686
+ def get_model_metadata(model) -> dict[str, tuple[torch.dtype, torch.Size]]:
687
+ """Extract model metadata from a model.
688
+
689
+ Args:
690
+ model: A model with state_dict() or a model wrapper.
691
+
692
+ Returns:
693
+ Dict mapping parameter names to (dtype, shape) tuples.
694
+
695
+ Note:
696
+ This function must extract keys in the same format as WeightStrategy.extract_weights()
697
+ to ensure consistency between metadata and actual weight keys during broadcasting.
698
+ """
699
+ # Extract state_dict directly from the model
700
+ # This ensures keys match what extract_weights() will produce
701
+ if hasattr(model, "state_dict"):
702
+ if hasattr(model, "merge_and_unload"):
703
+ # LoRA model
704
+ sd = model.merge_and_unload().state_dict()
705
+ else:
706
+ sd = model.state_dict()
707
+ else:
708
+ raise TypeError(f"Cannot extract state_dict from {type(model)}")
709
+
710
+ return {k: (v.dtype, v.shape) for k, v in sd.items()}