torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,2241 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import collections
8
+
9
+ import importlib.util
10
+ import threading
11
+ import warnings
12
+ from typing import Any, Literal, TYPE_CHECKING
13
+
14
+ import torch
15
+ from tensordict import (
16
+ lazy_stack,
17
+ LazyStackedTensorDict,
18
+ MetaData,
19
+ NonTensorStack,
20
+ set_list_to_stack,
21
+ TensorDict,
22
+ TensorDictBase,
23
+ )
24
+ from tensordict.tensorclass import from_dataclass, TensorClass
25
+ from tensordict.utils import _zip_strict, NestedKey
26
+ from torch import distributions as D
27
+ from torch.nn.utils.rnn import pad_sequence
28
+
29
+ from torchrl.envs.utils import _classproperty
30
+ from torchrl.modules.llm.policies.common import (
31
+ _batching,
32
+ _extract_responses_from_full_histories,
33
+ ChatHistory,
34
+ LLMWrapperBase,
35
+ LogProbs,
36
+ Masks,
37
+ Text,
38
+ Tokens,
39
+ )
40
+ from torchrl.modules.utils.utils import _unpad_tensors
41
+
42
+
43
+ _HAS_VLLM = importlib.util.find_spec("vllm") is not None
44
+ _HAS_TRANSFORMERS = importlib.util.find_spec("transformers") is not None
45
+
46
+ if TYPE_CHECKING:
47
+ from vllm.inputs import TokensPrompt # type: ignore[import-not-found]
48
+ from vllm.outputs import RequestOutput # type: ignore[import-not-found]
49
+ from vllm.sampling_params import SamplingParams # type: ignore[import-not-found]
50
+ elif _HAS_VLLM:
51
+ from vllm.outputs import RequestOutput
52
+ from vllm.sampling_params import SamplingParams
53
+
54
+ try:
55
+ from vllm.inputs import TokensPrompt
56
+ except ImportError:
57
+ # Fallback for older vLLM versions
58
+ TokensPrompt = None
59
+ else:
60
+ SamplingParams = None # Will error at usage if vLLM not available
61
+ RequestOutput = None
62
+ TokensPrompt = None
63
+
64
+
65
+ def _require_transformers() -> None:
66
+ if not _HAS_TRANSFORMERS:
67
+ raise ImportError(
68
+ "transformers is required for vLLMWrapper. Please install it with `pip install transformers`."
69
+ )
70
+
71
+
72
+ def _require_vllm():
73
+ """Import vLLM lazily.
74
+
75
+ We intentionally avoid importing vLLM at module import time because importing vLLM can
76
+ load native extensions that may hard-crash the interpreter on some platforms.
77
+ """
78
+ if not _HAS_VLLM:
79
+ raise ImportError(
80
+ "vllm is required for vLLMWrapper. Please install it with `pip install vllm`."
81
+ )
82
+ import vllm as _vllm # local import is intentional / required
83
+
84
+ return _vllm
85
+
86
+
87
+ # Import async vLLM engines
88
+
89
+
90
+ class vLLMWrapper(LLMWrapperBase):
91
+ """A wrapper class for vLLM models, providing a consistent interface for text generation and log probability computation.
92
+
93
+ This class is a subclass of :class:`~torchrl.modules.llm.policies.LLMWrapperBase` and provides a unified API for handling different input
94
+ modalities (history, text, tokens) with consistent output structure using :class:`~tensordict.TensorClass` objects.
95
+
96
+ The wrapper supports both synchronous (vllm.LLM) and asynchronous (:class:`~torchrl.modules.llm.backends.AsyncVLLM`) vLLM engines.
97
+
98
+ .. note::
99
+ **Recommended: Use AsyncVLLM for better performance**
100
+
101
+ For distributed inference and better resource utilization, we recommend using
102
+ :class:`~torchrl.modules.llm.backends.AsyncVLLM` instead of the synchronous vllm.LLM:
103
+
104
+ >>> from torchrl.modules.llm.backends import AsyncVLLM
105
+ >>> from torchrl.modules.llm import vLLMWrapper
106
+ >>>
107
+ >>> # Recommended approach
108
+ >>> async_engine = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B", num_replicas=2)
109
+ >>> wrapper = vLLMWrapper(async_engine, input_mode="history", generate=True)
110
+
111
+ AsyncVLLM provides:
112
+ - Better GPU utilization through Ray-based distribution
113
+ - Multiple replicas for higher throughput
114
+ - Native vLLM batching for optimal performance
115
+ - Automatic resource management and cleanup
116
+
117
+ Args:
118
+ model (vllm.LLM | AsyncVLLM | Ray Actor | str): The vLLM model to wrap.
119
+ - If a string, it will be converted to an AsyncVLLM instance (recommended)
120
+ - If a vllm.LLM instance, uses synchronous generation via `model.generate()`
121
+ - If an AsyncVLLM instance, uses async generation via `model.generate()`
122
+ - If a Ray actor with generate method, uses remote calls via `ray.get(model.generate.remote())`
123
+
124
+ Keyword Args:
125
+ tokenizer (transformers.tokenization_utils.PreTrainedTokenizer | str | None, optional): The tokenizer to use for encoding and decoding text.
126
+ If `None`, the tokenizer associated with the model will be used. If a string, it will be passed to `transformers.AutoTokenizer.from_pretrained`.
127
+ Defaults to `None`.
128
+ input_mode (str, optional): The input modality to use. Must be one of `"history"`, `"text"`, or `"tokens"`. Defaults to `"history"`.
129
+ input_key (str | None, optional): The key for the input data. If `None`, defaults to
130
+ - `("history", "prompt")` for `"history"` when `generate=True`, `("history", "full")` for `"history"` when `generate=False`
131
+ - `("text", "prompt")` for `"text"` when `generate=True`, `("text", "full")` for `"text"` when `generate=False`
132
+ - `("tokens", "prompt")` for `"tokens"` when `generate=True`, `("tokens", "full")` for `"tokens"` when `generate=False`
133
+ attention_mask_key (str, optional): The key for attention masks (used in `"tokens"` mode). Defaults to `"attention_mask"`.
134
+
135
+ .. warning:: This argument is under development and may change in the future.
136
+
137
+ generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on the input.
138
+ If `False`, only log probabilities will be computed. Defaults to `True`.
139
+ return_log_probs (bool, optional): Whether to return log probabilities. Defaults to `True`.
140
+ generate_kwargs (dict | None, optional): Additional arguments to pass to the model's generate method. Defaults to `None`.
141
+
142
+ **Standardized Parameters (cross-backend compatible):**
143
+
144
+ * **max_new_tokens** (int): Maximum number of new tokens to generate (maps to vLLM's max_tokens)
145
+ * **num_return_sequences** (int): Number of sequences to return (maps to vLLM's n)
146
+ * **temperature** (float): Sampling temperature (0.0 = deterministic, higher = more random)
147
+ * **top_p** (float): Nucleus sampling parameter (0.0-1.0)
148
+ * **top_k** (int): Top-k sampling parameter
149
+ * **repetition_penalty** (float): Penalty for repeating tokens
150
+ * **do_sample** (bool): Whether to use sampling vs greedy decoding
151
+ * **num_beams** (int): Number of beams for beam search
152
+ * **length_penalty** (float): Penalty for sequence length
153
+ * **early_stopping** (bool): Whether to stop early in beam search
154
+ * **stop_sequences** (list): Sequences that stop generation (maps to vLLM's stop)
155
+ * **skip_special_tokens** (bool): Whether to skip special tokens in output
156
+ * **logprobs** (bool): Whether to return log probabilities
157
+
158
+ .. warning:: Usage of this parameter is discouraged as it may conflict with the `generate` parameter
159
+ of the class.
160
+
161
+ **vLLM-Specific Parameters:**
162
+
163
+ * **presence_penalty** (float): Penalty for token presence
164
+ * **frequency_penalty** (float): Penalty for token frequency
165
+ * **ignore_eos** (bool): Whether to ignore EOS token
166
+ * **prompt_logprobs** (bool): Whether to return prompt log probabilities
167
+ * **detokenize** (bool): Whether to detokenize output
168
+ * **include_stop_str_in_output** (bool): Whether to include stop strings in output
169
+ * **spaces_between_special_tokens** (bool): Whether to add spaces between special tokens
170
+ * **sampling_type** (str): Type of sampling to use
171
+ * **temperature_last** (bool): Whether to apply temperature only to last token
172
+ * **top_p_last** (bool): Whether to apply top_p only to last token
173
+ * **top_k_last** (bool): Whether to apply top_k only to last token
174
+
175
+ **Legacy Parameter Support:**
176
+
177
+ * **max_tokens** (int): Automatically converted to max_new_tokens
178
+ * **n** (int): Automatically converted to num_return_sequences
179
+
180
+ **Parameter Conflict Resolution:**
181
+
182
+ When both legacy (vLLM-specific) and standardized parameter names are provided,
183
+ a :exc:`ValueError` is raised to prevent confusion. For example:
184
+
185
+ * If both ``max_tokens`` and ``max_new_tokens`` are passed, an error is raised
186
+ * If both ``n`` and ``num_return_sequences`` are passed, an error is raised
187
+
188
+ This ensures clear parameter usage and prevents unexpected behavior.
189
+
190
+ tokenizer_kwargs (dict | None, optional): Additional arguments to pass to the tokenizer. Defaults to `None`.
191
+ pad_output (bool, optional): Whether to pad the output sequences to a uniform length. Defaults to `False`.
192
+ pad_model_input (bool, optional): Whether to pad the model input sequences to a uniform length.
193
+ This is not supported by vLLM.
194
+ inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place operations. Defaults to `True`.
195
+ device (torch.device | None, optional): The device to use for computation. Defaults to `None`.
196
+ layout (torch.layout | None, optional): The layout to use for the output tensors when `pad_output=False`. Defaults to `torch.strided`.
197
+ chat_template_name (Literal["chatml_format", "qwen"] | None, optional): The name of the chat template to use when applying the chat template to the history.
198
+ Defaults to `None`. For `input_mode="history"` only.
199
+ chat_template (str | None, optional): The chat template to use when applying the chat template to the history. Defaults to `None`.
200
+ For `input_mode="history"` only.
201
+ num_samples (int | None, optional): The number of samples to generate. Defaults to `None` (one sample, and no batch-dimension for it).
202
+ Can also be set via the `generate_kwargs["n"] = value` argument.
203
+ log_probs_key (NestedKey | None, optional): The key for the log probabilities :class:`~torchrl.modules.llm.policies.LogProbs` object. Defaults to `"log_probs"`.
204
+ text_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Text` object. Defaults to `"text"`.
205
+ tokens_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Tokens` object. Defaults to `"tokens"`.
206
+ masks_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Masks` object. Defaults to `"masks"`.
207
+ history_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.ChatHistory` object. Defaults to `"history"`.
208
+ batching (bool, optional): Whether to enable batching. Defaults to `False`. See `Batching`_ below for more details.
209
+ min_batch_size (int | None, optional): The minimum batch size to use for batching. See `Batching`_ below for more details.
210
+ max_batch_size (int | None, optional): The maximum batch size to use for batching. See `Batching`_ below for more details.
211
+ batching_timeout (float, optional): The timeout for batching. See `Batching`_ below for more details.
212
+
213
+ .. _Batching:
214
+
215
+ **Batching**
216
+
217
+ Batching is a feature that allows the module to process multiple inputs in a single call.
218
+ It is designed to work in a multi-threaded environment.
219
+ To enable batching, it suffices to set `batching=True` which will set `min_batch_size` to 1 if not provided.
220
+ If you want to set a different value for `min_batch_size` or `max_batch_size` for a fine-grained control,
221
+ you can to set `batching=True` and then set `min_batch_size` or `max_batch_size` to a value greater or equal to 1.
222
+ The way batching works is as follows:
223
+ - If `min_batch_size` is not provided but `max_batch_size` is, `min_batch_size` is set to 1.
224
+ - If `max_batch_size` is not provided but `min_batch_size` is, `max_batch_size` is set to the number of inputs in the queue.
225
+ - When the model is called, a check is performed to see if the number of inputs in the queue is greater or equal to `min_batch_size`.
226
+ If it is, the batch is processed immediately, while waiting for the previous batch to be processed if the model is busy.
227
+ Otherwise, the input is added to the queue and the function waits for the batch to be completed.
228
+ While waiting for the batch to be completed, a timeout is set to `batching_timeout` seconds such that if the batch is not
229
+ completed after `batching_timeout` seconds, the remaining items to process are processed as is and the function returns after
230
+ at most `batching_timeout` seconds (plus the time to finish processing the previous and current batch).
231
+
232
+ Input Keys:
233
+ The input key depends on both `input_mode` and `generate`:
234
+
235
+ - If `input_mode="history"` and `generate=True`: `input_key` (defaults to `("history", "prompt")`)
236
+ - If `input_mode="history"` and `generate=False`: `input_key` (defaults to `("history", "full")`)
237
+ - If `input_mode="text"` and `generate=True`: `input_key` (defaults to `("text", "prompt")`)
238
+ - If `input_mode="text"` and `generate=False`: `input_key` (defaults to `("text", "full")`)
239
+ - If `input_mode="tokens"` and `generate=True`: `input_key` (defaults to `("tokens", "prompt")`)
240
+ - If `input_mode="tokens"` and `generate=False`: `input_key` (defaults to `("tokens", "full")`)
241
+
242
+ Output Keys:
243
+ The output keys are automatically determined based on the input_mode:
244
+ - **Tokens**: Always returned (`tokens_key`, defaults to `"tokens"`)
245
+ - **Text**: Returned for `"text"` and `"history"` modes (`text_key`, defaults to `"text"`)
246
+ - **History**: Returned only for `"history"` mode (`history_key`, defaults to `"history"`)
247
+ - **Masks**: Always returned (`masks_key`, defaults to `"masks"`)
248
+ - **Log Probs**: Returned when `return_log_probs=True` (`log_probs_key`, defaults to `"log_probs"`)
249
+
250
+ Example output structure for `input_mode="history"`::
251
+
252
+ TensorDict(
253
+ text=Text(prompt=..., response=..., full=...),
254
+ masks=Masks(all_attention_mask=..., all_assistant_mask=...),
255
+ tokens=Tokens(prompt=..., response=..., full=...),
256
+ log_probs=LogProbs(prompt=..., response=..., full=...),
257
+ history=ChatHistory(prompt=..., response=..., full=...)
258
+ )
259
+
260
+ Example:
261
+ >>> from vllm import LLM
262
+ >>> from transformers import AutoTokenizer
263
+ >>> from torchrl.data.llm import History
264
+ >>> from torchrl.modules.llm.policies import ChatHistory
265
+ >>>
266
+ >>> model = LLM("gpt2")
267
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
268
+ >>>
269
+ >>> # History input (recommended for RL environments)
270
+ >>> wrapper = vLLMWrapper(
271
+ ... model,
272
+ ... tokenizer=tokenizer,
273
+ ... input_mode="history",
274
+ ... generate=True,
275
+ ... return_log_probs=True,
276
+ ... generate_kwargs={
277
+ ... "max_new_tokens": 50, # Standardized parameter
278
+ ... "temperature": 0.7,
279
+ ... "top_p": 0.9,
280
+ ... "do_sample": True,
281
+ ... }
282
+ ... )
283
+ >>>
284
+ >>> history = History.from_chats([[
285
+ ... {"role": "user", "content": "Hello"},
286
+ ... {"role": "assistant", "content": "Hi there!"}
287
+ ... ]])
288
+ >>> chat_history = ChatHistory(prompt=history)
289
+ >>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,)))
290
+ >>> print(result["text"].response) # Generated text
291
+ >>> print(result["log_probs"].response) # Log probabilities
292
+ >>> print(result["history"].response) # History with response
293
+
294
+ Attributes:
295
+ collector: The collector associated with the module, if it exists.
296
+
297
+ .. seealso::
298
+ - :class:`~torchrl.modules.llm.policies.LLMWrapperBase`
299
+ - :class:`~torchrl.modules.llm.policies.TransformersWrapper`
300
+ """
301
+
302
+ def __init__(
303
+ self,
304
+ model: Any, # vllm.LLM | AsyncVLLMEngineService | AsyncLLMEngineExtended | str
305
+ *,
306
+ tokenizer: callable | str | None = None, # type: ignore
307
+ input_mode: str = "history",
308
+ input_key: NestedKey | None = None,
309
+ attention_mask_key: str = "attention_mask",
310
+ generate: bool = True,
311
+ generate_kwargs: dict | None = None,
312
+ tokenizer_kwargs: dict | None = None,
313
+ pad_output: bool = False,
314
+ pad_model_input: bool | None = None,
315
+ inplace: Literal[True, False, "empty"] | None = None,
316
+ device: torch.device | None = None,
317
+ layout: torch.layout | None = None,
318
+ num_samples: int | None = None,
319
+ chat_template_name: Literal["chatml_format", "qwen"] | None = None,
320
+ chat_template: str | None = None,
321
+ return_log_probs: bool | None = None,
322
+ history_key: NestedKey | None = "history",
323
+ text_key: NestedKey | None = "text",
324
+ tokens_key: NestedKey | None = "tokens",
325
+ masks_key: NestedKey | None = "masks",
326
+ log_probs_key: NestedKey | None = "log_probs",
327
+ batching: bool | None = None,
328
+ min_batch_size: int | None = None,
329
+ max_batch_size: int | None = None,
330
+ batching_timeout: float = 10.0,
331
+ ):
332
+ super().__init__()
333
+
334
+ if batching and min_batch_size is None:
335
+ min_batch_size = 1
336
+ elif (min_batch_size is not None or max_batch_size is not None) and (
337
+ batching is False
338
+ ):
339
+ raise ValueError(
340
+ "min_batch_size and max_batch_size must be None if batching is False."
341
+ )
342
+
343
+ # Validate that min_batch_size <= max_batch_size when both are specified
344
+ if min_batch_size is not None and max_batch_size is not None:
345
+ if min_batch_size > max_batch_size:
346
+ raise ValueError(
347
+ f"min_batch_size ({min_batch_size}) must be <= max_batch_size ({max_batch_size})"
348
+ )
349
+
350
+ self._min_batch_size = min_batch_size
351
+ self._max_batch_size = max_batch_size
352
+ self._batching_timeout = batching_timeout
353
+ self._batch_queue = []
354
+ self._futures = []
355
+ if self.batching:
356
+ self._batching_lock = threading.Lock()
357
+ else:
358
+ self._batching_lock = None
359
+
360
+ _require_transformers()
361
+
362
+ # Detect and initialize model
363
+ if isinstance(model, str):
364
+ # Import lazily to avoid importing vLLM backends unless actually needed.
365
+ from torchrl.modules.llm.backends.vllm import ( # local import is intentional / required
366
+ AsyncVLLM,
367
+ )
368
+
369
+ model = AsyncVLLM.from_pretrained(model)
370
+
371
+ # Validate model type
372
+ model_type = type(model)
373
+ model_module = getattr(model_type, "__module__", "")
374
+ model_name = getattr(model_type, "__name__", "")
375
+ if model_name == "AsyncVLLM" and model_module.startswith(
376
+ "torchrl.modules.llm.backends.vllm"
377
+ ):
378
+ self._model_type = "async_vllm"
379
+ elif model_name == "LLM" and model_module.startswith("vllm"):
380
+ self._model_type = "sync_vllm"
381
+ elif hasattr(model, "generate") and hasattr(model, "remote"):
382
+ # Ray actor with generate method
383
+ self._model_type = "ray_actor"
384
+ else:
385
+ raise ValueError(
386
+ f"model must be a string, vllm.LLM, AsyncVLLM, or Ray actor. Got {type(model)}"
387
+ )
388
+
389
+ if isinstance(tokenizer, str):
390
+ from transformers import AutoTokenizer
391
+
392
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
393
+ # Import vLLM lazily: only needed if we are going to interact with vLLM types.
394
+ # (This keeps importing this module safe even if vLLM hard-crashes on import.)
395
+ if self._model_type in ("sync_vllm",):
396
+ _require_vllm()
397
+
398
+ # Validate input_mode
399
+ if input_mode not in ["history", "text", "tokens"]:
400
+ raise ValueError(
401
+ f"input_mode must be one of 'history', 'text', 'tokens'. Got '{input_mode}'"
402
+ )
403
+
404
+ self.model = model
405
+ self.input_mode = input_mode
406
+ self.attention_mask_key = attention_mask_key
407
+ self.generate = generate
408
+ if pad_model_input is not None:
409
+ raise ValueError("pad_model_input is not supported by vLLMWrapper.")
410
+
411
+ # Auto-determine what to return based on input mode
412
+ self.return_history = input_mode in ("history",)
413
+ self.return_text = input_mode in ("text", "history")
414
+ self.return_tokens = input_mode in ("tokens", "history", "text")
415
+ self.return_masks = True
416
+ if return_log_probs is False and not generate:
417
+ raise ValueError("return_log_probs must be True when generate=False.")
418
+ return_log_probs = (
419
+ True
420
+ if (return_log_probs is None and generate) or (not generate)
421
+ else bool(return_log_probs)
422
+ )
423
+ self.return_log_probs = return_log_probs
424
+
425
+ self.history_key = history_key
426
+ self.log_probs_key = log_probs_key
427
+ self.masks_key = masks_key
428
+ self.text_key = text_key
429
+ self.tokens_key = tokens_key
430
+
431
+ if not isinstance(pad_output, bool):
432
+ raise ValueError("pad_output must be a boolean")
433
+ self.pad_output = pad_output
434
+ self._device = device
435
+ if not pad_output and layout is None:
436
+ layout = torch.strided
437
+ self.layout = layout
438
+ padding_value = None
439
+
440
+ # Set input keys based on mode and generate parameter
441
+ if input_mode == "history":
442
+ if generate:
443
+ self.in_keys = [
444
+ ("history", "prompt") if input_key is None else input_key
445
+ ]
446
+ else:
447
+ self.in_keys = [("history", "full") if input_key is None else input_key]
448
+ elif input_mode == "text":
449
+ if generate:
450
+ self.in_keys = [("text", "prompt") if input_key is None else input_key]
451
+ else:
452
+ self.in_keys = [("text", "full") if input_key is None else input_key]
453
+ elif input_mode == "tokens":
454
+ if generate:
455
+ self.in_keys = [
456
+ ("tokens", "prompt") if input_key is None else input_key
457
+ ]
458
+ else:
459
+ self.in_keys = [("tokens", "full") if input_key is None else input_key]
460
+ else:
461
+ raise ValueError(f"Invalid input_mode: {input_mode}")
462
+ self.input_key = self.in_keys[0]
463
+
464
+ # Set output keys based on auto-determined return flags
465
+ self.out_keys = []
466
+ if self.return_text:
467
+ self.out_keys.append(self.text_key)
468
+ if self.return_masks:
469
+ self.out_keys.append(self.masks_key)
470
+ if self.return_tokens:
471
+ self.out_keys.append(self.tokens_key)
472
+ if self.return_log_probs:
473
+ self.out_keys.append(self.log_probs_key)
474
+ if self.return_history:
475
+ self.out_keys.append(self.history_key)
476
+
477
+ # Tokenizer setup
478
+ if not tokenizer_kwargs:
479
+ tokenizer_kwargs = {}
480
+ if not tokenizer_kwargs.setdefault("return_attention_mask", True):
481
+ raise RuntimeError("return_attention_mask must be True")
482
+
483
+ # If we don't pad, we use lists
484
+ return_tensors = "pt" if self.pad_output else False
485
+ if return_tensors:
486
+ if (
487
+ tokenizer_kwargs.setdefault("return_tensors", return_tensors)
488
+ != return_tensors
489
+ ):
490
+ raise RuntimeError
491
+ if tokenizer_kwargs.setdefault("padding", self.pad_output) not in (
492
+ self.pad_output,
493
+ ):
494
+ raise RuntimeError
495
+ if tokenizer_kwargs.setdefault("padding_side", "left") != "left":
496
+ raise RuntimeError
497
+
498
+ self.tokenizer_kwargs = tokenizer_kwargs
499
+
500
+ # Get tokenizer if needed
501
+ if tokenizer is None:
502
+ try:
503
+ if hasattr(model, "get_tokenizer"):
504
+ tokenizer = model.get_tokenizer()
505
+ else:
506
+ # Try to extract model name and load tokenizer as fallback
507
+ model_name = self._extract_model_name(model)
508
+ if model_name:
509
+ warnings.warn(
510
+ f"No tokenizer provided. Attempting to load tokenizer from model name: {model_name}"
511
+ )
512
+ from transformers import AutoTokenizer
513
+
514
+ try:
515
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
516
+ except Exception as tokenizer_error:
517
+ warnings.warn(
518
+ f"Failed to load tokenizer from {model_name}: {tokenizer_error}"
519
+ )
520
+ else:
521
+ warnings.warn(
522
+ "No tokenizer provided and no tokenizer found in model."
523
+ )
524
+ except Exception as e:
525
+ warnings.warn(f"Could not get tokenizer from model: {e}")
526
+ self.tokenizer = tokenizer
527
+
528
+ if self.tokenizer is not None and (
529
+ not hasattr(self.tokenizer, "pad_token") or self.tokenizer.pad_token is None
530
+ ):
531
+ self.tokenizer.pad_token = self.tokenizer.eos_token
532
+ if self.tokenizer is not None:
533
+ padding_value = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0]
534
+ self.padding_value = padding_value
535
+
536
+ # Generate kwargs setup
537
+ if generate_kwargs is None:
538
+ generate_kwargs = {}
539
+ else:
540
+ generate_kwargs = dict(generate_kwargs)
541
+
542
+ # Standardize common parameters
543
+ generate_kwargs = self._standardize_generate_kwargs(generate_kwargs)
544
+
545
+ # Extract wrapper-specific parameters
546
+ vllm_specific_kwargs = self._get_wrapper_specific_kwargs(
547
+ generate_kwargs, "vllm"
548
+ )
549
+
550
+ # Convert common parameters back to vLLM format
551
+ vllm_kwargs = {}
552
+ for key, value in generate_kwargs.items():
553
+ if key in self.COMMON_GENERATION_PARAMS:
554
+ # Convert common names to vLLM names
555
+ if key == "max_new_tokens":
556
+ vllm_kwargs["max_tokens"] = value
557
+ elif key == "num_return_sequences":
558
+ vllm_kwargs["n"] = value
559
+ elif key == "stop_sequences":
560
+ vllm_kwargs["stop"] = value
561
+ elif key == "logprobs":
562
+ # vLLM expects int for logprobs, not bool
563
+ if isinstance(value, bool):
564
+ value = 1 if value else None
565
+ vllm_kwargs["logprobs"] = value
566
+ elif key == "do_sample":
567
+ # do_sample is handled through the sampling parameters
568
+ # If do_sample=False, we use greedy decoding (temperature=0)
569
+ # If do_sample=True, we use the provided sampling parameters
570
+ if not value:
571
+ vllm_kwargs["temperature"] = 0.0
572
+ # If do_sample=True, we keep the existing temperature/top_p/top_k values
573
+ elif key in ["length_penalty", "early_stopping", "num_beams"]:
574
+ # These parameters are not supported by vLLM, skip them
575
+ pass
576
+ else:
577
+ # Direct mapping for other common parameters
578
+ vllm_kwargs[key] = value
579
+
580
+ # Add vLLM-specific parameters
581
+ vllm_kwargs.update(vllm_specific_kwargs)
582
+
583
+ self.num_samples = num_samples
584
+ if vllm_kwargs.get("n", 1) > 1 or num_samples is not None:
585
+ if inplace in (True, "empty"):
586
+ raise ValueError(
587
+ "inplace must be False (or None) when generating more than one sample."
588
+ )
589
+ if inplace is None:
590
+ inplace = False
591
+ if (
592
+ vllm_kwargs.get("n", 1) > 1
593
+ and num_samples is not None
594
+ and vllm_kwargs.get("n", 1) != num_samples
595
+ ):
596
+ raise ValueError("num_samples differs from generate_kwargs['n'].")
597
+ elif num_samples is None:
598
+ self.num_samples = vllm_kwargs.get("n", 1)
599
+ vllm_kwargs["n"] = self.num_samples
600
+ elif inplace is None:
601
+ inplace = True
602
+
603
+ self.inplace = inplace
604
+
605
+ # vLLM expects int for logprobs, not bool. Use 1 if True, None if False.
606
+ prompt_logprobs = 1 if return_log_probs else None
607
+
608
+ if not generate:
609
+ # We want only the log-probs, we generate a single token (that we then discard)
610
+ # and retrieve the prompt log-probs
611
+ vllm_kwargs["max_tokens"] = 1
612
+ if not return_log_probs:
613
+ raise ValueError("return_log_probs must be True when generate=False.")
614
+
615
+ vllm_kwargs.setdefault("detokenize", not pad_output)
616
+ vllm_kwargs.setdefault("prompt_logprobs", prompt_logprobs)
617
+ vllm_kwargs.setdefault("logprobs", 1 if return_log_probs else None)
618
+ vllm_kwargs.setdefault("include_stop_str_in_output", True)
619
+ vllm_kwargs.setdefault("skip_special_tokens", False)
620
+
621
+ sampling_params = SamplingParams(**vllm_kwargs)
622
+ self.sampling_params = sampling_params
623
+
624
+ # Additional transformers-specific settings
625
+ self.chat_template_name = chat_template_name
626
+ self.chat_template = chat_template
627
+
628
+ def get_new_version(self, **kwargs):
629
+ """Returns a new version of the module with altered parameters.
630
+
631
+ For instance, the generate parameter can be altered to enable text generation or log-probabilities computation.
632
+ This is especially useful when one wants to avoid re-initializing the module with a new set of parameters, when the
633
+ same parameters could be used to gather log-probs.
634
+
635
+ Positional arguments are not supported.
636
+
637
+ See the class constructor for more details about the parameters.
638
+ """
639
+ # Build the constructor arguments by using current values for missing parameters
640
+ constructor_kwargs = {}
641
+
642
+ # Model is always required
643
+ constructor_kwargs["model"] = kwargs.get("model", self.model)
644
+
645
+ # Check for each parameter and use current value if not provided
646
+ if "tokenizer" in kwargs:
647
+ constructor_kwargs["tokenizer"] = kwargs["tokenizer"]
648
+ elif hasattr(self, "tokenizer"):
649
+ constructor_kwargs["tokenizer"] = self.tokenizer
650
+
651
+ if "input_mode" in kwargs:
652
+ constructor_kwargs["input_mode"] = kwargs["input_mode"]
653
+ elif hasattr(self, "input_mode"):
654
+ constructor_kwargs["input_mode"] = self.input_mode
655
+
656
+ if "input_key" in kwargs:
657
+ constructor_kwargs["input_key"] = kwargs["input_key"]
658
+ # Since the input_key is dynamically determined, we don't want to set it here
659
+ # elif hasattr(self, "input_key"):
660
+ # constructor_kwargs["input_key"] = self.input_key
661
+
662
+ if "attention_mask_key" in kwargs:
663
+ constructor_kwargs["attention_mask_key"] = kwargs["attention_mask_key"]
664
+ elif hasattr(self, "attention_mask_key"):
665
+ constructor_kwargs["attention_mask_key"] = self.attention_mask_key
666
+
667
+ if "generate" in kwargs:
668
+ constructor_kwargs["generate"] = kwargs["generate"]
669
+ elif hasattr(self, "generate"):
670
+ constructor_kwargs["generate"] = self.generate
671
+
672
+ if "return_log_probs" in kwargs:
673
+ constructor_kwargs["return_log_probs"] = kwargs["return_log_probs"]
674
+ elif not constructor_kwargs.get("generate", True):
675
+ # if we are not generating, we want to return log-probs
676
+ constructor_kwargs["return_log_probs"] = True
677
+ elif hasattr(self, "return_log_probs"):
678
+ constructor_kwargs["return_log_probs"] = self.return_log_probs
679
+
680
+ if "generate_kwargs" in kwargs:
681
+ constructor_kwargs["generate_kwargs"] = kwargs["generate_kwargs"]
682
+ elif hasattr(self, "generate_kwargs"):
683
+ constructor_kwargs["generate_kwargs"] = self.generate_kwargs
684
+
685
+ if "pad_output" in kwargs:
686
+ constructor_kwargs["pad_output"] = kwargs["pad_output"]
687
+ elif hasattr(self, "pad_output"):
688
+ constructor_kwargs["pad_output"] = self.pad_output
689
+
690
+ if "tokenizer_kwargs" in kwargs:
691
+ constructor_kwargs["tokenizer_kwargs"] = kwargs["tokenizer_kwargs"]
692
+ elif hasattr(self, "tokenizer_kwargs"):
693
+ constructor_kwargs["tokenizer_kwargs"] = dict(self.tokenizer_kwargs)
694
+ if (
695
+ "pad_output" in kwargs
696
+ and kwargs.get("pad_output")
697
+ != constructor_kwargs["tokenizer_kwargs"]["padding"]
698
+ ):
699
+ constructor_kwargs["tokenizer_kwargs"]["padding"] = kwargs.get(
700
+ "pad_output"
701
+ )
702
+
703
+ if "inplace" in kwargs:
704
+ constructor_kwargs["inplace"] = kwargs["inplace"]
705
+ elif hasattr(self, "inplace"):
706
+ constructor_kwargs["inplace"] = self.inplace
707
+
708
+ if "device" in kwargs:
709
+ constructor_kwargs["device"] = kwargs["device"]
710
+ elif hasattr(self, "_device"):
711
+ constructor_kwargs["device"] = self._device
712
+
713
+ if "layout" in kwargs:
714
+ constructor_kwargs["layout"] = kwargs["layout"]
715
+ elif hasattr(self, "layout"):
716
+ constructor_kwargs["layout"] = self.layout
717
+
718
+ if "num_samples" in kwargs:
719
+ constructor_kwargs["num_samples"] = kwargs["num_samples"]
720
+ elif hasattr(self, "num_samples"):
721
+ constructor_kwargs["num_samples"] = self.num_samples
722
+
723
+ if "chat_template_name" in kwargs:
724
+ constructor_kwargs["chat_template_name"] = kwargs["chat_template_name"]
725
+ elif hasattr(self, "chat_template_name"):
726
+ constructor_kwargs["chat_template_name"] = self.chat_template_name
727
+
728
+ if "chat_template" in kwargs:
729
+ constructor_kwargs["chat_template"] = kwargs["chat_template"]
730
+ elif hasattr(self, "chat_template"):
731
+ constructor_kwargs["chat_template"] = self.chat_template
732
+
733
+ if "history_key" in kwargs:
734
+ constructor_kwargs["history_key"] = kwargs["history_key"]
735
+ elif hasattr(self, "history_key"):
736
+ constructor_kwargs["history_key"] = self.history_key
737
+
738
+ if "text_key" in kwargs:
739
+ constructor_kwargs["text_key"] = kwargs["text_key"]
740
+ elif hasattr(self, "text_key"):
741
+ constructor_kwargs["text_key"] = self.text_key
742
+
743
+ if "tokens_key" in kwargs:
744
+ constructor_kwargs["tokens_key"] = kwargs["tokens_key"]
745
+ elif hasattr(self, "tokens_key"):
746
+ constructor_kwargs["tokens_key"] = self.tokens_key
747
+
748
+ if "masks_key" in kwargs:
749
+ constructor_kwargs["masks_key"] = kwargs["masks_key"]
750
+ elif hasattr(self, "masks_key"):
751
+ constructor_kwargs["masks_key"] = self.masks_key
752
+
753
+ if "log_probs_key" in kwargs:
754
+ constructor_kwargs["log_probs_key"] = kwargs["log_probs_key"]
755
+ elif hasattr(self, "log_probs_key"):
756
+ constructor_kwargs["log_probs_key"] = self.log_probs_key
757
+
758
+ # Create and return new instance
759
+ return type(self)(**constructor_kwargs)
760
+
761
+ def set_tokenizer(self, tokenizer):
762
+ """Set the tokenizer for the wrapper. Useful for async engines where tokenizer retrieval is deferred."""
763
+ self.tokenizer = tokenizer
764
+ if self.tokenizer is not None and (
765
+ not hasattr(self.tokenizer, "pad_token") or self.tokenizer.pad_token is None
766
+ ):
767
+ self.tokenizer.pad_token = self.tokenizer.eos_token
768
+ if self.tokenizer is not None:
769
+ padding_value = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0]
770
+ else:
771
+ padding_value = None
772
+ self.padding_value = padding_value
773
+
774
+ def _extract_model_name(self, model) -> str | None:
775
+ """Extract model name from different model types for tokenizer fallback."""
776
+ try:
777
+ # For AsyncVLLM, try to get the model name from engine_args
778
+ if hasattr(model, "engine_args") and hasattr(model.engine_args, "model"):
779
+ return model.engine_args.model
780
+
781
+ # For vllm.LLM, try to get the model name
782
+ elif hasattr(model, "llm_engine") and hasattr(
783
+ model.llm_engine, "model_config"
784
+ ):
785
+ return getattr(model.llm_engine.model_config, "model", None)
786
+
787
+ # For Ray actors, try to get model name via remote call
788
+ elif hasattr(model, "remote") and hasattr(model, "get_model_name"):
789
+ import ray
790
+
791
+ try:
792
+ return ray.get(model.get_model_name.remote())
793
+ except Exception:
794
+ pass
795
+
796
+ # Try common attributes that might contain model name
797
+ for attr in ["model_name", "model", "model_path", "_model_name"]:
798
+ if hasattr(model, attr):
799
+ value = getattr(model, attr)
800
+ if isinstance(value, str):
801
+ return value
802
+
803
+ return None
804
+ except Exception:
805
+ return None
806
+
807
+ def _call_generate(self, *args, **kwargs):
808
+ """Call generate method based on model type.
809
+
810
+ In vLLM 0.14+, prompt_token_ids should be passed as TokensPrompt objects
811
+ rather than as a keyword argument.
812
+ """
813
+ # Convert prompt_token_ids to TokensPrompt format for vLLM 0.14+ compatibility
814
+ prompt_token_ids = kwargs.pop("prompt_token_ids", None)
815
+ if prompt_token_ids is not None and TokensPrompt is not None:
816
+ # Convert list of token ID lists to TokensPrompt objects
817
+ if isinstance(prompt_token_ids, list) and len(prompt_token_ids) > 0:
818
+ if isinstance(prompt_token_ids[0], list):
819
+ # List of token ID lists -> list of TokensPrompt
820
+ prompts = [
821
+ TokensPrompt(prompt_token_ids=tids) for tids in prompt_token_ids
822
+ ]
823
+ else:
824
+ # Single token ID list -> single TokensPrompt
825
+ prompts = TokensPrompt(prompt_token_ids=prompt_token_ids)
826
+ # Insert prompts as the first positional argument
827
+ args = (prompts,) + args
828
+ elif prompt_token_ids is not None:
829
+ # Fallback for older vLLM versions that still support prompt_token_ids kwarg
830
+ kwargs["prompt_token_ids"] = prompt_token_ids
831
+
832
+ if self._model_type == "ray_actor":
833
+ import ray
834
+
835
+ return ray.get(self.model.generate.remote(*args, **kwargs))
836
+ else:
837
+ # Both sync_vllm and async_vllm have direct generate methods
838
+ return self.model.generate(*args, **kwargs)
839
+
840
+ @set_list_to_stack(True)
841
+ @_batching
842
+ def forward(
843
+ self,
844
+ tensordict: TensorDictBase,
845
+ *,
846
+ tensordict_out: TensorDictBase | None = None,
847
+ logits_only: bool = False,
848
+ **kwargs,
849
+ ) -> TensorDictBase:
850
+ tensordict_orig = tensordict
851
+ if not tensordict.ndim:
852
+ if tensordict_out is not None:
853
+ raise ValueError(
854
+ "tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
855
+ "please submit an issue on github."
856
+ )
857
+ # unsqueeze - squeeze the input
858
+ return self.forward(lazy_stack([tensordict]), logits_only=logits_only)[0]
859
+ elif tensordict.ndim > 1:
860
+ if tensordict_out is not None:
861
+ raise ValueError(
862
+ "tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
863
+ "please submit an issue on github."
864
+ )
865
+ return self.forward(tensordict.reshape(-1), logits_only=logits_only).view(
866
+ tensordict.shape
867
+ )
868
+
869
+ if not isinstance(tensordict, LazyStackedTensorDict):
870
+ tensordict = tensordict.to_lazystack(0)
871
+
872
+ _source_device = None
873
+ if self._device:
874
+ _source_device = tensordict.device
875
+ if tensordict.device:
876
+ tensordict = tensordict.copy().clear_device_()
877
+
878
+ if kwargs:
879
+ from vllm import SamplingParams
880
+
881
+ sampling_params = SamplingParams(**kwargs)
882
+ else:
883
+ sampling_params = self.sampling_params
884
+
885
+ if self.num_samples is not None:
886
+ out = (
887
+ TensorDict(
888
+ device=tensordict.device,
889
+ batch_size=(
890
+ tensordict.batch_size[0],
891
+ self.num_samples,
892
+ *tensordict.batch_size[1:],
893
+ ),
894
+ )
895
+ .to_lazystack(1)
896
+ .to_lazystack(0)
897
+ )
898
+ else:
899
+ out = TensorDict(
900
+ device=tensordict.device, batch_size=tensordict.batch_size
901
+ ).to_lazystack(0)
902
+
903
+ if self.input_mode == "history":
904
+ if self.generate:
905
+ out = self._from_vllm_generate_history(tensordict, sampling_params, out)
906
+ else:
907
+ out = self._from_vllm_logprobs_history(tensordict, sampling_params, out)
908
+ elif self.input_mode == "text":
909
+ if self.generate:
910
+ out = self._from_vllm_generate_text(tensordict, sampling_params, out)
911
+ else:
912
+ out = self._from_vllm_logprobs_text(tensordict, sampling_params, out)
913
+ elif self.input_mode == "tokens":
914
+ if self.generate:
915
+ out = self._from_vllm_generate_tokens(tensordict, sampling_params, out)
916
+ else:
917
+ out = self._from_vllm_logprobs_tokens(tensordict, sampling_params, out)
918
+
919
+ if _source_device:
920
+ out = out.to(_source_device)
921
+
922
+ if tensordict_out is None:
923
+ if self.inplace is True:
924
+ # The output is the input
925
+ tensordict_out = tensordict_orig
926
+ elif self.inplace is False:
927
+ # The output is the new structure
928
+ tensordict_out = out
929
+ elif self.inplace == "empty":
930
+ # The output is empty
931
+ tensordict_out = tensordict.empty()
932
+
933
+ if tensordict_out is not None and tensordict_out is not out:
934
+ result = tensordict_out.exclude(*self.out_keys, inplace=True)
935
+ result.update(out, keys_to_update=self.out_keys)
936
+ elif tensordict_out is out:
937
+ result = out.select(*self.out_keys)
938
+ elif self.inplace:
939
+ result = out
940
+ keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
941
+ result = tensordict.exclude(*self.out_keys, inplace=True).update(
942
+ result, keys_to_update=keys
943
+ )
944
+ else:
945
+ result = out
946
+ return result
947
+
948
+ def _from_vllm_generate_history(
949
+ self,
950
+ tensordict_input: TensorDictBase,
951
+ sampling_params: Any,
952
+ out: TensorDictBase,
953
+ ) -> TensorDictBase:
954
+ """Generate text from history input."""
955
+ from torchrl.data.llm import History
956
+
957
+ assert isinstance(
958
+ tensordict_input, TensorDictBase
959
+ ), f"tensordict_input must be TensorDictBase, got {type(tensordict_input)}"
960
+ assert isinstance(
961
+ sampling_params, SamplingParams
962
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
963
+ assert isinstance(
964
+ out, TensorDictBase
965
+ ), f"out must be TensorDictBase, got {type(out)}"
966
+
967
+ # Validate input
968
+ if self.input_key not in tensordict_input:
969
+ raise ValueError(
970
+ f"Expected '{self.input_key}' key for history input mode, "
971
+ f"but found keys: {list(tensordict_input.keys())}"
972
+ )
973
+
974
+ history = tensordict_input.get(self.input_key)
975
+ if not isinstance(history, History):
976
+ raise TypeError(
977
+ f"Expected History object for '{self.input_key}', got {type(history)}"
978
+ )
979
+
980
+ # Apply chat template
981
+ tokenizer_kwargs = {}
982
+ if self.chat_template_name is not None:
983
+ tokenizer_kwargs.setdefault("chat_template_name", self.chat_template_name)
984
+ if self.chat_template is not None:
985
+ tokenizer_kwargs.setdefault("chat_template", self.chat_template)
986
+ tokenizer_kwargs.setdefault("add_generation_prompt", True)
987
+ text_prompt = history.apply_chat_template(
988
+ tokenizer=self.tokenizer, **tokenizer_kwargs
989
+ )
990
+
991
+ tokenizer_kwargs.setdefault("return_assistant_tokens_mask", False)
992
+ tokenizer_kwargs.setdefault("tokenize", True)
993
+ tokenizer_kwargs.setdefault("padding", False)
994
+ tokenizer_kwargs.setdefault("return_dict", True)
995
+ response_struct = history.apply_chat_template(
996
+ tokenizer=self.tokenizer, **tokenizer_kwargs
997
+ )
998
+ tokens_prompt_padded = None
999
+ tokens_prompt_unpadded = None
1000
+ if self.pad_output:
1001
+ tokens_prompt_padded = response_struct.get(
1002
+ "input_ids",
1003
+ as_padded_tensor=True,
1004
+ padding_value=self.padding_value,
1005
+ padding_side="left",
1006
+ )
1007
+ else:
1008
+ tokens_prompt_unpadded = response_struct.get("input_ids", as_list=True)
1009
+
1010
+ result = self._generate_from_tokens(
1011
+ tokens_prompt_padded=tokens_prompt_padded,
1012
+ tokens_prompt_unpadded=tokens_prompt_unpadded,
1013
+ sampling_params=sampling_params,
1014
+ out=out,
1015
+ )
1016
+
1017
+ # Generate using text path
1018
+ if self.pad_output:
1019
+ result[(self.tokens_key, "prompt")] = (
1020
+ tokens_prompt_padded
1021
+ if not self.num_samples
1022
+ else tokens_prompt_padded.unsqueeze(1).repeat(1, self.num_samples, 1)
1023
+ )
1024
+ else:
1025
+ tokens_prompt_nested = torch.nested.as_nested_tensor(tokens_prompt_unpadded)
1026
+ if not self.num_samples:
1027
+ result[(self.tokens_key, "prompt")] = tokens_prompt_nested
1028
+ else:
1029
+ for r in result.unbind(1):
1030
+ r[(self.tokens_key, "prompt")] = tokens_prompt_nested
1031
+
1032
+ text_result = Text._from_tensordict(result.empty())
1033
+ result.set(self.text_key, text_result)
1034
+ if not self.num_samples:
1035
+ text_result.prompt = text_prompt
1036
+ else:
1037
+ for r in result.unbind(1):
1038
+ r[self.text_key, "prompt"] = text_prompt
1039
+ with result.view(-1) as result_flat:
1040
+ if self.pad_output:
1041
+ tokens_full_padded = result_flat.get(
1042
+ (self.tokens_key, "full"),
1043
+ as_padded_tensor=True,
1044
+ padding_side="right",
1045
+ padding_value=self.padding_value,
1046
+ )
1047
+ if tokens_full_padded is None:
1048
+ raise ValueError("tokens_full_padded is None")
1049
+ text_full = self.tokenizer.batch_decode(
1050
+ tokens_full_padded, skip_special_tokens=False
1051
+ )
1052
+ else:
1053
+ tokens_full_unpadded = result_flat.get(
1054
+ (self.tokens_key, "full"), as_list=True
1055
+ )
1056
+ # print("shapes of assistant masks", [t.shape for t in result_flat.get(("masks", "all_assistant_mask"), as_list=True)])
1057
+ if tokens_full_unpadded is None:
1058
+ raise ValueError("tokens_full_unpadded is None")
1059
+ text_full = self.tokenizer.batch_decode(
1060
+ tokens_full_unpadded, skip_special_tokens=False
1061
+ )
1062
+ text_prompt = result_flat[self.text_key, "prompt"]
1063
+ text_response = [
1064
+ txt[len(prompt) :]
1065
+ for txt, prompt in _zip_strict(text_full, text_prompt)
1066
+ ]
1067
+ result_flat.set((self.text_key, "full"), text_full)
1068
+ result_flat.set((self.text_key, "response"), text_response)
1069
+
1070
+ # Now parse the full text back to a history object, and use the extra history objects
1071
+ # as response
1072
+ history_chat = ChatHistory._from_tensordict(result.empty())
1073
+ if self.num_samples is None:
1074
+ history_chat.prompt = history
1075
+ else:
1076
+ for h in history_chat.unbind(1):
1077
+ h.prompt = history
1078
+ with history_chat.view(-1) as history_chat_flat:
1079
+ prompt_histories = history_chat_flat.prompt
1080
+ # Extract response histories from full text
1081
+ h_responses = _extract_responses_from_full_histories(
1082
+ text_full, prompt_histories, self.chat_template_name, self.tokenizer
1083
+ )
1084
+ history_chat_flat.response = h_responses
1085
+ history_chat_flat.full = history_chat_flat.prompt.extend(
1086
+ h_responses, inplace=False, dim=-1
1087
+ )
1088
+ result.set(self.history_key, history_chat)
1089
+ return result
1090
+
1091
+ def _from_vllm_logprobs_history(
1092
+ self,
1093
+ tensordict_input: TensorDictBase,
1094
+ sampling_params: Any,
1095
+ out: TensorDictBase,
1096
+ ) -> TensorDictBase:
1097
+ """Compute log-probs from history input."""
1098
+ assert isinstance(
1099
+ tensordict_input, TensorDictBase
1100
+ ), f"tensordict_input must be TensorDictBase, got {type(tensordict_input)}"
1101
+ assert isinstance(
1102
+ sampling_params, SamplingParams
1103
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
1104
+ assert isinstance(
1105
+ out, TensorDictBase
1106
+ ), f"out must be TensorDictBase, got {type(out)}"
1107
+
1108
+ from torchrl.data.llm import History
1109
+
1110
+ # Validate input
1111
+ if self.input_key not in tensordict_input:
1112
+ raise ValueError(
1113
+ f"Expected '{self.input_key}' key for history input mode, "
1114
+ f"but found keys: {list(tensordict_input.keys())}"
1115
+ )
1116
+
1117
+ history = tensordict_input.get(self.input_key)
1118
+ if not isinstance(history, History):
1119
+ raise TypeError(
1120
+ f"Expected History object for '{self.input_key}', got {type(history)}"
1121
+ )
1122
+
1123
+ # Apply chat template
1124
+ tokenizer_kwargs = {}
1125
+ if self.chat_template_name is not None:
1126
+ tokenizer_kwargs.setdefault("chat_template_name", self.chat_template_name)
1127
+ if self.chat_template is not None:
1128
+ tokenizer_kwargs.setdefault("chat_template", self.chat_template)
1129
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
1130
+ text_full = history.apply_chat_template(
1131
+ tokenizer=self.tokenizer, **tokenizer_kwargs
1132
+ )
1133
+ tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
1134
+ tokenizer_kwargs.setdefault("tokenize", True)
1135
+ tokenizer_kwargs.setdefault("padding", False)
1136
+ tokenizer_kwargs.setdefault("return_dict", True)
1137
+ response_struct = history.apply_chat_template(
1138
+ tokenizer=self.tokenizer, **tokenizer_kwargs
1139
+ )
1140
+
1141
+ result = self._logprobs_from_tokens(
1142
+ response_struct=response_struct, sampling_params=sampling_params, out=out
1143
+ )
1144
+ text_result = Text._from_tensordict(result.empty())
1145
+ result.set(self.text_key, text_result)
1146
+ result[self.text_key, "full"] = text_full
1147
+ result.set(self.history_key, ChatHistory(full=history))
1148
+ return result
1149
+
1150
+ def _from_vllm_generate_text(
1151
+ self, td: TensorDictBase, sampling_params: Any, out: TensorDictBase
1152
+ ) -> TensorDictBase:
1153
+ """Generate text from text input."""
1154
+ # Type assertions
1155
+ assert isinstance(
1156
+ td, TensorDictBase
1157
+ ), f"td must be TensorDictBase, got {type(td)}"
1158
+ assert isinstance(
1159
+ sampling_params, SamplingParams
1160
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
1161
+ assert isinstance(
1162
+ out, TensorDictBase
1163
+ ), f"out must be TensorDictBase, got {type(out)}"
1164
+
1165
+ # Validate input
1166
+ if self.input_key not in td:
1167
+ raise ValueError(
1168
+ f"Expected '{self.input_key}' key for text input mode, "
1169
+ f"but found keys: {list(td.keys())}"
1170
+ )
1171
+
1172
+ text = td.get(self.input_key)
1173
+ if text is None:
1174
+ raise ValueError(f"Expected '{self.input_key}' key for text input mode")
1175
+
1176
+ return self._generate_from_text(text, sampling_params, out)
1177
+
1178
+ def _from_vllm_logprobs_text(
1179
+ self, td: TensorDictBase, sampling_params: Any, out: TensorDictBase
1180
+ ) -> TensorDictBase:
1181
+ """Compute log-probs from text input."""
1182
+ # Type assertions
1183
+ assert isinstance(
1184
+ td, TensorDictBase
1185
+ ), f"td must be TensorDictBase, got {type(td)}"
1186
+ assert isinstance(
1187
+ sampling_params, SamplingParams
1188
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
1189
+ assert isinstance(
1190
+ out, TensorDictBase
1191
+ ), f"out must be TensorDictBase, got {type(out)}"
1192
+
1193
+ # Validate input
1194
+ if self.input_key not in td:
1195
+ raise ValueError(
1196
+ f"Expected '{self.input_key}' key for text input mode, "
1197
+ f"but found keys: {list(td.keys())}"
1198
+ )
1199
+
1200
+ text = td.get(self.input_key)
1201
+ if text is None:
1202
+ raise ValueError(f"Expected '{self.input_key}' key for text input mode")
1203
+
1204
+ return self._logprobs_from_text(text, sampling_params, out)
1205
+
1206
+ def _from_vllm_generate_tokens(
1207
+ self, td: TensorDictBase, sampling_params: Any, out: TensorDictBase
1208
+ ) -> TensorDictBase:
1209
+ """Generate text from tokens input."""
1210
+ # Type assertions
1211
+ assert isinstance(
1212
+ td, TensorDictBase
1213
+ ), f"td must be TensorDictBase, got {type(td)}"
1214
+ assert isinstance(
1215
+ sampling_params, SamplingParams
1216
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
1217
+ assert isinstance(
1218
+ out, TensorDictBase
1219
+ ), f"out must be TensorDictBase, got {type(out)}"
1220
+
1221
+ # Validate input
1222
+ if self.input_key not in td:
1223
+ raise ValueError(
1224
+ f"Expected '{self.input_key}' key for tokens input mode, "
1225
+ f"but found keys: {list(td.keys())}"
1226
+ )
1227
+
1228
+ tokens_prompt_padded = None
1229
+ tokens_prompt_unpadded = None
1230
+ if self.pad_output:
1231
+ tokens_prompt_padded = td.get(self.input_key)
1232
+ else:
1233
+ tokens_prompt_unpadded = list(td.get(self.input_key, as_list=True))
1234
+ # make sure we remove the padding tokens
1235
+ tokens_prompt_unpadded = [
1236
+ tokens[tokens != self.padding_value]
1237
+ for tokens in tokens_prompt_unpadded
1238
+ ]
1239
+
1240
+ return self._generate_from_tokens(
1241
+ tokens_prompt_unpadded=tokens_prompt_unpadded,
1242
+ tokens_prompt_padded=tokens_prompt_padded,
1243
+ sampling_params=sampling_params,
1244
+ out=out,
1245
+ )
1246
+
1247
+ def _from_vllm_logprobs_tokens(
1248
+ self, td: TensorDictBase, sampling_params: Any, out: TensorDictBase
1249
+ ) -> TensorDictBase:
1250
+ """Compute log-probs from tokens input."""
1251
+ # Type assertions
1252
+ assert isinstance(
1253
+ td, TensorDictBase
1254
+ ), f"td must be TensorDictBase, got {type(td)}"
1255
+ assert isinstance(
1256
+ sampling_params, SamplingParams
1257
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
1258
+ assert isinstance(
1259
+ out, TensorDictBase
1260
+ ), f"out must be TensorDictBase, got {type(out)}"
1261
+
1262
+ # Validate input
1263
+ if self.input_key not in td:
1264
+ raise ValueError(
1265
+ f"Expected '{self.input_key}' key for tokens input mode, "
1266
+ f"but found keys: {list(td.keys())}"
1267
+ )
1268
+
1269
+ tokens_full_padded = None
1270
+ tokens_full_unpadded = None
1271
+ if self.pad_output:
1272
+ tokens_full_padded = td.get(self.input_key)
1273
+ else:
1274
+ tokens_full_unpadded = list(td.get(self.input_key, as_list=True))
1275
+ # make sure we remove the padding tokens
1276
+ tokens_full_unpadded = [
1277
+ tokens[tokens != self.padding_value] for tokens in tokens_full_unpadded
1278
+ ]
1279
+
1280
+ return self._logprobs_from_tokens(
1281
+ response_struct=None,
1282
+ tokens_full_unpadded=tokens_full_unpadded,
1283
+ tokens_full_padded=tokens_full_padded,
1284
+ sampling_params=sampling_params,
1285
+ out=out,
1286
+ )
1287
+
1288
+ def _cat_text(
1289
+ self, text: str | list[str], response_text: str | list[str] | None
1290
+ ) -> str | list[str]:
1291
+ """Concatenate text and response text."""
1292
+ assert isinstance(
1293
+ text, (str, list)
1294
+ ), f"text must be str or list, got {type(text)}"
1295
+
1296
+ # Handle None response_text (when tokenizer is not available)
1297
+ if response_text is None:
1298
+ raise RuntimeError(
1299
+ "response_text is None, likely due to missing tokenizer. "
1300
+ "Cannot decode vLLM response without a tokenizer. "
1301
+ "Please provide a tokenizer explicitly or ensure the model has one available."
1302
+ )
1303
+
1304
+ assert isinstance(
1305
+ response_text, (str, list)
1306
+ ), f"response_text must be str or list, got {type(response_text)}"
1307
+
1308
+ if isinstance(text, list):
1309
+ return [self._cat_text(t, t_) for t, t_ in _zip_strict(text, response_text)]
1310
+ else:
1311
+ return text + response_text
1312
+
1313
+ def _generate_from_text(
1314
+ self,
1315
+ text: str | list[str] | NonTensorStack,
1316
+ sampling_params: Any,
1317
+ out: TensorDictBase,
1318
+ ) -> TensorDictBase:
1319
+ """Generate text from text input."""
1320
+ # Convert text to list format
1321
+ if isinstance(text, str):
1322
+ text = [text]
1323
+ elif not isinstance(text, list):
1324
+ text = text.tolist()
1325
+
1326
+ assert isinstance(
1327
+ text, (str, list)
1328
+ ), f"text must be str or list, got {type(text)}"
1329
+ assert isinstance(
1330
+ sampling_params, SamplingParams
1331
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
1332
+ assert isinstance(
1333
+ out, TensorDictBase
1334
+ ), f"out must be TensorDictBase, got {type(out)}"
1335
+
1336
+ generate_kwargs = {"sampling_params": sampling_params}
1337
+ args = ()
1338
+
1339
+ # Convert text to list format
1340
+ if isinstance(text, str):
1341
+ text = [text]
1342
+ elif not isinstance(text, list):
1343
+ text = text.tolist()
1344
+
1345
+ # Call generate based on model type
1346
+ request_output = self._call_generate(text, *args, **generate_kwargs)
1347
+
1348
+ request_output_tc = _RequestOutput_tc.from_request_output(request_output)
1349
+
1350
+ # Extract response tokens and text
1351
+ outputs = (
1352
+ request_output_tc.outputs.view(-1)
1353
+ if self.num_samples is not None
1354
+ else request_output_tc.outputs
1355
+ )
1356
+ if self.pad_output:
1357
+ response_tokens_padded = outputs.view(-1).get(
1358
+ "token_ids",
1359
+ as_padded_tensor=self.pad_output,
1360
+ padding_value=self.padding_value,
1361
+ padding_side="right",
1362
+ )
1363
+ response_tokens_list = outputs.view(-1).get(
1364
+ "token_ids",
1365
+ as_list=True,
1366
+ )
1367
+ self._check_not_padded(response_tokens_list)
1368
+ if self.tokenizer is not None:
1369
+ response_text = self.tokenizer.batch_decode(
1370
+ response_tokens_list, skip_special_tokens=False
1371
+ )
1372
+ else:
1373
+ response_text = None
1374
+
1375
+ # Build output TensorClass objects
1376
+
1377
+ masks_obj = Masks._from_tensordict(out.empty())
1378
+ masks_obj.all_attention_mask = None
1379
+ masks_obj.all_assistant_mask = None
1380
+ masks_obj.padded = MetaData(self.pad_output)
1381
+ out.set(self.masks_key, masks_obj)
1382
+
1383
+ if self.num_samples is not None:
1384
+ text = [txt for txt in text for _ in range(self.num_samples)]
1385
+ text_obj = Text._from_tensordict(out.empty())
1386
+ with text_obj.view(-1) as text_obj_flat:
1387
+ text_obj_flat.prompt = text
1388
+ text_obj_flat.response = response_text
1389
+ text_obj_flat.full = self._cat_text(text, response_text)
1390
+ out.set(self.text_key, text_obj)
1391
+
1392
+ tokens_obj = Tokens._from_tensordict(out.empty())
1393
+ with tokens_obj.view(-1) as tokens_obj_flat:
1394
+ tokens_obj_flat.prompt = None # We don't have prompt tokens in this path
1395
+ if self.pad_output:
1396
+ tokens_obj_flat.response = response_tokens_padded
1397
+ self._check_padded(response_tokens_padded)
1398
+ else:
1399
+ tokens_obj_flat.response = response_tokens_list
1400
+ self._check_not_padded(response_tokens_list)
1401
+ tokens_obj_flat.full = (
1402
+ None # we don't have prompt tokens in this path so no all_tokens either
1403
+ )
1404
+ tokens_obj.padded = MetaData(self.pad_output)
1405
+ out.set(self.tokens_key, tokens_obj)
1406
+
1407
+ if self.return_log_probs:
1408
+ log_probs_obj = LogProbs._from_tensordict(out.empty())
1409
+ with log_probs_obj.view(-1) as log_probs_obj_flat:
1410
+ if self.pad_output:
1411
+ log_probs_padded = outputs.get(
1412
+ "logprobs",
1413
+ as_padded_tensor=self.pad_output,
1414
+ padding_value=self.padding_value,
1415
+ padding_side="right",
1416
+ )
1417
+ self._check_padded(log_probs_padded)
1418
+ log_probs_obj_flat.response = log_probs_padded
1419
+ log_probs_obj_flat.full = log_probs_padded
1420
+ else:
1421
+ log_probs_list = outputs.get(
1422
+ "logprobs",
1423
+ as_list=True,
1424
+ )
1425
+ self._check_not_padded(log_probs_list)
1426
+ log_probs_obj_flat.response = log_probs_list
1427
+ log_probs_obj_flat.full = log_probs_list
1428
+ log_probs_obj_flat.prompt = None
1429
+ log_probs_obj.padded = MetaData(self.pad_output)
1430
+ out.set(self.log_probs_key, log_probs_obj)
1431
+
1432
+ return out
1433
+
1434
+ def _logprobs_from_text(
1435
+ self,
1436
+ text: str | list[str] | NonTensorStack,
1437
+ sampling_params: Any,
1438
+ out: TensorDictBase,
1439
+ ) -> TensorDictBase:
1440
+ """Compute log-probs from text input."""
1441
+ # Convert text to list format
1442
+ if isinstance(text, str):
1443
+ text = [text]
1444
+ elif not isinstance(text, list):
1445
+ text = text.tolist()
1446
+
1447
+ assert isinstance(
1448
+ text, (str, list)
1449
+ ), f"text must be str or list, got {type(text)}"
1450
+ assert isinstance(
1451
+ sampling_params, SamplingParams
1452
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
1453
+ assert isinstance(
1454
+ out, TensorDictBase
1455
+ ), f"out must be TensorDictBase, got {type(out)}"
1456
+
1457
+ # Tokenize the text
1458
+ if self.tokenizer is None:
1459
+ raise ValueError(
1460
+ "Tokenizer is required for log-probs computation with text input"
1461
+ )
1462
+
1463
+ # Tokenize the text
1464
+ tokenized_output = self.tokenizer(text, **self.tokenizer_kwargs)
1465
+ if self.pad_output:
1466
+ tokens_full_padded = tokenized_output["input_ids"]
1467
+ attention_mask_full_padded = tokenized_output["attention_mask"]
1468
+ tokens_full_list = self._to_list(
1469
+ tokens_full_padded, attention_mask_full_padded
1470
+ )
1471
+ else:
1472
+ tokens_full_unpadded = tokenized_output["input_ids"]
1473
+ tokens_full_list = self._to_list(tokens_full_unpadded, None)
1474
+ attention_mask_full_unpadded = tokenized_output["attention_mask"]
1475
+ attention_mask_full_unpadded = [
1476
+ am.bool()
1477
+ if isinstance(am, torch.Tensor)
1478
+ else torch.tensor(am, dtype=torch.bool)
1479
+ for am in attention_mask_full_unpadded
1480
+ ]
1481
+
1482
+ # Convert to list format for vLLM
1483
+ generate_kwargs = {
1484
+ "sampling_params": sampling_params,
1485
+ "prompt_token_ids": tokens_full_list,
1486
+ }
1487
+
1488
+ # Generate with vLLM to get prompt_logprobs
1489
+ request_output = self._call_generate(**generate_kwargs)
1490
+
1491
+ request_output_tc = _RequestOutput_tc.from_request_output(request_output)
1492
+
1493
+ # Extract log-probs from prompt_logprobs
1494
+ if self.pad_output:
1495
+ # For padded case, use all prompt_logprobs
1496
+ log_probs_full_padded = request_output_tc.get(
1497
+ "prompt_logprobs",
1498
+ as_padded_tensor=True,
1499
+ padding_value=0,
1500
+ padding_side="left",
1501
+ )
1502
+
1503
+ # Mask out padding
1504
+ attention_mask_full_padded = tokens_full_padded != self.padding_value
1505
+ log_probs_full_padded = torch.where(
1506
+ attention_mask_full_padded, log_probs_full_padded, 0.0
1507
+ )
1508
+ else:
1509
+ # For unpadded case, extract from each sequence
1510
+ log_probs_full_unpadded = request_output_tc.get(
1511
+ "prompt_logprobs", as_list=True
1512
+ )
1513
+ self._check_not_padded(log_probs_full_unpadded)
1514
+
1515
+ masks_obj = Masks._from_tensordict(
1516
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1517
+ )
1518
+ if self.pad_output:
1519
+ self._check_padded(attention_mask_full_padded)
1520
+ masks_obj.all_attention_mask = attention_mask_full_padded.bool()
1521
+ else:
1522
+ self._check_not_padded(attention_mask_full_unpadded)
1523
+ masks_obj.all_attention_mask = attention_mask_full_unpadded
1524
+ masks_obj.padded = MetaData(self.pad_output)
1525
+ out.set(self.masks_key, masks_obj)
1526
+
1527
+ # Build output TensorClass objects
1528
+ text_obj = Text._from_tensordict(
1529
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1530
+ )
1531
+ text_obj.prompt = None
1532
+ text_obj.response = None
1533
+ text_obj.full = text
1534
+ out.set(self.text_key, text_obj)
1535
+
1536
+ tokens_obj = Tokens._from_tensordict(
1537
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1538
+ )
1539
+ if self.pad_output:
1540
+ self._check_padded(tokens_full_padded)
1541
+ tokens_obj.full = tokens_full_padded
1542
+ else:
1543
+ tokens_obj.full = tokens_full_unpadded
1544
+ tokens_obj.response = None
1545
+ tokens_obj.padded = MetaData(self.pad_output)
1546
+ out.set(self.tokens_key, tokens_obj)
1547
+
1548
+ if self.return_log_probs:
1549
+ log_probs_obj = LogProbs._from_tensordict(
1550
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1551
+ )
1552
+ if self.pad_output:
1553
+ self._check_padded(log_probs_full_padded)
1554
+ log_probs_obj.full = log_probs_full_padded
1555
+ else:
1556
+ self._check_not_padded(log_probs_full_unpadded)
1557
+ log_probs_obj.full = log_probs_full_unpadded
1558
+ log_probs_obj.response = None
1559
+ log_probs_obj.padded = MetaData(self.pad_output)
1560
+ out.set(self.log_probs_key, log_probs_obj)
1561
+
1562
+ return out
1563
+
1564
+ def _cat_tensors(
1565
+ self,
1566
+ tokens: list[torch.Tensor] | torch.Tensor,
1567
+ response_tokens: list[torch.Tensor] | torch.Tensor,
1568
+ ) -> list[torch.Tensor] | torch.Tensor:
1569
+ """Concatenate tokens and response tokens."""
1570
+ if isinstance(tokens, list) or isinstance(response_tokens, list):
1571
+ return [
1572
+ self._cat_tensors(t, t_)
1573
+ for t, t_ in _zip_strict(tokens, response_tokens)
1574
+ ]
1575
+ else:
1576
+ return torch.cat([tokens, response_tokens], dim=-1)
1577
+
1578
+ def _generate_from_tokens(
1579
+ self,
1580
+ tokens_prompt_unpadded: list[torch.Tensor] | None,
1581
+ tokens_prompt_padded: torch.Tensor | None,
1582
+ sampling_params: Any,
1583
+ out: TensorDictBase,
1584
+ ) -> TensorDictBase:
1585
+ """Generate text from tokens input."""
1586
+ assert isinstance(
1587
+ tokens_prompt_padded, (torch.Tensor, type(None))
1588
+ ), f"tokens_prompt_padded must be torch.Tensor or None, got {type(tokens_prompt_padded)}"
1589
+ assert isinstance(
1590
+ tokens_prompt_unpadded, (list, type(None))
1591
+ ), f"tokens_prompt_unpadded must be list or None, got {type(tokens_prompt_unpadded)}"
1592
+ assert isinstance(
1593
+ sampling_params, SamplingParams
1594
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
1595
+ assert isinstance(
1596
+ out, TensorDictBase
1597
+ ), f"out must be TensorDictBase, got {type(out)}"
1598
+
1599
+ generate_kwargs = {"sampling_params": sampling_params}
1600
+ args = ()
1601
+ empirical_attention_mask = None
1602
+
1603
+ if tokens_prompt_unpadded is None:
1604
+ # TODO: To be on the safe side, we may do this even in the unpadded case since we're not sure
1605
+ # the user passed an unpadded tensor in the first place.
1606
+ empirical_attention_mask = tokens_prompt_padded != self.padding_value
1607
+ tokens_prompt_list = self._to_list(
1608
+ tokens_prompt_padded, empirical_attention_mask
1609
+ )
1610
+ else:
1611
+ tokens_prompt_list = self._to_list(tokens_prompt_unpadded, None)
1612
+ generate_kwargs.update({"prompt_token_ids": tokens_prompt_list})
1613
+
1614
+ # Call generate based on model type
1615
+ request_output = self._call_generate(*args, **generate_kwargs)
1616
+
1617
+ request_output_tc = _RequestOutput_tc.from_request_output(request_output)
1618
+
1619
+ # Extract response tokens and text
1620
+ outputs = (
1621
+ request_output_tc.outputs.view(-1)
1622
+ if self.num_samples is not None
1623
+ else request_output_tc.outputs
1624
+ )
1625
+ if self.pad_output:
1626
+ tokens_response_padded = outputs.get(
1627
+ "token_ids",
1628
+ as_padded_tensor=self.pad_output,
1629
+ padding_value=self.padding_value,
1630
+ padding_side="right",
1631
+ )
1632
+ self._check_padded(tokens_response_padded)
1633
+ tokens_response_unpadded = outputs.get(
1634
+ "token_ids",
1635
+ as_list=True,
1636
+ )
1637
+ self._check_not_padded(tokens_response_unpadded)
1638
+
1639
+ tokens_obj = Tokens._from_tensordict(out.empty())
1640
+ if self.pad_output:
1641
+ self._check_padded(tokens_response_padded)
1642
+ self._check_padded(tokens_prompt_padded)
1643
+ else:
1644
+ self._check_not_padded(tokens_response_unpadded)
1645
+ self._check_not_padded(tokens_prompt_unpadded)
1646
+
1647
+ if self.num_samples is not None:
1648
+ # replicate tokens
1649
+ for i in range(self.num_samples):
1650
+ tokens_obj[:, i].prompt = (
1651
+ tokens_prompt_unpadded
1652
+ if not self.pad_output
1653
+ else tokens_prompt_padded
1654
+ )
1655
+ else:
1656
+ tokens_obj.prompt = (
1657
+ tokens_prompt_unpadded if not self.pad_output else tokens_prompt_padded
1658
+ )
1659
+ with tokens_obj.view(-1) as tokens_obj_flat:
1660
+ if self.pad_output:
1661
+ tokens_obj_flat.response = tokens_response_padded
1662
+ tokens_full_padded = self._cat_tensors(
1663
+ tokens_obj_flat.prompt, tokens_response_padded
1664
+ )
1665
+ tokens_obj_flat.full = tokens_full_padded
1666
+ else:
1667
+ tokens_obj_flat.response = tokens_response_unpadded
1668
+ tokens_full_unpadded = self._cat_tensors(
1669
+ tokens_obj_flat.get("prompt", as_list=True),
1670
+ tokens_response_unpadded,
1671
+ )
1672
+ tokens_obj_flat.full = tokens_full_unpadded
1673
+ tokens_obj.padded = MetaData(self.pad_output)
1674
+ out.set(self.tokens_key, tokens_obj)
1675
+
1676
+ masks_obj = Masks._from_tensordict(out.empty())
1677
+ # self.return_tokens must be True
1678
+ if self.pad_output:
1679
+ # Get "real" attention masks
1680
+ full_attention_mask_padded = tokens_obj.get("full") != self.padding_value
1681
+ masks_obj.all_attention_mask = full_attention_mask_padded.bool()
1682
+ else:
1683
+ # Get "real" attention masks
1684
+ # We can use select to avoid batch-size problems
1685
+ _td = torch.ones_like(
1686
+ out.select(("tokens", "full"))
1687
+ .copy()
1688
+ .rename_key_(("tokens", "full"), "all_attention_mask")
1689
+ ).bool()
1690
+ del _td["tokens"]
1691
+ masks_obj.update(_td)
1692
+ masks_obj.all_assistant_mask = None
1693
+ masks_obj.padded = MetaData(self.pad_output)
1694
+ out.set(self.masks_key, masks_obj)
1695
+
1696
+ if self.return_log_probs:
1697
+ if self.pad_output:
1698
+ log_probs_padded = outputs.get(
1699
+ "logprobs",
1700
+ as_padded_tensor=self.pad_output,
1701
+ padding_value=self.padding_value,
1702
+ padding_side="right",
1703
+ )
1704
+ else:
1705
+ log_probs_list = outputs.get(
1706
+ "logprobs",
1707
+ as_list=True,
1708
+ )
1709
+ self._check_not_padded(log_probs_list)
1710
+ if self.num_samples is None:
1711
+ # TODO: this is not correct, we should use the prompt_logprobs
1712
+ # but they're not returned by vLLM
1713
+ if self.pad_output:
1714
+ prompt_logprobs_padded = request_output_tc.get(
1715
+ "prompt_logprobs",
1716
+ as_padded_tensor=self.pad_output,
1717
+ padding_value=self.padding_value,
1718
+ padding_side="right",
1719
+ )
1720
+ if (
1721
+ prompt_logprobs_padded.shape[-1]
1722
+ != tokens_prompt_padded.shape[-1]
1723
+ ):
1724
+ tshape = tokens_prompt_padded.shape
1725
+ oshape = prompt_logprobs_padded.shape
1726
+ # it could be that the input was padded already - padding again then
1727
+ prompt_logprobs_padded = torch.cat(
1728
+ [
1729
+ prompt_logprobs_padded.new_zeros(
1730
+ tshape[:-1] + (tshape[-1] - oshape[-1],)
1731
+ ),
1732
+ prompt_logprobs_padded,
1733
+ ],
1734
+ -1,
1735
+ )
1736
+ else:
1737
+ prompt_logprobs_list = request_output_tc.get(
1738
+ "prompt_logprobs",
1739
+ as_list=True,
1740
+ )
1741
+ self._check_not_padded(prompt_logprobs_list)
1742
+ log_probs_obj = LogProbs._from_tensordict(out.empty())
1743
+ if self.pad_output:
1744
+ self._check_padded(log_probs_padded)
1745
+ if self.num_samples is None:
1746
+ self._check_padded(prompt_logprobs_padded)
1747
+ log_probs_obj.prompt = prompt_logprobs_padded
1748
+ else:
1749
+ self._check_not_padded(log_probs_list)
1750
+ if self.num_samples is None:
1751
+ self._check_not_padded(prompt_logprobs_list)
1752
+ log_probs_obj.prompt = prompt_logprobs_list
1753
+ with log_probs_obj.view(-1) as log_probs_obj_flat:
1754
+ log_probs_obj_flat.response = (
1755
+ log_probs_padded if self.pad_output else log_probs_list
1756
+ )
1757
+ if self.num_samples is None:
1758
+ if self.pad_output:
1759
+ log_probs_obj_flat.full = self._cat_tensors(
1760
+ log_probs_obj_flat.prompt, log_probs_padded
1761
+ )
1762
+ else:
1763
+ log_probs_obj_flat.full = self._cat_tensors(
1764
+ log_probs_obj_flat.get("prompt", as_list=True),
1765
+ log_probs_list,
1766
+ )
1767
+ else:
1768
+ log_probs_obj_flat.full = None
1769
+ log_probs_obj.padded = MetaData(self.pad_output)
1770
+ out.set(self.log_probs_key, log_probs_obj)
1771
+ return out
1772
+
1773
+ def _logprobs_from_tokens(
1774
+ self,
1775
+ *,
1776
+ response_struct: TensorDictBase | None = None,
1777
+ tokens_full_unpadded: list[torch.Tensor] | None = None,
1778
+ tokens_full_padded: torch.Tensor | None = None,
1779
+ sampling_params: Any | None = None,
1780
+ out: TensorDictBase | None = None,
1781
+ ) -> TensorDictBase:
1782
+ """Compute log-probs from tokens input."""
1783
+ assert isinstance(
1784
+ response_struct, (TensorDictBase, type(None))
1785
+ ), f"response_struct must be TensorDictBase or None, got {type(response_struct)}"
1786
+ assert isinstance(
1787
+ tokens_full_unpadded, (list, type(None))
1788
+ ), f"tokens_full_unpadded must be list or None, got {type(tokens_full_unpadded)}"
1789
+ assert isinstance(
1790
+ tokens_full_padded, (torch.Tensor, type(None))
1791
+ ), f"tokens_full_padded must be torch.Tensor or None, got {type(tokens_full_padded)}"
1792
+ assert isinstance(
1793
+ sampling_params, (SamplingParams, type(None))
1794
+ ), f"sampling_params must be SamplingParams or None, got {type(sampling_params)}"
1795
+ assert isinstance(
1796
+ out, (TensorDictBase, type(None))
1797
+ ), f"out must be TensorDictBase or None, got {type(out)}"
1798
+
1799
+ # Convert to list format for vLLM
1800
+ if response_struct is not None:
1801
+ tokens_full_padded = response_struct.get(
1802
+ "input_ids",
1803
+ as_padded_tensor=True,
1804
+ padding_value=self.padding_value,
1805
+ padding_side="left",
1806
+ )
1807
+ attention_mask_full_padded = response_struct.get(
1808
+ "attention_mask",
1809
+ as_padded_tensor=True,
1810
+ padding_value=False,
1811
+ padding_side="left",
1812
+ ).bool()
1813
+ attention_mask_full_unpadded = _unpad_tensors(
1814
+ attention_mask_full_padded, attention_mask_full_padded, as_nested=False
1815
+ )
1816
+ elif tokens_full_unpadded is not None:
1817
+ tokens_full_padded = pad_sequence(
1818
+ tokens_full_unpadded,
1819
+ padding_value=self.padding_value,
1820
+ batch_first=True,
1821
+ padding_side="left",
1822
+ )
1823
+ attention_mask_full_unpadded = [
1824
+ t != self.padding_value for t in tokens_full_unpadded
1825
+ ]
1826
+ attention_mask_full_padded = pad_sequence(
1827
+ attention_mask_full_unpadded,
1828
+ padding_value=False,
1829
+ batch_first=True,
1830
+ padding_side="left",
1831
+ )
1832
+ elif tokens_full_padded is not None:
1833
+ attention_mask_full_padded = tokens_full_padded != self.padding_value
1834
+ else:
1835
+ raise ValueError("Either response_struct or tokens must be provided")
1836
+
1837
+ assert isinstance(tokens_full_padded, torch.Tensor)
1838
+ assert isinstance(attention_mask_full_padded, torch.Tensor)
1839
+ if tokens_full_unpadded is None:
1840
+ tokens_full_list = self._to_list(
1841
+ tokens_full_padded, attention_mask_full_padded
1842
+ )
1843
+ else:
1844
+ tokens_full_list = self._to_list(tokens_full_unpadded, None)
1845
+
1846
+ generate_kwargs = {
1847
+ "sampling_params": sampling_params,
1848
+ "prompt_token_ids": tokens_full_list,
1849
+ }
1850
+
1851
+ # Generate with vLLM to get prompt_logprobs
1852
+ tokens_out_stuct = self._call_generate(**generate_kwargs)
1853
+
1854
+ request_output_tc = _RequestOutput_tc.from_request_output(tokens_out_stuct)
1855
+
1856
+ # For unpadded case, extract from each sequence
1857
+ log_probs_full_unpadded = request_output_tc.get("prompt_logprobs", as_list=True)
1858
+
1859
+ # Extract log-probs from prompt_logprobs
1860
+ if self.pad_output:
1861
+ # For padded case, use all prompt_logprobs
1862
+ if attention_mask_full_padded is not None:
1863
+ attention_mask_full_padded = tokens_full_padded != self.padding_value
1864
+ log_probs_full_padded = torch.zeros_like(
1865
+ tokens_full_padded, dtype=torch.get_default_dtype()
1866
+ )
1867
+ log_probs_full_padded[attention_mask_full_padded] = torch.cat(
1868
+ log_probs_full_unpadded, -1
1869
+ )
1870
+ else:
1871
+ self._check_not_padded(log_probs_full_unpadded)
1872
+
1873
+ assistant_mask_full_padded = None
1874
+ if response_struct is not None:
1875
+ assistant_mask_full_padded = response_struct.get(
1876
+ "assistant_masks",
1877
+ as_padded_tensor=True,
1878
+ padding_side="left",
1879
+ padding_value=0,
1880
+ )
1881
+ if assistant_mask_full_padded is not None:
1882
+ assistant_mask_full_padded = assistant_mask_full_padded.bool()
1883
+ if not self.pad_output:
1884
+ assistant_mask_full_unpadded = _unpad_tensors(
1885
+ assistant_mask_full_padded,
1886
+ attention_mask_full_padded,
1887
+ as_nested=False,
1888
+ )
1889
+ else:
1890
+ assistant_mask_full_unpadded = None
1891
+ else:
1892
+ assistant_mask_full_unpadded = None
1893
+
1894
+ masks_obj = Masks._from_tensordict(
1895
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1896
+ )
1897
+ if self.pad_output:
1898
+ self._check_padded(attention_mask_full_padded)
1899
+ masks_obj.all_attention_mask = attention_mask_full_padded.bool()
1900
+ if assistant_mask_full_padded is not None:
1901
+ masks_obj.all_assistant_mask = assistant_mask_full_padded
1902
+ else:
1903
+ self._check_not_padded(attention_mask_full_unpadded)
1904
+ masks_obj.all_attention_mask = attention_mask_full_unpadded
1905
+ if assistant_mask_full_unpadded is not None:
1906
+ masks_obj.all_assistant_mask = assistant_mask_full_unpadded
1907
+ masks_obj.padded = MetaData(self.pad_output)
1908
+ out.set(self.masks_key, masks_obj)
1909
+
1910
+ tokens_obj = Tokens._from_tensordict(
1911
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1912
+ )
1913
+ if self.pad_output:
1914
+ self._check_padded(tokens_full_padded)
1915
+ tokens_obj.full = tokens_full_padded
1916
+ else:
1917
+ tokens_obj.full = tokens_full_unpadded
1918
+ tokens_obj.response = None
1919
+ tokens_obj.padded = MetaData(self.pad_output)
1920
+ out.set(self.tokens_key, tokens_obj)
1921
+
1922
+ log_probs_obj = LogProbs._from_tensordict(
1923
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1924
+ )
1925
+ if self.pad_output:
1926
+ self._check_padded(log_probs_full_padded)
1927
+ log_probs_obj.full = log_probs_full_padded
1928
+ else:
1929
+ self._check_not_padded(log_probs_full_unpadded)
1930
+ log_probs_obj.full = log_probs_full_unpadded
1931
+ log_probs_obj.response = None
1932
+ log_probs_obj.padded = MetaData(self.pad_output)
1933
+ out.set(self.log_probs_key, log_probs_obj)
1934
+
1935
+ return out
1936
+
1937
+ def _to_list(
1938
+ self,
1939
+ tokens_padded: torch.Tensor | list[torch.Tensor],
1940
+ attention_mask_padded: torch.Tensor | None,
1941
+ ) -> list[list[int]]:
1942
+ """Converts a tensor of integers into a masked list (of lists) of integers."""
1943
+ if isinstance(tokens_padded, torch.Tensor):
1944
+ parent = []
1945
+ queue = collections.deque()
1946
+ if attention_mask_padded is None:
1947
+ attention_mask_padded = torch.ones_like(tokens_padded)
1948
+ queue.append((tokens_padded, attention_mask_padded.bool(), parent))
1949
+ while queue:
1950
+ token_tensor, attention_mask_bool, _parent = queue.popleft()
1951
+ if token_tensor.ndim == 1:
1952
+ _parent.extend(token_tensor[attention_mask_bool].tolist())
1953
+ else:
1954
+ _parent.extend([[] for _ in range(token_tensor.shape[0])])
1955
+ queue.extend(
1956
+ [
1957
+ (t, m, local_parent)
1958
+ for t, m, local_parent in zip(
1959
+ token_tensor, attention_mask_bool, _parent
1960
+ )
1961
+ ]
1962
+ )
1963
+ tokens_list = parent
1964
+ elif isinstance(tokens_padded, list):
1965
+ parent = []
1966
+ queue = collections.deque()
1967
+ queue.append((tokens_padded, parent))
1968
+ while queue:
1969
+ tokens_list, _parent = queue.popleft()
1970
+ if isinstance(tokens_list, list) and isinstance(
1971
+ tokens_list[0], (list, torch.Tensor)
1972
+ ):
1973
+ _parent.extend([[] for _ in tokens_list])
1974
+ queue.extend(
1975
+ [
1976
+ (t, local_parent)
1977
+ for t, local_parent in zip(tokens_list, _parent)
1978
+ ]
1979
+ )
1980
+ continue
1981
+ elif isinstance(tokens_list, torch.Tensor):
1982
+ tokens_list = tokens_list.tolist()
1983
+ _parent.extend(tokens_list)
1984
+ tokens_list = parent
1985
+
1986
+ return tokens_list
1987
+
1988
+ @_classproperty
1989
+ def CompletionOutput_tc(cls):
1990
+ _vllm = _require_vllm()
1991
+
1992
+ if hasattr(cls, "_CompletionOutput_tc"):
1993
+ return cls._CompletionOutput_tc
1994
+ CompletionOutput_tc = from_dataclass(_vllm.outputs.CompletionOutput) # type: ignore
1995
+ cls._CompletionOutput_tc = CompletionOutput_tc
1996
+ return CompletionOutput_tc
1997
+
1998
+ def get_dist(
1999
+ self,
2000
+ tensordict: TensorDictBase,
2001
+ tensordict_out: TensorDictBase | None = None,
2002
+ logits_key: NestedKey = "logits",
2003
+ mask_key: NestedKey | None = None,
2004
+ as_padded_tensor: bool | None = None,
2005
+ as_nested_tensor: bool | None = None,
2006
+ padding_value: float | None = None,
2007
+ padding_side: str = "right",
2008
+ layout: torch.layout | None = None,
2009
+ **kwargs,
2010
+ ) -> D.Distribution:
2011
+ """Get distribution from logits/log-probs with optional masking.
2012
+
2013
+ vLLM does not return logits, so this method is not supported.
2014
+ """
2015
+ raise NotImplementedError(
2016
+ "vLLM does not return logits, so get_dist is not supported"
2017
+ )
2018
+
2019
+ def get_dist_with_prompt_mask(
2020
+ self,
2021
+ tensordict: TensorDictBase,
2022
+ tokens_key: NestedKey = ("tokens", "full"),
2023
+ logits_key: NestedKey = "logits",
2024
+ assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
2025
+ attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
2026
+ **kwargs,
2027
+ ) -> D.Distribution:
2028
+ """Get distribution masked to only include response tokens (exclude prompt).
2029
+
2030
+ vLLM does not return logits, so this method is not supported.
2031
+
2032
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
2033
+ """
2034
+ raise NotImplementedError(
2035
+ "vLLM does not return logits, so get_dist_with_prompt_mask is not supported"
2036
+ )
2037
+
2038
+ def _get_dist_with_assistant_mask(
2039
+ self,
2040
+ tensordict: TensorDictBase,
2041
+ assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
2042
+ logits_key: NestedKey = "logits",
2043
+ **kwargs,
2044
+ ) -> D.Distribution:
2045
+ """Get distribution masked to only include assistant tokens.
2046
+
2047
+ vLLM does not return logits, so this method is not supported.
2048
+
2049
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
2050
+ """
2051
+ raise NotImplementedError(
2052
+ "vLLM does not return logits, so get_dist_with_assistant_mask is not supported"
2053
+ )
2054
+
2055
+ def _get_dist_with_attention_mask(
2056
+ self,
2057
+ tensordict: TensorDictBase,
2058
+ attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
2059
+ logits_key: NestedKey = "logits",
2060
+ **kwargs,
2061
+ ) -> D.Distribution:
2062
+ """Get distribution masked using attention mask.
2063
+
2064
+ vLLM does not return logits, so this method is not supported.
2065
+
2066
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
2067
+ """
2068
+ raise NotImplementedError(
2069
+ "vLLM does not return logits, so get_dist_with_attention_mask is not supported"
2070
+ )
2071
+
2072
+ def _get_dist_with_custom_mask(
2073
+ self,
2074
+ tensordict: TensorDictBase,
2075
+ mask: torch.Tensor,
2076
+ logits_key: NestedKey = "logits",
2077
+ **kwargs,
2078
+ ) -> D.Distribution:
2079
+ """Get distribution with custom mask.
2080
+
2081
+ vLLM does not return logits, so this method is not supported.
2082
+
2083
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
2084
+ """
2085
+ raise NotImplementedError(
2086
+ "vLLM does not return logits, so get_dist_with_custom_mask is not supported"
2087
+ )
2088
+
2089
+ # Convenience methods for common LLM training scenarios
2090
+ def _get_sft_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
2091
+ """Get distribution suitable for SFT loss (response tokens only).
2092
+
2093
+ vLLM does not return logits, so this method is not supported.
2094
+
2095
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
2096
+ """
2097
+ raise NotImplementedError(
2098
+ "vLLM does not return logits, so get_sft_dist is not supported"
2099
+ )
2100
+
2101
+ def _get_rlhf_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
2102
+ """Get distribution suitable for RLHF loss (assistant tokens only).
2103
+
2104
+ vLLM does not return logits, so this method is not supported.
2105
+
2106
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
2107
+ """
2108
+ raise NotImplementedError(
2109
+ "vLLM does not return logits, so get_rlhf_dist is not supported"
2110
+ )
2111
+
2112
+ def _get_generic_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
2113
+ """Get distribution suitable for generic losses (all tokens).
2114
+
2115
+ vLLM does not return logits, so this method is not supported.
2116
+
2117
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
2118
+ """
2119
+ raise NotImplementedError(
2120
+ "vLLM does not return logits, so get_generic_dist is not supported"
2121
+ )
2122
+
2123
+
2124
+ class _RequestOutput_tc(TensorClass["nocast"]):
2125
+ """TensorClass wrapper for vLLM RequestOutput."""
2126
+
2127
+ request_id: str
2128
+ prompt: str
2129
+ prompt_token_ids: torch.Tensor
2130
+ prompt_logprobs: torch.Tensor
2131
+ outputs: Any
2132
+ finished: str
2133
+ metrics: str
2134
+ lora_request: str
2135
+ encoder_prompt: str
2136
+ encoder_prompt_token_ids: str
2137
+ num_cached_tokens: torch.Tensor
2138
+
2139
+ def __post_init__(self):
2140
+ CompletionOutput_tc = vLLMWrapper.CompletionOutput_tc
2141
+
2142
+ def postproc(output):
2143
+ def get_logprob(output):
2144
+ t = []
2145
+ token_ids = output.token_ids
2146
+ if isinstance(token_ids, torch.Tensor):
2147
+ token_ids = token_ids.tolist()
2148
+ for v, tid in zip(output.logprobs, token_ids):
2149
+ t.append(
2150
+ v[tid]["logprob"] if v[tid].get("logprob") is not None else 0.0
2151
+ )
2152
+ return torch.tensor(t)
2153
+
2154
+ if output.logprobs:
2155
+ output.logprobs = get_logprob(output)
2156
+ output.token_ids = torch.as_tensor(output.token_ids)
2157
+ return output
2158
+
2159
+ if isinstance(self.outputs, list):
2160
+ outputs = self.outputs
2161
+ outputs = [
2162
+ postproc(from_dataclass(output, dest_cls=CompletionOutput_tc))
2163
+ for output in outputs
2164
+ ]
2165
+ if len(outputs) == 1:
2166
+ self.outputs = outputs[0]
2167
+ else:
2168
+ # Check if we can stack the outputs (they should have the same shape)
2169
+ try:
2170
+ self.outputs = lazy_stack(outputs)
2171
+ except RuntimeError:
2172
+ # If stacking fails (different sizes), keep as list
2173
+ self.outputs = outputs
2174
+
2175
+ @classmethod
2176
+ def from_request_output(
2177
+ cls, requests: RequestOutput | list[RequestOutput]
2178
+ ) -> _RequestOutput_tc | list[_RequestOutput_tc]:
2179
+ """Create _RequestOutput_tc from vLLM RequestOutput."""
2180
+ # Type assertions
2181
+ assert isinstance(
2182
+ requests, (RequestOutput, list)
2183
+ ), f"requests must be RequestOutput or list, got {type(requests)}"
2184
+
2185
+ # Check if we can stack the outputs
2186
+ try:
2187
+ out = lazy_stack(
2188
+ [
2189
+ cls(
2190
+ request_id=request.request_id,
2191
+ prompt=request.prompt,
2192
+ prompt_token_ids=torch.as_tensor(request.prompt_token_ids),
2193
+ prompt_logprobs=torch.tensor(
2194
+ [
2195
+ v[int(tid)].logprob if v is not None else 0.0
2196
+ for v, tid in _zip_strict(
2197
+ request.prompt_logprobs, request.prompt_token_ids
2198
+ )
2199
+ ]
2200
+ )
2201
+ if request.prompt_logprobs is not None
2202
+ else torch.tensor([]),
2203
+ outputs=request.outputs,
2204
+ finished=request.finished,
2205
+ metrics=request.metrics,
2206
+ lora_request=request.lora_request,
2207
+ encoder_prompt=request.encoder_prompt,
2208
+ encoder_prompt_token_ids=request.encoder_prompt_token_ids,
2209
+ num_cached_tokens=torch.as_tensor(request.num_cached_tokens),
2210
+ )
2211
+ for request in requests
2212
+ ]
2213
+ )
2214
+ return out
2215
+ except RuntimeError:
2216
+ # If stacking fails, return a list of individual _RequestOutput_tc objects
2217
+ return [
2218
+ cls(
2219
+ request_id=request.request_id,
2220
+ prompt=request.prompt,
2221
+ prompt_token_ids=torch.as_tensor(request.prompt_token_ids),
2222
+ prompt_logprobs=torch.tensor(
2223
+ [
2224
+ v[int(tid)].logprob if v is not None else 0.0
2225
+ for v, tid in _zip_strict(
2226
+ request.prompt_logprobs, request.prompt_token_ids
2227
+ )
2228
+ ]
2229
+ )
2230
+ if request.prompt_logprobs is not None
2231
+ else torch.tensor([]),
2232
+ outputs=request.outputs,
2233
+ finished=request.finished,
2234
+ metrics=request.metrics,
2235
+ lora_request=request.lora_request,
2236
+ encoder_prompt=request.encoder_prompt,
2237
+ encoder_prompt_token_ids=request.encoder_prompt_token_ids,
2238
+ num_cached_tokens=torch.as_tensor(request.num_cached_tokens),
2239
+ )
2240
+ for request in requests
2241
+ ]