torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,2075 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Async vLLM engine implementation for efficient batching and inference.
7
+
8
+ This module provides an async vLLM engine that leverages native vLLM batching
9
+ for better performance and memory efficiency compared to the explicit batching
10
+ approach used in the legacy vLLM backend.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import asyncio
16
+ import os
17
+ import random
18
+ import time
19
+ import uuid
20
+ from collections.abc import Iterator, Sequence
21
+ from concurrent.futures import ThreadPoolExecutor, wait
22
+ from typing import Any, Literal, TYPE_CHECKING
23
+
24
+ import torch
25
+
26
+ from torchrl._utils import logger as torchrl_logger
27
+
28
+ # Import RLvLLMEngine and shared utilities
29
+ from .base import RLvLLMEngine
30
+ from .vllm_utils import stateless_init_process_group
31
+
32
+
33
+ _has_vllm = True
34
+
35
+ if TYPE_CHECKING:
36
+ from vllm.engine.async_llm_engine import AsyncEngineArgs
37
+ from vllm.engine.request import RequestOutput
38
+ from vllm.engine.sampling_params import SamplingParams
39
+
40
+ TIMEOUT_SECONDS = os.getenv("TORCHRL_VLLM_TIMEOUT_SECONDS", 300)
41
+
42
+ try:
43
+ import vllm
44
+
45
+ _has_vllm = True
46
+ except ImportError:
47
+ vllm = None
48
+ _has_vllm = False
49
+
50
+
51
+ def _get_ray():
52
+ """Import Ray on demand to avoid global import side-effects.
53
+
54
+ Returns:
55
+ ModuleType: The imported Ray module.
56
+
57
+ Raises:
58
+ ImportError: If Ray is not installed.
59
+ """
60
+ try:
61
+ import ray # type: ignore
62
+
63
+ return ray
64
+ except Exception as e: # pragma: no cover - surfaced to callers
65
+ raise ImportError(
66
+ "ray is not installed. Please install it with `pip install ray`."
67
+ ) from e
68
+
69
+
70
+ class _AsyncvLLMWorker:
71
+ """Async vLLM worker extension for Ray with weight update capabilities."""
72
+
73
+ def init_weight_update_group(
74
+ self,
75
+ master_address: str,
76
+ master_port: str,
77
+ rank_offset: int,
78
+ world_size: int,
79
+ ):
80
+ """Initialize weight update group for this worker (non-blocking).
81
+
82
+ This method starts NCCL initialization in a background thread and returns immediately,
83
+ allowing the RPC to complete. The NCCL collective will complete when the trainer joins.
84
+
85
+ Args:
86
+ master_address (str): The master address for distributed training.
87
+ master_port (str): The master port for distributed training.
88
+ rank_offset (int): Rank offset for this worker in the global weight update group.
89
+ world_size (int): Total number of processes in the weight update group.
90
+ """
91
+ import threading
92
+
93
+ from vllm.distributed.parallel_state import get_world_group
94
+
95
+ torchrl_logger.info(f"=> in {type(self).__name__}.init_weight_update_group")
96
+ if getattr(self, "model_update_group", None) is not None:
97
+ torchrl_logger.info("Model update group already initialized")
98
+ return
99
+
100
+ # Get the local rank within the tensor parallel group
101
+ tp_group = get_world_group()
102
+ local_rank = tp_group.rank
103
+ torchrl_logger.info(f"Local rank in tensor parallel group: {local_rank}")
104
+
105
+ # Calculate the global rank for weight update group
106
+ rank = local_rank + rank_offset
107
+ torchrl_logger.info(
108
+ f"Starting {type(self).__name__} weight update group init (non-blocking) with "
109
+ f"{master_address=}, {master_port=}, {rank=}, {world_size=}, device={self.device}"
110
+ )
111
+
112
+ # Start NCCL init in a background thread so this RPC can return immediately
113
+ def _init_nccl_background():
114
+ try:
115
+ from .vllm_utils import stateless_init_process_group
116
+
117
+ torchrl_logger.info(
118
+ f"Worker rank {rank}: Starting NCCL init (will block until collective completes)..."
119
+ )
120
+ self.model_update_group = stateless_init_process_group(
121
+ master_address, master_port, rank, world_size, self.device
122
+ )
123
+ torchrl_logger.info(f"Worker rank {rank}: NCCL init complete!")
124
+ except Exception as e:
125
+ torchrl_logger.error(f"Worker rank {rank}: NCCL init failed: {e}")
126
+ raise
127
+
128
+ thread = threading.Thread(target=_init_nccl_background, daemon=False)
129
+ thread.start()
130
+
131
+ # Store thread reference for potential cleanup
132
+ self._nccl_init_thread = thread
133
+
134
+ torchrl_logger.info(
135
+ f"{type(self).__name__}.init_weight_update_group dispatched (non-blocking)"
136
+ )
137
+
138
+ def update_weight(self, name: str, dtype_name: str, shape: tuple[int, ...]):
139
+ """Update weight via broadcast from master (rank 0) - periodic-mono pattern.
140
+
141
+ Args:
142
+ name (str): Parameter name.
143
+ dtype_name (str): Parameter dtype name (e.g., 'bfloat16').
144
+ shape (tuple[int, ...]): Parameter shape.
145
+ """
146
+ if self.model_update_group is None:
147
+ raise RuntimeError("Weight update group not initialized")
148
+
149
+ # Convert dtype name to dtype (like periodic-mono)
150
+ dtype = getattr(torch, dtype_name)
151
+
152
+ # Workers receive broadcast from master (rank 0)
153
+ weight = torch.empty(shape, dtype=dtype, device="cuda")
154
+ self.model_update_group.broadcast(
155
+ weight, src=0, stream=torch.cuda.current_stream()
156
+ )
157
+ self.model_runner.model.load_weights(weights=[(name, weight)])
158
+ del weight
159
+
160
+ def check_nccl_group_ready(self):
161
+ """Check if NCCL group is ready for communication."""
162
+ ready = self.model_update_group is not None
163
+ torchrl_logger.info(f"Worker NCCL group ready: {ready}")
164
+ return ready
165
+
166
+ def load_weights_from_storage(self, storage_path: str, num_threads: int = 1):
167
+ """Load weights from shared storage (double-buffer approach).
168
+
169
+ This method reads weights from a memory-mapped TensorDict directory
170
+ and loads them into the model. Used for file-based weight synchronization
171
+ as an alternative to NCCL collectives.
172
+
173
+ Args:
174
+ storage_path: Path to the directory containing memory-mapped weights
175
+ num_threads: Number of threads for reading (default: 1)
176
+ """
177
+ from tensordict import TensorDict
178
+
179
+ torchrl_logger.info(f"Worker loading weights from {storage_path}")
180
+
181
+ # Read weights from shared storage
182
+ weights = TensorDict.load_memmap(storage_path)
183
+ weights = weights.flatten_keys(".")
184
+
185
+ # Convert to list of (name, tensor) tuples
186
+ weights_list = list(weights.items())
187
+
188
+ torchrl_logger.info(f"Worker loading {len(weights_list)} weights into model")
189
+
190
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
191
+ futures = [
192
+ executor.submit(self.model_runner.model.load_weights, weights)
193
+ for weights in weights_list
194
+ ]
195
+ wait(futures)
196
+
197
+ torchrl_logger.info(
198
+ f"Worker successfully loaded {len(weights_list)} weights from storage"
199
+ )
200
+
201
+
202
+ class _AsyncLLMEngine:
203
+ """Extended AsyncLLMEngine with TorchRL-specific features.
204
+
205
+ This class wraps vLLM's AsyncLLMEngine and adds functionality needed
206
+ for TorchRL integration, including weight updates and batch management.
207
+
208
+ This is a private class and should not be used directly. Use the ray remote actor class :class:`AsyncLLMEngineActor` instead.
209
+
210
+ Keyword Args:
211
+ engine_args (AsyncEngineArgs): Arguments for creating the AsyncLLMEngine instances.
212
+ bundle_indices (list[int], optional): Bundle indices for the engine.
213
+ enable_prefix_caching (bool, optional): Whether to enable prefix caching.
214
+
215
+ .. warning::
216
+ enable_prefix_caching is set to False by default, which is recommended if prompt log probs are needed.
217
+ Set it to True if prompt log probs are not needed.
218
+ See `this issue <https://github.com/vllm-project/vllm/issues/8268>`_ for more details.
219
+ """
220
+
221
+ def __init__(
222
+ self,
223
+ *,
224
+ engine_args: AsyncEngineArgs,
225
+ bundle_indices: list[int] | None = None,
226
+ enable_prefix_caching: bool = False,
227
+ ):
228
+ if not _has_vllm:
229
+ raise ImportError(
230
+ "vllm is not installed. Please install it with `pip install vllm`."
231
+ )
232
+
233
+ from vllm import AsyncLLMEngine
234
+
235
+ if bundle_indices is not None:
236
+ os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
237
+
238
+ engine_args.enable_prefix_caching = enable_prefix_caching
239
+
240
+ # Fix for vLLM issue #19123: Set RAY_ADDRESS so vLLM subprocesses connect
241
+ # to the same Ray cluster instead of starting a new one (causes KeyError: 'bundles')
242
+ try:
243
+ import ray
244
+
245
+ if ray.is_initialized():
246
+ # Get the current Ray address and set it in the environment
247
+ # so vLLM's subprocess connects to the same cluster
248
+ ray_address = ray.get_runtime_context().gcs_address
249
+ if ray_address and "RAY_ADDRESS" not in os.environ:
250
+ os.environ["RAY_ADDRESS"] = ray_address
251
+ torchrl_logger.debug(
252
+ f"Set RAY_ADDRESS={ray_address} for vLLM subprocess"
253
+ )
254
+ except Exception:
255
+ pass # Ray not available or not initialized, let vLLM handle it
256
+
257
+ # Create the engine directly - this is the source of the blocking ray.get issue
258
+ # but we need to handle it differently for multiple replicas
259
+ self.engine = AsyncLLMEngine.from_engine_args(engine_args)
260
+ self.bundle_indices = bundle_indices
261
+
262
+ def ready(self) -> bool:
263
+ """Check if engine is ready for inference."""
264
+ return True
265
+
266
+ async def generate(
267
+ self,
268
+ prompts: Any = None,
269
+ sampling_params: SamplingParams | None = None,
270
+ *,
271
+ prompt_token_ids: list[int] | list[list[int]] | None = None,
272
+ use_tqdm: bool = True,
273
+ lora_request: Any = None,
274
+ prompt_adapter_request: Any = None,
275
+ guided_options_request: Any = None,
276
+ timeout_seconds: float | None = None,
277
+ ) -> RequestOutput | list[RequestOutput]:
278
+ """Generate text with the same interface as vLLM.LLM.generate.
279
+
280
+ This method mirrors the interface of vLLM.LLM.generate to provide seamless
281
+ compatibility between sync and async engines.
282
+
283
+ Args:
284
+ prompts: String, TokensPrompt, or list of these. Input prompts for generation.
285
+ sampling_params: SamplingParams object for controlling generation behavior.
286
+ prompt_token_ids: Alternative to prompts - token IDs for generation.
287
+ use_tqdm: Whether to show progress bar (not used in async engine).
288
+ lora_request: LoRA request for adapter-based generation.
289
+ guided_options_request: Guided decoding options.
290
+ timeout_seconds: Timeout for generation in seconds.
291
+
292
+ Returns:
293
+ RequestOutput or list of RequestOutput: Generated outputs from vLLM.
294
+ """
295
+ if not _has_vllm:
296
+ raise ImportError(
297
+ "vllm is not installed. Please install it with `pip install vllm`."
298
+ )
299
+
300
+ from vllm import SamplingParams, TokensPrompt
301
+
302
+ # Track whether input was originally a single prompt
303
+ single_prompt_input = False
304
+
305
+ # Handle prompt_token_ids if provided
306
+ if prompt_token_ids is not None:
307
+ if prompts is not None:
308
+ raise ValueError("Cannot specify both prompts and prompt_token_ids")
309
+
310
+ # Convert token IDs to TokensPrompt objects
311
+ if not prompt_token_ids:
312
+ raise ValueError("prompt_token_ids cannot be empty")
313
+
314
+ # Check if it's a list of lists or a single list
315
+ if prompt_token_ids and isinstance(prompt_token_ids[0], list):
316
+ # List of token ID lists
317
+ prompts = [
318
+ TokensPrompt(prompt_token_ids=tokens) for tokens in prompt_token_ids
319
+ ]
320
+ else:
321
+ # Single token ID list - cast to ensure type compatibility
322
+ token_list = list(prompt_token_ids) if prompt_token_ids else []
323
+ prompts = TokensPrompt(prompt_token_ids=token_list)
324
+ single_prompt_input = True
325
+
326
+ elif prompts is None:
327
+ raise ValueError("Must specify either prompts or prompt_token_ids")
328
+ else:
329
+ # prompts was provided directly
330
+ if not isinstance(prompts, (list, tuple)):
331
+ single_prompt_input = True
332
+
333
+ # Default sampling params if not provided
334
+ if sampling_params is None:
335
+ sampling_params = SamplingParams()
336
+
337
+ async def _gen_one(prompt) -> RequestOutput:
338
+ request_id = str(uuid.uuid4())
339
+ final = None
340
+
341
+ # Build kwargs for engine.generate
342
+ gen_kwargs = {
343
+ "prompt": prompt,
344
+ "sampling_params": sampling_params,
345
+ "request_id": request_id,
346
+ }
347
+
348
+ # Add optional parameters if provided
349
+ if lora_request is not None:
350
+ gen_kwargs["lora_request"] = lora_request
351
+ if prompt_adapter_request is not None:
352
+ gen_kwargs["prompt_adapter_request"] = prompt_adapter_request
353
+ if guided_options_request is not None:
354
+ gen_kwargs["guided_options_request"] = guided_options_request
355
+
356
+ async for output in self.engine.generate(**gen_kwargs):
357
+ if output.finished:
358
+ final = output
359
+ assert final is not None
360
+ return final
361
+
362
+ async def _run_generation():
363
+ if single_prompt_input:
364
+ return await _gen_one(prompts)
365
+
366
+ # List of prompts: run concurrently
367
+ tasks = [asyncio.create_task(_gen_one(p)) for p in prompts]
368
+ results = await asyncio.gather(*tasks)
369
+ return results
370
+
371
+ try:
372
+ if timeout_seconds is not None and timeout_seconds > 0:
373
+ return await asyncio.wait_for(
374
+ _run_generation(), timeout=timeout_seconds
375
+ )
376
+ else:
377
+ return await _run_generation()
378
+ except TimeoutError:
379
+ # Best-effort cleanup
380
+ try:
381
+ abort_fn = getattr(self.engine, "abort", None)
382
+ if callable(abort_fn):
383
+ # We can't easily track all request IDs, so this is best-effort
384
+ pass
385
+ except Exception:
386
+ pass
387
+ raise TimeoutError(
388
+ f"vLLM generation timed out after {timeout_seconds} seconds"
389
+ )
390
+
391
+ async def get_tokenizer(self):
392
+ """Get the tokenizer from the engine."""
393
+ return await self.engine.get_tokenizer()
394
+
395
+ async def collective_rpc_v1(
396
+ self,
397
+ method: str,
398
+ timeout: float | None = None,
399
+ args: tuple = (),
400
+ kwargs: dict | None = None,
401
+ ):
402
+ """Perform a collective RPC call to the given method (vLLM V1).
403
+
404
+ Args:
405
+ method (str): Method name to call.
406
+ timeout (float | None): Timeout for the RPC call.
407
+ args (tuple): Arguments to pass to the method.
408
+ kwargs (dict | None): Keyword arguments to pass to the method.
409
+ """
410
+ from vllm import envs
411
+
412
+ if envs and envs.VLLM_USE_V1:
413
+ return await self.engine.collective_rpc(method, timeout, args, kwargs)
414
+ else:
415
+ return self.engine.engine.collective_rpc(method, timeout, args, kwargs)
416
+
417
+ def collective_rpc_v0(
418
+ self,
419
+ method: str,
420
+ timeout: float | None = None,
421
+ args: tuple = (),
422
+ kwargs: dict | None = None,
423
+ ):
424
+ """Perform a collective RPC call to the given method (vLLM V0).
425
+
426
+ Args:
427
+ method (str): Method name to call.
428
+ timeout (float | None): Timeout for the RPC call.
429
+ args (tuple): Arguments to pass to the method.
430
+ kwargs (dict | None): Keyword arguments to pass to the method.
431
+ """
432
+ return self.engine.engine.collective_rpc(method, timeout, args, kwargs)
433
+
434
+ def get_num_unfinished_requests(self) -> int:
435
+ """Get the number of unfinished requests in the engine.
436
+
437
+ Returns:
438
+ int: Number of unfinished requests.
439
+ """
440
+ try:
441
+ # Try to access the method directly if available
442
+ if hasattr(self.engine, "get_num_unfinished_requests"):
443
+ return self.engine.get_num_unfinished_requests()
444
+ # Fallback to accessing through engine.engine for v0
445
+ elif hasattr(self.engine, "engine") and hasattr(
446
+ self.engine.engine, "get_num_unfinished_requests"
447
+ ):
448
+ return self.engine.engine.get_num_unfinished_requests()
449
+ else:
450
+ # If method not available, return 0 as fallback
451
+ torchrl_logger.warning(
452
+ "get_num_unfinished_requests not available, returning 0"
453
+ )
454
+ return 0
455
+ except Exception as e:
456
+ torchrl_logger.warning(f"Error getting unfinished requests count: {e}")
457
+ return 0
458
+
459
+ def get_cache_usage(self) -> float:
460
+ """Get the KV cache usage as a fraction between 0 and 1.
461
+
462
+ Returns:
463
+ float: Cache usage fraction (0.0 = empty, 1.0 = full).
464
+ """
465
+ try:
466
+ # Try to get cache usage from the engine
467
+ if hasattr(self.engine, "engine") and hasattr(
468
+ self.engine.engine, "cache_config"
469
+ ):
470
+ # Access the LLM engine's cache information
471
+ cache_config = self.engine.engine.cache_config
472
+ if hasattr(cache_config, "cache_usage"):
473
+ return cache_config.cache_usage
474
+ elif hasattr(self.engine.engine, "scheduler"):
475
+ # Try to get usage from the scheduler
476
+ scheduler = self.engine.engine.scheduler
477
+ if hasattr(scheduler, "get_num_free_gpu_blocks") and hasattr(
478
+ scheduler, "get_num_total_gpu_blocks"
479
+ ):
480
+ free_blocks = scheduler.get_num_free_gpu_blocks()
481
+ total_blocks = scheduler.get_num_total_gpu_blocks()
482
+ if total_blocks > 0:
483
+ return 1.0 - (free_blocks / total_blocks)
484
+ # Fallback: return a random value for now (this should be replaced with actual metrics)
485
+ torchrl_logger.warning(
486
+ "Cache usage metrics not available, returning random value"
487
+ )
488
+ return (
489
+ random.random() * 0.5
490
+ ) # Return a value between 0 and 0.5 to simulate partial usage
491
+ except Exception as e:
492
+ torchrl_logger.warning(f"Error getting cache usage: {e}")
493
+ return 0.0
494
+
495
+
496
+ def _gpus_per_replica(engine_args: AsyncEngineArgs) -> int:
497
+ """Get the number of GPUs per replica for the given engine args."""
498
+ return (
499
+ engine_args.tensor_parallel_size
500
+ * getattr(engine_args, "data_parallel_size", 1) # Default to 1 if not present
501
+ * getattr(
502
+ engine_args, "pipeline_parallel_size", 1
503
+ ) # Default to 1 if not present
504
+ )
505
+
506
+
507
+ # Ray actor wrapper is created lazily in __init__ to avoid global Ray import.
508
+
509
+
510
+ class AsyncVLLM(RLvLLMEngine):
511
+ """A service that manages multiple async vLLM engine actors for distributed inference.
512
+
513
+ This is the main entry point for async vLLM inference in TorchRL. It manages multiple
514
+ vLLM engine replicas running as Ray actors, providing load balancing, weight updates,
515
+ and a unified interface for text generation.
516
+
517
+ The service automatically handles Ray actor lifecycle management, GPU allocation through
518
+ placement groups, and provides both synchronous and asynchronous generation interfaces
519
+ that are compatible with the standard vLLM API.
520
+
521
+ Args:
522
+ engine_args (AsyncEngineArgs): Configuration for the vLLM engines.
523
+ num_replicas (int, optional): Number of engine replicas to create. Defaults to 1.
524
+ actor_class (optional): Custom Ray actor class. Defaults to the internal actor implementation.
525
+ enable_prefix_caching (bool, optional): Whether to enable prefix caching. Defaults to False.
526
+
527
+ .. warning::
528
+ enable_prefix_caching is set to False by default, which is recommended if prompt log probs are needed.
529
+ Set it to True if prompt log probs are not needed.
530
+ See `this issue <https://github.com/vllm-project/vllm/issues/8268>`_ for more details.
531
+
532
+ Example:
533
+ >>> from torchrl.modules.llm import AsyncVLLM
534
+ >>> from vllm import SamplingParams
535
+ >>>
536
+ >>> # Simple usage - single GPU, single replica
537
+ >>> service = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B")
538
+ >>>
539
+ >>> # Advanced usage - multi-GPU tensor parallel with multiple replicas
540
+ >>> service = AsyncVLLM.from_pretrained(
541
+ ... "Qwen/Qwen2.5-7B",
542
+ ... num_devices=2, # Use 2 GPUs for tensor parallelism
543
+ ... num_replicas=2, # Create 2 replicas for higher throughput
544
+ ... max_model_len=4096
545
+ ... )
546
+ >>>
547
+ >>> # Generate text
548
+ >>> sampling_params = SamplingParams(temperature=0.7, max_tokens=100)
549
+ >>> result = service.generate("Hello, world!", sampling_params)
550
+ >>> print(result.outputs[0].text)
551
+ >>>
552
+ >>> # Alternative: using AsyncEngineArgs directly for advanced configuration
553
+ >>> from vllm import AsyncEngineArgs
554
+ >>> engine_args = AsyncEngineArgs(
555
+ ... model="Qwen/Qwen2.5-3B",
556
+ ... tensor_parallel_size=2
557
+ ... )
558
+ >>> service = AsyncVLLM.launch(engine_args, num_replicas=2)
559
+
560
+ .. note::
561
+ **Architecture and Design**
562
+
563
+ The AsyncVLLM service implements a distributed inference architecture with the following key components:
564
+
565
+ 1. **Ray Actor Management**: Each replica runs as a separate Ray actor with dedicated GPU resources.
566
+ The service creates a placement group to ensure optimal GPU allocation and co-location of
567
+ tensor-parallel workers on the same node when possible.
568
+
569
+ 2. **Load Balancing**: Generation requests are distributed across replicas using random selection
570
+ by default, or can target specific replicas using the `actor_index` parameter.
571
+
572
+ 3. **Weight Synchronization**: The service supports weight updates across all replicas through
573
+ NCCL communication groups, enabling integration with distributed training workflows.
574
+
575
+ 4. **Resource Management**: Automatic GPU allocation and cleanup through Ray placement groups,
576
+ with proper shutdown procedures to prevent resource leaks.
577
+
578
+ 5. **API Compatibility**: Provides the same interface as vLLM's synchronous `LLM.generate()`
579
+ method, making it a drop-in replacement for async workloads.
580
+
581
+ **Ray Integration**
582
+
583
+ The service leverages Ray's actor model for distributed execution. Each replica is an independent
584
+ Ray actor that can be scheduled on different nodes. The service handles actor lifecycle,
585
+ monitors readiness, and provides centralized access to all replicas.
586
+
587
+ **Performance Considerations**
588
+
589
+ - Prefix caching is enabled by default for better performance with repeated prompts
590
+ - Tensor parallelism is supported for large models that don't fit on single GPUs
591
+ - Multiple replicas allow concurrent processing of different requests
592
+ - Native vLLM batching is used within each replica for optimal throughput
593
+
594
+ **Error Handling**
595
+
596
+ The service includes timeout support, graceful shutdown procedures, and best-effort
597
+ request cleanup on failures. Ray's fault tolerance mechanisms provide additional
598
+ resilience for long-running inference workloads.
599
+ """
600
+
601
+ def __init__(
602
+ self,
603
+ engine_args: AsyncEngineArgs,
604
+ num_replicas: int = 1,
605
+ actor_class=None,
606
+ enable_prefix_caching: bool = False,
607
+ ):
608
+ if not _has_vllm:
609
+ raise ImportError(
610
+ "vllm is not installed. Please install it with `pip install vllm`."
611
+ )
612
+ # Lazily import ray only when constructing the actor class to avoid global import
613
+
614
+ # Enable prefix caching by default for better performance
615
+ engine_args.enable_prefix_caching = enable_prefix_caching
616
+
617
+ self.engine_args = engine_args
618
+ self.num_replicas = num_replicas
619
+ if actor_class is None:
620
+ ray = _get_ray()
621
+ self.actor_class = ray.remote(num_cpus=0, num_gpus=0)(_AsyncLLMEngine)
622
+ else:
623
+ self.actor_class = actor_class
624
+ self.actors: list = []
625
+ self._launched = False
626
+ self._service_id = uuid.uuid4().hex[
627
+ :8
628
+ ] # Unique suffix to avoid name collisions
629
+ self._placement_group = None
630
+ self._load_balancer = None
631
+
632
+ def _launch(self):
633
+ """Launch all actor replicas."""
634
+ if self._launched:
635
+ torchrl_logger.warning("AsyncVLLMEngineService already launched")
636
+ return
637
+
638
+ # Local imports to avoid global Ray dependency
639
+ ray = _get_ray()
640
+ from ray.util.placement_group import placement_group
641
+ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
642
+
643
+ torchrl_logger.info(
644
+ f"Launching {self.num_replicas} async vLLM engine actors..."
645
+ )
646
+
647
+ # Create placement groups - one per replica to avoid conflicts
648
+ self._placement_groups = []
649
+
650
+ # Create actor replicas sequentially to avoid race conditions
651
+ for i in range(self.num_replicas):
652
+ torchrl_logger.info(
653
+ f"Creating async actor replica {i + 1}/{self.num_replicas} ..."
654
+ )
655
+
656
+ # Create individual placement group for this replica
657
+ num_gpus = _gpus_per_replica(self.engine_args)
658
+ bundles = [{"GPU": 1.0, "CPU": 1.0} for _ in range(num_gpus)]
659
+ torchrl_logger.info(
660
+ f"Creating placement group for replica {i + 1} with {len(bundles)} bundles"
661
+ )
662
+
663
+ placement_group_name = f"vllm-replica-{self._service_id}-{i}"
664
+ pg = placement_group(bundles, strategy="PACK", name=placement_group_name)
665
+ self._placement_groups.append(pg)
666
+ torchrl_logger.info(f"Placement group {placement_group_name} created: {pg}")
667
+
668
+ # Wait for placement group to be ready
669
+ ray.get(pg.ready(), timeout=180)
670
+ torchrl_logger.info(f"Placement group {placement_group_name} ready")
671
+
672
+ # Calculate bundle indices for tensor parallelism
673
+ bundle_indices = None
674
+ if num_gpus > 1:
675
+ bundle_indices = list(range(num_gpus))
676
+ bundle_index = 0 # Always use first bundle since each replica has its own placement group
677
+
678
+ scheduling_strategy = PlacementGroupSchedulingStrategy(
679
+ placement_group=pg,
680
+ placement_group_capture_child_tasks=True,
681
+ placement_group_bundle_index=bundle_index,
682
+ )
683
+
684
+ actor = self.actor_class.options(
685
+ name=f"async-vllm-replica-{self._service_id}-{i}",
686
+ namespace="torchrl_vllm",
687
+ scheduling_strategy=scheduling_strategy,
688
+ num_gpus=0,
689
+ num_cpus=0,
690
+ ).remote(
691
+ engine_args=self.engine_args,
692
+ bundle_indices=bundle_indices,
693
+ enable_prefix_caching=self.engine_args.enable_prefix_caching,
694
+ )
695
+ self.actors.append(actor)
696
+
697
+ torchrl_logger.info("Waiting for actors to be ready")
698
+ # Wait for this actor to be ready before creating the next one
699
+ ready_futures = [actor.ready.remote() for actor in self.actors]
700
+ try:
701
+ ray.get(
702
+ ready_futures, timeout=TIMEOUT_SECONDS
703
+ ) # 5 minute timeout for engine initialization
704
+ torchrl_logger.info("✅ Actors are ready")
705
+ except Exception as e:
706
+ torchrl_logger.error(
707
+ f"❌ Failed to initialize actors within {TIMEOUT_SECONDS} seconds: {e}. You can increase the timeout by setting the TORCHRL_VLLM_TIMEOUT_SECONDS environment variable."
708
+ )
709
+ raise
710
+
711
+ # Store the first placement group for backward compatibility
712
+ self._placement_group = (
713
+ self._placement_groups[0] if self._placement_groups else None
714
+ )
715
+
716
+ self._launched = True
717
+ torchrl_logger.info(
718
+ f"✅ Successfully launched {len(self.actors)} async vLLM engine actors"
719
+ )
720
+
721
+ @classmethod
722
+ def launch(
723
+ cls,
724
+ engine_args: AsyncEngineArgs,
725
+ num_replicas: int = 1,
726
+ ) -> AsyncVLLM:
727
+ """Launch a new AsyncVLLMEngineService.
728
+
729
+ Args:
730
+ engine_args (AsyncEngineArgs): Arguments for creating the AsyncLLMEngine instances.
731
+ num_replicas (int): Number of actor replicas to create.
732
+
733
+ Returns:
734
+ AsyncVLLMEngineService: The launched service.
735
+ """
736
+ service = cls(engine_args, num_replicas)
737
+ service._launch()
738
+ # create a default load balancer with smart routing
739
+ service.create_load_balancer()
740
+ return service
741
+
742
+ @classmethod
743
+ def from_pretrained(
744
+ cls,
745
+ model_name: str,
746
+ num_devices: int | None = None,
747
+ num_replicas: int = 1,
748
+ verbose: bool = True,
749
+ compile: bool = True,
750
+ enable_fp32_output: bool = False,
751
+ **kwargs,
752
+ ) -> AsyncVLLM:
753
+ """Create an AsyncVLLM instance from a pretrained model.
754
+
755
+ This is a convenience method that combines model loading and service launching
756
+ in a single call, similar to how other ML libraries work.
757
+
758
+ Args:
759
+ model_name (str): The model name to pass to vLLM.
760
+ num_devices (int, optional): Number of devices to use, per replica.
761
+ num_replicas (int): Number of engine replicas to create.
762
+ verbose (bool, optional): Whether to enable verbose logging with throughput statistics. Defaults to True.
763
+ compile (bool, optional): Whether to enable model compilation for better performance. Defaults to True.
764
+ enable_fp32_output (bool, optional): Whether to enable FP32 output for the final layer. Defaults to False.
765
+ **kwargs: Additional arguments passed to AsyncEngineArgs.
766
+
767
+ Returns:
768
+ AsyncVLLM: The launched async vLLM service.
769
+
770
+ Example:
771
+ >>> # Simple usage with defaults
772
+ >>> service = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B")
773
+ >>>
774
+ >>> # Multi-GPU tensor parallel with multiple replicas
775
+ >>> service = AsyncVLLM.from_pretrained(
776
+ ... "Qwen/Qwen2.5-7B",
777
+ ... num_devices=2,
778
+ ... num_replicas=2,
779
+ ... max_model_len=4096
780
+ ... )
781
+ >>>
782
+ >>> # Generate text
783
+ >>> from vllm import SamplingParams
784
+ >>> result = service.generate("Hello, world!", SamplingParams(max_tokens=50))
785
+ >>>
786
+ >>> # Enable FP32 output for better numerical stability
787
+ >>> service = AsyncVLLM.from_pretrained(
788
+ ... "Qwen/Qwen2.5-3B",
789
+ ... enable_fp32_output=True
790
+ ... )
791
+ """
792
+ return make_async_vllm_engine(
793
+ model_name=model_name,
794
+ num_devices=num_devices,
795
+ num_replicas=num_replicas,
796
+ verbose=verbose,
797
+ compile=compile,
798
+ enable_fp32_output=enable_fp32_output,
799
+ **kwargs,
800
+ )
801
+
802
+ def _is_batch(
803
+ self, prompts: Any, prompt_token_ids: list[int] | list[list[int]] | None = None
804
+ ) -> bool:
805
+ """Check if the input represents a batch of prompts.
806
+
807
+ Args:
808
+ prompts: Input prompts that can be string, TokensPrompt, or list of these
809
+ prompt_token_ids: Alternative token IDs input
810
+
811
+ Returns:
812
+ bool: True if this represents multiple prompts, False for single prompt
813
+ """
814
+ # If prompts is a list, we need to determine if it's a batch or a single prompt
815
+ if isinstance(prompts, list):
816
+ # Empty list is not a batch
817
+ if len(prompts) == 0:
818
+ return False
819
+
820
+ # If all elements are integers, it's a single prompt represented as token IDs
821
+ # We trust that if one is an int, then all are ints.
822
+ if any(isinstance(item, int) for item in prompts):
823
+ return False
824
+
825
+ # If it contains strings, TokensPrompt objects, or other non-integer types,
826
+ # it's a batch of prompts
827
+ return True
828
+
829
+ # If prompt_token_ids is provided and is a list of lists, it's a batch
830
+ if prompt_token_ids is not None and isinstance(prompt_token_ids, list):
831
+ if len(prompt_token_ids) > 0 and isinstance(prompt_token_ids[0], list):
832
+ return True
833
+
834
+ return False
835
+
836
+ def _iterate(
837
+ self, prompts: Any, prompt_token_ids: list[int] | list[list[int]] | None = None
838
+ ):
839
+ """Iterate over individual prompts in a batch.
840
+
841
+ Args:
842
+ prompts: Input prompts that can be string, TokensPrompt, or list of these
843
+ prompt_token_ids: Alternative token IDs input
844
+
845
+ Yields:
846
+ tuple: (individual_prompt, individual_prompt_token_ids) for each item
847
+ """
848
+ if isinstance(prompts, list):
849
+ # Check if this is actually a single prompt represented as token IDs
850
+ if all(isinstance(item, int) for item in prompts):
851
+ # This is a single prompt as token IDs, not a batch
852
+ yield prompts, prompt_token_ids
853
+ return
854
+
855
+ # Handle list of prompts (actual batch)
856
+ if prompt_token_ids is None:
857
+ for prompt in prompts:
858
+ yield prompt, None
859
+ elif (
860
+ isinstance(prompt_token_ids, list)
861
+ and len(prompt_token_ids) > 0
862
+ and isinstance(prompt_token_ids[0], list)
863
+ ):
864
+ # Both prompts and prompt_token_ids are lists
865
+ for prompt, token_ids in zip(prompts, prompt_token_ids):
866
+ yield prompt, token_ids
867
+ else:
868
+ # prompts is list, but prompt_token_ids is single list - replicate it
869
+ for prompt in prompts:
870
+ yield prompt, prompt_token_ids
871
+ else:
872
+ # Single prompt case
873
+ if (
874
+ prompt_token_ids is not None
875
+ and isinstance(prompt_token_ids, list)
876
+ and len(prompt_token_ids) > 0
877
+ and isinstance(prompt_token_ids[0], list)
878
+ ):
879
+ # Single prompt but multiple token_ids - replicate prompt
880
+ for token_ids in prompt_token_ids:
881
+ yield prompts, token_ids
882
+ else:
883
+ # Single prompt, single (or no) token_ids
884
+ yield prompts, prompt_token_ids
885
+
886
+ def _generate_impl(
887
+ self,
888
+ prompt: Any,
889
+ sampling_params: SamplingParams | None = None,
890
+ *,
891
+ prompt_token_ids: list[int] | None = None,
892
+ use_tqdm: bool = True,
893
+ lora_request: Any = None,
894
+ prompt_adapter_request: Any = None,
895
+ guided_options_request: Any = None,
896
+ timeout_seconds: float | None = None,
897
+ actor_index: int | None = None,
898
+ ):
899
+ """Generate text for a single prompt and return a Ray future.
900
+
901
+ This is the internal implementation that returns a future instead of the result.
902
+ Used for batched generation to enable parallel execution.
903
+
904
+ Args:
905
+ prompt: Single prompt (string, TokensPrompt, etc.)
906
+ sampling_params: SamplingParams object for controlling generation behavior
907
+ prompt_token_ids: Token IDs for a single prompt
908
+ use_tqdm: Whether to show progress bar (not used in async engine)
909
+ lora_request: LoRA request for adapter-based generation
910
+ prompt_adapter_request: Prompt adapter request
911
+ guided_options_request: Guided decoding options
912
+ timeout_seconds: Timeout for generation in seconds
913
+ actor_index: Specific actor to use (random if None)
914
+
915
+ Returns:
916
+ Ray ObjectRef: Future that will resolve to RequestOutput
917
+ """
918
+ if actor_index is None:
919
+ if len(self.actors) == 1:
920
+ actor = self.actors[0]
921
+ else:
922
+ if self._load_balancer is None:
923
+ raise RuntimeError(
924
+ "LoadBalancer is not created. Create a LoadBalancer using AsyncVLLM.create_load_balancer before calling generate."
925
+ )
926
+ # Extract single prompt for prefix-aware routing
927
+ single_prompt = self._extract_single_prompt_for_routing(
928
+ prompt, prompt_token_ids
929
+ )
930
+ actor_index = self._load_balancer.select_actor(prompt=single_prompt)
931
+ actor = self.actors[actor_index]
932
+ else:
933
+ actor = self.actors[actor_index]
934
+
935
+ return actor.generate.remote(
936
+ prompt,
937
+ sampling_params,
938
+ prompt_token_ids=prompt_token_ids,
939
+ use_tqdm=use_tqdm,
940
+ lora_request=lora_request,
941
+ prompt_adapter_request=prompt_adapter_request,
942
+ guided_options_request=guided_options_request,
943
+ timeout_seconds=timeout_seconds,
944
+ )
945
+
946
+ def generate(
947
+ self,
948
+ prompts: Any = None,
949
+ sampling_params: SamplingParams | None = None,
950
+ *,
951
+ prompt_token_ids: list[int] | list[list[int]] | None = None,
952
+ use_tqdm: bool = True,
953
+ lora_request: Any = None,
954
+ prompt_adapter_request: Any = None,
955
+ guided_options_request: Any = None,
956
+ timeout_seconds: float | None = None,
957
+ actor_index: int | None = None,
958
+ ) -> RequestOutput | list[RequestOutput]:
959
+ """Generate text using one of the actors with vLLM.LLM.generate interface.
960
+
961
+ This method provides the same interface as vLLM.LLM.generate for seamless
962
+ compatibility between sync and async engines. It can be used to generate text
963
+ within multiple threads / actors. If `actor_index` is not provided, the load balancer
964
+ will be used to select the actor.
965
+
966
+ `generate` is a blocking method, so it will wait for the generation to complete.
967
+
968
+ Args:
969
+ prompts (String, TokensPrompt, or list of these): Input prompts for generation.
970
+ sampling_params (SamplingParams): SamplingParams object for controlling generation behavior.
971
+ prompt_token_ids (list[int] | list[list[int]]): Alternative to prompts - token IDs for generation.
972
+ use_tqdm (bool): Whether to show progress bar (not used in async engine).
973
+ lora_request (Any): LoRA request for adapter-based generation.
974
+ prompt_adapter_request (Any): Prompt adapter request.
975
+ guided_options_request (Any): Guided decoding options.
976
+ timeout_seconds (float | None): Timeout for generation in seconds.
977
+ actor_index (int | None): Specific actor to use (random if None).
978
+
979
+ Returns:
980
+ RequestOutput | list[RequestOutput]: Generated outputs from vLLM.
981
+ """
982
+ ray = _get_ray()
983
+ # Check if this is a batch request
984
+ if self._is_batch(prompts, prompt_token_ids):
985
+ # Handle batched input by unbinding and sending individual requests
986
+ futures = []
987
+ for prompt, prompt_token_ids_i in self._iterate(prompts, prompt_token_ids):
988
+ future = self._generate_impl(
989
+ prompt,
990
+ sampling_params,
991
+ prompt_token_ids=prompt_token_ids_i,
992
+ use_tqdm=use_tqdm,
993
+ lora_request=lora_request,
994
+ prompt_adapter_request=prompt_adapter_request,
995
+ guided_options_request=guided_options_request,
996
+ timeout_seconds=timeout_seconds,
997
+ actor_index=actor_index,
998
+ )
999
+ futures.append(future)
1000
+
1001
+ # Collect all results
1002
+ results = ray.get(futures)
1003
+ return results
1004
+ else:
1005
+ # Single prompt case - call _generate_impt and get result directly
1006
+ future = self._generate_impl(
1007
+ prompts,
1008
+ sampling_params,
1009
+ prompt_token_ids=prompt_token_ids,
1010
+ use_tqdm=use_tqdm,
1011
+ lora_request=lora_request,
1012
+ prompt_adapter_request=prompt_adapter_request,
1013
+ guided_options_request=guided_options_request,
1014
+ timeout_seconds=timeout_seconds,
1015
+ actor_index=actor_index,
1016
+ )
1017
+ result = ray.get(future)
1018
+ return result
1019
+
1020
+ def get_random_actor_index(self) -> int:
1021
+ """Get a random actor index."""
1022
+ return random.randint(0, len(self.actors) - 1)
1023
+
1024
+ def _init_weight_update_group_internal(self, master_address: str, master_port: str):
1025
+ """Initialize NCCL weight update group across all actors.
1026
+
1027
+ Args:
1028
+ master_address (str): Master address for distributed training.
1029
+ master_port (str): Master port for distributed training.
1030
+
1031
+ Returns:
1032
+ list: Ray futures for initialization calls.
1033
+ """
1034
+ gpus_per_replica = _gpus_per_replica(self.engine_args)
1035
+ weight_sync_world_size = self.num_replicas * gpus_per_replica + 1
1036
+ torchrl_logger.info(
1037
+ f"AsyncVLLMEngineService requests weight update group for {self.num_replicas} actors "
1038
+ f"with {gpus_per_replica} GPUs per replica and {weight_sync_world_size} world size"
1039
+ )
1040
+
1041
+ from vllm import envs
1042
+
1043
+ refs = []
1044
+ for i, actor in enumerate(self.actors):
1045
+ rank_offset = 1 + i * gpus_per_replica
1046
+ if envs and envs.VLLM_USE_V1:
1047
+ actor_collective_rpc = actor.collective_rpc_v1
1048
+ else:
1049
+ actor_collective_rpc = actor.collective_rpc_v0
1050
+
1051
+ refs.append(
1052
+ actor_collective_rpc.remote(
1053
+ "init_weight_update_group",
1054
+ args=(
1055
+ master_address,
1056
+ master_port,
1057
+ rank_offset,
1058
+ weight_sync_world_size,
1059
+ ),
1060
+ )
1061
+ )
1062
+ torchrl_logger.info(
1063
+ f"AsyncVLLMEngineService args: {master_address=}, {master_port=}, "
1064
+ f"{rank_offset=}, {weight_sync_world_size=}"
1065
+ )
1066
+ torchrl_logger.info(
1067
+ f"AsyncVLLMEngineService requests weight update group for actor {i} "
1068
+ f"with rank_offset {rank_offset}"
1069
+ )
1070
+ return refs
1071
+
1072
+ def collective_rpc(
1073
+ self,
1074
+ method: str,
1075
+ timeout: float | None = None,
1076
+ args: tuple = (),
1077
+ kwargs: dict | None = None,
1078
+ ) -> list[Any]:
1079
+ """Forward an RPC to all actors.
1080
+
1081
+ Args:
1082
+ method (str): Method name to call.
1083
+ timeout (float | None): Timeout for the RPC call.
1084
+ args (tuple): Arguments to pass to the method.
1085
+ kwargs (dict | None): Keyword arguments to pass to the method.
1086
+
1087
+ Returns:
1088
+ list[Any]: Ray futures for all RPC calls.
1089
+ """
1090
+ from vllm import envs
1091
+
1092
+ futures = []
1093
+ for actor in self.actors:
1094
+ if envs and envs.VLLM_USE_V1:
1095
+ actor_collective_rpc = actor.collective_rpc_v1
1096
+ else:
1097
+ actor_collective_rpc = actor.collective_rpc_v0
1098
+ futures.append(actor_collective_rpc.remote(method, timeout, args, kwargs))
1099
+ return futures
1100
+
1101
+ def shutdown(self):
1102
+ """Shutdown all actors and clean up resources."""
1103
+ torchrl_logger.info(
1104
+ f"Shutting down {len(self.actors)} async vLLM engine actors..."
1105
+ )
1106
+
1107
+ ray = _get_ray()
1108
+ from ray.util.placement_group import remove_placement_group
1109
+
1110
+ # Kill all actors
1111
+ for i, actor in enumerate(self.actors):
1112
+ try:
1113
+ ray.kill(actor)
1114
+ torchrl_logger.info(f"Shutdown async actor {i + 1}/{len(self.actors)}")
1115
+ except Exception as e:
1116
+ torchrl_logger.warning(f"Error shutting down async actor {i + 1}: {e}")
1117
+
1118
+ # Clear the actors list
1119
+ self.actors.clear()
1120
+
1121
+ # Remove placement groups if any
1122
+ if hasattr(self, "_placement_groups") and self._placement_groups:
1123
+ for i, pg in enumerate(self._placement_groups):
1124
+ try:
1125
+ remove_placement_group(pg)
1126
+ torchrl_logger.info(
1127
+ f"Removed placement group {i + 1}/{len(self._placement_groups)}"
1128
+ )
1129
+ except Exception as e:
1130
+ torchrl_logger.warning(
1131
+ f"Error removing placement group {i + 1}: {e}"
1132
+ )
1133
+ self._placement_groups = []
1134
+
1135
+ # Remove legacy single placement group if any
1136
+ if self._placement_group is not None:
1137
+ remove_placement_group(self._placement_group)
1138
+ self._placement_group = None
1139
+ self._launched = False
1140
+ torchrl_logger.info("AsyncVLLMEngineService shutdown complete")
1141
+
1142
+ # RLvLLMEngine interface implementation
1143
+ def get_tp_size(self) -> int:
1144
+ """Get the tensor parallel size."""
1145
+ return self.engine_args.tensor_parallel_size
1146
+
1147
+ def get_model_metadata(self) -> dict[str, tuple[torch.dtype, torch.Size]]:
1148
+ """Get model parameter metadata.
1149
+
1150
+ Note: This requires the model to be loaded. For now, we return an empty dict
1151
+ and expect the metadata to be provided externally during weight updates.
1152
+ """
1153
+ # TODO: Implement metadata extraction from loaded model
1154
+ # This would require accessing the model from one of the actors
1155
+ torchrl_logger.warning(
1156
+ "AsyncVLLM.get_model_metadata() not yet implemented - returning empty dict"
1157
+ )
1158
+ return {}
1159
+
1160
+ def get_master_address(self) -> str:
1161
+ """Get the master address for weight synchronization."""
1162
+ return "localhost" # Default for now
1163
+
1164
+ def get_master_port(self) -> int:
1165
+ """Get the master port for weight synchronization."""
1166
+ # Cache the port like V1 does to ensure consistency
1167
+ if not hasattr(self, "_cached_master_port"):
1168
+ if _has_vllm:
1169
+ try:
1170
+ from vllm.utils import get_open_port
1171
+
1172
+ self._cached_master_port = get_open_port()
1173
+ except ImportError:
1174
+ self._cached_master_port = 29500 # Default port if import fails
1175
+ else:
1176
+ self._cached_master_port = 29500 # Default port
1177
+ return self._cached_master_port
1178
+
1179
+ def init_weight_update_group(
1180
+ self,
1181
+ master_address: str,
1182
+ master_port: int | str,
1183
+ ) -> list[Any]:
1184
+ """Forward the request to init NCCL weight update group to all actors.
1185
+
1186
+ This method initializes the weight update group for all vLLM workers.
1187
+ The external trainer should be rank 0, and vLLM workers will be ranks 1+.
1188
+
1189
+ Args:
1190
+ master_address: Master address for NCCL communication.
1191
+ master_port: Master port for NCCL communication.
1192
+
1193
+ Returns:
1194
+ List of Ray futures for the initialization calls.
1195
+
1196
+ Note:
1197
+ The caller must wait on the returned futures (ray.get(refs)) to ensure
1198
+ all workers have completed initialization before sending weights.
1199
+ """
1200
+ if not self._launched:
1201
+ raise RuntimeError(
1202
+ "AsyncVLLM service must be launched before initializing weight update group"
1203
+ )
1204
+
1205
+ gpus_per_replica = _gpus_per_replica(self.engine_args)
1206
+ weight_sync_world_size = self.num_replicas * gpus_per_replica + 1
1207
+
1208
+ torchrl_logger.info(
1209
+ f"Initializing weight update group for {self.num_replicas} replicas "
1210
+ f"with {gpus_per_replica} GPUs each (world_size={weight_sync_world_size})"
1211
+ )
1212
+
1213
+ from vllm import envs
1214
+
1215
+ refs = []
1216
+ for i, actor in enumerate(self.actors):
1217
+ rank_offset = 1 + i * gpus_per_replica
1218
+ if envs and envs.VLLM_USE_V1:
1219
+ actor_collective_rpc = actor.collective_rpc_v1
1220
+ else:
1221
+ actor_collective_rpc = actor.collective_rpc_v0
1222
+ refs.append(
1223
+ actor_collective_rpc.remote(
1224
+ "init_weight_update_group",
1225
+ args=(
1226
+ master_address,
1227
+ str(master_port),
1228
+ rank_offset,
1229
+ weight_sync_world_size,
1230
+ ),
1231
+ )
1232
+ )
1233
+ torchrl_logger.info(
1234
+ f"Requested init for actor {i} with rank_offset {rank_offset}"
1235
+ )
1236
+
1237
+ return refs
1238
+
1239
+ def update_weights(self, weights: Iterator[tuple[str, torch.Tensor]]) -> None:
1240
+ """Update model weights across all replicas using NCCL broadcast.
1241
+
1242
+ Args:
1243
+ weights: Iterator yielding (parameter_name, tensor) tuples
1244
+ """
1245
+ if not self._launched:
1246
+ raise RuntimeError(
1247
+ "AsyncVLLM service must be launched before updating weights"
1248
+ )
1249
+
1250
+ # Convert iterator to dict for easier handling
1251
+ weights_dict = dict(weights)
1252
+
1253
+ if not weights_dict:
1254
+ torchrl_logger.warning("No weights provided for update")
1255
+ return
1256
+
1257
+ torchrl_logger.info(
1258
+ f"Updating {len(weights_dict)} parameters across {len(self.actors)} replicas using NCCL broadcast"
1259
+ )
1260
+
1261
+ self._update_weights_with_nccl_broadcast_simple(weights_dict)
1262
+
1263
+ torchrl_logger.info("AsyncVLLM NCCL weight update completed")
1264
+
1265
+ def _update_weights_with_nccl_broadcast_simple(
1266
+ self, weights_dict: dict[str, torch.Tensor]
1267
+ ) -> None:
1268
+ """Update weights using simple NCCL broadcast like V1.
1269
+
1270
+ This approach follows the V1 pattern:
1271
+ 1. Training process (master) broadcasts as rank 0
1272
+ 2. All vLLM workers receive as ranks 1, 2, 3...
1273
+ 3. Simple and reliable like the working V1 implementation
1274
+
1275
+ Args:
1276
+ weights_dict: Dictionary of parameter names to weight tensors
1277
+ """
1278
+ if not hasattr(self, "_nccl_master_group") or self._nccl_master_group is None:
1279
+ raise RuntimeError(
1280
+ "NCCL master group not initialized. This is a bug in the setup process."
1281
+ )
1282
+
1283
+ t0 = time.time()
1284
+
1285
+ # Move all weights to cuda:0 (matching NCCL communicator device)
1286
+ gpu_weights = {}
1287
+ for name, weight in weights_dict.items():
1288
+ # Ensure weight is on cuda:0 (matching NCCL communicator)
1289
+ if weight.device != torch.device("cuda:0"):
1290
+ gpu_weights[name] = weight.to("cuda:0", non_blocking=True)
1291
+ else:
1292
+ gpu_weights[name] = weight
1293
+
1294
+ # Use periodic-mono pattern: individual weight updates with immediate RPC->NCCL
1295
+ torchrl_logger.info(
1296
+ f"Updating {len(gpu_weights)} weights using periodic-mono pattern..."
1297
+ )
1298
+
1299
+ updated_weights = 0
1300
+ ray = _get_ray()
1301
+ with torch.cuda.device(0): # Ensure we're on the correct CUDA device
1302
+ for name, weight in gpu_weights.items():
1303
+ # Convert dtype to string name (like periodic-mono)
1304
+ dtype_name = str(weight.dtype).split(".")[
1305
+ -1
1306
+ ] # "torch.bfloat16" -> "bfloat16"
1307
+
1308
+ # Step 1: Send RPC to workers for this weight
1309
+ futures = self.collective_rpc(
1310
+ "update_weight", args=(name, dtype_name, tuple(weight.shape))
1311
+ )
1312
+
1313
+ # Step 2: Immediately broadcast this weight (like periodic-mono)
1314
+ self._nccl_master_group.broadcast(
1315
+ weight, src=0, stream=torch.cuda.current_stream()
1316
+ )
1317
+
1318
+ # Step 3: Wait for workers to complete this weight
1319
+ ray.get(futures)
1320
+ updated_weights += 1
1321
+
1322
+ torch.cuda.synchronize()
1323
+ t2 = time.time()
1324
+ torchrl_logger.info(
1325
+ f"Successfully updated {updated_weights}/{len(gpu_weights)} weights in {t2 - t0:.3f}s"
1326
+ )
1327
+
1328
+ def _setup_nccl_master_group(self) -> None:
1329
+ """Set up NCCL communication group for the master node (rank 0)."""
1330
+ # Calculate world size (should match what workers use)
1331
+ gpus_per_replica = _gpus_per_replica(self.engine_args)
1332
+ weight_sync_world_size = self.num_replicas * gpus_per_replica + 1
1333
+
1334
+ master_address = self.get_master_address()
1335
+ master_port = self.get_master_port()
1336
+
1337
+ torchrl_logger.info(
1338
+ f"Setting up NCCL master group: rank=0, world_size={weight_sync_world_size}, "
1339
+ f"address={master_address}:{master_port}"
1340
+ )
1341
+
1342
+ # Ensure CUDA is available and initialized
1343
+ if not torch.cuda.is_available():
1344
+ raise RuntimeError("CUDA not available for NCCL communication")
1345
+
1346
+ # Set CUDA device before initializing NCCL
1347
+ torch.cuda.set_device(0)
1348
+
1349
+ # Initialize master as rank 0 in the NCCL group (use synchronous version)
1350
+ self._nccl_master_group = stateless_init_process_group(
1351
+ master_address=master_address,
1352
+ master_port=str(master_port),
1353
+ rank=0, # Master is always rank 0
1354
+ world_size=weight_sync_world_size,
1355
+ device=torch.device("cuda:0"),
1356
+ )
1357
+
1358
+ torchrl_logger.info("NCCL master group initialized successfully")
1359
+
1360
+ def get_num_unfinished_requests(
1361
+ self, actor_index: int | None = None
1362
+ ) -> int | list[int]:
1363
+ """Get the number of unfinished requests for one or all actors.
1364
+
1365
+ Args:
1366
+ actor_index (int | None): Index of specific actor, or None for all actors.
1367
+
1368
+ Returns:
1369
+ int | list[int]: Number of unfinished requests for the specified actor,
1370
+ or list of counts for all actors if actor_index is None.
1371
+ """
1372
+ if not self._launched:
1373
+ raise RuntimeError(
1374
+ "AsyncVLLM service must be launched before getting request counts"
1375
+ )
1376
+
1377
+ ray = _get_ray()
1378
+ if actor_index is not None:
1379
+ if not (0 <= actor_index < len(self.actors)):
1380
+ raise IndexError(
1381
+ f"Actor index {actor_index} out of range [0, {len(self.actors)})"
1382
+ )
1383
+
1384
+ actor = self.actors[actor_index]
1385
+ return ray.get(actor.get_num_unfinished_requests.remote())
1386
+ else:
1387
+ # Get counts from all actors
1388
+ futures = [
1389
+ actor.get_num_unfinished_requests.remote() for actor in self.actors
1390
+ ]
1391
+ return ray.get(futures)
1392
+
1393
+ def get_cache_usage(self, actor_index: int | None = None) -> float | list[float]:
1394
+ """Get the KV cache usage for one or all actors.
1395
+
1396
+ Args:
1397
+ actor_index (int | None): Index of specific actor, or None for all actors.
1398
+
1399
+ Returns:
1400
+ float | list[float]: Cache usage fraction for the specified actor,
1401
+ or list of usage fractions for all actors if actor_index is None.
1402
+ """
1403
+ if not self._launched:
1404
+ raise RuntimeError(
1405
+ "AsyncVLLM service must be launched before getting cache usage"
1406
+ )
1407
+
1408
+ ray = _get_ray()
1409
+ if actor_index is not None:
1410
+ if not (0 <= actor_index < len(self.actors)):
1411
+ raise IndexError(
1412
+ f"Actor index {actor_index} out of range [0, {len(self.actors)})"
1413
+ )
1414
+
1415
+ actor = self.actors[actor_index]
1416
+ return ray.get(actor.get_cache_usage.remote())
1417
+ else:
1418
+ # Get usage from all actors
1419
+ futures = [actor.get_cache_usage.remote() for actor in self.actors]
1420
+ return ray.get(futures)
1421
+
1422
+ def create_load_balancer(
1423
+ self,
1424
+ strategy: Literal["requests", "kv-cache"]
1425
+ | Sequence[Literal["prefix-aware", "requests", "kv-cache", "round-robin"]]
1426
+ | None = None,
1427
+ **kwargs,
1428
+ ) -> LoadBalancer:
1429
+ """Create a load balancer for this AsyncVLLM service.
1430
+
1431
+ Args:
1432
+ strategy: Load balancing strategy or sequence of strategies in fallback order.
1433
+ Default: ["prefix-aware", "requests"] - tries cache-aware routing first,
1434
+ then load balancing. Single strategies: "requests", "kv-cache"
1435
+ Strategy sequences: ["prefix-aware", "requests", "round-robin"]
1436
+ **kwargs: Additional arguments passed to LoadBalancer constructor.
1437
+
1438
+ Returns:
1439
+ LoadBalancer: Configured load balancer instance. This is stored in the AsyncVLLM instance.
1440
+
1441
+ Examples:
1442
+ >>> service = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B", num_replicas=3)
1443
+
1444
+ >>> # Use smart defaults (prefix-aware -> requests)
1445
+ >>> lb = service.create_load_balancer()
1446
+ >>> selected_actor_index = lb.select_actor(prompt="Hello world")
1447
+
1448
+ >>> # Simple single strategy
1449
+ >>> lb = service.create_load_balancer("requests")
1450
+ >>> selected_actor_index = lb.select_actor()
1451
+
1452
+ >>> # Custom strategy hierarchy
1453
+ >>> lb = service.create_load_balancer(
1454
+ ... ["prefix-aware", "kv-cache", "round-robin"],
1455
+ ... prefix_length=16,
1456
+ ... overload_threshold=2.0
1457
+ ... )
1458
+ >>> selected_actor_index = lb.select_actor(prompt="Hello world")
1459
+ """
1460
+ if not self._launched:
1461
+ raise RuntimeError(
1462
+ "AsyncVLLM service must be launched before creating load balancer"
1463
+ )
1464
+
1465
+ load_balancer = LoadBalancer(self, strategy, **kwargs)
1466
+ self._load_balancer = load_balancer
1467
+ return load_balancer
1468
+
1469
+ def _extract_single_prompt_for_routing(
1470
+ self,
1471
+ prompts: Any = None,
1472
+ prompt_token_ids: list[int] | list[list[int]] | None = None,
1473
+ ) -> str | list[int] | None:
1474
+ """Extract a single prompt for load balancer routing, if possible.
1475
+
1476
+ Args:
1477
+ prompts: The prompts argument passed to generate().
1478
+ prompt_token_ids: The prompt_token_ids argument passed to generate().
1479
+
1480
+ Returns:
1481
+ str | list[int] | None: Single prompt for routing, or None if multiple prompts.
1482
+ """
1483
+ try:
1484
+ # Handle prompt_token_ids first (takes precedence over prompts)
1485
+ if prompt_token_ids is not None:
1486
+ if isinstance(prompt_token_ids, list):
1487
+ if len(prompt_token_ids) == 0:
1488
+ return None # Empty list
1489
+ elif len(prompt_token_ids) == 1:
1490
+ # Single prompt case - could be tokens directly or nested list
1491
+ if isinstance(prompt_token_ids[0], int):
1492
+ # Single token sequence: [token1, token2, ...]
1493
+ return prompt_token_ids
1494
+ elif isinstance(prompt_token_ids[0], list):
1495
+ # Nested list with single prompt: [[token1, token2, ...]]
1496
+ return prompt_token_ids[0]
1497
+ else:
1498
+ return None
1499
+ else:
1500
+ # Multiple prompts: [[tokens1...], [tokens2...], ...]
1501
+ return None
1502
+ else:
1503
+ # Not a list, invalid format
1504
+ return None
1505
+
1506
+ # Handle prompts argument
1507
+ if prompts is None:
1508
+ return None
1509
+
1510
+ # Import vLLM types for proper checking
1511
+ try:
1512
+ pass
1513
+ except ImportError:
1514
+ # Fallback if imports fail
1515
+ type(None)
1516
+ type(None)
1517
+
1518
+ # Single string prompt
1519
+ if isinstance(prompts, str):
1520
+ return prompts
1521
+
1522
+ # TokensPrompt object
1523
+ elif hasattr(prompts, "prompt_token_ids"): # TokensPrompt-like object
1524
+ return prompts.prompt_token_ids
1525
+
1526
+ # TextPrompt object
1527
+ elif hasattr(prompts, "prompt"): # TextPrompt-like object
1528
+ return prompts.prompt
1529
+
1530
+ # List of prompts
1531
+ elif isinstance(prompts, (list, tuple)):
1532
+ if len(prompts) == 0:
1533
+ return None # Empty list
1534
+ elif len(prompts) == 1:
1535
+ # Single prompt in list - recursively extract
1536
+ return self._extract_single_prompt_for_routing(prompts[0], None)
1537
+ else:
1538
+ # Multiple prompts - cannot do prefix routing
1539
+ return None
1540
+
1541
+ # Other types (shouldn't happen in normal usage)
1542
+ else:
1543
+ torchrl_logger.debug(
1544
+ f"Unknown prompt type for routing: {type(prompts)}"
1545
+ )
1546
+ return None
1547
+
1548
+ except Exception as e:
1549
+ torchrl_logger.debug(f"Error extracting single prompt for routing: {e}")
1550
+ return None
1551
+
1552
+
1553
+ class LoadBalancer:
1554
+ """Load balancer for distributing requests across AsyncVLLM actors with strategy hierarchy.
1555
+
1556
+ This class implements sophisticated load balancing with multiple strategies and intelligent
1557
+ fallback mechanisms. Strategies are tried in order until one succeeds, providing robust
1558
+ request routing even when some strategies fail.
1559
+
1560
+ Args:
1561
+ actors: Either a single AsyncVLLM instance or a list of Ray actors.
1562
+ strategy: Single strategy or sequence of strategies in fallback order.
1563
+ Available strategies:
1564
+
1565
+ - "prefix-aware": Route based on prompt prefix for cache locality
1566
+ - "requests": Select actor with fewest pending requests
1567
+ - "kv-cache": Select actor with lowest KV cache utilization
1568
+ - "round-robin": Simple round-robin distribution
1569
+
1570
+ Default: ["prefix-aware", "requests"]
1571
+
1572
+ prefix_length: Number of tokens/words to use for prefix routing (default: 8).
1573
+ overload_threshold: Multiplier for average load to consider actor overloaded (default: 1.5).
1574
+
1575
+ Examples:
1576
+ >>> service = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B", num_replicas=3)
1577
+
1578
+ >>> # Simple strategy
1579
+ >>> lb = LoadBalancer(service, "requests")
1580
+ >>> actor_idx = lb.select_actor()
1581
+
1582
+ >>> # Strategy hierarchy: try prefix-aware first, fall back to requests, then round-robin
1583
+ >>> lb = LoadBalancer(service, ["prefix-aware", "requests", "round-robin"])
1584
+ >>> actor_idx = lb.select_actor(prompt="Hello world") # Uses prefix routing
1585
+ >>> actor_idx = lb.select_actor() # Falls back to requests (no prompt)
1586
+
1587
+ >>> # Custom configuration
1588
+ >>> lb = LoadBalancer(
1589
+ ... service,
1590
+ ... ["prefix-aware", "kv-cache"],
1591
+ ... prefix_length=16,
1592
+ ... overload_threshold=2.0
1593
+ ... )
1594
+ """
1595
+
1596
+ def __init__(
1597
+ self,
1598
+ actors: list[Any] | AsyncVLLM,
1599
+ strategy: Literal["requests", "kv-cache"]
1600
+ | Sequence[Literal["prefix-aware", "requests", "kv-cache", "round-robin"]]
1601
+ | None = None,
1602
+ prefix_length: int = 8,
1603
+ overload_threshold: float = 1.5,
1604
+ ):
1605
+ if strategy is None:
1606
+ strategy = ["prefix-aware", "requests"]
1607
+ # Handle both AsyncVLLM instances and direct actor lists
1608
+ if hasattr(actors, "actors"): # AsyncVLLM instance
1609
+ self.actors = actors.actors
1610
+ self.async_vllm = actors
1611
+ elif isinstance(actors, list): # Direct list of actors
1612
+ self.actors = actors
1613
+ self.async_vllm = None
1614
+ else:
1615
+ raise ValueError(
1616
+ "actors must be either an AsyncVLLM instance or a list of actors"
1617
+ )
1618
+
1619
+ if not self.actors:
1620
+ raise ValueError("No actors provided")
1621
+
1622
+ # Handle both single strategy and strategy hierarchy
1623
+ if isinstance(strategy, str):
1624
+ self.strategies = [strategy]
1625
+ else:
1626
+ self.strategies = list(strategy)
1627
+
1628
+ # Validate strategies
1629
+ valid_strategies = {"prefix-aware", "requests", "kv-cache", "round-robin"}
1630
+ for s in self.strategies:
1631
+ if s not in valid_strategies:
1632
+ raise ValueError(
1633
+ f"Invalid strategy '{s}'. Must be one of {valid_strategies}"
1634
+ )
1635
+
1636
+ if not self.strategies:
1637
+ raise ValueError("At least one strategy must be provided")
1638
+
1639
+ self.strategy = self.strategies[
1640
+ 0
1641
+ ] # Primary strategy for backward compatibility
1642
+ self.prefix_length = prefix_length
1643
+ self.overload_threshold = overload_threshold
1644
+ self._round_robin_index = 0 # For round-robin fallback
1645
+
1646
+ def select_actor(
1647
+ self,
1648
+ prompt: str | list[int] | None = None,
1649
+ request_context: dict[str, Any] | None = None,
1650
+ ) -> int:
1651
+ """Select the optimal actor index based on the configured strategy hierarchy.
1652
+
1653
+ Args:
1654
+ prompt: The input prompt (string or token list) for prefix-aware routing.
1655
+ request_context: Additional context for routing decisions.
1656
+
1657
+ Returns:
1658
+ int: Index of the selected actor in the actors list.
1659
+
1660
+ Raises:
1661
+ RuntimeError: If unable to gather metrics from actors.
1662
+ ValueError: If no actors are available.
1663
+ """
1664
+ if not self.actors:
1665
+ raise ValueError("No actors available for selection")
1666
+
1667
+ # Try each strategy in order until one succeeds
1668
+ for i, strategy in enumerate(self.strategies):
1669
+ try:
1670
+ torchrl_logger.debug(
1671
+ f"Trying strategy {i + 1}/{len(self.strategies)}: {strategy}"
1672
+ )
1673
+
1674
+ if strategy == "prefix-aware":
1675
+ if prompt is not None:
1676
+ return self._select_by_prefix_aware(prompt)
1677
+ else:
1678
+ torchrl_logger.debug(
1679
+ "No prompt provided for prefix-aware routing, trying next strategy"
1680
+ )
1681
+ continue
1682
+
1683
+ elif strategy == "requests":
1684
+ return self._select_by_requests()
1685
+
1686
+ elif strategy == "kv-cache":
1687
+ return self._select_by_cache_usage()
1688
+
1689
+ elif strategy == "round-robin":
1690
+ return self._select_round_robin()
1691
+
1692
+ else:
1693
+ torchrl_logger.warning(
1694
+ f"Unknown strategy: {strategy}, trying next strategy"
1695
+ )
1696
+ continue
1697
+
1698
+ except Exception as e:
1699
+ torchrl_logger.warning(
1700
+ f"Strategy '{strategy}' failed with error: {e}. "
1701
+ f"Trying next strategy..."
1702
+ )
1703
+ continue
1704
+
1705
+ # All strategies failed, final fallback to random
1706
+ torchrl_logger.warning(
1707
+ f"All strategies {self.strategies} failed. Falling back to random selection."
1708
+ )
1709
+ return random.randint(0, len(self.actors) - 1)
1710
+
1711
+ def _select_by_requests(self) -> int:
1712
+ """Select actor with fewest pending requests."""
1713
+ if self.async_vllm is not None:
1714
+ # Use AsyncVLLM's built-in method to get request counts
1715
+ request_counts = self.async_vllm.get_num_unfinished_requests()
1716
+ else:
1717
+ # Query actors directly
1718
+ futures = [
1719
+ actor.get_num_unfinished_requests.remote() for actor in self.actors
1720
+ ]
1721
+ ray = _get_ray()
1722
+ request_counts = ray.get(futures)
1723
+
1724
+ # Find the actor with minimum pending requests
1725
+ min_requests = min(request_counts)
1726
+ min_indices = [
1727
+ i for i, count in enumerate(request_counts) if count == min_requests
1728
+ ]
1729
+
1730
+ # If multiple actors have the same minimum count, choose randomly among them
1731
+ selected_index = random.choice(min_indices)
1732
+
1733
+ torchrl_logger.debug(
1734
+ f"LoadBalancer (requests): Selected actor {selected_index} "
1735
+ f"with {min_requests} pending requests. "
1736
+ f"Request counts: {request_counts}"
1737
+ )
1738
+
1739
+ return selected_index
1740
+
1741
+ def _select_by_cache_usage(self) -> int:
1742
+ """Select actor with lowest KV cache utilization."""
1743
+ if self.async_vllm is not None:
1744
+ # Use AsyncVLLM's built-in method to get cache usage
1745
+ cache_usages = self.async_vllm.get_cache_usage()
1746
+ else:
1747
+ # Query actors directly
1748
+ futures = [actor.get_cache_usage.remote() for actor in self.actors]
1749
+ ray = _get_ray()
1750
+ cache_usages = ray.get(futures)
1751
+
1752
+ # Find the actor with minimum cache usage
1753
+ min_usage = min(cache_usages)
1754
+ min_indices = [
1755
+ i for i, usage in enumerate(cache_usages) if abs(usage - min_usage) < 1e-6
1756
+ ]
1757
+
1758
+ # If multiple actors have similar cache usage, choose randomly among them
1759
+ selected_index = random.choice(min_indices)
1760
+
1761
+ torchrl_logger.debug(
1762
+ f"LoadBalancer (kv-cache): Selected actor {selected_index} "
1763
+ f"with {min_usage:.3f} cache usage. "
1764
+ f"Cache usages: {[f'{u:.3f}' for u in cache_usages]}"
1765
+ )
1766
+
1767
+ return selected_index
1768
+
1769
+ def _select_by_prefix_aware(self, prompt: str | list[int]) -> int:
1770
+ """Select actor based on prompt prefix for cache locality.
1771
+
1772
+ Args:
1773
+ prompt: Input prompt as string or token list.
1774
+
1775
+ Returns:
1776
+ int: Selected actor index.
1777
+
1778
+ Raises:
1779
+ ValueError: If prefix cannot be extracted.
1780
+ """
1781
+ try:
1782
+ # Extract prefix tokens
1783
+ prefix_tokens = self._extract_prefix_tokens(prompt)
1784
+ if not prefix_tokens:
1785
+ raise ValueError("Could not extract meaningful prefix tokens")
1786
+
1787
+ # Create consistent hash from prefix
1788
+ prefix_hash = hash(tuple(prefix_tokens))
1789
+ preferred_actor = prefix_hash % len(self.actors)
1790
+
1791
+ # Check if preferred actor is overloaded
1792
+ if self._is_actor_overloaded(preferred_actor):
1793
+ torchrl_logger.debug(
1794
+ f"Preferred actor {preferred_actor} is overloaded "
1795
+ f"(threshold: {self.overload_threshold}), falling back to load-based selection"
1796
+ )
1797
+ # Fall back to requests-based selection
1798
+ return self._select_by_requests()
1799
+
1800
+ torchrl_logger.debug(
1801
+ f"LoadBalancer (prefix-aware): Selected actor {preferred_actor} "
1802
+ f"for prefix hash {prefix_hash} (tokens: {prefix_tokens[:4]}...)"
1803
+ )
1804
+
1805
+ return preferred_actor
1806
+
1807
+ except Exception as e:
1808
+ torchrl_logger.warning(f"Prefix-aware routing failed: {e}")
1809
+ raise
1810
+
1811
+ def _select_round_robin(self) -> int:
1812
+ """Select actor using round-robin strategy."""
1813
+ selected = self._round_robin_index % len(self.actors)
1814
+ self._round_robin_index = (self._round_robin_index + 1) % len(self.actors)
1815
+
1816
+ torchrl_logger.debug(f"LoadBalancer (round-robin): Selected actor {selected}")
1817
+ return selected
1818
+
1819
+ def _extract_prefix_tokens(self, prompt: str | list[int]) -> list[int]:
1820
+ """Extract prefix tokens from prompt (string or token list).
1821
+
1822
+ Args:
1823
+ prompt: Input prompt.
1824
+
1825
+ Returns:
1826
+ list[int]: Prefix tokens (up to self.prefix_length).
1827
+
1828
+ Raises:
1829
+ ValueError: If tokenization fails or prompt is invalid.
1830
+ """
1831
+ if isinstance(prompt, list):
1832
+ # Already tokenized
1833
+ if not prompt:
1834
+ raise ValueError("Empty token list provided")
1835
+ return prompt[: self.prefix_length]
1836
+
1837
+ elif isinstance(prompt, str):
1838
+ # Need to tokenize - this requires access to tokenizer
1839
+ if not prompt.strip():
1840
+ raise ValueError("Empty or whitespace-only string provided")
1841
+
1842
+ # Try to get tokenizer from AsyncVLLM instance
1843
+ if self.async_vllm is not None:
1844
+ try:
1845
+ # This is a simplistic approach - in practice you'd want to cache the tokenizer
1846
+ # For now, use a simple heuristic based on string content
1847
+ return self._simple_string_hash(prompt)
1848
+ except Exception as e:
1849
+ torchrl_logger.warning(f"Could not tokenize string: {e}")
1850
+ return self._simple_string_hash(prompt)
1851
+ else:
1852
+ # Fall back to simple string hashing
1853
+ return self._simple_string_hash(prompt)
1854
+ else:
1855
+ raise ValueError(f"Unsupported prompt type: {type(prompt)}")
1856
+
1857
+ def _simple_string_hash(self, text: str) -> list[int]:
1858
+ """Create pseudo-tokens from string for prefix routing.
1859
+
1860
+ This is a fallback when proper tokenization isn't available.
1861
+ """
1862
+ # Use words as pseudo-tokens, limited to prefix_length
1863
+ words = text.strip().split()[: self.prefix_length]
1864
+ if not words:
1865
+ raise ValueError("No words found in text")
1866
+
1867
+ # Convert words to integers using hash
1868
+ pseudo_tokens = [
1869
+ abs(hash(word)) % 50000 for word in words
1870
+ ] # Simulate vocab size
1871
+ return pseudo_tokens
1872
+
1873
+ def _is_actor_overloaded(self, actor_index: int) -> bool:
1874
+ """Check if an actor is overloaded compared to average load.
1875
+
1876
+ Args:
1877
+ actor_index: Index of actor to check.
1878
+
1879
+ Returns:
1880
+ bool: True if actor is overloaded.
1881
+ """
1882
+ try:
1883
+ if self.async_vllm is not None:
1884
+ request_counts = self.async_vllm.get_num_unfinished_requests()
1885
+ else:
1886
+ futures = [
1887
+ actor.get_num_unfinished_requests.remote() for actor in self.actors
1888
+ ]
1889
+ ray = _get_ray()
1890
+ request_counts = ray.get(futures)
1891
+
1892
+ if not request_counts:
1893
+ return False
1894
+
1895
+ avg_requests = sum(request_counts) / len(request_counts)
1896
+ actor_requests = request_counts[actor_index]
1897
+
1898
+ is_overloaded = actor_requests > avg_requests * self.overload_threshold
1899
+
1900
+ torchrl_logger.debug(
1901
+ f"Actor {actor_index}: {actor_requests} requests, "
1902
+ f"avg: {avg_requests:.1f}, threshold: {avg_requests * self.overload_threshold:.1f}, "
1903
+ f"overloaded: {is_overloaded}"
1904
+ )
1905
+
1906
+ return is_overloaded
1907
+
1908
+ except Exception as e:
1909
+ torchrl_logger.warning(f"Could not check actor load: {e}")
1910
+ return False # Assume not overloaded if we can't check
1911
+
1912
+ def get_stats(self) -> dict[str, Any]:
1913
+ """Get current load balancing statistics for all actors.
1914
+
1915
+ Returns:
1916
+ dict: Statistics including request counts and cache usage for all actors.
1917
+ """
1918
+ stats = {
1919
+ "strategies": self.strategies,
1920
+ "primary_strategy": self.strategy, # For backward compatibility
1921
+ "num_actors": len(self.actors),
1922
+ "prefix_length": self.prefix_length,
1923
+ "overload_threshold": self.overload_threshold,
1924
+ "round_robin_index": self._round_robin_index,
1925
+ "actor_stats": [],
1926
+ }
1927
+
1928
+ try:
1929
+ if self.async_vllm is not None:
1930
+ request_counts = self.async_vllm.get_num_unfinished_requests()
1931
+ cache_usages = self.async_vllm.get_cache_usage()
1932
+ else:
1933
+ request_futures = [
1934
+ actor.get_num_unfinished_requests.remote() for actor in self.actors
1935
+ ]
1936
+ cache_futures = [
1937
+ actor.get_cache_usage.remote() for actor in self.actors
1938
+ ]
1939
+ ray = _get_ray()
1940
+ request_counts = ray.get(request_futures)
1941
+ cache_usages = ray.get(cache_futures)
1942
+
1943
+ for i, (requests, cache_usage) in enumerate(
1944
+ zip(request_counts, cache_usages)
1945
+ ):
1946
+ stats["actor_stats"].append(
1947
+ {
1948
+ "actor_index": i,
1949
+ "pending_requests": requests,
1950
+ "cache_usage": cache_usage,
1951
+ }
1952
+ )
1953
+
1954
+ except Exception as e:
1955
+ torchrl_logger.warning(f"Error gathering load balancer stats: {e}")
1956
+ stats["error"] = str(e)
1957
+
1958
+ return stats
1959
+
1960
+
1961
+ def make_async_vllm_engine(
1962
+ *,
1963
+ model_name: str,
1964
+ num_devices: int | None = None,
1965
+ num_replicas: int = 1,
1966
+ verbose: bool = True,
1967
+ compile: bool = True,
1968
+ enable_fp32_output: bool = False,
1969
+ tensor_parallel_size: int | None = None,
1970
+ data_parallel_size: int | None = None,
1971
+ pipeline_parallel_size: int | None = None,
1972
+ **kwargs,
1973
+ ) -> AsyncVLLM:
1974
+ """Create an async vLLM engine service.
1975
+
1976
+ Keyword Args:
1977
+ model_name (str): The model name to pass to vLLM.
1978
+ num_devices (int, optional): Number of devices to use, per replica.
1979
+ num_replicas (int): Number of engine replicas to create.
1980
+ verbose (bool, optional): Whether to enable verbose logging with throughput statistics. Defaults to True.
1981
+ compile (bool, optional): Whether to enable model compilation for better performance. Defaults to True.
1982
+ enable_fp32_output (bool, optional): Whether to enable FP32 output for the final layer. Defaults to False.
1983
+ This can help with numerical stability for certain models. Requires model-specific support in
1984
+ torchrl.modules.llm.backends._models.
1985
+ tensor_parallel_size (int, optional): Number of devices to use, per replica. Defaults to None.
1986
+ data_parallel_size (int, optional): Number of data parallel groups to use. Defaults to None.
1987
+ pipeline_parallel_size (int, optional): Number of pipeline parallel groups to use. Defaults to None.
1988
+ **kwargs: Additional arguments passed to AsyncEngineArgs.
1989
+
1990
+ Returns:
1991
+ AsyncVLLM: The launched engine service.
1992
+
1993
+ Raises:
1994
+ RuntimeError: If no CUDA devices are available.
1995
+ ValueError: If invalid device configuration is provided.
1996
+
1997
+ Example:
1998
+ >>> # Create a single-GPU async engine
1999
+ >>> service = make_async_vllm_engine("Qwen/Qwen2.5-3B")
2000
+ >>>
2001
+ >>> # Create a 2-GPU tensor parallel async engine with 2 replicas
2002
+ >>> service = make_async_vllm_engine("Qwen/Qwen2.5-3B", num_devices=2, num_replicas=2)
2003
+ >>> # Generate text
2004
+ >>> result = service.generate("Hello, world!", sampling_params)
2005
+ >>>
2006
+ >>> # Create with FP32 output enabled
2007
+ >>> service = make_async_vllm_engine("Qwen/Qwen2.5-3B", enable_fp32_output=True)
2008
+ """
2009
+ if not _has_vllm:
2010
+ raise ImportError(
2011
+ "vllm is not installed. Please install it with `pip install vllm`."
2012
+ )
2013
+
2014
+ from vllm import AsyncEngineArgs
2015
+
2016
+ # Set FP32 output environment variable if requested
2017
+ if enable_fp32_output:
2018
+ os.environ["VLLM_ENABLE_FP32_OUTPUT"] = "1"
2019
+ torchrl_logger.info(
2020
+ "Enabled FP32 output for vLLM (VLLM_ENABLE_FP32_OUTPUT=1). "
2021
+ "This will use FP32 for the final output layer if the model supports it."
2022
+ )
2023
+
2024
+ # Configure verbose logging if requested
2025
+ if verbose:
2026
+ import logging
2027
+
2028
+ # Enable vLLM's throughput logging by setting the appropriate log level
2029
+ logging.getLogger("vllm.engine.metrics").setLevel(logging.INFO)
2030
+ logging.getLogger("vllm").setLevel(logging.INFO)
2031
+
2032
+ # vLLM logs throughput stats at INFO level every few seconds
2033
+ # The stats include: prompt throughput, generation throughput, running/pending requests, GPU KV cache usage
2034
+ torchrl_logger.info(
2035
+ "Enabled verbose vLLM logging - throughput statistics will be displayed"
2036
+ )
2037
+
2038
+ # Set tensor_parallel_size to num_devices if not set
2039
+ if tensor_parallel_size is None:
2040
+ if num_devices is None:
2041
+ tensor_parallel_size = 1
2042
+ else:
2043
+ tensor_parallel_size = num_devices
2044
+ elif num_devices is not None and tensor_parallel_size != num_devices:
2045
+ raise ValueError(f"tensor_parallel_size must be set to {num_devices}")
2046
+
2047
+ if data_parallel_size is None:
2048
+ data_parallel_size = 1
2049
+
2050
+ if pipeline_parallel_size is None:
2051
+ pipeline_parallel_size = 1
2052
+
2053
+ # Create engine args
2054
+ kwargs.setdefault("distributed_executor_backend", "ray")
2055
+ # Don't explicitly set enable_prefix_caching to avoid conflicts
2056
+ kwargs.setdefault("enable_prefix_caching", True)
2057
+
2058
+ # Set compilation flag - this controls whether vLLM will compile the model for better performance
2059
+ # Disabled by default in GRPO since it can cause issues during training
2060
+ if "compilation_config" not in kwargs:
2061
+ if compile:
2062
+ kwargs["compilation_config"] = {"level": 3} # PIECEWISE compilation
2063
+ else:
2064
+ kwargs["compilation_config"] = {"level": 0} # NO_COMPILATION
2065
+
2066
+ engine_args = AsyncEngineArgs(
2067
+ model=model_name,
2068
+ tensor_parallel_size=tensor_parallel_size,
2069
+ data_parallel_size=data_parallel_size,
2070
+ pipeline_parallel_size=pipeline_parallel_size,
2071
+ worker_extension_cls="torchrl.modules.llm.backends.vllm.vllm_async._AsyncvLLMWorker",
2072
+ **kwargs,
2073
+ )
2074
+
2075
+ return AsyncVLLM.launch(engine_args, num_replicas)