torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1809 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import threading
8
+ import warnings
9
+ import weakref
10
+ from concurrent.futures import CancelledError, Future, wait
11
+
12
+ from contextlib import nullcontext
13
+ from functools import wraps
14
+ from typing import Any, Literal, overload, TYPE_CHECKING
15
+
16
+ import torch
17
+ from tensordict import lazy_stack, NestedKey, TensorDictBase
18
+ from tensordict.nn import TensorDictModuleBase
19
+ from tensordict.tensorclass import TensorClass
20
+ from tensordict.utils import _zip_strict
21
+ from torch import distributions as D
22
+ from torch.distributions import Categorical
23
+ from torch.nn.utils.rnn import pad_sequence
24
+ from torchrl._utils import logger as torchrl_logger
25
+ from torchrl.data.llm import History
26
+ from torchrl.data.tensor_specs import Unbounded
27
+ from torchrl.modules.distributions.discrete import LLMMaskedCategorical
28
+
29
+ if TYPE_CHECKING:
30
+ from transformers import AutoTokenizer
31
+
32
+ # TODOs:
33
+ # - [ ] Remove the useless view(-1) calls when num_samples is not > 1
34
+ # - [ ] Remove as_list=True and use a context manager to handle that
35
+ # - [ ] Make sure tensordict can handle nested lazy tds that have a get(key, as_list=True) - I think it breaks atm
36
+ # - [ ] Handle packing
37
+
38
+
39
+ class Tokens(TensorClass["nocast"]):
40
+ """A Tokens container.
41
+
42
+ Args:
43
+ prompt (torch.Tensor | None): The prompt tokens.
44
+ response (torch.Tensor | None): The response tokens.
45
+ assistant (torch.Tensor | None): The assistant tokens.
46
+ full (torch.Tensor | None): The tokens across prompt and response.
47
+ padded (bool | None): Whether the tokens are padded.
48
+
49
+ Shapes:
50
+ - prompt: (batch_size, prompt_length). If padded, padded on the left.
51
+ - response: (batch_size, response_length). If padded, padded on the right.
52
+ - full: (batch_size, prompt_length + response_length). If padded, padded on the left and/or right.
53
+ - padded: bool.
54
+
55
+ """
56
+
57
+ prompt: torch.Tensor | None = None
58
+ response: torch.Tensor | None = None
59
+ full: torch.Tensor | None = None
60
+ padded: bool | None = None
61
+
62
+ @classmethod
63
+ def default_spec(
64
+ cls,
65
+ shape=(-1,),
66
+ keys: list[Literal["prompt", "response", "full"]] | None = None,
67
+ ):
68
+ """A default spec to use in transforms / envs that return Tokens objects."""
69
+ from torchrl.data import Composite, NonTensor
70
+
71
+ if keys is None:
72
+ keys = ["prompt", "response", "full"]
73
+
74
+ defaults = {k: Unbounded(shape=shape + (-1,)) for k in keys}
75
+ defaults["padded"] = NonTensor(shape=shape, example_data=False)
76
+
77
+ return Composite(defaults, shape=shape[:-1], data_cls=cls, step_mdp_static=True)
78
+
79
+ def to_text(
80
+ self,
81
+ tokenizer: AutoTokenizer,
82
+ skip_special_tokens: bool = False,
83
+ ) -> Text:
84
+ """Convert tokens to text using the tokenizer.
85
+
86
+ Args:
87
+ tokenizer: The tokenizer to use for decoding.
88
+ skip_special_tokens: Whether to skip special tokens in the output.
89
+
90
+ Returns:
91
+ A Text object with decoded text.
92
+
93
+ Raises:
94
+ ValueError: If padded tokens are provided (not yet supported).
95
+ """
96
+ # Check if padded - handle both bool and LinkedList cases
97
+ padded = self.padded
98
+ if isinstance(padded, bool):
99
+ if padded:
100
+ raise ValueError(
101
+ "Conversion from padded tokens to text is not yet supported. "
102
+ "Please use unpadded tokens (nested tensors)."
103
+ )
104
+ else:
105
+ # LinkedList case (when stacked) - check if any are True
106
+ padded_list = self.view(-1).padded
107
+ if any(padded_list):
108
+ raise ValueError(
109
+ "Conversion from padded tokens to text is not yet supported. "
110
+ "Please use unpadded tokens (nested tensors)."
111
+ )
112
+
113
+ # Create output structure
114
+ text_out = Text._from_tensordict(self._tensordict.empty())
115
+
116
+ # Helper to prepare tokens for batch_decode
117
+ def _prepare_tokens_for_decode(tokens_list):
118
+ """Ensure tokens are in the right format for batch_decode."""
119
+ if isinstance(tokens_list, list):
120
+ # Squeeze out extra batch dimensions if present
121
+ return [t.squeeze(0) if t.dim() > 1 else t for t in tokens_list]
122
+ else:
123
+ # Single tensor case
124
+ return tokens_list
125
+
126
+ # Decode prompt if available
127
+ if "prompt" in self._tensordict.keys():
128
+ prompt_tokens_list = self.get("prompt", as_list=True)
129
+ prompt_tokens_list = _prepare_tokens_for_decode(prompt_tokens_list)
130
+ prompt_texts = tokenizer.batch_decode(
131
+ prompt_tokens_list, skip_special_tokens=skip_special_tokens
132
+ )
133
+ text_out.set("prompt", prompt_texts)
134
+
135
+ # Decode response if available
136
+ if "response" in self._tensordict.keys():
137
+ response_tokens_list = self.get("response", as_list=True)
138
+ response_tokens_list = _prepare_tokens_for_decode(response_tokens_list)
139
+ response_texts = tokenizer.batch_decode(
140
+ response_tokens_list, skip_special_tokens=skip_special_tokens
141
+ )
142
+ text_out.set("response", response_texts)
143
+
144
+ # Decode full if available
145
+ if "full" in self._tensordict.keys():
146
+ full_tokens_list = self.get("full", as_list=True)
147
+ full_tokens_list = _prepare_tokens_for_decode(full_tokens_list)
148
+ full_texts = tokenizer.batch_decode(
149
+ full_tokens_list, skip_special_tokens=skip_special_tokens
150
+ )
151
+ text_out.set("full", full_texts)
152
+
153
+ return text_out
154
+
155
+ def to_history(
156
+ self,
157
+ tokenizer: AutoTokenizer,
158
+ chat_template_name: str | None = None,
159
+ skip_special_tokens: bool = False,
160
+ ) -> ChatHistory:
161
+ """Convert tokens to history by first decoding to text, then parsing.
162
+
163
+ Args:
164
+ tokenizer: The tokenizer to use for decoding and parsing.
165
+ chat_template_name: Optional chat template name for parsing.
166
+ skip_special_tokens: Whether to skip special tokens when decoding.
167
+
168
+ Returns:
169
+ A ChatHistory object with parsed conversation history.
170
+
171
+ Raises:
172
+ ValueError: If padded tokens are provided (not yet supported).
173
+ """
174
+ # First convert to text
175
+ text_obj = self.to_text(tokenizer, skip_special_tokens=skip_special_tokens)
176
+
177
+ # Then convert text to history
178
+ return text_obj.to_history(tokenizer, chat_template_name=chat_template_name)
179
+
180
+
181
+ class Masks(TensorClass["nocast"]):
182
+ """A Masks container.
183
+
184
+ Args:
185
+ all_attention_mask (torch.Tensor | None): The attention mask across all tokens. The attention mask represents
186
+ the tokens that are not masked. and that the model can attend to.
187
+ all_assistant_mask (torch.Tensor | None): The assistant mask across all tokens, i.e. the tokens that
188
+ are produced by the assistant.
189
+ This is recovered from the the `assistant_masks` output of :meth:`~torchrl.data.llm.History.apply_chat_template`,
190
+ if the chat template supports it.
191
+ padded (bool | None): Whether the masks are padded.
192
+
193
+ The masks always have the same shape as the `full` tensor in :class:`~torchrl.modules.llm.policies.common.Tokens`,
194
+ and :class:`~torchrl.modules.llm.policies.common.LogProbs`.
195
+
196
+ """
197
+
198
+ all_attention_mask: torch.Tensor | None = None
199
+ all_assistant_mask: torch.Tensor | None = None
200
+ padded: bool | None = None
201
+
202
+ @classmethod
203
+ def default_spec(
204
+ cls,
205
+ shape=(-1,),
206
+ keys: list[Literal["all_attention_mask", "all_assistant_mask"]] | None = None,
207
+ ):
208
+ """A default spec to use in transforms / envs that return Masks objects."""
209
+ from torchrl.data import Composite, NonTensor
210
+
211
+ if keys is None:
212
+ keys = ["all_attention_mask", "all_assistant_mask"]
213
+
214
+ defaults = {k: Unbounded(shape=shape + (-1,)) for k in keys}
215
+ defaults["padded"] = NonTensor(shape=shape, example_data=False)
216
+
217
+ return Composite(defaults, shape=shape[:-1], data_cls=cls, step_mdp_static=True)
218
+
219
+
220
+ class ChatHistory(TensorClass["nocast"]):
221
+ """A chat history container for managing conversation data in LLM environments.
222
+
223
+ This class serves as a structured container for chat history data, similar to how
224
+ :class:`~torchrl.modules.llm.policies.Text` and :class:`~torchrl.modules.llm.policies.Tokens`
225
+ are used for text and token data respectively.
226
+
227
+ **Recent Changes:**
228
+ - **Modular Design**: ChatHistory is now used consistently across LLM wrappers and environments
229
+ to represent conversation state in a structured way.
230
+ - **Integration with Wrappers**: Both vLLMWrapper and TransformersWrapper now use ChatHistory
231
+ objects when `input_mode="history"` is specified.
232
+ - **Environment Support**: ChatEnv and related environments use ChatHistory for state management.
233
+
234
+ Args:
235
+ prompt (History | None): The prompt history stack containing the conversation up to the current point.
236
+ response (History | None): The response history items (typically generated by the LLM).
237
+ full (History | None): The complete history across prompt and response.
238
+
239
+ Example:
240
+ >>> from torchrl.data.llm import History
241
+ >>> from torchrl.modules.llm.policies import ChatHistory
242
+ >>>
243
+ >>> # Create a conversation history
244
+ >>> history = History.from_chats([[
245
+ ... {"role": "user", "content": "Hello"},
246
+ ... {"role": "assistant", "content": "Hi there!"}
247
+ ... ]])
248
+ >>>
249
+ >>> # Create ChatHistory object for LLM wrapper input
250
+ >>> chat_history = ChatHistory(prompt=history)
251
+ >>>
252
+ >>> # Use with LLM wrapper
253
+ >>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,)))
254
+ >>> print(result["history"].response) # New response from LLM
255
+ >>> print(result["history"].full) # Complete conversation
256
+
257
+ .. seealso::
258
+ :class:`~torchrl.modules.llm.policies.Text`: Container for text data.
259
+ :class:`~torchrl.modules.llm.policies.Tokens`: Container for token data.
260
+ :class:`~torchrl.data.llm.History`: The underlying History class for conversation data.
261
+ """
262
+
263
+ prompt: History | None = None
264
+ response: History | None = None
265
+ full: History | None = None
266
+
267
+ @classmethod
268
+ def default_spec(
269
+ cls,
270
+ shape=(-1,),
271
+ keys: list[Literal["prompt", "response", "full"]] | None = None,
272
+ ):
273
+ """A default spec to use in transforms / envs that return ChatHistory objects."""
274
+ from torchrl.data import Composite
275
+
276
+ if keys is None:
277
+ keys = ["prompt", "response", "full"]
278
+ return Composite(
279
+ {k: History.default_spec(shape=shape + (-1,)) for k in keys},
280
+ shape=shape[:-1],
281
+ data_cls=cls,
282
+ step_mdp_static=True,
283
+ )
284
+
285
+ def __post_init__(self):
286
+ # Check that all history objects have one more batch dimension than the ChatHistory object
287
+ if self.prompt is not None:
288
+ if getattr(self.prompt, "batch_dims", None) == self.batch_dims:
289
+ warnings.warn(
290
+ "Prompt history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
291
+ f"got {self.prompt.batch_dims} and {self.batch_dims}. "
292
+ "The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
293
+ )
294
+ self.prompt = lazy_stack(
295
+ [self.prompt], -1
296
+ ) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
297
+ if self.response is not None:
298
+ if getattr(self.response, "batch_dims", None) == self.batch_dims:
299
+ warnings.warn(
300
+ "Response history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
301
+ f"got {self.response.batch_dims} and {self.batch_dims}. "
302
+ "The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
303
+ )
304
+ self.response = lazy_stack(
305
+ [self.response], -1
306
+ ) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
307
+ if self.full is not None:
308
+ if getattr(self.full, "batch_dims", None) == self.batch_dims:
309
+ warnings.warn(
310
+ "Full history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
311
+ f"got {self.full.batch_dims} and {self.batch_dims}. "
312
+ "The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
313
+ )
314
+ self.full = lazy_stack(
315
+ [self.full], -1
316
+ ) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
317
+
318
+ def to_tokens(
319
+ self,
320
+ tokenizer: AutoTokenizer,
321
+ chat_template_name: str | None = None,
322
+ chat_template: str | None = None,
323
+ ) -> Tokens:
324
+ """Tokenize the conversation history into a :class:`Tokens` object.
325
+
326
+ Args:
327
+ tokenizer: The tokenizer to use for tokenization.
328
+ chat_template_name: Optional chat template name to use.
329
+ chat_template: Optional chat template string to use.
330
+
331
+ Returns:
332
+ A Tokens object with prompt, response, and full tokens.
333
+
334
+ Note:
335
+ - For prompt: uses add_generation_prompt=True
336
+ - For full: uses add_generation_prompt=False
337
+ - Response is computed by slicing full tokens after prompt length
338
+ """
339
+ from tensordict.utils import _zip_strict
340
+
341
+ tokenizer_kwargs = {}
342
+ if chat_template_name is not None:
343
+ tokenizer_kwargs["chat_template_name"] = chat_template_name
344
+ if chat_template is not None:
345
+ tokenizer_kwargs["chat_template"] = chat_template
346
+
347
+ # Create output structure
348
+ tokens_out = Tokens._from_tensordict(self._tensordict.empty())
349
+
350
+ # Process prompt if available
351
+ if self.prompt is not None:
352
+ prompt_tokens = self.prompt.apply_chat_template(
353
+ tokenizer=tokenizer,
354
+ return_dict=True,
355
+ add_generation_prompt=True,
356
+ tokenize=True,
357
+ padding=False,
358
+ **tokenizer_kwargs,
359
+ )
360
+ # Get input_ids using as_nested_tensor to handle different lengths
361
+ tokens_out._tensordict.set(
362
+ "prompt", prompt_tokens.get("input_ids", as_list=True)
363
+ )
364
+
365
+ # Process full if available
366
+ if self.full is not None:
367
+ full_tokens = self.full.apply_chat_template(
368
+ tokenizer=tokenizer,
369
+ return_dict=True,
370
+ add_generation_prompt=False,
371
+ tokenize=True,
372
+ padding=False,
373
+ **tokenizer_kwargs,
374
+ )
375
+ # Get input_ids using as_nested_tensor to handle different lengths
376
+ tokens_out._tensordict.set(
377
+ "full", full_tokens.get("input_ids", as_list=True)
378
+ )
379
+
380
+ # Compute response by slicing if both prompt and full are available
381
+ if self.prompt is not None and self.full is not None:
382
+ prompt_tokens_list = tokens_out.get("prompt", as_list=True)
383
+ full_tokens_list = tokens_out.get("full", as_list=True)
384
+ response_tokens_list = []
385
+
386
+ for prompt_tok, full_tok in _zip_strict(
387
+ prompt_tokens_list, full_tokens_list
388
+ ):
389
+ prompt_len = prompt_tok.shape[-1]
390
+ response_tok = full_tok[..., prompt_len:]
391
+ response_tokens_list.append(response_tok)
392
+
393
+ tokens_out.set("response", response_tokens_list)
394
+
395
+ # Process response directly if available (and full is not)
396
+ elif self.response is not None:
397
+ response_tokens = self.response.apply_chat_template(
398
+ tokenizer=tokenizer,
399
+ return_dict=True,
400
+ add_generation_prompt=False,
401
+ tokenize=True,
402
+ padding=False,
403
+ **tokenizer_kwargs,
404
+ )
405
+ # Get input_ids using as_nested_tensor to handle different lengths
406
+ tokens_out._tensordict.set(
407
+ "response", response_tokens.get("input_ids", as_list=True)
408
+ )
409
+
410
+ tokens_out.padded = False
411
+ return tokens_out
412
+
413
+ def to_text(
414
+ self,
415
+ tokenizer: AutoTokenizer,
416
+ chat_template_name: str | None = None,
417
+ chat_template: str | None = None,
418
+ ) -> Text:
419
+ """Convert the conversation history into a :class:`Text` object.
420
+
421
+ Args:
422
+ tokenizer: The tokenizer to use for applying chat templates.
423
+ chat_template_name: Optional chat template name to use.
424
+ chat_template: Optional chat template string to use.
425
+
426
+ Returns:
427
+ A Text object with prompt, response, and full text.
428
+
429
+ Note:
430
+ - For prompt: uses add_generation_prompt=True
431
+ - For full: uses add_generation_prompt=False
432
+ - Response is computed by removing prompt prefix from full text
433
+ """
434
+ from tensordict.utils import _zip_strict
435
+
436
+ tokenizer_kwargs = {}
437
+ if chat_template_name is not None:
438
+ tokenizer_kwargs["chat_template_name"] = chat_template_name
439
+ if chat_template is not None:
440
+ tokenizer_kwargs["chat_template"] = chat_template
441
+
442
+ # Create output structure
443
+ text_out = Text._from_tensordict(self._tensordict.empty())
444
+
445
+ # Process prompt if available
446
+ if self.prompt is not None:
447
+ prompt_text = self.prompt.apply_chat_template(
448
+ tokenizer=tokenizer,
449
+ tokenize=False,
450
+ add_generation_prompt=True,
451
+ **tokenizer_kwargs,
452
+ )
453
+ text_out.set("prompt", prompt_text)
454
+
455
+ # Process full if available
456
+ if self.full is not None:
457
+ full_text = self.full.apply_chat_template(
458
+ tokenizer=tokenizer,
459
+ tokenize=False,
460
+ add_generation_prompt=False,
461
+ **tokenizer_kwargs,
462
+ )
463
+ text_out.set("full", full_text)
464
+
465
+ # Compute response by removing prompt prefix if both are available
466
+ if self.prompt is not None and self.full is not None:
467
+ prompt_texts_list = text_out.get("prompt", as_list=True)
468
+ full_texts_list = text_out.get("full", as_list=True)
469
+ response_texts_list = []
470
+
471
+ for prompt_txt, full_txt in _zip_strict(prompt_texts_list, full_texts_list):
472
+ if full_txt.startswith(prompt_txt):
473
+ response_txt = full_txt[len(prompt_txt) :]
474
+ else:
475
+ raise ValueError(
476
+ f"Full text does not start with prompt text. "
477
+ f"Prompt: {prompt_txt[:50]}..., Full: {full_txt[:50]}..."
478
+ )
479
+ response_texts_list.append(response_txt)
480
+
481
+ text_out.set("response", response_texts_list)
482
+
483
+ # Process response directly if available (and full is not)
484
+ elif self.response is not None:
485
+ response_text = self.response.apply_chat_template(
486
+ tokenizer=tokenizer,
487
+ tokenize=False,
488
+ add_generation_prompt=False,
489
+ **tokenizer_kwargs,
490
+ )
491
+ text_out.set("response", response_text)
492
+
493
+ return text_out
494
+
495
+
496
+ class LogProbs(TensorClass["nocast"]):
497
+ """A log-probability container.
498
+
499
+ Args:
500
+ prompt (torch.Tensor | None): The prompt log-probabilities.
501
+ response (torch.Tensor | None): The response log-probabilities.
502
+ assistant (torch.Tensor | None): The assistant log-probabilities.
503
+ full (torch.Tensor | None): The log-probabilities across prompt and response.
504
+ padded (bool | None): Whether the log-probabilities are padded.
505
+
506
+ Shapes:
507
+ - prompt: (batch_size, prompt_length). If padded, padded on the left.
508
+ - response: (batch_size, response_length). If padded, padded on the right.
509
+ - full: (batch_size, prompt_length + response_length). If padded, padded on the left and/or right.
510
+ - padded: bool.
511
+
512
+ """
513
+
514
+ prompt: torch.Tensor | None = None
515
+ response: torch.Tensor | None = None
516
+ full: torch.Tensor | None = None
517
+ padded: bool | None = None
518
+
519
+ @classmethod
520
+ def default_spec(
521
+ cls,
522
+ shape=(-1,),
523
+ keys: list[Literal["prompt", "response", "full"]] | None = None,
524
+ ):
525
+ """A default spec to use in transforms / envs that return LogProbs objects."""
526
+ from torchrl.data import Composite, NonTensor
527
+
528
+ if keys is None:
529
+ keys = ["prompt", "response", "full"]
530
+
531
+ defaults = {k: Unbounded(shape=shape + (-1,)) for k in keys}
532
+ defaults["padded"] = NonTensor(shape=shape, example_data=False)
533
+
534
+ return Composite(defaults, shape=shape[:-1], data_cls=cls, step_mdp_static=True)
535
+
536
+
537
+ class Text(TensorClass["nocast"]):
538
+ """A text container.
539
+
540
+ Args:
541
+ prompt (str | None): The prompt text.
542
+ response (str | None): The response text.
543
+ full (str | None): The text across prompt and response.
544
+ """
545
+
546
+ prompt: str | None = None
547
+ response: str | None = None
548
+ full: str | None = None
549
+
550
+ @classmethod
551
+ def default_spec(
552
+ cls,
553
+ shape=(-1,),
554
+ keys: list[Literal["prompt", "response", "full"]] | None = None,
555
+ ):
556
+ """A default spec to use in transforms / envs that return Text objects."""
557
+ from torchrl.data import Composite, NonTensor
558
+
559
+ if keys is None:
560
+ keys = ["prompt", "response", "full"]
561
+
562
+ defaults = {k: NonTensor(shape=shape, example_data="a string") for k in keys}
563
+
564
+ return Composite(defaults, shape=shape[:-1], data_cls=cls, step_mdp_static=True)
565
+
566
+ def to_tokens(
567
+ self,
568
+ tokenizer: AutoTokenizer,
569
+ padding: bool = False,
570
+ truncation: bool = False,
571
+ return_tensors: str = "pt",
572
+ ) -> Tokens:
573
+ """Convert text to tokens using the tokenizer.
574
+
575
+ Args:
576
+ tokenizer: The tokenizer to use for encoding.
577
+ padding: Whether to pad the sequences.
578
+ truncation: Whether to truncate the sequences.
579
+ return_tensors: The format of the output tensors.
580
+
581
+ Returns:
582
+ A Tokens object with tokenized text.
583
+
584
+ Raises:
585
+ ValueError: If padding is requested (not yet supported).
586
+ """
587
+ if padding:
588
+ raise ValueError(
589
+ "Padding is not yet supported for text to tokens conversion. "
590
+ "Please use padding=False."
591
+ )
592
+
593
+ # When not padding, we can't use return_tensors because sequences have different lengths
594
+ # We'll get lists and convert them to tensors ourselves
595
+ actual_return_tensors = return_tensors if padding else None
596
+
597
+ # Create output structure
598
+ tokens_out = Tokens._from_tensordict(self._tensordict.empty())
599
+
600
+ # Tokenize prompt if available
601
+ if self.prompt is not None:
602
+ prompt_texts_list = self.prompt
603
+ prompt_tokens = tokenizer(
604
+ prompt_texts_list,
605
+ padding=padding,
606
+ truncation=truncation,
607
+ return_tensors=actual_return_tensors,
608
+ )
609
+ # Convert to list of tensors
610
+ input_ids = prompt_tokens["input_ids"]
611
+ if not isinstance(input_ids, list):
612
+ input_ids = list(input_ids)
613
+ else:
614
+ # Convert each list to tensor
615
+ input_ids = [torch.tensor(ids) for ids in input_ids]
616
+ tokens_out.set("prompt", input_ids)
617
+
618
+ # Tokenize response if available
619
+ if self.response is not None:
620
+ response_texts_list = self.response
621
+ response_tokens = tokenizer(
622
+ response_texts_list,
623
+ padding=padding,
624
+ truncation=truncation,
625
+ return_tensors=actual_return_tensors,
626
+ )
627
+ # Convert to list of tensors
628
+ input_ids = response_tokens["input_ids"]
629
+ if not isinstance(input_ids, list):
630
+ input_ids = list(input_ids)
631
+ else:
632
+ # Convert each list to tensor
633
+ input_ids = [torch.tensor(ids) for ids in input_ids]
634
+ tokens_out.set("response", input_ids)
635
+
636
+ # Tokenize full if available
637
+ if self.full is not None:
638
+ full_texts_list = self.full
639
+ full_tokens = tokenizer(
640
+ full_texts_list,
641
+ padding=padding,
642
+ truncation=truncation,
643
+ return_tensors=actual_return_tensors,
644
+ )
645
+ # Convert to list of tensors
646
+ input_ids = full_tokens["input_ids"]
647
+ if not isinstance(input_ids, list):
648
+ input_ids = list(input_ids)
649
+ else:
650
+ # Convert each list to tensor
651
+ input_ids = [torch.tensor(ids) for ids in input_ids]
652
+ tokens_out.set("full", input_ids)
653
+
654
+ tokens_out.padded = padding
655
+ return tokens_out
656
+
657
+ def to_history(
658
+ self,
659
+ tokenizer: AutoTokenizer,
660
+ chat_template_name: str | None = None,
661
+ ) -> ChatHistory:
662
+ """Convert text to history by parsing the chat format.
663
+
664
+ Args:
665
+ tokenizer: The tokenizer to use for parsing.
666
+ chat_template_name: Optional chat template name for parsing.
667
+
668
+ Returns:
669
+ A ChatHistory object with parsed conversation history.
670
+ """
671
+ from torchrl.data.llm import History
672
+
673
+ # Create output structure
674
+ history_out = ChatHistory._from_tensordict(self._tensordict.empty())
675
+
676
+ # Parse prompt if available
677
+ if self.prompt is not None:
678
+ prompt_texts_list = self.prompt
679
+ prompt_histories_list = []
680
+ for prompt_text in prompt_texts_list:
681
+ prompt_hist = History.from_text(
682
+ prompt_text,
683
+ chat_template_name=chat_template_name,
684
+ tokenizer=tokenizer,
685
+ )
686
+ prompt_histories_list.append(prompt_hist)
687
+ history_out.set("prompt", lazy_stack(prompt_histories_list))
688
+
689
+ # Parse response if available
690
+ if self.response is not None:
691
+ response_texts_list = self.response
692
+ response_histories_list = []
693
+ for response_text in response_texts_list:
694
+ response_hist = History.from_text(
695
+ response_text,
696
+ chat_template_name=chat_template_name,
697
+ tokenizer=tokenizer,
698
+ )
699
+ response_histories_list.append(response_hist)
700
+ history_out.set("response", lazy_stack(response_histories_list))
701
+
702
+ # Parse full if available
703
+ if self.full is not None:
704
+ full_texts_list = self.full
705
+ full_histories_list = []
706
+ for full_text in full_texts_list:
707
+ full_hist = History.from_text(
708
+ full_text,
709
+ chat_template_name=chat_template_name,
710
+ tokenizer=tokenizer,
711
+ )
712
+ full_histories_list.append(full_hist)
713
+ history_out.set("full", lazy_stack(full_histories_list))
714
+
715
+ return history_out
716
+
717
+
718
+ class LogProbDistribution(D.Distribution):
719
+ """A distribution that works directly with log-probabilities.
720
+
721
+ This is useful when we have pre-computed log-probabilities (e.g., from vLLM)
722
+ and want to compute log_prob() without having access to the original logits.
723
+ """
724
+
725
+ def __init__(self, log_probs: torch.Tensor, mask: torch.Tensor | None = None):
726
+ """Initialize with log-probabilities.
727
+
728
+ Args:
729
+ log_probs: Tensor of shape [batch, seq_len] containing log-probabilities
730
+ mask: Optional mask of shape [batch, seq_len] indicating valid positions
731
+ """
732
+ self.log_probs = log_probs
733
+ self.mask = mask
734
+ batch_shape = log_probs.shape[:-1] if log_probs.dim() > 1 else log_probs.shape
735
+ event_shape = log_probs.shape[-1:] if log_probs.dim() > 1 else torch.Size([])
736
+ super().__init__(batch_shape=batch_shape, event_shape=event_shape)
737
+
738
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
739
+ """Compute log-probability for the given tokens.
740
+
741
+ Args:
742
+ value: Tensor of shape [batch, seq_len] containing token indices
743
+
744
+ Returns:
745
+ Tensor of shape [batch, seq_len] containing log-probabilities
746
+ """
747
+ # For log-prob distributions, we just return the pre-computed log-probs
748
+ # at the positions specified by the value tensor
749
+ if value.shape != self.log_probs.shape:
750
+ raise ValueError(
751
+ f"Value shape {value.shape} must match log_probs shape {self.log_probs.shape}"
752
+ )
753
+
754
+ result = self.log_probs.clone()
755
+
756
+ # Apply mask if provided
757
+ if self.mask is not None:
758
+ result = torch.where(
759
+ self.mask,
760
+ result,
761
+ torch.tensor(0.0, device=result.device, dtype=result.dtype),
762
+ )
763
+
764
+ return result
765
+
766
+ def sample(self, sample_shape: tuple | torch.Size | None = None) -> torch.Tensor:
767
+ """Sample from the distribution.
768
+
769
+ Note: This is not implemented for log-prob distributions since we don't have
770
+ the full probability distribution, only the log-probs for specific tokens.
771
+ """
772
+ raise NotImplementedError("Sampling not supported for LogProbDistribution")
773
+
774
+ def entropy(self) -> torch.Tensor:
775
+ """Compute entropy.
776
+
777
+ Note: This is not implemented for log-prob distributions since we don't have
778
+ the full probability distribution.
779
+ """
780
+ raise NotImplementedError("Entropy not supported for LogProbDistribution")
781
+
782
+
783
+ class LLMWrapperBase(TensorDictModuleBase):
784
+ r"""A LLM wrapper base class.
785
+
786
+ This class provides a consistent interface for LLM wrappers with the following features:
787
+ - Support for different input modalities (history, text, tokens)
788
+ - Consistent output structure using TensorClass objects (Text, Tokens, Masks, LogProbs)
789
+ - Configurable generation and log-probability computation
790
+ - Standardized generation parameters across different backends
791
+
792
+ Args:
793
+ model: The underlying model to wrap.
794
+
795
+ Keyword Args:
796
+ tokenizer: The tokenizer to use for encoding and decoding text.
797
+ input_mode: The input modality to use. Must be one of "history", "text", or "tokens".
798
+ input_key: The key for the input data. If None, defaults to the input_mode name.
799
+ attention_mask_key: The key for attention masks (used in "tokens" mode).
800
+ generate: Whether to enable text generation.
801
+ generate_kwargs: Additional arguments to pass to the model's generate method.
802
+
803
+ **Common Parameters (cross-backend compatible):**
804
+
805
+ * **max_new_tokens** (int): Maximum number of new tokens to generate
806
+ * **num_return_sequences** (int): Number of sequences to return
807
+ * **temperature** (float): Sampling temperature (0.0 = deterministic, higher = more random)
808
+ * **top_p** (float): Nucleus sampling parameter (0.0-1.0)
809
+ * **top_k** (int): Top-k sampling parameter
810
+ * **repetition_penalty** (float): Penalty for repeating tokens
811
+ * **do_sample** (bool): Whether to use sampling vs greedy decoding
812
+ * **num_beams** (int): Number of beams for beam search
813
+ * **length_penalty** (float): Penalty for sequence length
814
+ * **early_stopping** (bool): Whether to stop early in beam search
815
+ * **stop_sequences** (list): Sequences that stop generation
816
+ * **skip_special_tokens** (bool): Whether to skip special tokens in output
817
+ * **logprobs** (bool): Whether to return log probabilities
818
+
819
+ **Parameter Conflict Resolution:**
820
+
821
+ When both legacy (backend-specific) and standardized parameter names are provided,
822
+ the legacy names silently prevail. This ensures backward compatibility with existing code.
823
+
824
+ * If both ``max_tokens`` and ``max_new_tokens`` are passed, ``max_tokens`` wins
825
+ * If both ``n`` and ``num_return_sequences`` are passed, ``n`` wins
826
+
827
+ This behavior allows existing code to continue working without modification.
828
+
829
+ **Parameter Validation:**
830
+
831
+ The following validations are performed:
832
+
833
+ * Temperature must be non-negative
834
+ * top_p must be between 0 and 1
835
+ * top_k must be positive
836
+ * repetition_penalty must be positive
837
+ * When do_sample=False, temperature must be 0 for greedy decoding
838
+
839
+ tokenizer_kwargs: Additional arguments to pass to the tokenizer.
840
+ pad_output: Whether to pad the output sequences to a uniform length.
841
+ pad_model_input: Whether to pad the model input sequences to a uniform length.
842
+ May not be supported by all models.
843
+ inplace: Determines how the module should handle in-place operations.
844
+ device: The device to use for computation.
845
+ layout: The layout to use for the output tensors when pad_output=False.
846
+ num_samples: The number of samples to generate.
847
+ log_probs_key (NestedKey | None, optional): The key for the log probabilities :class:`~torchrl.modules.llm.policies.LogProbs` object. Defaults to `"log_probs"`.
848
+ text_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Text` object. Defaults to `"text"`.
849
+ tokens_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Tokens` object. Defaults to `"tokens"`.
850
+ masks_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Masks` object. Defaults to `"masks"`.
851
+ batching (bool | None, optional): Whether to enable batching. See `Batching`_ below for more details.
852
+ min_batch_size (int | None, optional): The minimum batch size to use for batching. See `Batching`_ below for more details.
853
+ max_batch_size (int | None, optional): The maximum batch size to use for batching. See `Batching`_ below for more details.
854
+ batching_timeout (float, optional): The timeout for batching. See `Batching`_ below for more details.
855
+
856
+ .. _Batching:
857
+
858
+ **Batching**
859
+
860
+ Batching is a feature that allows the module to process multiple inputs in a single call.
861
+ It is designed to work in a multi-threaded environment.
862
+ To enable batching, it suffices to set `batching=True` which will set `min_batch_size` to 1 if not provided.
863
+ If you want to set a different value for `min_batch_size` or `max_batch_size` for a fine-grained control,
864
+ you can to set `batching=True` and then set `min_batch_size` or `max_batch_size` to a value greater or equal to 1.
865
+ The way batching works is as follows:
866
+ - If `min_batch_size` is not provided but `max_batch_size` is, `min_batch_size` is set to 1.
867
+ - If `max_batch_size` is not provided but `min_batch_size` is, `max_batch_size` is set to the number of inputs in the queue.
868
+ - When the model is called, a check is performed to see if the number of inputs in the queue is greater or equal to `min_batch_size`.
869
+ If it is, the batch is processed immediately, while waiting for the previous batch to be processed if the model is busy.
870
+ Otherwise, the input is added to the queue and the function waits for the batch to be completed.
871
+ While waiting for the batch to be completed, a timeout is set to `batching_timeout` seconds such that if the batch is not
872
+ completed after `batching_timeout` seconds, the remaining items to process are processed as is and the function returns after
873
+ at most `batching_timeout` seconds (plus the time to finish processing the previous and current batch).
874
+
875
+ Attributes:
876
+ collector: The collector associated with the module, if it exists.
877
+
878
+ .. seealso::
879
+ - :class:`~torchrl.modules.llm.policies.TransformersWrapper`
880
+ - :class:`~torchrl.modules.llm.policies.vLLMWrapper`
881
+ """
882
+
883
+ generate: bool
884
+ pad_output: bool
885
+ text_key: NestedKey
886
+ tokens_key: NestedKey
887
+ masks_key: NestedKey
888
+ log_probs_key: NestedKey
889
+ in_keys: list[NestedKey]
890
+ out_keys: list[NestedKey]
891
+ inplace: bool
892
+ device: torch.device | None
893
+ layout: torch.layout | None
894
+ num_samples: int | None
895
+ _min_batch_size: int | None
896
+ _max_batch_size: int | None
897
+ _batching_lock: threading.Lock | None
898
+ _batching_timeout: float | None
899
+
900
+ # Common generation parameters that work across both vLLM and Transformers
901
+ COMMON_GENERATION_PARAMS = {
902
+ "max_new_tokens",
903
+ "num_return_sequences",
904
+ "temperature",
905
+ "top_p",
906
+ "top_k",
907
+ "repetition_penalty",
908
+ "do_sample",
909
+ "num_beams",
910
+ "length_penalty",
911
+ "early_stopping",
912
+ "stop_sequences",
913
+ "skip_special_tokens",
914
+ "logprobs",
915
+ }
916
+
917
+ @overload
918
+ def __init__(
919
+ self,
920
+ model: Any | str,
921
+ *,
922
+ tokenizer: callable | str | None = None, # type: ignore
923
+ input_mode: str = "history",
924
+ input_key: NestedKey | None = None,
925
+ attention_mask_key: str = "attention_mask",
926
+ generate: bool = True,
927
+ generate_kwargs: dict | None = None,
928
+ tokenizer_kwargs: dict | None = None,
929
+ pad_output: bool = False,
930
+ inplace: Literal[True, False, "empty"] | None = None,
931
+ device: torch.device | None = None,
932
+ layout: torch.layout | None = None,
933
+ num_samples: int | None = None,
934
+ chat_template_name: Literal["chatml_format", "qwen"] | None = None,
935
+ chat_template: str | None = None,
936
+ return_log_probs: bool | None = None,
937
+ history_key: NestedKey | None = "history",
938
+ text_key: NestedKey | None = "text",
939
+ tokens_key: NestedKey | None = "tokens",
940
+ masks_key: NestedKey | None = "masks",
941
+ log_probs_key: NestedKey | None = "log_probs",
942
+ batching: bool | None = None,
943
+ min_batch_size: int | None = None,
944
+ max_batch_size: int | None = None,
945
+ batching_timeout: float = 10.0,
946
+ ):
947
+ ...
948
+
949
+ def __init__(self, *args, **kwargs):
950
+ super().__init__()
951
+
952
+ @classmethod
953
+ def _standardize_generate_kwargs(cls, generate_kwargs: dict | None) -> dict:
954
+ """Standardize generation parameters to use common names across wrappers.
955
+
956
+ This method converts wrapper-specific parameter names to common names:
957
+
958
+ * vLLM's ``max_tokens`` -> ``max_new_tokens``
959
+ * vLLM's ``n`` -> ``num_return_sequences``
960
+
961
+ **Parameter Conflict Resolution:**
962
+
963
+ When both legacy (backend-specific) and standardized parameter names are provided,
964
+ the legacy names silently prevail. This ensures backward compatibility with existing code.
965
+
966
+ Args:
967
+ generate_kwargs: The generation parameters to standardize
968
+
969
+ Returns:
970
+ Standardized generation parameters
971
+ """
972
+ if generate_kwargs is None:
973
+ return {}
974
+
975
+ standardized = dict(generate_kwargs)
976
+
977
+ # Convert vLLM parameter names to common names
978
+ # Legacy names prevail in conflicts (backward compatibility)
979
+ if "max_tokens" in standardized:
980
+ if "max_new_tokens" in standardized:
981
+ # Legacy name wins - remove the standardized name
982
+ standardized.pop("max_new_tokens")
983
+ standardized["max_new_tokens"] = standardized.pop("max_tokens")
984
+
985
+ if "n" in standardized:
986
+ if "num_return_sequences" in standardized:
987
+ # Legacy name wins - remove the standardized name
988
+ standardized.pop("num_return_sequences")
989
+ standardized["num_return_sequences"] = standardized.pop("n")
990
+
991
+ # Validate parameter combinations
992
+ cls._validate_parameter_combinations(standardized)
993
+
994
+ return standardized
995
+
996
+ @classmethod
997
+ def _validate_parameter_combinations(cls, generate_kwargs: dict) -> None:
998
+ """Validate that parameter combinations make sense.
999
+
1000
+ This method performs the following validations:
1001
+
1002
+ * Temperature must be non-negative
1003
+ * top_p must be between 0 and 1
1004
+ * top_k must be positive
1005
+ * repetition_penalty must be positive
1006
+ * When do_sample=False, temperature must be 0 for greedy decoding
1007
+
1008
+ Args:
1009
+ generate_kwargs: The generation parameters to validate
1010
+
1011
+ Raises:
1012
+ ValueError: If parameter combinations are invalid
1013
+ """
1014
+ # Check for conflicting sampling parameters
1015
+ if generate_kwargs.get("do_sample") is False:
1016
+ # If do_sample=False, temperature should be 0 for greedy decoding
1017
+ if generate_kwargs.get("temperature", 0) != 0:
1018
+ raise ValueError(
1019
+ "When do_sample=False (greedy decoding), temperature must be 0. "
1020
+ f"Got temperature={generate_kwargs.get('temperature')}"
1021
+ )
1022
+
1023
+ # Check for valid temperature range
1024
+ temperature = generate_kwargs.get("temperature")
1025
+ if temperature is not None and temperature < 0:
1026
+ raise ValueError(f"Temperature must be non-negative, got {temperature}")
1027
+
1028
+ # Check for valid top_p range
1029
+ top_p = generate_kwargs.get("top_p")
1030
+ if top_p is not None and not (0 <= top_p <= 1):
1031
+ raise ValueError(f"top_p must be between 0 and 1, got {top_p}")
1032
+
1033
+ # Check for valid top_k
1034
+ top_k = generate_kwargs.get("top_k")
1035
+ if top_k is not None and top_k <= 0:
1036
+ raise ValueError(f"top_k must be positive, got {top_k}")
1037
+
1038
+ # Check for valid repetition_penalty
1039
+ repetition_penalty = generate_kwargs.get("repetition_penalty")
1040
+ if repetition_penalty is not None and repetition_penalty <= 0:
1041
+ raise ValueError(
1042
+ f"repetition_penalty must be positive, got {repetition_penalty}"
1043
+ )
1044
+
1045
+ @classmethod
1046
+ def _get_wrapper_specific_kwargs(
1047
+ cls, generate_kwargs: dict, wrapper_type: str
1048
+ ) -> dict:
1049
+ """Extract wrapper-specific generation parameters.
1050
+
1051
+ Args:
1052
+ generate_kwargs: The generation parameters
1053
+ wrapper_type: Either 'vllm' or 'transformers'
1054
+
1055
+ Returns:
1056
+ Wrapper-specific parameters
1057
+ """
1058
+ if generate_kwargs is None:
1059
+ return {}
1060
+
1061
+ if wrapper_type == "vllm":
1062
+ # vLLM-specific parameters
1063
+ vllm_specific = {
1064
+ "presence_penalty",
1065
+ "frequency_penalty",
1066
+ "ignore_eos",
1067
+ "prompt_logprobs",
1068
+ "detokenize",
1069
+ "include_stop_str_in_output",
1070
+ "spaces_between_special_tokens",
1071
+ "sampling_type",
1072
+ "temperature_last",
1073
+ "top_p_last",
1074
+ "top_k_last",
1075
+ }
1076
+ return {k: v for k, v in generate_kwargs.items() if k in vllm_specific}
1077
+
1078
+ elif wrapper_type == "transformers":
1079
+ # Transformers-specific parameters
1080
+ transformers_specific = {
1081
+ "pad_token_id",
1082
+ "eos_token_id",
1083
+ "bad_words_ids",
1084
+ "force_words_ids",
1085
+ "no_repeat_ngram_size",
1086
+ "encoder_repetition_penalty",
1087
+ "num_beam_groups",
1088
+ "diversity_penalty",
1089
+ "output_scores",
1090
+ "return_dict_in_generate",
1091
+ }
1092
+ return {
1093
+ k: v for k, v in generate_kwargs.items() if k in transformers_specific
1094
+ }
1095
+
1096
+ return {}
1097
+
1098
+ @property
1099
+ def batching(self) -> bool:
1100
+ """Whether batching is enabled."""
1101
+ return self._min_batch_size is not None or self._max_batch_size is not None
1102
+
1103
+ def get_new_version(self, **kwargs):
1104
+ """Returns a new version of the module with altered parameters.
1105
+
1106
+ For instance, the generate parameter can be altered to enable text generation or log-probabilities computation.
1107
+ This is especially useful when one wants to avoid re-initializing the module with a new set of parameters, when the
1108
+ same parameters could be used to gather log-probs.
1109
+
1110
+ Positional arguments are not supported.
1111
+
1112
+ See the class constructor for more details about the parameters.
1113
+ """
1114
+ raise NotImplementedError
1115
+
1116
+ _collector: weakref.ReferenceType[
1117
+ LLMCollector # noqa: F821 # type: ignore
1118
+ ] | None = None
1119
+
1120
+ def register_collector(self, collector: LLMCollector): # noqa: F821 # type: ignore
1121
+ """Registers a weak reference to the container collector.
1122
+
1123
+ This is automatically called by the :class:`~torchrl.collectors.llm.LLMCollector` class.
1124
+ """
1125
+ self._collector = weakref.ref(collector)
1126
+
1127
+ @property
1128
+ def collector(self) -> LLMCollector | None: # noqa: F821 # type: ignore
1129
+ """Returns the collector associated with the module, if it exists."""
1130
+ return self._collector() if self._collector is not None else None
1131
+
1132
+ def get_dist(
1133
+ self,
1134
+ tensordict: TensorDictBase,
1135
+ tensordict_out: TensorDictBase | None = None,
1136
+ logits_key: NestedKey = "logits",
1137
+ mask_key: NestedKey | None = None,
1138
+ as_padded_tensor: bool | None = None,
1139
+ as_nested_tensor: bool | None = None,
1140
+ padding_value: float | None = None,
1141
+ padding_side: str = "left",
1142
+ layout: torch.layout | None = None,
1143
+ **kwargs,
1144
+ ) -> D.Distribution:
1145
+ """Get distribution from logits/log-probs with optional masking.
1146
+
1147
+ Args:
1148
+ tensordict: Input tensordict
1149
+ tensordict_out: Output tensordict (optional)
1150
+ logits_key: Key for logits/log-probs
1151
+ mask_key: Key for mask (optional).
1152
+ as_padded_tensor: Whether to return padded tensor. Default is False.
1153
+ as_nested_tensor: Whether to return nested tensor. Default is False.
1154
+ padding_value: Value for padding. Default is 0.0 for logits and False for masks.
1155
+ padding_side: Side for padding. Default is left by convention.
1156
+ layout: Tensor layout
1157
+ **kwargs: Additional arguments
1158
+
1159
+ Returns:
1160
+ Distribution (Categorical or LLMMaskedCategorical)
1161
+ """
1162
+ if self.generate:
1163
+ raise NotImplementedError(
1164
+ "get_dist is not implemented for generate=True. "
1165
+ "You can create a new version of this wrapper using the `get_new_version` method."
1166
+ )
1167
+
1168
+ td_out = self.forward(tensordict.copy(), logits_only=True)
1169
+
1170
+ # Get logits/log-probs
1171
+ if as_padded_tensor is None:
1172
+ as_padded_tensor = as_nested_tensor is not True
1173
+ if padding_value is None:
1174
+ padding_value = 0.0
1175
+ if as_nested_tensor is None:
1176
+ as_nested_tensor = False
1177
+
1178
+ logits = td_out.get(
1179
+ logits_key,
1180
+ as_padded_tensor=as_padded_tensor,
1181
+ as_nested_tensor=as_nested_tensor,
1182
+ padding_value=padding_value,
1183
+ padding_side=padding_side,
1184
+ layout=layout,
1185
+ )
1186
+
1187
+ # Get mask if provided
1188
+ mask = None
1189
+ if mask_key is not None:
1190
+ mask = td_out.get(
1191
+ mask_key,
1192
+ as_padded_tensor=as_padded_tensor,
1193
+ as_nested_tensor=as_nested_tensor,
1194
+ padding_value=False,
1195
+ padding_side=padding_side,
1196
+ layout=layout,
1197
+ )
1198
+ elif as_padded_tensor:
1199
+ # Default mask for padded tensors
1200
+ mask = logits != padding_value
1201
+
1202
+ if mask is not None:
1203
+ dist = LLMMaskedCategorical(
1204
+ logits=logits,
1205
+ mask=mask,
1206
+ )
1207
+ if not dist._position_level_masking:
1208
+ raise ValueError(
1209
+ "Mask is not a position-level mask. "
1210
+ "This is likely because the mask is not a position-level mask."
1211
+ )
1212
+ return dist
1213
+ return Categorical(logits)
1214
+
1215
+ def _get_dist_with_prompt_mask(
1216
+ self,
1217
+ tensordict: TensorDictBase,
1218
+ tokens_key: NestedKey = ("tokens", "prompt"),
1219
+ logits_key: NestedKey = "logits",
1220
+ # TODO: add a prompt_mask and response_mask in Masks
1221
+ assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
1222
+ attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
1223
+ padding_side: str = "left",
1224
+ **kwargs,
1225
+ ) -> D.Distribution:
1226
+ """Get distribution masked to only include response tokens (exclude prompt).
1227
+
1228
+ This is suitable for single-turn scenarios where we want to compute loss
1229
+ only on the generated response, not the input prompt.
1230
+
1231
+ Note: If prompt tokens are not available (e.g., when using history input),
1232
+ this method falls back to using the assistant mask.
1233
+
1234
+ Padding side is left by convention.
1235
+
1236
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
1237
+ """
1238
+ if self.generate:
1239
+ raise NotImplementedError(
1240
+ "get_dist_with_prompt_mask is not implemented for generate=True. "
1241
+ "You can create a new version of this wrapper using the `get_new_version` method."
1242
+ )
1243
+ td_out = self.forward(tensordict.copy(), logits_only=True)
1244
+
1245
+ # Try to get prompt tokens first
1246
+ if self.pad_output:
1247
+ prompt_tokens = tensordict.get(
1248
+ tokens_key,
1249
+ as_padded_tensor=True,
1250
+ padding_value=-100,
1251
+ padding_side=padding_side,
1252
+ )
1253
+ logits = td_out.get(
1254
+ logits_key,
1255
+ as_padded_tensor=True,
1256
+ padding_value=0.0,
1257
+ padding_side=padding_side,
1258
+ )
1259
+ attention_mask = tensordict.get(
1260
+ attention_mask_key,
1261
+ as_padded_tensor=True,
1262
+ padding_value=False,
1263
+ padding_side=padding_side,
1264
+ )
1265
+ assistant_mask = tensordict.get(
1266
+ assistant_mask_key,
1267
+ as_padded_tensor=True,
1268
+ padding_value=False,
1269
+ padding_side=padding_side,
1270
+ )
1271
+ else:
1272
+ prompt_tokens = tensordict.get(tokens_key, as_list=True)
1273
+ logits = td_out.get(logits_key, as_list=True)
1274
+ attention_mask = td_out.get(attention_mask_key, as_list=True)
1275
+ assistant_mask = td_out.get(assistant_mask_key, as_list=True)
1276
+
1277
+ if prompt_tokens is None:
1278
+ if assistant_mask is None:
1279
+ raise ValueError(
1280
+ f"Assistant mask not found in tensordict at key {assistant_mask_key} (keys: {td_out.keys()})"
1281
+ )
1282
+ if self.pad_output:
1283
+ response_mask = assistant_mask.clone()
1284
+ else:
1285
+ response_mask = [am.clone() for am in assistant_mask]
1286
+ else:
1287
+ if self.pad_output:
1288
+ response_mask = attention_mask.clone()
1289
+ response_mask[..., : prompt_tokens.shape[-1]] = False
1290
+ else:
1291
+ response_mask = []
1292
+ for am, p in _zip_strict(attention_mask, prompt_tokens):
1293
+ am = am.clone()
1294
+ am[..., : p.size(-1)] = False
1295
+ response_mask.append(am)
1296
+
1297
+ if logits is None:
1298
+ raise ValueError(
1299
+ f"Logits not found in tensordict at key {logits_key} (keys: {td_out.keys()})"
1300
+ )
1301
+
1302
+ # Make the response mask using prompt tokens
1303
+ if not self.pad_output:
1304
+ # Check that the lengths of the mask is the same as the logits
1305
+ torchrl_logger.info(f"Response mask: {response_mask}")
1306
+ torchrl_logger.info(f"Logits: {logits}")
1307
+ for m, lg in _zip_strict(response_mask, logits):
1308
+ if m.shape[-1] != lg.shape[-2]:
1309
+ raise ValueError(
1310
+ f"Mask and logits have different lengths: {m.shape[-1]} != {lg.shape[-2]}.\n"
1311
+ f"All the logits shapes: {[lg.shape for lg in logits]}, all the mask shapes: {[m.shape for m in response_mask]}"
1312
+ )
1313
+ logits = pad_sequence(
1314
+ logits, batch_first=True, padding_value=0.0, padding_side=padding_side
1315
+ )
1316
+ response_mask = pad_sequence(
1317
+ response_mask,
1318
+ batch_first=True,
1319
+ padding_value=False,
1320
+ padding_side=padding_side,
1321
+ )
1322
+
1323
+ dist = LLMMaskedCategorical(
1324
+ logits=logits,
1325
+ mask=response_mask.bool(),
1326
+ )
1327
+ if not dist._position_level_masking:
1328
+ raise ValueError(
1329
+ "Mask is not a position-level mask. "
1330
+ "This is likely because the mask is not a position-level mask."
1331
+ )
1332
+ return dist
1333
+
1334
+ def _get_dist_with_assistant_mask(
1335
+ self,
1336
+ tensordict: TensorDictBase,
1337
+ assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
1338
+ logits_key: NestedKey = "logits",
1339
+ padding_side: str = "left",
1340
+ **kwargs,
1341
+ ) -> D.Distribution:
1342
+ """Get distribution masked to only include assistant tokens.
1343
+
1344
+ This is suitable for multi-turn scenarios where we want to compute loss
1345
+ only on assistant-generated tokens across the entire conversation.
1346
+
1347
+ Padding side is left by convention.
1348
+
1349
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
1350
+ """
1351
+ if self.generate:
1352
+ raise NotImplementedError(
1353
+ "get_dist_with_assistant_mask is not implemented for generate=True. "
1354
+ "You can create a new version of this wrapper using the `get_new_version` method."
1355
+ )
1356
+ td_out = self.forward(tensordict.copy(), logits_only=True)
1357
+ # Update the tokens key to reflect the tokenized history when querying the log-probs
1358
+ tensordict.update(
1359
+ td_out,
1360
+ keys_to_update=[
1361
+ ("tokens", "full"),
1362
+ ],
1363
+ )
1364
+
1365
+ if self.pad_output:
1366
+ logits = td_out.get(logits_key)
1367
+ assistant_mask = td_out.get(assistant_mask_key)
1368
+ else:
1369
+ logits = td_out.get(
1370
+ logits_key,
1371
+ as_padded_tensor=True,
1372
+ padding_value=0.0,
1373
+ padding_side=padding_side,
1374
+ )
1375
+ assistant_mask = td_out.get(
1376
+ assistant_mask_key,
1377
+ as_padded_tensor=True,
1378
+ padding_value=False,
1379
+ padding_side=padding_side,
1380
+ )
1381
+ if logits is None:
1382
+ raise ValueError(f"Logits not found in tensordict at key {logits_key}")
1383
+ if assistant_mask is None:
1384
+ if self.input_mode != "history":
1385
+ post_msg = "This is likely because the input_mode is not 'history'."
1386
+ else:
1387
+ post_msg = ""
1388
+ raise ValueError(
1389
+ f"Assistant mask not found in tensordict at key {assistant_mask_key}. {post_msg}"
1390
+ )
1391
+
1392
+ dist = LLMMaskedCategorical(
1393
+ logits=logits,
1394
+ mask=assistant_mask,
1395
+ )
1396
+ if not dist._position_level_masking:
1397
+ raise ValueError(
1398
+ "Assistant mask is not a position-level mask. "
1399
+ "This is likely because the assistant mask is not a position-level mask."
1400
+ )
1401
+ return dist
1402
+
1403
+ def _get_dist_with_attention_mask(
1404
+ self,
1405
+ tensordict: TensorDictBase,
1406
+ attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
1407
+ logits_key: NestedKey = "logits",
1408
+ padding_side: str = "left",
1409
+ **kwargs,
1410
+ ) -> D.Distribution:
1411
+ """Get distribution masked using attention mask.
1412
+
1413
+ This is suitable for generic scenarios where we want to compute loss
1414
+ on all valid tokens (non-padding tokens).
1415
+
1416
+ Padding side is left by convention.
1417
+
1418
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
1419
+ """
1420
+ if self.generate:
1421
+ raise NotImplementedError(
1422
+ "get_dist_with_attention_mask is not implemented for generate=True. "
1423
+ "You can create a new version of this wrapper using the `get_new_version` method."
1424
+ )
1425
+ td_out = self.forward(tensordict.copy(), logits_only=True)
1426
+ if self.pad_output:
1427
+ logits = td_out.get(logits_key)
1428
+ attention_mask = td_out.get(attention_mask_key)
1429
+ else:
1430
+ logits = td_out.get(
1431
+ logits_key,
1432
+ as_padded_tensor=True,
1433
+ padding_value=0.0,
1434
+ padding_side=padding_side,
1435
+ )
1436
+ attention_mask = td_out.get(
1437
+ attention_mask_key,
1438
+ as_padded_tensor=True,
1439
+ padding_value=False,
1440
+ padding_side=padding_side,
1441
+ )
1442
+
1443
+ if logits is None:
1444
+ raise ValueError(f"Logits not found in tensordict at key {logits_key}")
1445
+ if attention_mask is None:
1446
+ raise ValueError(
1447
+ f"Attention mask not found in tensordict at key {attention_mask_key}"
1448
+ )
1449
+
1450
+ dist = LLMMaskedCategorical(
1451
+ logits=logits,
1452
+ mask=attention_mask,
1453
+ )
1454
+ if not dist._position_level_masking:
1455
+ raise ValueError(
1456
+ "Attention mask is not a position-level mask. "
1457
+ "This is likely because the attention mask is not a position-level mask."
1458
+ )
1459
+ return dist
1460
+
1461
+ def _get_dist_with_custom_mask(
1462
+ self,
1463
+ tensordict: TensorDictBase,
1464
+ mask: torch.Tensor,
1465
+ logits_key: NestedKey = "logits",
1466
+ padding_side: str = "left",
1467
+ **kwargs,
1468
+ ) -> D.Distribution:
1469
+ """Get distribution with custom mask.
1470
+
1471
+ This allows for completely custom masking logic.
1472
+
1473
+ Padding side is left by convention.
1474
+
1475
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
1476
+ """
1477
+ if self.generate:
1478
+ raise NotImplementedError(
1479
+ "get_dist_with_custom_mask is not implemented for generate=True. "
1480
+ "You can create a new version of this wrapper using the `get_new_version` method."
1481
+ )
1482
+ td_out = self.forward(tensordict.copy(), logits_only=True)
1483
+ if self.pad_output:
1484
+ logits = td_out.get(logits_key)
1485
+ else:
1486
+ logits = td_out.get(
1487
+ logits_key,
1488
+ as_padded_tensor=True,
1489
+ padding_value=0.0,
1490
+ padding_side=padding_side,
1491
+ )
1492
+
1493
+ if logits is None:
1494
+ raise ValueError(f"Logits not found in tensordict at key {logits_key}")
1495
+
1496
+ dist = LLMMaskedCategorical(
1497
+ logits=logits,
1498
+ mask=mask,
1499
+ )
1500
+ if not dist._position_level_masking:
1501
+ raise ValueError(
1502
+ "Custom mask is not a position-level mask. "
1503
+ "This is likely because the custom mask is not a position-level mask."
1504
+ )
1505
+ return dist
1506
+
1507
+ # Convenience methods for common LLM training scenarios
1508
+ def _get_sft_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
1509
+ """Get distribution suitable for SFT loss (response tokens only).
1510
+
1511
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
1512
+ """
1513
+ return self._get_dist_with_prompt_mask(tensordict, **kwargs)
1514
+
1515
+ def _get_rlhf_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
1516
+ """Get distribution suitable for RLHF loss (assistant tokens only).
1517
+
1518
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
1519
+ """
1520
+ return self._get_dist_with_assistant_mask(tensordict, **kwargs)
1521
+
1522
+ def _get_generic_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
1523
+ """Get distribution suitable for generic losses (all tokens).
1524
+
1525
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
1526
+ """
1527
+ return self._get_dist_with_attention_mask(tensordict, **kwargs)
1528
+
1529
+ def forward(
1530
+ self,
1531
+ tensordict: TensorDictBase,
1532
+ *,
1533
+ tensordict_out: TensorDictBase | None = None,
1534
+ logits_only: bool = False,
1535
+ **kwargs,
1536
+ ) -> TensorDictBase: # noqa: D417
1537
+ """Forward pass for the LLM policy.
1538
+
1539
+ Args:
1540
+ tensordict (TensorDictBase): The input tensordict.
1541
+
1542
+ Keyword Args:
1543
+ tensordict_out (TensorDictBase | None): The output tensordict.
1544
+ logits_only (bool): Whether to return only the logits. Only effective if generate=False. Defaults to `False`.
1545
+ """
1546
+ raise NotImplementedError
1547
+
1548
+ def _check_padded(self, val: torch.Tensor) -> torch.Tensor:
1549
+ """Check that a value is a padded tensor."""
1550
+ assert isinstance(
1551
+ val, torch.Tensor
1552
+ ), f"val must be torch.Tensor, got {type(val)}"
1553
+ if not isinstance(val, torch.Tensor):
1554
+ raise ValueError("Not a padded tensor")
1555
+ return val
1556
+
1557
+ def _check_not_padded(
1558
+ self, val: list[torch.Tensor] | torch.Tensor
1559
+ ) -> list[torch.Tensor] | torch.Tensor:
1560
+ """Check that a value is not a padded tensor (i.e., a list of tensors)."""
1561
+ if isinstance(val, torch.Tensor):
1562
+ raise ValueError("Expected a list of tensors - not padded, got a tensor")
1563
+ return val
1564
+
1565
+ @property
1566
+ def log_prob_keys(self) -> list[NestedKey]:
1567
+ return getattr(self, "_log_prob_keys", ["log_probs"])
1568
+
1569
+ @log_prob_keys.setter
1570
+ def log_prob_keys(self, value: list[NestedKey]):
1571
+ self._log_prob_keys = value
1572
+
1573
+ @property
1574
+ def dist_params_keys(self) -> list[NestedKey]:
1575
+ raise NotImplementedError
1576
+
1577
+ @property
1578
+ def dist_sample_keys(self) -> list[NestedKey]:
1579
+ return ["tokens_response"]
1580
+
1581
+ def log_prob(self, data: TensorDictBase, **get_kwargs) -> TensorDictBase:
1582
+ if not self.generate:
1583
+ data = self(data)
1584
+ return data.get((self.log_prob_key, "response"), **get_kwargs)
1585
+ raise RuntimeError("log_prob not callable when generate=True.")
1586
+
1587
+ def cleanup_batching(self, *, flush: bool = False) -> None:
1588
+ """Reset the internal batching state.
1589
+
1590
+ Args:
1591
+ flush (bool, default False):
1592
+ • False → cancel / fail every still-pending Future.
1593
+ • True → try to run one last forward pass with whatever is left in
1594
+ `_batch_queue`, so callers receive real results instead of an
1595
+ exception.
1596
+ """
1597
+ # ── 0. Fast-exit if batching was never enabled ──────────────────────────────
1598
+ if not hasattr(self, "_batch_queue"):
1599
+ return
1600
+
1601
+ # ── 1. Enter the same lock used by the decorator to avoid races ────────────
1602
+ lock = getattr(self, "_batching_lock", None) # may be None
1603
+ with (lock or nullcontext()):
1604
+ # ── 2. Resolve outstanding Futures ───────────────────────────────────
1605
+ if flush and self._batch_queue:
1606
+ try:
1607
+ # one last forward pass
1608
+ results = self(
1609
+ lazy_stack(self._batch_queue),
1610
+ _batched_cleanup=True, # avoid going through the decorator
1611
+ ).unbind(0)
1612
+ except Exception as exc:
1613
+ for fut in self._futures:
1614
+ if not fut.done():
1615
+ fut.set_exception(exc)
1616
+ else:
1617
+ # size mismatch ⇒ fall back to exceptions
1618
+ if len(results) != len(self._futures):
1619
+ exc = RuntimeError(
1620
+ f"cleanup_batching(): expected {len(self._futures)} "
1621
+ f"results, got {len(results)}"
1622
+ )
1623
+ for fut in self._futures:
1624
+ if not fut.done():
1625
+ fut.set_exception(exc)
1626
+ else:
1627
+ for fut, res in zip(self._futures, results):
1628
+ if not fut.done():
1629
+ fut.set_result(res)
1630
+ else:
1631
+ # cancel / fail everything so waiting threads can return
1632
+ cancel_exc = CancelledError("Batching aborted by cleanup_batching()")
1633
+ for fut in getattr(self, "_futures", ()):
1634
+ if not fut.done():
1635
+ fut.set_exception(cancel_exc)
1636
+
1637
+ # ── 3. Clear containers (they may hold large tensors) ────────────────
1638
+ self._batch_queue.clear()
1639
+ self._futures.clear()
1640
+
1641
+ def __del__(self):
1642
+ self.cleanup_batching()
1643
+
1644
+ def get_batching_state(self):
1645
+ """Get the current batching state for debugging and monitoring.
1646
+
1647
+ Returns:
1648
+ dict: A dictionary containing the current batching state including
1649
+ queue size, number of pending futures, and batch size.
1650
+ """
1651
+ if not self.batching:
1652
+ return {"batching_enabled": False}
1653
+
1654
+ lock = getattr(self, "_batching_lock", None)
1655
+ if lock is not None:
1656
+ lock_state = "locked" if lock.locked() else "unlocked"
1657
+ else:
1658
+ lock_state = "not initialized"
1659
+ return {
1660
+ "batching_enabled": True,
1661
+ "min_batch_size": getattr(self, "_min_batch_size", None),
1662
+ "max_batch_size": getattr(self, "_max_batch_size", None),
1663
+ "queue_size": len(getattr(self, "_batch_queue", [])),
1664
+ "processing": lock_state == "locked",
1665
+ "lock_state": lock_state,
1666
+ "pending_futures": len(getattr(self, "_futures", [])),
1667
+ "timeout": getattr(self, "_batching_timeout", None),
1668
+ }
1669
+
1670
+
1671
+ def _extract_responses_from_full_histories(
1672
+ text_full: list[str],
1673
+ prompt_histories,
1674
+ chat_template_name: str | None = None,
1675
+ tokenizer=None,
1676
+ ) -> History:
1677
+ """Extract response histories from full text histories.
1678
+
1679
+ This function parses the full text back to history objects and extracts
1680
+ the response portions (everything after the prompt).
1681
+
1682
+ Args:
1683
+ text_full: List of full text strings to parse
1684
+ prompt_histories: The original prompt histories
1685
+ chat_template_name: Optional chat template name for parsing
1686
+ tokenizer: Optional tokenizer for template detection
1687
+
1688
+ Returns:
1689
+ Stacked History object with response portions
1690
+
1691
+ Raises:
1692
+ RuntimeError: If full history is shorter than prompt history
1693
+ RuntimeError: If parsing produces inconsistent batch shapes
1694
+ """
1695
+ import torch
1696
+ from tensordict.utils import _zip_strict
1697
+ from torchrl.data.llm import History
1698
+
1699
+ # Extract response portions by processing each element individually
1700
+ # This avoids the stacking issue when different batch elements produce
1701
+ # different numbers of responses
1702
+ response_histories = []
1703
+ full_histories = History.from_text(
1704
+ text_full, chat_template_name=chat_template_name, tokenizer=tokenizer
1705
+ )
1706
+ for h_prompt, h_full in _zip_strict(
1707
+ prompt_histories.unbind(0), full_histories.unbind(0)
1708
+ ):
1709
+ if h_full.shape[0] <= h_prompt.shape[0]:
1710
+ raise RuntimeError(
1711
+ f"Full history is shorter than prompt history: {h_full.shape} <= {h_prompt.shape}"
1712
+ )
1713
+ # Note: there can be more than one response, so the response has the same number of dims as prompt
1714
+ response_histories.append(h_full[h_prompt.shape[0] :])
1715
+
1716
+ # Check if all responses have the same shape
1717
+ shapes = [r.shape for r in response_histories]
1718
+ if len(set(shapes)) > 1:
1719
+ # Different shapes detected - pad to the same length
1720
+ max_length = max(r.shape[0] for r in response_histories)
1721
+ padded_responses = []
1722
+ for response in response_histories:
1723
+ if response.shape[0] < max_length:
1724
+ # Pad with empty messages using "<none>" role
1725
+ padding_needed = max_length - response.shape[0]
1726
+ padding_history = History(
1727
+ role="<none>", content="", batch_size=(padding_needed,)
1728
+ )
1729
+ padded_response = response.extend(padding_history, inplace=False)
1730
+ padded_responses.append(padded_response)
1731
+ else:
1732
+ padded_responses.append(response)
1733
+ return torch.stack(padded_responses)
1734
+
1735
+ return torch.stack(response_histories)
1736
+
1737
+
1738
+ def _batching(func):
1739
+ @wraps(func)
1740
+ def _batched_func(self, td_input: TensorDictBase, **kwargs):
1741
+ # -- 0. Bypass if batching disabled
1742
+ if not self.batching:
1743
+ return func(self, td_input, **kwargs)
1744
+
1745
+ # -- 1. Normalise --------------------------------------------------------
1746
+ if td_input.batch_dims > 1:
1747
+ raise ValueError(
1748
+ f"Batching not supported for batch_dims > 1: {td_input.batch_dims}"
1749
+ )
1750
+
1751
+ single = td_input.batch_dims == 0
1752
+ inputs = [td_input] if single else list(td_input.unbind(0))
1753
+ futures = [Future() for _ in inputs]
1754
+ pending = set(futures) # ← track our own Futures
1755
+
1756
+ # -- 2. Enqueue ----------------------------------------------------------
1757
+ self._batch_queue.extend(inputs)
1758
+ self._futures.extend(futures)
1759
+
1760
+ min_bs = getattr(self, "_min_batch_size", 1)
1761
+ max_bs = getattr(self, "_max_batch_size", None)
1762
+
1763
+ # -- 3. Drain while holding the lock ------------------------------------
1764
+ with self._batching_lock:
1765
+ if all(f.done() for f in futures):
1766
+ # Our items were already processed by another thread.
1767
+ # Skip draining; other workers will handle the rest of the queue.
1768
+ pass
1769
+ else:
1770
+ while len(self._batch_queue) >= min_bs:
1771
+ slice_size = (
1772
+ len(self._batch_queue)
1773
+ if max_bs is None
1774
+ else min(max_bs, len(self._batch_queue))
1775
+ )
1776
+ batch = self._batch_queue[:slice_size]
1777
+ fut_slice = self._futures[:slice_size]
1778
+
1779
+ try:
1780
+ results = func(self, lazy_stack(batch), **kwargs).unbind(0)
1781
+ if len(results) != slice_size:
1782
+ raise RuntimeError(
1783
+ f"Expected {slice_size} results, got {len(results)}"
1784
+ )
1785
+ for fut, res in zip(fut_slice, results):
1786
+ fut.set_result(res)
1787
+ pending.discard(fut) # ← mark as done
1788
+ except Exception as exc:
1789
+ for fut in fut_slice:
1790
+ fut.set_exception(exc)
1791
+ pending.discard(fut)
1792
+ raise
1793
+
1794
+ # Pop processed work
1795
+ del self._batch_queue[:slice_size]
1796
+ del self._futures[:slice_size]
1797
+
1798
+ # ---- Early-exit: all *our* Futures are done -------------------
1799
+ if not pending:
1800
+ break
1801
+
1802
+ # -- 4. Outside the lock: wait only on remaining (rare) -----------------
1803
+ if pending: # usually empty; safety for min_bs > queue size
1804
+ wait(pending)
1805
+ results = [f.result() for f in futures]
1806
+
1807
+ return results[0] if single else lazy_stack(results)
1808
+
1809
+ return _batched_func