torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,275 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import copy
8
+
9
+ import warnings
10
+ from collections.abc import Callable, Iterator
11
+ from typing import Any
12
+
13
+ import torch
14
+ from tensordict import TensorDictBase
15
+ from tensordict.nn import TensorDictModuleBase
16
+ from torchrl.collectors.llm import LLMCollector
17
+ from torchrl.collectors.weight_update import WeightUpdaterBase
18
+ from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
19
+ from torchrl.envs import EnvBase
20
+ from torchrl.envs.llm.transforms.policy_version import PolicyVersion
21
+
22
+ RAY_ERR = None
23
+ try:
24
+ import ray
25
+
26
+ _has_ray = True
27
+ except ImportError as err:
28
+ _has_ray = False
29
+ RAY_ERR = err
30
+
31
+
32
+ class RayLLMCollector(LLMCollector):
33
+ """A lightweight Ray implementation of the LLM Collector that can be extended and sampled remotely.
34
+
35
+ Args:
36
+ env (EnvBase or EnvBase constructor): the environment to be used for data collection.
37
+
38
+ Keyword Args:
39
+ policy (Callable[[TensorDictBase], TensorDictBase]): the policy to be used for data collection.
40
+ policy_factory (Callable[[], Callable], optional): a callable that returns
41
+ a policy instance. This is exclusive with the `policy` argument.
42
+ dialog_turns_per_batch (int): A keyword-only argument representing the total
43
+ number of elements in a batch.
44
+ total_dialog_turns (int): A keyword-only argument representing the total
45
+ number of dialog turns returned by the collector during its lifespan.
46
+ yield_only_last_steps (bool, optional): whether to yield every step of a trajectory, or only the
47
+ last (done) steps.
48
+ yield_completed_trajectories (bool, optional): whether to yield batches of rollouts with a given number of steps
49
+ or single, completed trajectories.
50
+ postproc (Callable, optional): A post-processing transform.
51
+ async_envs (bool, optional): if True, the environment will be run asynchronously.
52
+ replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
53
+ but populate the buffer instead.
54
+ reset_at_each_iter (bool, optional): if True, the environment will be reset at each iteration.
55
+ flatten_data (bool, optional): if True, the collector will flatten the collected data
56
+ before returning it.
57
+ weight_updater (WeightUpdaterBase or constructor, optional): An instance of WeightUpdaterBase
58
+ or its subclass, responsible for updating the policy weights on remote inference workers.
59
+ ray_init_config (dict[str, Any], optional): keyword arguments to pass to ray.init().
60
+ remote_config (dict[str, Any], optional): keyword arguments to pass to cls.as_remote().
61
+ num_cpus (int, optional): Number of CPUs of the actor. Defaults to `None` (taken from remote_config).
62
+ num_gpus (int, optional): Number of GPUs of the actor. Defaults to `None` (taken from remote_config).
63
+ sync_iter (bool, optional): if `True`, items yeilded by the collector will be synced to the local process.
64
+ If `False`, the collector will collect the next batch of data in between yielding.
65
+ This has no effect when data is collected through the :meth:`start` method.
66
+ For example:
67
+
68
+ >>> collector = RayLLMCollector(..., sync_iter=True)
69
+ >>> for data in collector: # blocking
70
+ ... # expensive operation - collector is idle
71
+ >>> collector = RayLLMCollector(..., sync_iter=False)
72
+ >>> for data in collector: # non-blocking
73
+ ... # expensive operation - collector is collecting data
74
+
75
+ This is somehwat equivalent to using :class:`~torchrl.collectors.MultiSyncCollector` (`sync_iter=True`) or
76
+ :class:`~torchrl.collectors.MultiAsyncDataCollector` (`sync_iter=False`).
77
+ Defaults to `True`.
78
+ verbose (bool, optional): if ``True``, the collector will print progress information.
79
+ Defaults to `False`.
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ env: EnvBase | Callable[[], EnvBase],
85
+ *,
86
+ policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
87
+ policy_factory: Callable[[], Callable[[TensorDictBase], TensorDictBase]]
88
+ | None = None,
89
+ dialog_turns_per_batch: int,
90
+ total_dialog_turns: int = -1,
91
+ yield_only_last_steps: bool | None = None,
92
+ yield_completed_trajectories: bool | None = None,
93
+ postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
94
+ async_envs: bool | None = None,
95
+ replay_buffer: ReplayBuffer | None = None,
96
+ reset_at_each_iter: bool = False,
97
+ flatten_data: bool | None = None,
98
+ weight_updater: WeightUpdaterBase
99
+ | Callable[[], WeightUpdaterBase]
100
+ | None = None,
101
+ ray_init_config: dict[str, Any] | None = None,
102
+ remote_config: dict[str, Any] | None = None,
103
+ track_policy_version: bool | PolicyVersion = False,
104
+ sync_iter: bool = True,
105
+ verbose: bool = False,
106
+ num_cpus: int | None = None,
107
+ num_gpus: int | None = None,
108
+ ) -> None:
109
+ if not _has_ray:
110
+ raise RuntimeError(
111
+ "ray library not found, unable to create a RayLLMCollector. "
112
+ ) from RAY_ERR
113
+ if not ray.is_initialized():
114
+ if ray_init_config is None:
115
+ from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG
116
+
117
+ ray_init_config = DEFAULT_RAY_INIT_CONFIG
118
+ ray.init(**ray_init_config)
119
+ if not sync_iter:
120
+ remote_config = copy.copy(remote_config)
121
+ remote_config.setdefault("max_concurrency", 2)
122
+ if num_cpus is not None:
123
+ remote_config.setdefault("num_cpus", num_cpus)
124
+ if num_gpus is not None:
125
+ remote_config.setdefault("num_gpus", num_gpus)
126
+ remote_cls = LLMCollector.as_remote(remote_config).remote
127
+ self.sync_iter = sync_iter
128
+ self._collector = remote_cls(
129
+ env=env,
130
+ policy=policy,
131
+ policy_factory=policy_factory,
132
+ dialog_turns_per_batch=dialog_turns_per_batch,
133
+ total_dialog_turns=total_dialog_turns,
134
+ yield_only_last_steps=yield_only_last_steps,
135
+ yield_completed_trajectories=yield_completed_trajectories,
136
+ postproc=postproc,
137
+ async_envs=async_envs,
138
+ replay_buffer=replay_buffer,
139
+ reset_at_each_iter=reset_at_each_iter,
140
+ flatten_data=flatten_data,
141
+ weight_updater=weight_updater,
142
+ track_policy_version=track_policy_version,
143
+ verbose=verbose,
144
+ )
145
+
146
+ def set_postproc(self, postproc: Callable[[TensorDictBase], TensorDictBase]):
147
+ return ray.get(self._collector.set_postproc.remote(postproc))
148
+
149
+ def _next_remote(self) -> None:
150
+ return self._collector.next.remote()
151
+
152
+ def next(self) -> None:
153
+ """Get the next batch of data from the collector.
154
+
155
+ Returns:
156
+ None as the data is written directly to the replay buffer.
157
+ """
158
+ return ray.get(self._next_remote())
159
+
160
+ def __iter__(self) -> Iterator[None]:
161
+ """Returns an iterator that yields None as the collector writes directly to the replay buffer."""
162
+ if not self.sync_iter:
163
+ future = self._next_remote()
164
+ else:
165
+ future = None
166
+ while True:
167
+ try:
168
+ if self.sync_iter:
169
+ yield self.next()
170
+ else:
171
+ result = ray.get(future)
172
+ future = self._next_remote()
173
+ yield result
174
+ except StopIteration:
175
+ break
176
+
177
+ def start(self):
178
+ """Starts the collector in a background thread."""
179
+ pending_task = self._collector.start.remote()
180
+ return ray.get(pending_task)
181
+
182
+ def is_running(self):
183
+ return ray.get(self._collector.is_running.remote())
184
+
185
+ def shutdown(self):
186
+ """Shuts down the collector."""
187
+ pending_task = self._collector.shutdown.remote()
188
+ return ray.get(pending_task)
189
+
190
+ def async_shutdown(self, timeout=None):
191
+ """Shuts down the collector asynchronously."""
192
+ pending_task = self._collector.async_shutdown.remote(timeout=timeout)
193
+ return ray.get(pending_task)
194
+
195
+ def update_policy_weights_(
196
+ self,
197
+ policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
198
+ *,
199
+ worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
200
+ **kwargs,
201
+ ):
202
+ """Updates the policy weights on remote workers.
203
+
204
+ Args:
205
+ policy_or_weights: The weights to update with. Can be:
206
+ - TensorDictModuleBase: A policy module whose weights will be extracted
207
+ - TensorDictBase: A TensorDict containing weights
208
+ - dict: A regular dict containing weights
209
+ - None: Will try to get weights from server using _get_server_weights()
210
+ worker_ids: The workers to update. If None, updates all workers.
211
+ """
212
+ if "policy_weights" in kwargs:
213
+ warnings.warn(
214
+ "`policy_weights` is deprecated. Use `policy_or_weights` instead.",
215
+ DeprecationWarning,
216
+ )
217
+ policy_or_weights = kwargs.pop("policy_weights")
218
+
219
+ pending_task = self._collector.update_policy_weights_.remote(
220
+ policy_or_weights=policy_or_weights, worker_ids=worker_ids
221
+ )
222
+ return ray.get(pending_task)
223
+
224
+ @property
225
+ def total_dialog_turns(self):
226
+ """Total number of dialog turns to collect."""
227
+ return ray.get(self._collector.total_dialog_turns.remote)
228
+
229
+ @property
230
+ def dialog_turns_per_batch(self) -> int:
231
+ """Number of dialog turns per batch."""
232
+ return ray.get(self._collector.dialog_turns_per_batch.remote)
233
+
234
+ @property
235
+ def rollout(self) -> Callable[[], TensorDictBase]:
236
+ """Returns the rollout function."""
237
+ return ray.get(self._collector.rollout.remote())
238
+
239
+ def init_updater(self, *args, **kwargs):
240
+ """Initialize the weight updater with custom arguments.
241
+
242
+ This method calls init_updater on the remote collector.
243
+
244
+ Args:
245
+ *args: Positional arguments for weight updater initialization
246
+ **kwargs: Keyword arguments for weight updater initialization
247
+ """
248
+ ray.get(self._collector.init_updater.remote(*args, **kwargs))
249
+
250
+ @property
251
+ def policy_version(self) -> str | int | None:
252
+ """The current version of the policy.
253
+
254
+ Returns:
255
+ The current version number (int) or UUID (str), or None if version tracking is disabled.
256
+ """
257
+ return ray.get(self._collector.get_policy_version.remote())
258
+
259
+ @property
260
+ def weight_updater(self) -> WeightUpdaterBase:
261
+ """The weight updater instance.
262
+
263
+ We can pass the weight updater because it's stateless, hence serializable.
264
+ """
265
+ return ray.get(self._collector.weight_updater.remote)
266
+
267
+ @weight_updater.setter
268
+ def weight_updater(self, weight_updater: WeightUpdaterBase):
269
+ """Set the weight updater instance."""
270
+ ray.get(self._collector.set_weight_updater.remote(weight_updater))
271
+ weight_updater.register_collector(self)
272
+
273
+ def increment_version(self):
274
+ """Increment the policy version."""
275
+ return ray.get(self._collector.increment_version.remote())
@@ -0,0 +1,36 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+
9
+ from queue import Full as QueueFull, Queue
10
+
11
+ from tensordict import TensorDictBase
12
+
13
+ from torchrl._utils import logger as torchrl_logger
14
+
15
+ _has_ray = importlib.util.find_spec("ray") is not None
16
+
17
+
18
+ class _QueueAsRB:
19
+ def __init__(self, queue: Queue | ray.util.queue.Queue): # noqa
20
+ if not _has_ray:
21
+ raise ImportError("Ray not installed.")
22
+ self.queue = queue
23
+
24
+ def extend(self, data: TensorDictBase):
25
+ from ray.util.queue import Full as RayQueueFull
26
+
27
+ # unbind the data and put in the queue
28
+ for item in data.unbind(0):
29
+ while True:
30
+ try:
31
+ self.queue.put_nowait(item)
32
+ break
33
+ except (QueueFull, RayQueueFull):
34
+ self.queue.get() # Remove the oldest item to make space
35
+ torchrl_logger.warn("rollout queue full. Discarding data.")
36
+ return
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from .vllm import vLLMUpdater
8
+ from .vllm_v2 import vLLMUpdaterV2
9
+
10
+ __all__ = ["vLLMUpdater", "vLLMUpdaterV2"]
@@ -0,0 +1,348 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import abc
8
+
9
+ import importlib.util
10
+
11
+ import torch
12
+ import torch.cuda
13
+ import torch.distributed
14
+
15
+ from tensordict import TensorDict, TensorDictBase
16
+ from tensordict.nn import TensorDictModuleBase
17
+
18
+ from torchrl._utils import logger as torchrl_logger
19
+
20
+ from torchrl.collectors.weight_update import WeightUpdaterBase
21
+ from torchrl.modules.llm.backends import stateless_init_process_group
22
+
23
+ _has_vllm = importlib.util.find_spec("vllm") is not None
24
+
25
+ _has_ray = importlib.util.find_spec("ray") is not None
26
+
27
+
28
+ class vLLMUpdaterMeta(abc.ABCMeta):
29
+ """Metaclass for vLLMUpdater that allows switching between V1 and V2 implementations.
30
+
31
+ When instantiating vLLMUpdater with v2=True, returns a vLLMUpdaterV2 instance instead.
32
+ This provides a unified entry point for both updater versions while maintaining
33
+ backward compatibility.
34
+ """
35
+
36
+ def __call__(cls, *args, v2=False, **kwargs):
37
+ if v2:
38
+ # Import V2 here to avoid circular imports
39
+ from .vllm_v2 import vLLMUpdaterV2
40
+
41
+ # V2 has a different signature - it expects a vllm_engine parameter
42
+ # If the user is providing the old signature, we need to handle this gracefully
43
+ if args or any(
44
+ k in kwargs
45
+ for k in [
46
+ "master_address",
47
+ "master_port",
48
+ "model_metadata",
49
+ "vllm_tp_size",
50
+ ]
51
+ ):
52
+ # Old signature detected - we can't auto-convert, user needs to update their code
53
+ raise TypeError(
54
+ "When using v2=True, you must provide a vllm_engine parameter instead of "
55
+ "the v1 parameters (master_address, master_port, model_metadata, vllm_tp_size). "
56
+ "See vLLMUpdaterV2 documentation for details."
57
+ )
58
+
59
+ # Forward to V2 constructor
60
+ return vLLMUpdaterV2(*args, **kwargs)
61
+ else:
62
+ # Use original V1 constructor
63
+ return super().__call__(*args, **kwargs)
64
+
65
+
66
+ class vLLMUpdater(WeightUpdaterBase, metaclass=vLLMUpdaterMeta):
67
+ """A class that sends weights to vLLM workers.
68
+
69
+ This class handles synchronizing weights between a training policy and vLLM inference workers.
70
+ It supports both local vLLM instances and remote Ray actors.
71
+
72
+ Args:
73
+ master_address (str, optional): The master address for distributed training. Defaults to localhost.
74
+ master_port (int, optional): The master port for distributed training. If None, will auto-assign.
75
+ model_metadata (dict[str, tuple[torch.dtype, torch.Size]], optional): Model metadata mapping
76
+ parameter names to their dtype and shape. If not provided, will be extracted from policy.
77
+ vllm_tp_size (int, optional): vLLM tensor parallel size. Defaults to 1.
78
+ v2 (bool, optional): If True, returns a vLLMUpdaterV2 instance instead. This is an experimental
79
+ feature that provides better integration with AsyncVLLM engines. When using v2=True, you must
80
+ provide a vllm_engine parameter instead of the above parameters. Defaults to False.
81
+
82
+ Methods:
83
+ init: Initialize the updater with model metadata and initialize the group.
84
+ _sync_weights_with_worker: Synchronize weights with a vLLM worker.
85
+ _get_server_weights: Not used - weights must be passed directly.
86
+ _maybe_map_weights: No mapping needed.
87
+ all_worker_ids: Returns [0] since we only have one worker.
88
+
89
+ .. note::
90
+ This class assumes the policy is a transformers model that can be loaded by vLLM.
91
+ The policy must have a state_dict() method that returns the model weights.
92
+
93
+ .. warning::
94
+ The v2=True option is experimental and may have backward-compatibility breaking changes
95
+ in future releases. However, it is generally considered a better option for working with
96
+ AsyncVLLM engines and provides improved performance and reliability.
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ master_address: str | None = None,
102
+ master_port: int | None = None,
103
+ model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None,
104
+ vllm_tp_size: int | None = None,
105
+ ):
106
+ torchrl_logger.debug(f"=> in {type(self).__name__}.__init__")
107
+ self.master_address = master_address
108
+ self.master_port = master_port
109
+ self.model_metadata = model_metadata
110
+ self.initialized_group = None
111
+ if vllm_tp_size is None:
112
+ vllm_tp_size = 1
113
+ self.vllm_tp_size = vllm_tp_size
114
+ self._model_ref = None
115
+
116
+ def init(self, model_metadata: dict[str, tuple[torch.dtype, torch.Size]]) -> None:
117
+ """Initialize the updater with model metadata and initialize the group.
118
+
119
+ Args:
120
+ model_metadata (dict[str, tuple[torch.dtype, torch.Size]]): The model metadata mapping
121
+ parameter names to their dtype and shape.
122
+ """
123
+ self.model_metadata = model_metadata
124
+ self.maybe_init_group()
125
+
126
+ @property
127
+ def master_address(self):
128
+ if self._master_address is None:
129
+ self._master_address = "localhost" # get_ip()
130
+ return self._master_address
131
+
132
+ @master_address.setter
133
+ def master_address(self, value):
134
+ self._master_address = value
135
+
136
+ @property
137
+ def master_port(self):
138
+ if self._master_port is None:
139
+ try:
140
+ from vllm.utils import get_open_port
141
+
142
+ self._master_port = get_open_port()
143
+ except ImportError:
144
+ # Fallback if vLLM get_open_port not available
145
+ import socket
146
+
147
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
148
+ s.bind(("", 0))
149
+ self._master_port = s.getsockname()[1]
150
+ return self._master_port
151
+
152
+ @master_port.setter
153
+ def master_port(self, value):
154
+ self._master_port = value
155
+
156
+ def _get_model_ref(self):
157
+ """Get a reference to the model actor, either directly or through Ray.
158
+
159
+ Returns:
160
+ The model reference that can be used for weight updates
161
+ """
162
+ if self._model_ref is not None:
163
+ return self._model_ref
164
+
165
+ if hasattr(self.collector, "_collector"):
166
+ # We're dealing with a RayLLMCollector
167
+ import ray
168
+
169
+ # Get direct reference to the model actor
170
+ self._model_ref = ray.get(
171
+ self.collector._collector.get_policy_model.remote()
172
+ )
173
+ else:
174
+ # We're dealing with a local collector
175
+ self._model_ref = self.collector.policy.model
176
+
177
+ return self._model_ref
178
+
179
+ def _init_group(self):
180
+ import ray
181
+
182
+ torchrl_logger.debug(f"=> in {type(self).__name__}._init_group")
183
+ weight_sync_world_size = self.vllm_tp_size + 1
184
+ torchrl_logger.debug(f"initializing group with {weight_sync_world_size=}...")
185
+ torchrl_logger.debug(f"vllm_tp_size={self.vllm_tp_size}")
186
+
187
+ model_ref = self._get_model_ref()
188
+
189
+ torchrl_logger.debug(f"model_ref: {model_ref}")
190
+ # Initialize the weight update group
191
+ torchrl_logger.debug("Calling init_weight_update_group...")
192
+ init_weight_update_group_getter = model_ref.collective_rpc.remote(
193
+ "init_weight_update_group",
194
+ args=(self.master_address, self.master_port, 1, weight_sync_world_size),
195
+ )
196
+ torchrl_logger.debug("init_weight_update_group remote call succeeded")
197
+
198
+ torchrl_logger.debug("Calling stateless_init_process_group within updater...")
199
+ self.vllm_comm_group = stateless_init_process_group(
200
+ self.master_address,
201
+ self.master_port,
202
+ 0,
203
+ weight_sync_world_size,
204
+ torch.device("cuda:0"),
205
+ )
206
+
207
+ ray.get(init_weight_update_group_getter)
208
+ torchrl_logger.debug("init_weight_update_group getter succeeded")
209
+
210
+ torchrl_logger.debug("group initialized")
211
+ self.initialized_group = True
212
+
213
+ def maybe_init_group(self):
214
+ if self.initialized_group is None and self.collector is not None:
215
+ self._init_group()
216
+
217
+ def _sync_weights_with_worker(
218
+ self,
219
+ *,
220
+ worker_id: int | torch.device | None = None,
221
+ server_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
222
+ ) -> None:
223
+ """Synchronize weights with a vLLM worker.
224
+
225
+ Args:
226
+ worker_id: Not used - we only have one worker.
227
+ server_weights: The weights to sync. Can be:
228
+ - TensorDictModuleBase: A policy module whose weights will be extracted
229
+ - TensorDictBase: A TensorDict containing weights
230
+ - dict: A regular dict containing weights
231
+ - None: Will try to get weights from server using _get_server_weights()
232
+ """
233
+ if server_weights is None:
234
+ raise ValueError("server_weights cannot be None for vLLM updater")
235
+
236
+ if self.initialized_group is None:
237
+ raise RuntimeError(
238
+ "Failed to update weights because sender is not initialized."
239
+ )
240
+ if self.model_metadata is None:
241
+ raise RuntimeError(
242
+ "Failed to update weights because model metadata is not set. "
243
+ "In async mode, you must set the model metadata in the training actor "
244
+ "before any weight updates."
245
+ )
246
+
247
+ model_ref = self._get_model_ref()
248
+
249
+ # First broadcast metadata
250
+ torchrl_logger.debug("broadcasting with update_weight_broadcast")
251
+ remotes = []
252
+ for k, (dtype, shape) in self.model_metadata.items():
253
+ remotes.append(
254
+ model_ref.collective_rpc.remote(
255
+ "update_weight_broadcast", args=(k, dtype, shape)
256
+ )
257
+ )
258
+
259
+ # # Then update weights
260
+ # remotes = []
261
+ # pbar = tqdm.tqdm(server_weights.items(), desc="Updating weights", total=len(server_weights))
262
+ # for k, val in pbar:
263
+ # pbar.set_description(f"Updating {k}")
264
+ # remotes.append(model_ref.collective_rpc.remote("update_weight", args=(k, val)))
265
+ # # ray.get(remotes)
266
+
267
+ # if self.vllm_comm_group is not True:
268
+ torchrl_logger.debug("broadcasting...")
269
+ for k in self.model_metadata:
270
+ val = server_weights[k].to(torch.device("cuda:0"))
271
+ self.vllm_comm_group.broadcast(
272
+ val,
273
+ src=0,
274
+ stream=torch.cuda.current_stream(),
275
+ )
276
+ del val
277
+ import ray
278
+
279
+ ray.get(remotes)
280
+ torchrl_logger.debug("done broadcasting")
281
+ torch.cuda.synchronize()
282
+
283
+ def _get_server_weights(self) -> TensorDictBase | None:
284
+ """Not used - weights must be passed directly via policy."""
285
+ return None
286
+
287
+ def _maybe_map_weights(
288
+ self, server_weights: TensorDictBase | TensorDictModuleBase | dict
289
+ ) -> TensorDictBase:
290
+ """Map weights from any format to the format expected by vLLM.
291
+
292
+ Args:
293
+ server_weights: The weights to map. Can be:
294
+ - TensorDictModuleBase: A policy module whose weights will be extracted
295
+ - TensorDictBase: A TensorDict containing weights
296
+ - dict: A regular dict containing weights
297
+
298
+ Returns:
299
+ TensorDictBase: The mapped weights in TensorDict format
300
+ """
301
+ if isinstance(server_weights, TensorDictModuleBase):
302
+ # Extract weights from policy module using merge_and_unload for LLMs
303
+ if not hasattr(server_weights, "model"):
304
+ raise ValueError("TensorDictModuleBase must have a 'model' attribute")
305
+ # Check if it's a LoRA model
306
+ if hasattr(server_weights.model, "merge_and_unload"):
307
+ state_dict = server_weights.model.merge_and_unload().state_dict()
308
+ else:
309
+ state_dict = server_weights.model.state_dict()
310
+ return TensorDict(state_dict, [])
311
+ elif isinstance(server_weights, TensorDictBase):
312
+ return server_weights
313
+ elif isinstance(server_weights, dict):
314
+ return TensorDict(server_weights, [])
315
+ else:
316
+ raise TypeError(
317
+ f"server_weights must be TensorDictModuleBase, TensorDictBase or dict, got {type(server_weights)}"
318
+ )
319
+
320
+ @classmethod
321
+ def get_model_metadata(
322
+ cls, model: TensorDictModuleBase
323
+ ) -> dict[str, tuple[torch.dtype, torch.Size]]:
324
+ """Get the model metadata from a model.
325
+
326
+ Args:
327
+ model (TensorDictModuleBase): The model to get the metadata from.
328
+ Must be a TransformersWrapper or equivalent.
329
+
330
+ Returns:
331
+ dict[str, tuple[torch.dtype, torch.Size]]: The model metadata.
332
+ """
333
+ # Check if the model has a LoRA adapter
334
+ if hasattr(model.model, "merge_and_unload"):
335
+ sd = model.model.merge_and_unload().state_dict()
336
+ else:
337
+ sd = model.model.state_dict()
338
+ model_metadata = {k: (v.dtype, v.shape) for k, v in sd.items()}
339
+ return model_metadata
340
+
341
+ def all_worker_ids(self) -> list[int]:
342
+ """Returns [0] since we only have one worker."""
343
+ return [0]
344
+
345
+ def register_collector(self, collector: BaseCollector): # noqa: F821
346
+ result = super().register_collector(collector)
347
+ self.register_post_hook(collector.increment_version)
348
+ return result