torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,2756 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import contextlib
8
+ import threading
9
+ from contextlib import nullcontext
10
+ from copy import copy
11
+ from typing import Any, Literal
12
+
13
+ import torch
14
+ from tensordict import (
15
+ lazy_stack,
16
+ LazyStackedTensorDict,
17
+ MetaData,
18
+ NonTensorStack,
19
+ set_list_to_stack,
20
+ TensorDict,
21
+ TensorDictBase,
22
+ )
23
+ from tensordict.utils import _zip_strict, NestedKey
24
+ from torch import distributions as D
25
+ from torch.nn.utils.rnn import pad_sequence
26
+ from torchrl import logger as torchrl_logger
27
+ from torchrl.modules.llm.policies.common import (
28
+ _batching,
29
+ _extract_responses_from_full_histories,
30
+ ChatHistory,
31
+ LLMWrapperBase,
32
+ LogProbs,
33
+ Masks,
34
+ Text,
35
+ Tokens,
36
+ )
37
+ from torchrl.modules.utils.utils import _unpad_tensors
38
+
39
+
40
+ class TransformersWrapper(LLMWrapperBase):
41
+ """A wrapper class for Hugging Face Transformers models, providing a consistent interface for text generation and log probability computation.
42
+
43
+ Packing vs Padding:
44
+ - Packing (`pad_model_input=False`):
45
+ * More memory efficient for variable-length sequences.
46
+ * Not all models support packed input (requires custom attention masks and position ids).
47
+ * May be less compatible with some HuggingFace models or custom architectures.
48
+ - Padding (`pad_model_input=True`):
49
+ * Universally supported by all models.
50
+ * Wastes memory for short sequences in a batch.
51
+ * Simpler, but less efficient for highly variable-length data.
52
+ - If unsure, use padding for maximum compatibility. Use packing for large batches of variable-length data and when your model supports it.
53
+
54
+ Additional error handling is provided for empty and overlong sequences.
55
+
56
+ Args:
57
+ model (transformers.AutoModelForCausalLM | str): The Hugging Face Transformers model to wrap.
58
+ If a string, it will be passed to `transformers.AutoModelForCausalLM.from_pretrained` (and `AutoTokenizer.from_pretrained`
59
+ if `tokenizer` is not provided).
60
+
61
+ Keyword Args:
62
+ tokenizer (transformers.tokenization_utils.PreTrainedTokenizer | str | None, optional): The tokenizer to use for
63
+ encoding and decoding text. If `None`, the tokenizer associated with the model will be used.
64
+ If a string, it will be passed to `transformers.AutoTokenizer.from_pretrained`. Defaults to `None`.
65
+ input_mode (str, optional): The input modality to use. Must be one of `"history"`, `"text"`, or `"tokens"`.
66
+ Defaults to `"history"`.
67
+ input_key (str | None, optional): The key for the input data. If `None`, defaults to
68
+ - `("history", "prompt")` for `"history"` when `generate=True`, `("history", "full")` for `"history"` when `generate=False`
69
+ - `("text", "prompt")` for `"text"` when `generate=True`, `("text", "full")` for `"text"` when `generate=False`
70
+ - `("tokens", "prompt")` for `"tokens"` when `generate=True`, `("tokens", "full")` for `"tokens"` when `generate=False`
71
+ attention_mask_key (str, optional): The key for attention masks (used in `"tokens"` mode). Defaults to `"attention_mask"`.
72
+
73
+ .. warning:: This argument is under development and may change in the future.
74
+
75
+ generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on the input.
76
+ If `False`, only log probabilities will be computed. Defaults to `True`.
77
+ return_log_probs (bool, optional): Whether to return log probabilities. Defaults to `False`.
78
+ generate_kwargs (dict | None, optional): Additional arguments to pass to the model's generate method. Defaults to `None`.
79
+
80
+ **Standardized Parameters (cross-backend compatible):**
81
+
82
+ * **max_new_tokens** (int): Maximum number of new tokens to generate
83
+ * **num_return_sequences** (int): Number of sequences to return
84
+ * **temperature** (float): Sampling temperature (0.0 = deterministic, higher = more random)
85
+ * **top_p** (float): Nucleus sampling parameter (0.0-1.0)
86
+ * **top_k** (int): Top-k sampling parameter
87
+ * **repetition_penalty** (float): Penalty for repeating tokens
88
+ * **do_sample** (bool): Whether to use sampling vs greedy decoding
89
+ * **num_beams** (int): Number of beams for beam search
90
+ * **length_penalty** (float): Penalty for sequence length
91
+ * **early_stopping** (bool): Whether to stop early in beam search
92
+ * **stop_sequences** (list): Sequences that stop generation (requires custom stopping criteria)
93
+ * **skip_special_tokens** (bool): Whether to skip special tokens in output
94
+ * **logprobs** (bool): Whether to return log probabilities (maps to output_scores)
95
+
96
+ .. warning:: Usage of this parameter is discouraged as it may conflict with the `generate` parameter
97
+ of the class.
98
+
99
+ **Transformers-Specific Parameters:**
100
+
101
+ * **pad_token_id** (int): Token ID for padding
102
+ * **eos_token_id** (int): Token ID for end of sequence
103
+ * **bad_words_ids** (list): List of token IDs to avoid
104
+ * **force_words_ids** (list): List of token IDs to force
105
+ * **no_repeat_ngram_size** (int): Size of n-grams to avoid repeating
106
+ * **encoder_repetition_penalty** (float): Repetition penalty for encoder-decoder models
107
+ * **num_beam_groups** (int): Number of beam groups for diverse beam search
108
+ * **diversity_penalty** (float): Penalty for beam diversity
109
+ * **output_scores** (bool): Whether to output scores
110
+ * **return_dict_in_generate** (bool): Whether to return dict in generate
111
+
112
+ **Legacy Parameter Support:**
113
+
114
+ * **max_tokens** (int): Automatically converted to max_new_tokens
115
+ * **n** (int): Automatically converted to num_return_sequences
116
+
117
+ **Parameter Conflict Resolution:**
118
+
119
+ When both legacy (Transformers-specific) and standardized parameter names are provided,
120
+ a :exc:`ValueError` is raised to prevent confusion. For example:
121
+
122
+ * If both ``max_tokens`` and ``max_new_tokens`` are passed, an error is raised
123
+ * If both ``n`` and ``num_return_sequences`` are passed, an error is raised
124
+
125
+ This ensures clear parameter usage and prevents unexpected behavior.
126
+
127
+ tokenizer_kwargs (dict | None, optional): Additional arguments to pass to the tokenizer. Defaults to `None`.
128
+ pad_output (bool, optional): Whether to pad the output sequences to a uniform length. This does not impact the underlying padding
129
+ during call to the model. To use padding or packing during the model `forward` call, see `pad_model_input`.
130
+ Defaults to `False`.
131
+ pad_model_input (bool, optional): Whether to pad the model input sequences to a uniform length.
132
+ If `False`, packing will be used instead. Packing is generally more memory efficient than padding,
133
+ but this feature may not work with all models.
134
+ `pad_model_input` can only be used when `generate=False`.
135
+ This does not impact the padding of the model output - one may ask for padded output though `pad_output=True` while the model
136
+ is called with `pad_model_input=False`.
137
+ Defaults to `True`.
138
+ inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place operations. Defaults to `True`.
139
+ device (torch.device | None, optional): The device to use for computation. Defaults to `None`.
140
+ layout (torch.layout | None, optional): The layout to use for the output tensors when `pad_output=False`. Defaults to `torch.strided`.
141
+ num_samples (int | None, optional): The number of samples to generate. Defaults to `None` (one sample, and no batch-dimension for it).
142
+ Can also be set via the `generate_kwargs["num_return_sequences"] = value` argument. Requires the "do_sample" argument to be set to `True` in `generate_kwargs`.
143
+ chat_template_name (Literal["chatml_format", "qwen"] | None, optional): The name of the chat template to use when applying the chat
144
+ template to the history. Defaults to `None`. For `input_mode="history"` only.
145
+ chat_template (str | None, optional): The chat template to use when applying the chat template to the history.
146
+ Defaults to `None`. For `input_mode="history"` only.
147
+ log_probs_key (NestedKey | None, optional): The key for the log probabilities :class:`~torchrl.modules.llm.policies.LogProbs` object. Defaults to `"log_probs"`.
148
+ text_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Text` object. Defaults to `"text"`.
149
+ tokens_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Tokens` object. Defaults to `"tokens"`.
150
+ masks_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Masks` object. Defaults to `"masks"`.
151
+ history_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.ChatHistory` object. Defaults to `"history"`.
152
+ batching (bool | None, optional): Whether to enable batching. See `Batching`_ below for more details.
153
+ min_batch_size (int | None, optional): The minimum batch size to use for batching. See `Batching`_ below for more details.
154
+ max_batch_size (int | None, optional): The maximum batch size to use for batching. See `Batching`_ below for more details.
155
+ batching_timeout (float, optional): The timeout for batching. See `Batching`_ below for more details.
156
+
157
+ .. _Batching:
158
+
159
+ **Batching**
160
+
161
+ Batching is a feature that allows the module to process multiple inputs in a single call.
162
+ It is designed to work in a multi-threaded environment.
163
+ To enable batching, it suffices to set `batching=True` which will set `min_batch_size` to 1 if not provided.
164
+ If you want to set a different value for `min_batch_size` or `max_batch_size` for a fine-grained control,
165
+ you can to set `batching=True` and then set `min_batch_size` or `max_batch_size` to a value greater or equal to 1.
166
+ The way batching works is as follows:
167
+ - If `min_batch_size` is not provided but `max_batch_size` is, `min_batch_size` is set to 1.
168
+ - If `max_batch_size` is not provided but `min_batch_size` is, `max_batch_size` is set to the number of inputs in the queue.
169
+ - When the model is called, a check is performed to see if the number of inputs in the queue is greater or equal to `min_batch_size`.
170
+ If it is, the batch is processed immediately, while waiting for the previous batch to be processed if the model is busy.
171
+ Otherwise, the input is added to the queue and the function waits for the batch to be completed.
172
+ While waiting for the batch to be completed, a timeout is set to `batching_timeout` seconds such that if the batch is not
173
+ completed after `batching_timeout` seconds, the remaining items to process are processed as is and the function returns after
174
+ at most `batching_timeout` seconds (plus the time to finish processing the previous and current batch).
175
+
176
+ Input Keys:
177
+ The input key depends on both `input_mode` and `generate`:
178
+
179
+ - If `input_mode="history"` and `generate=True`: `input_key` (defaults to `("history", "prompt")`)
180
+ - If `input_mode="history"` and `generate=False`: `input_key` (defaults to `("history", "full")`)
181
+ - If `input_mode="text"` and `generate=True`: `input_key` (defaults to `("text", "prompt")`)
182
+ - If `input_mode="text"` and `generate=False`: `input_key` (defaults to `("text", "full")`)
183
+ - If `input_mode="tokens"` and `generate=True`: `input_key` (defaults to `("tokens", "prompt")`)
184
+ - If `input_mode="tokens"` and `generate=False`: `input_key` (defaults to `("tokens", "full")`)
185
+
186
+ Output Keys:
187
+ The output keys are automatically determined based on the input_mode:
188
+ - **Tokens**: Always returned (`tokens_key`, defaults to `"tokens"`)
189
+ - **Text**: Returned for `"text"` and `"history"` modes (`text_key`, defaults to `"text"`)
190
+ - **History**: Returned only for `"history"` mode (`history_key`, defaults to `"history"`)
191
+ - **Masks**: Always returned (`masks_key`, defaults to `"masks"`)
192
+ - **Log Probs**: Returned when `return_log_probs=True` (`log_probs_key`, defaults to `"log_probs"`)
193
+
194
+ Example output structure for `input_mode="history"`::
195
+
196
+ TensorDict(
197
+ text=Text(prompt=..., response=..., full=...),
198
+ masks=Masks(all_attention_mask=..., all_assistant_mask=...),
199
+ tokens=Tokens(prompt=..., response=..., full=...),
200
+ log_probs=LogProbs(prompt=..., response=..., full=...),
201
+ history=ChatHistory(prompt=..., response=..., full=...)
202
+ )
203
+
204
+ Example:
205
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer
206
+ >>> from torchrl.data.llm import History
207
+ >>> from torchrl.modules.llm.policies import ChatHistory
208
+ >>>
209
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
210
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
211
+ >>>
212
+ >>> # History input (recommended for RL environments)
213
+ >>> wrapper = TransformersWrapper(
214
+ ... model,
215
+ ... tokenizer=tokenizer,
216
+ ... input_mode="history",
217
+ ... generate=True,
218
+ ... return_log_probs=True,
219
+ ... generate_kwargs={
220
+ ... "max_new_tokens": 50, # Standardized parameter
221
+ ... "temperature": 0.7,
222
+ ... "top_p": 0.9,
223
+ ... "do_sample": True,
224
+ ... }
225
+ ... )
226
+ >>>
227
+ >>> history = History.from_chats([[
228
+ ... {"role": "user", "content": "Hello"},
229
+ ... {"role": "assistant", "content": "Hi there!"}
230
+ ... ]])
231
+ >>> chat_history = ChatHistory(prompt=history)
232
+ >>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,)))
233
+ >>> print(result["text"].response) # Generated text
234
+ >>> print(result["log_probs"].response) # Log probabilities
235
+ >>> print(result["history"].response) # History with response
236
+
237
+ Attributes:
238
+ collector: The collector associated with the module, if it exists.
239
+
240
+ .. seealso::
241
+ - :class:`~torchrl.modules.llm.policies.LLMWrapperBase`
242
+ - :class:`~torchrl.modules.llm.policies.vLLMWrapper`
243
+ """
244
+
245
+ def __init__(
246
+ self,
247
+ model,
248
+ *,
249
+ tokenizer=None,
250
+ input_mode: str = "history",
251
+ input_key: str | None = None,
252
+ attention_mask_key: str = "attention_mask",
253
+ generate: bool = True,
254
+ generate_kwargs: dict | None = None,
255
+ tokenizer_kwargs: dict | None = None,
256
+ pad_output: bool = False,
257
+ pad_model_input: bool | None = None,
258
+ inplace: Literal[True, False, "empty"] | None = None,
259
+ device: torch.device | None = None,
260
+ layout: torch.layout | None = None,
261
+ num_samples: int | None = None,
262
+ chat_template_name: Literal["chatml_format", "qwen"] | None = None,
263
+ chat_template: str | None = None,
264
+ return_log_probs: bool | None = None,
265
+ history_key: NestedKey | None = "history",
266
+ text_key: NestedKey | None = "text",
267
+ tokens_key: NestedKey | None = "tokens",
268
+ masks_key: NestedKey | None = "masks",
269
+ log_probs_key: NestedKey | None = "log_probs",
270
+ batching: bool | None = None,
271
+ min_batch_size: int | None = None,
272
+ max_batch_size: int | None = None,
273
+ batching_timeout: float = 10.0,
274
+ ):
275
+ super().__init__()
276
+
277
+ if batching and min_batch_size is None:
278
+ min_batch_size = 1
279
+ elif (min_batch_size is not None or max_batch_size is not None) and (
280
+ batching is False
281
+ ):
282
+ raise ValueError(
283
+ "min_batch_size and max_batch_size must be None if batching is False."
284
+ )
285
+
286
+ # Validate that min_batch_size <= max_batch_size when both are specified
287
+ if min_batch_size is not None and max_batch_size is not None:
288
+ if min_batch_size > max_batch_size:
289
+ raise ValueError(
290
+ f"min_batch_size ({min_batch_size}) must be <= max_batch_size ({max_batch_size})"
291
+ )
292
+
293
+ self._min_batch_size = min_batch_size
294
+ self._max_batch_size = max_batch_size
295
+ self._batching_timeout = batching_timeout
296
+ self._batch_queue = []
297
+ self._futures = []
298
+ if self.batching:
299
+ self._batching_lock = threading.Lock()
300
+ else:
301
+ self._batching_lock = None
302
+
303
+ if isinstance(model, str):
304
+ if tokenizer is None:
305
+ from transformers import AutoTokenizer
306
+
307
+ tokenizer = AutoTokenizer.from_pretrained(model)
308
+
309
+ from transformers import AutoModelForCausalLM
310
+
311
+ model = AutoModelForCausalLM.from_pretrained(model)
312
+
313
+ if isinstance(tokenizer, str):
314
+ from transformers import AutoTokenizer
315
+
316
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
317
+
318
+ # Validate input_mode
319
+ if input_mode not in ["history", "text", "tokens"]:
320
+ raise ValueError(
321
+ f"input_mode must be one of 'history', 'text', 'tokens'. Got '{input_mode}'"
322
+ )
323
+
324
+ self.model = model
325
+ self.input_mode = input_mode
326
+ self.attention_mask_key = attention_mask_key
327
+ self.generate = generate
328
+ if pad_model_input is not None and generate:
329
+ raise ValueError("pad_model_input is not supported when generate=True.")
330
+ pad_model_input = pad_model_input if pad_model_input is not None else True
331
+ self.pad_model_input = pad_model_input
332
+
333
+ # Auto-determine what to return based on input mode
334
+ self.return_history = input_mode in ("history",)
335
+ self.return_text = input_mode in ("text", "history")
336
+ self.return_tokens = input_mode in ("tokens", "history", "text")
337
+ self.return_masks = True
338
+ if return_log_probs is False and not generate:
339
+ raise ValueError("return_log_probs must be True when generate=False.")
340
+ return_log_probs = (
341
+ True
342
+ if (return_log_probs is None and generate) or (not generate)
343
+ else bool(return_log_probs)
344
+ )
345
+ self.return_log_probs = return_log_probs
346
+
347
+ self.history_key = history_key
348
+ self.text_key = text_key
349
+ self.tokens_key = tokens_key
350
+ self.masks_key = masks_key
351
+ self.log_probs_key = log_probs_key
352
+ if not isinstance(pad_output, bool):
353
+ raise ValueError("pad_output must be a boolean")
354
+ self.pad_output = pad_output
355
+ self._device = device
356
+ if not pad_output and layout is None:
357
+ layout = torch.strided
358
+ self.layout = layout
359
+ padding_value = None
360
+
361
+ # Auto-determine input_key if not provided
362
+
363
+ # Set input keys based on mode and generate parameter
364
+ if input_mode == "history":
365
+ if generate:
366
+ self.in_keys = [
367
+ ("history", "prompt") if input_key is None else input_key
368
+ ]
369
+ else:
370
+ self.in_keys = [("history", "full") if input_key is None else input_key]
371
+ elif input_mode == "text":
372
+ if generate:
373
+ self.in_keys = [("text", "prompt") if input_key is None else input_key]
374
+ else:
375
+ self.in_keys = [("text", "full") if input_key is None else input_key]
376
+ elif input_mode == "tokens":
377
+ if generate:
378
+ self.in_keys = [
379
+ ("tokens", "prompt") if input_key is None else input_key
380
+ ]
381
+ else:
382
+ self.in_keys = [("tokens", "full") if input_key is None else input_key]
383
+ self.input_key = self.in_keys[0]
384
+
385
+ # Set output keys based on auto-determined return flags
386
+ self.out_keys = []
387
+ if self.return_text:
388
+ self.out_keys.append(self.text_key)
389
+ if self.return_masks:
390
+ self.out_keys.append(self.masks_key)
391
+ if self.return_tokens:
392
+ self.out_keys.append(self.tokens_key)
393
+ if self.return_log_probs:
394
+ self.out_keys.append(self.log_probs_key)
395
+ if self.return_history:
396
+ self.out_keys.append(self.history_key)
397
+
398
+ # Tokenizer setup
399
+ if not tokenizer_kwargs:
400
+ tokenizer_kwargs = {}
401
+ else:
402
+ tokenizer_kwargs = dict(tokenizer_kwargs)
403
+ if not tokenizer_kwargs.setdefault("return_attention_mask", True):
404
+ raise RuntimeError("return_attention_mask must be True")
405
+
406
+ # We always pad, so we always return tensors
407
+ return_tensors = "pt"
408
+ tokenizer_kwargs.setdefault("padding", True)
409
+ if return_tensors:
410
+ if (
411
+ tokenizer_kwargs.setdefault("return_tensors", return_tensors)
412
+ != return_tensors
413
+ ):
414
+ raise RuntimeError
415
+
416
+ # We always pad atm
417
+ if tokenizer_kwargs.setdefault("padding_side", "left") != "left":
418
+ raise RuntimeError
419
+
420
+ self.tokenizer_kwargs = tokenizer_kwargs
421
+
422
+ # Get tokenizer if needed
423
+ if (
424
+ pad_output or (input_mode in ["text", "history"] and not generate)
425
+ ) and tokenizer is None:
426
+ tokenizer = model.get_tokenizer()
427
+ self.tokenizer = tokenizer
428
+
429
+ if self.tokenizer is not None and (
430
+ not hasattr(self.tokenizer, "pad_token") or self.tokenizer.pad_token is None
431
+ ):
432
+ self.tokenizer.pad_token = self.tokenizer.eos_token
433
+ if self.tokenizer is not None:
434
+ padding_value = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0]
435
+ self.padding_value = padding_value
436
+
437
+ # Generate kwargs setup
438
+ if generate_kwargs is None:
439
+ generate_kwargs = {}
440
+ else:
441
+ generate_kwargs = dict(generate_kwargs)
442
+
443
+ # Standardize common parameters
444
+ generate_kwargs = self._standardize_generate_kwargs(generate_kwargs)
445
+
446
+ # Extract wrapper-specific parameters
447
+ transformers_specific_kwargs = self._get_wrapper_specific_kwargs(
448
+ generate_kwargs, "transformers"
449
+ )
450
+
451
+ # Convert common parameters to Transformers format
452
+ transformers_kwargs = {}
453
+ for key, value in generate_kwargs.items():
454
+ if key in self.COMMON_GENERATION_PARAMS:
455
+ # Convert common names to Transformers names
456
+ if key == "stop_sequences":
457
+ # Transformers uses stopping_criteria for stop sequences
458
+ # This requires custom stopping criteria implementation
459
+ # For now, we'll warn and skip this parameter
460
+ import warnings
461
+
462
+ warnings.warn(
463
+ "stop_sequences parameter is not yet fully supported in TransformersWrapper. "
464
+ "Use eos_token_id or implement custom stopping criteria for full support.",
465
+ UserWarning,
466
+ stacklevel=2,
467
+ )
468
+ continue
469
+ elif key == "logprobs":
470
+ transformers_kwargs["output_scores"] = value
471
+ else:
472
+ # Direct mapping for other common parameters
473
+ transformers_kwargs[key] = value
474
+
475
+ # Add Transformers-specific parameters
476
+ transformers_kwargs.update(transformers_specific_kwargs)
477
+
478
+ self.num_samples = num_samples
479
+ if (
480
+ transformers_kwargs.get("num_return_sequences", 1) > 1
481
+ or num_samples is not None
482
+ ):
483
+ if inplace in (True, "empty"):
484
+ raise ValueError(
485
+ "inplace must be False (or None) when generating more than one sample."
486
+ )
487
+ if inplace is None:
488
+ inplace = False
489
+ if (
490
+ transformers_kwargs.get("num_return_sequences", 1) > 1
491
+ and num_samples is not None
492
+ and transformers_kwargs.get("num_return_sequences", 1) != num_samples
493
+ ):
494
+ raise ValueError("num_samples differs from generate_kwargs['n'].")
495
+ elif num_samples is None:
496
+ self.num_samples = transformers_kwargs.get("num_return_sequences", 1)
497
+ transformers_kwargs["num_return_sequences"] = self.num_samples
498
+ elif inplace is None:
499
+ inplace = True
500
+
501
+ self.inplace = inplace
502
+
503
+ if not generate:
504
+ # We want only the log-probs, we generate a single token (that we then discard)
505
+ # and retrieve the prompt log-probs
506
+ transformers_kwargs["max_new_tokens"] = 1
507
+
508
+ transformers_kwargs.setdefault("tokenizer", self.tokenizer)
509
+ transformers_kwargs.setdefault("output_logits", self.return_log_probs)
510
+ transformers_kwargs.setdefault("return_dict_in_generate", True)
511
+
512
+ self.generate_kwargs = transformers_kwargs
513
+
514
+ # Additional transformers-specific settings
515
+ self.chat_template_name = chat_template_name
516
+ self.chat_template = chat_template
517
+
518
+ # Flag to track when we're in a get_dist call
519
+ self._in_get_dist_call = False
520
+
521
+ def get_new_version(self, **kwargs):
522
+ """Returns a new version of the module with altered parameters.
523
+
524
+ For instance, the generate parameter can be altered to enable text generation or log-probabilities computation.
525
+ This is especially useful when one wants to avoid re-initializing the module with a new set of parameters, when the
526
+ same parameters could be used to gather log-probs.
527
+
528
+ Positional arguments are not supported.
529
+
530
+ See the class constructor for more details about the parameters.
531
+ """
532
+ # Build the constructor arguments by using current values for missing parameters
533
+ constructor_kwargs = {}
534
+
535
+ # Model is always required
536
+ constructor_kwargs["model"] = kwargs.get("model", self.model)
537
+
538
+ # Check for each parameter and use current value if not provided
539
+ if "tokenizer" in kwargs:
540
+ constructor_kwargs["tokenizer"] = kwargs["tokenizer"]
541
+ elif hasattr(self, "tokenizer"):
542
+ constructor_kwargs["tokenizer"] = self.tokenizer
543
+
544
+ if "input_mode" in kwargs:
545
+ constructor_kwargs["input_mode"] = kwargs["input_mode"]
546
+ elif hasattr(self, "input_mode"):
547
+ constructor_kwargs["input_mode"] = self.input_mode
548
+
549
+ if "input_key" in kwargs:
550
+ constructor_kwargs["input_key"] = kwargs["input_key"]
551
+ elif hasattr(self, "input_key"):
552
+ constructor_kwargs["input_key"] = self.input_key
553
+
554
+ if "attention_mask_key" in kwargs:
555
+ constructor_kwargs["attention_mask_key"] = kwargs["attention_mask_key"]
556
+ elif hasattr(self, "attention_mask_key"):
557
+ constructor_kwargs["attention_mask_key"] = self.attention_mask_key
558
+
559
+ if "generate" in kwargs:
560
+ constructor_kwargs["generate"] = kwargs["generate"]
561
+ elif hasattr(self, "generate"):
562
+ constructor_kwargs["generate"] = self.generate
563
+
564
+ if "generate_kwargs" in kwargs:
565
+ constructor_kwargs["generate_kwargs"] = kwargs["generate_kwargs"]
566
+ elif hasattr(self, "generate_kwargs"):
567
+ constructor_kwargs["generate_kwargs"] = self.generate_kwargs
568
+
569
+ if "pad_output" in kwargs:
570
+ constructor_kwargs["pad_output"] = kwargs["pad_output"]
571
+ elif hasattr(self, "pad_output"):
572
+ constructor_kwargs["pad_output"] = self.pad_output
573
+
574
+ if "tokenizer_kwargs" in kwargs:
575
+ constructor_kwargs["tokenizer_kwargs"] = kwargs["tokenizer_kwargs"]
576
+ elif hasattr(self, "tokenizer_kwargs"):
577
+ constructor_kwargs["tokenizer_kwargs"] = self.tokenizer_kwargs
578
+ if (
579
+ "pad_output" in kwargs
580
+ and kwargs.get("pad_output")
581
+ != constructor_kwargs["tokenizer_kwargs"]["padding"]
582
+ ):
583
+ constructor_kwargs["tokenizer_kwargs"]["padding"] = kwargs.get(
584
+ "pad_output"
585
+ )
586
+
587
+ if "inplace" in kwargs:
588
+ constructor_kwargs["inplace"] = kwargs["inplace"]
589
+ elif hasattr(self, "inplace"):
590
+ constructor_kwargs["inplace"] = self.inplace
591
+
592
+ if "device" in kwargs:
593
+ constructor_kwargs["device"] = kwargs["device"]
594
+ elif hasattr(self, "_device"):
595
+ constructor_kwargs["device"] = self._device
596
+
597
+ if "layout" in kwargs:
598
+ constructor_kwargs["layout"] = kwargs["layout"]
599
+ elif hasattr(self, "layout"):
600
+ constructor_kwargs["layout"] = self.layout
601
+
602
+ if "num_samples" in kwargs:
603
+ constructor_kwargs["num_samples"] = kwargs["num_samples"]
604
+ elif hasattr(self, "num_samples"):
605
+ constructor_kwargs["num_samples"] = self.num_samples
606
+
607
+ if "chat_template_name" in kwargs:
608
+ constructor_kwargs["chat_template_name"] = kwargs["chat_template_name"]
609
+ elif hasattr(self, "chat_template_name"):
610
+ constructor_kwargs["chat_template_name"] = self.chat_template_name
611
+
612
+ if "chat_template" in kwargs:
613
+ constructor_kwargs["chat_template"] = kwargs["chat_template"]
614
+ elif hasattr(self, "chat_template"):
615
+ constructor_kwargs["chat_template"] = self.chat_template
616
+
617
+ if "text_key" in kwargs:
618
+ constructor_kwargs["text_key"] = kwargs["text_key"]
619
+ elif hasattr(self, "text_key"):
620
+ constructor_kwargs["text_key"] = self.text_key
621
+
622
+ if "tokens_key" in kwargs:
623
+ constructor_kwargs["tokens_key"] = kwargs["tokens_key"]
624
+ elif hasattr(self, "tokens_key"):
625
+ constructor_kwargs["tokens_key"] = self.tokens_key
626
+
627
+ if "masks_key" in kwargs:
628
+ constructor_kwargs["masks_key"] = kwargs["masks_key"]
629
+ elif hasattr(self, "masks_key"):
630
+ constructor_kwargs["masks_key"] = self.masks_key
631
+
632
+ if "log_probs_key" in kwargs:
633
+ constructor_kwargs["log_probs_key"] = kwargs["log_probs_key"]
634
+ elif hasattr(self, "log_probs_key"):
635
+ constructor_kwargs["log_probs_key"] = self.log_probs_key
636
+
637
+ # Create and return new instance
638
+ return type(self)(**constructor_kwargs)
639
+
640
+ @set_list_to_stack(True)
641
+ @_batching
642
+ def forward(
643
+ self,
644
+ tensordict: TensorDictBase,
645
+ *,
646
+ tensordict_out: TensorDictBase | None = None,
647
+ logits_only: bool = False,
648
+ **kwargs,
649
+ ) -> TensorDictBase:
650
+ tensordict_orig = tensordict
651
+ if not tensordict.ndim:
652
+ if tensordict_out is not None:
653
+ raise ValueError(
654
+ "tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
655
+ "please submit an issue on github."
656
+ )
657
+ # unsqueeze - squeeze the input
658
+ return self.forward(lazy_stack([tensordict]), logits_only=logits_only)[0]
659
+ elif tensordict.ndim > 1:
660
+ if tensordict_out is not None:
661
+ raise ValueError(
662
+ "tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
663
+ "please submit an issue on github."
664
+ )
665
+ return self.forward(tensordict.reshape(-1), logits_only=logits_only).view(
666
+ tensordict.shape
667
+ )
668
+
669
+ if not isinstance(tensordict, LazyStackedTensorDict):
670
+ tensordict = tensordict.to_lazystack(0)
671
+
672
+ _source_device = None
673
+ if self._device:
674
+ _source_device = tensordict.device
675
+ if tensordict.device:
676
+ tensordict = tensordict.copy().clear_device_()
677
+
678
+ if kwargs:
679
+ from transformers import GenerationConfig
680
+
681
+ cfg = GenerationConfig(**kwargs)
682
+ else:
683
+ cfg = None
684
+
685
+ if self.num_samples is not None:
686
+ out = (
687
+ TensorDict(
688
+ device=tensordict.device,
689
+ batch_size=(
690
+ tensordict.batch_size[0],
691
+ self.num_samples,
692
+ *tensordict.batch_size[1:],
693
+ ),
694
+ )
695
+ .to_lazystack(1)
696
+ .to_lazystack(0)
697
+ )
698
+ else:
699
+ out = TensorDict(
700
+ device=tensordict.device, batch_size=tensordict.batch_size
701
+ ).to_lazystack(0)
702
+
703
+ if self.input_mode == "history":
704
+ if self.generate:
705
+ out = self._from_transformers_generate_history(tensordict, cfg, out)
706
+ else:
707
+ out = self._from_transformers_logprobs_history(
708
+ tensordict, cfg, out, logits_only=logits_only
709
+ )
710
+ elif self.input_mode == "text":
711
+ if self.generate:
712
+ out = self._from_transformers_generate_text(tensordict, cfg, out)
713
+ else:
714
+ out = self._from_transformers_logprobs_text(
715
+ tensordict, cfg, out, logits_only=logits_only
716
+ )
717
+ elif self.input_mode == "tokens":
718
+ if self.generate:
719
+ out = self._from_transformers_generate_tokens(tensordict, cfg, out)
720
+ else:
721
+ out = self._from_transformers_logprobs_tokens(
722
+ tensordict, cfg, out, logits_only=logits_only
723
+ )
724
+
725
+ if _source_device:
726
+ out = out.to(_source_device)
727
+
728
+ if tensordict_out is None:
729
+ if self.inplace is True:
730
+ # The output is the input
731
+ tensordict_out = tensordict_orig
732
+ elif self.inplace is False:
733
+ # The output is the new structure
734
+ tensordict_out = out
735
+ elif self.inplace == "empty":
736
+ # The output is empty
737
+ tensordict_out = tensordict.empty()
738
+
739
+ if tensordict_out is not None and tensordict_out is not out:
740
+ result = tensordict_out.exclude(*self.out_keys, inplace=True)
741
+ result.update(out, keys_to_update=self.out_keys)
742
+ elif tensordict_out is out:
743
+ result = out.select(*self.out_keys)
744
+ elif self.inplace:
745
+ result = out
746
+ keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
747
+ result = tensordict.exclude(*self.out_keys, inplace=True).update(
748
+ result, keys_to_update=keys
749
+ )
750
+ else:
751
+ result = out
752
+ return result
753
+
754
+ def _from_transformers_generate_history(self, td, cfg, out) -> TensorDictBase:
755
+ """Generate text from history input."""
756
+ from torchrl.data.llm import History
757
+
758
+ # Validate input
759
+ if self.input_key not in td:
760
+ raise ValueError(
761
+ f"Expected '{self.input_key}' key for history input mode, "
762
+ f"but found keys: {list(td.keys())}"
763
+ )
764
+
765
+ history = td.get(self.input_key)
766
+ if not isinstance(history, History):
767
+ raise TypeError(
768
+ f"Expected History object for '{self.input_key}', got {type(history)}"
769
+ )
770
+
771
+ # Apply chat template
772
+ tokenizer_kwargs = {}
773
+ if self.chat_template_name is not None:
774
+ tokenizer_kwargs.setdefault("chat_template_name", self.chat_template_name)
775
+ if self.chat_template is not None:
776
+ tokenizer_kwargs.setdefault("chat_template", self.chat_template)
777
+ tokenizer_kwargs.setdefault("add_generation_prompt", True)
778
+ text_prompt = history.apply_chat_template(
779
+ tokenizer=self.tokenizer, **tokenizer_kwargs
780
+ )
781
+ if not isinstance(text_prompt, list):
782
+ raise ValueError(
783
+ f"Expected list of text for history input, got {type(text_prompt)}"
784
+ )
785
+ tokenizer_kwargs.setdefault("return_assistant_tokens_mask", False)
786
+ tokenizer_kwargs.setdefault("tokenize", True)
787
+ tokenizer_kwargs.setdefault("padding", False)
788
+ tokenizer_kwargs.setdefault("return_dict", True)
789
+ response_struct = history.apply_chat_template(
790
+ tokenizer=self.tokenizer, **tokenizer_kwargs
791
+ )
792
+
793
+ if self._device is not None:
794
+ response_struct = response_struct.to(self._device)
795
+
796
+ tokens_prompt_padded = response_struct.get(
797
+ "input_ids",
798
+ as_padded_tensor=True,
799
+ padding_value=self.padding_value,
800
+ padding_side="left",
801
+ )
802
+ attention_mask_prompt_padded = response_struct.get(
803
+ "attention_mask",
804
+ as_padded_tensor=True,
805
+ padding_value=0,
806
+ padding_side="left",
807
+ )
808
+
809
+ if attention_mask_prompt_padded is None:
810
+ attention_mask_prompt_padded = (
811
+ tokens_prompt_padded != self.tokenizer.pad_token_id
812
+ )
813
+
814
+ result = self._generate_from_tokens(
815
+ tokens_prompt_padded, attention_mask_prompt_padded, cfg, out
816
+ )
817
+
818
+ # Generate using text path
819
+ if self.pad_output:
820
+ result[(self.tokens_key, "prompt")] = (
821
+ tokens_prompt_padded
822
+ if not self.num_samples
823
+ else tokens_prompt_padded.unsqueeze(1).repeat(1, self.num_samples, 1)
824
+ )
825
+ else:
826
+ tokens_prompt_unpadded = response_struct.get(
827
+ "input_ids",
828
+ as_nested_tensor=True,
829
+ )
830
+ if not self.num_samples:
831
+ result[(self.tokens_key, "prompt")] = tokens_prompt_unpadded
832
+ else:
833
+ for r in result.unbind(1):
834
+ r[(self.tokens_key, "prompt")] = tokens_prompt_unpadded
835
+
836
+ text_result = Text._from_tensordict(result.empty())
837
+ result.set(self.text_key, text_result)
838
+ if not self.num_samples:
839
+ text_result.prompt = text_prompt
840
+ else:
841
+ for r in result.unbind(1):
842
+ r[self.text_key, "prompt"] = text_prompt
843
+ with result.view(-1) as result_flat:
844
+ if self.pad_output:
845
+ tokens_full_padded = result_flat.get(
846
+ (self.tokens_key, "full"),
847
+ as_padded_tensor=True,
848
+ padding_side="right",
849
+ padding_value=self.padding_value,
850
+ )
851
+ if tokens_full_padded is None:
852
+ raise ValueError("tokens_full_padded is None")
853
+ text_full = self.tokenizer.batch_decode(
854
+ tokens_full_padded, skip_special_tokens=False
855
+ )
856
+ else:
857
+ tokens_full_unpadded = result_flat.get(
858
+ (self.tokens_key, "full"), as_list=True
859
+ )
860
+ if tokens_full_unpadded is None:
861
+ raise ValueError("tokens_full_unpadded is None")
862
+ text_full = self.tokenizer.batch_decode(
863
+ tokens_full_unpadded, skip_special_tokens=False
864
+ )
865
+ text_prompt = result_flat[self.text_key, "prompt"]
866
+ text_response = [
867
+ txt[len(prompt) :]
868
+ for txt, prompt in _zip_strict(text_full, text_prompt)
869
+ ]
870
+ result_flat.set((self.text_key, "full"), text_full)
871
+ result_flat.set((self.text_key, "response"), text_response)
872
+ # Now parse the full text back to a history object, and use the extra history objects
873
+ # as response
874
+ history_chat = ChatHistory._from_tensordict(result.empty())
875
+ if self.num_samples is None:
876
+ history_chat.prompt = history
877
+ else:
878
+ for h in history_chat.unbind(1):
879
+ h.prompt = history
880
+ with history_chat.view(-1) as history_chat_flat:
881
+ prompt_histories = history_chat_flat.prompt
882
+ # Extract response histories from full text
883
+ h_responses = _extract_responses_from_full_histories(
884
+ text_full, prompt_histories, self.chat_template_name, self.tokenizer
885
+ )
886
+ history_chat_flat.response = h_responses
887
+ # Combine prompt and response to create full history
888
+ history_chat_flat.full = history_chat_flat.prompt.extend(
889
+ h_responses, inplace=False, dim=-1
890
+ )
891
+ result.set(self.history_key, history_chat)
892
+ return result
893
+
894
+ def _from_transformers_logprobs_history(self, td, cfg, out, logits_only=False):
895
+ """Compute log-probs from history input."""
896
+ from torchrl.data.llm import History
897
+
898
+ # Validate input
899
+ if self.input_key not in td:
900
+ raise ValueError(
901
+ f"Expected '{self.input_key}' key for history input mode, "
902
+ f"but found keys: {list(td.keys())}"
903
+ )
904
+
905
+ history = td.get(self.input_key)
906
+ if not isinstance(history, History):
907
+ raise TypeError(
908
+ f"Expected History object for '{self.input_key}', got {type(history)}"
909
+ )
910
+
911
+ # Apply chat template
912
+ tokenizer_kwargs = {}
913
+ if self.chat_template_name is not None:
914
+ tokenizer_kwargs.setdefault("chat_template_name", self.chat_template_name)
915
+ if self.chat_template is not None:
916
+ tokenizer_kwargs.setdefault("chat_template", self.chat_template)
917
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
918
+ text_full = history.apply_chat_template(
919
+ tokenizer=self.tokenizer, **tokenizer_kwargs
920
+ )
921
+
922
+ tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
923
+ tokenizer_kwargs.setdefault("tokenize", True)
924
+ tokenizer_kwargs.setdefault("padding", False)
925
+ tokenizer_kwargs.setdefault("return_dict", True)
926
+
927
+ with torch.device(self._device) if self._device is not None else nullcontext():
928
+ response_tokens = history.apply_chat_template(
929
+ tokenizer=self.tokenizer, **tokenizer_kwargs
930
+ )
931
+ if not isinstance(response_tokens, TensorDictBase):
932
+ raise ValueError(
933
+ f"Expected TensorDictBase for history input, got {type(response_tokens)}"
934
+ )
935
+ result = self._logprobs_from_history_tokens(
936
+ response_tokens, cfg, out, logits_only=logits_only
937
+ )
938
+ text_result = Text._from_tensordict(result.empty())
939
+ result.set(self.text_key, text_result)
940
+ result[self.text_key, "full"] = text_full
941
+ result.set(self.history_key, ChatHistory(full=history))
942
+ return result
943
+
944
+ def _cat_text(self, text, response_text):
945
+ """Concatenate text and response text."""
946
+ if isinstance(text, list):
947
+ return [self._cat_text(t, t_) for t, t_ in _zip_strict(text, response_text)]
948
+ else:
949
+ return text + response_text
950
+
951
+ def _generate_from_text(self, text, cfg, out) -> TensorDictBase:
952
+ """Generate text from text input."""
953
+ pad_val = self.tokenizer.pad_token_id
954
+
955
+ # Convert text to list format
956
+ if isinstance(text, str):
957
+ text = [text]
958
+ elif not isinstance(text, list):
959
+ text = text.tolist()
960
+
961
+ tokenizer_kwargs = dict(self.tokenizer_kwargs)
962
+ tokenizer_kwargs.setdefault("padding", True)
963
+
964
+ with torch.device(
965
+ self._device
966
+ ) if self._device is not None else contextlib.nullcontext():
967
+ tokens_in = self.tokenizer(text, **tokenizer_kwargs)
968
+ if self._device is not None:
969
+ tokens_in = tokens_in.to(self._device)
970
+ # We are going to map this tokens_in to a tensordict to facilitate the padding in case we need it
971
+ tokens_in = dict(tokens_in)
972
+ for k, v in dict(tokens_in).items():
973
+ if isinstance(v, list):
974
+ if isinstance(v[0], torch.Tensor):
975
+ v = torch.nested.nested_tensor(v)
976
+ else:
977
+ v = torch.nested.nested_tensor([torch.tensor(t) for t in v])
978
+ tokens_in[k] = v
979
+ tokens_in = (
980
+ TensorDict(batch_size=tokens_in["input_ids"].size(0))
981
+ .to_lazystack(0)
982
+ .update(tokens_in)
983
+ )
984
+ tokens_prompt_padded = tokens_in.get(
985
+ "input_ids",
986
+ as_padded_tensor=True,
987
+ padding_side="left",
988
+ padding_value=pad_val,
989
+ )
990
+ attention_mask_prompt_padded = tokens_in.get(
991
+ "attention_mask",
992
+ as_padded_tensor=True,
993
+ padding_side="left",
994
+ padding_value=0,
995
+ )
996
+
997
+ if cfg is not None:
998
+ kwargs = copy(self.generate_kwargs)
999
+ kwargs["generation_config"] = cfg
1000
+ else:
1001
+ kwargs = self.generate_kwargs
1002
+
1003
+ tokens_out = self.model.generate(
1004
+ input_ids=tokens_prompt_padded,
1005
+ attention_mask=attention_mask_prompt_padded,
1006
+ **kwargs,
1007
+ )
1008
+ tokens_full_padded = tokens_out["sequences"]
1009
+ tokens_response_padded = tokens_full_padded[
1010
+ ..., tokens_prompt_padded.shape[-1] :
1011
+ ]
1012
+
1013
+ attention_mask_response_padded = tokens_response_padded != pad_val
1014
+ if self.num_samples:
1015
+ attention_mask_full_padded = torch.cat(
1016
+ [
1017
+ attention_mask_prompt_padded.repeat_interleave(
1018
+ self.num_samples, dim=0
1019
+ ),
1020
+ attention_mask_response_padded,
1021
+ ],
1022
+ dim=-1,
1023
+ )
1024
+ else:
1025
+ attention_mask_full_padded = torch.cat(
1026
+ [attention_mask_prompt_padded, attention_mask_response_padded], dim=-1
1027
+ )
1028
+ tokens_response_unpadded = _unpad_tensors(
1029
+ tokens_response_padded, attention_mask_response_padded, as_nested=False
1030
+ )
1031
+
1032
+ if self.return_log_probs:
1033
+ # These are only for the new tokens, not for the prompt - to get that, we'd need to run the forward pass again
1034
+ logits = torch.stack(list(tokens_out["logits"]), 1)
1035
+ log_probs, logits = self._log_probs_generate(
1036
+ tokens_response_padded, logits, pad_val=-100, pad=False
1037
+ )
1038
+
1039
+ response_text = self.tokenizer.batch_decode(
1040
+ tokens_response_unpadded, skip_special_tokens=False
1041
+ )
1042
+
1043
+ # Build output TensorClass objects
1044
+ if self.num_samples is not None:
1045
+ text = [txt for txt in text for _ in range(self.num_samples)]
1046
+ text_obj = Text._from_tensordict(out.empty())
1047
+ with text_obj.view(-1) as text_obj_flat:
1048
+ text_obj_flat.prompt = text
1049
+ text_obj_flat.response = response_text
1050
+ text_obj_flat.full = self._cat_text(text, response_text)
1051
+ out.set(self.text_key, text_obj)
1052
+
1053
+ tokens_obj = Tokens._from_tensordict(out.empty())
1054
+ if self.pad_output:
1055
+ prompt = tokens_prompt_padded
1056
+ else:
1057
+ prompt = _unpad_tensors(
1058
+ tokens_prompt_padded, attention_mask_prompt_padded, as_nested=False
1059
+ )
1060
+ if tokens_obj.ndim == 2:
1061
+ for i in range(self.num_samples):
1062
+ tokens_obj[:, i].prompt = prompt
1063
+ else:
1064
+ tokens_obj.prompt = prompt
1065
+ with tokens_obj.view(-1) as tokens_obj_flat:
1066
+ if not self.pad_output:
1067
+ tokens_obj_flat.response = tokens_response_unpadded
1068
+ tokens_full_unpadded = _unpad_tensors(
1069
+ tokens_full_padded, attention_mask_full_padded, as_nested=False
1070
+ )
1071
+ tokens_obj_flat.full = tokens_full_unpadded
1072
+ else:
1073
+ tokens_obj_flat.response = tokens_response_padded
1074
+ tokens_obj_flat.full = tokens_full_padded
1075
+ tokens_obj.padded = MetaData(self.pad_output)
1076
+ out.set(self.tokens_key, tokens_obj)
1077
+
1078
+ masks_obj = Masks._from_tensordict(out.empty())
1079
+ if out.ndim == 2:
1080
+ attention_mask_full_padded = attention_mask_full_padded.unflatten(
1081
+ 0, (-1, self.num_samples)
1082
+ )
1083
+ if self.pad_output:
1084
+ masks_obj.all_attention_mask = attention_mask_full_padded.bool()
1085
+ else:
1086
+ if out.ndim == 2:
1087
+ with tokens_obj.view(-1) as tokens_obj_flat, masks_obj.view(
1088
+ -1
1089
+ ) as masks_obj_flat:
1090
+ attention_mask_full_unpadded = attention_mask_full_padded.flatten(
1091
+ 0, 1
1092
+ )
1093
+ attention_mask_full_unpadded = _unpad_tensors(
1094
+ attention_mask_full_unpadded.bool(),
1095
+ attention_mask_full_padded.flatten(0, 1),
1096
+ as_nested=False,
1097
+ )
1098
+ masks_obj_flat.all_attention_mask = attention_mask_full_unpadded
1099
+ else:
1100
+ attention_mask_full_unpadded = _unpad_tensors(
1101
+ attention_mask_full_padded.bool(),
1102
+ attention_mask_full_padded,
1103
+ as_nested=False,
1104
+ )
1105
+ masks_obj.all_attention_mask = attention_mask_full_unpadded
1106
+ masks_obj.all_assistant_mask = None
1107
+ masks_obj.padded = MetaData(self.pad_output)
1108
+ out.set(self.masks_key, masks_obj)
1109
+
1110
+ if self.return_log_probs:
1111
+ log_probs_obj = LogProbs._from_tensordict(out.empty())
1112
+ with log_probs_obj.view(-1) as log_probs_obj_flat:
1113
+ # Unfortunate but we only have the log-probs for the new tokens, not for the prompt - to get that, we'd need to run the forward pass again
1114
+ if self.pad_output:
1115
+ log_probs_obj_flat.prompt = None
1116
+ log_probs_obj_flat.response = log_probs
1117
+ log_probs_obj_flat.full = None
1118
+ else:
1119
+ log_probs_unpadded = _unpad_tensors(
1120
+ log_probs, attention_mask_response_padded, as_nested=False
1121
+ )
1122
+ log_probs_obj_flat.prompt = None
1123
+ log_probs_obj_flat.response = log_probs_unpadded
1124
+ log_probs_obj_flat.full = None
1125
+ log_probs_obj.padded = MetaData(self.pad_output)
1126
+ out.set(self.log_probs_key, log_probs_obj)
1127
+
1128
+ # Add logits to output if we're in a get_dist call
1129
+ if self._in_get_dist_call:
1130
+ if self.pad_output:
1131
+ out.set("logits", logits)
1132
+ else:
1133
+ logits_full_unpadded = _unpad_tensors(
1134
+ logits, attention_mask_full_padded, as_nested=False
1135
+ )
1136
+ out.set("logits", logits_full_unpadded)
1137
+
1138
+ return out
1139
+
1140
+ def _cat_tensors(
1141
+ self,
1142
+ tokens: torch.Tensor | list[torch.Tensor],
1143
+ response_tokens: torch.Tensor | list[torch.Tensor],
1144
+ cast: torch.dtype | None = None,
1145
+ ):
1146
+ """Concatenate tokens and response tokens."""
1147
+ if isinstance(tokens, list) or isinstance(response_tokens, list):
1148
+ return [
1149
+ self._cat_tensors(t, t_, cast=cast)
1150
+ for t, t_ in _zip_strict(tokens, response_tokens)
1151
+ ]
1152
+ else:
1153
+ result = torch.cat([tokens, response_tokens], dim=-1)
1154
+ if cast is not None:
1155
+ result = result.to(cast)
1156
+ return result
1157
+
1158
+ def _logprobs_from_history_tokens(
1159
+ self, response_tokens, cfg, out, logits_only=False
1160
+ ):
1161
+ """Compute log-probs from history tokens."""
1162
+ pad_val = self.tokenizer.pad_token_id
1163
+
1164
+ if cfg is not None:
1165
+ kwargs = copy(self.generate_kwargs)
1166
+ kwargs["generation_config"] = cfg
1167
+ else:
1168
+ kwargs = self.generate_kwargs
1169
+
1170
+ # non-packed forward pass
1171
+ if self.pad_model_input:
1172
+ # unfortunately HF wants us to use padded tensors
1173
+ tokens_full_padded = response_tokens.get(
1174
+ "input_ids",
1175
+ as_padded_tensor=True,
1176
+ padding_side="left",
1177
+ padding_value=pad_val,
1178
+ )
1179
+ if not isinstance(tokens_full_padded, torch.Tensor):
1180
+ raise ValueError(
1181
+ f"Expected Tensor for tokens_full_padded, got {type(tokens_full_padded)}"
1182
+ )
1183
+ attention_mask_full_padded = response_tokens.get(
1184
+ "attention_mask",
1185
+ as_padded_tensor=True,
1186
+ padding_side="left",
1187
+ padding_value=0,
1188
+ )
1189
+ if not isinstance(attention_mask_full_padded, torch.Tensor):
1190
+ raise ValueError(
1191
+ f"Expected Tensor for attention_mask_full_padded, got {type(attention_mask_full_padded)}"
1192
+ )
1193
+
1194
+ (
1195
+ log_probs_full_padded,
1196
+ logits_full_padded,
1197
+ ) = self._model_forward_with_padded_sequences(
1198
+ tokens_full_padded,
1199
+ attention_mask_full_padded,
1200
+ pad_val=pad_val,
1201
+ logits_only=logits_only,
1202
+ **kwargs,
1203
+ )
1204
+ else:
1205
+ # unfortunately HF wants us to use padded tensors
1206
+ tokens_full_unpadded = response_tokens.get(
1207
+ "input_ids",
1208
+ as_nested_tensor=True,
1209
+ layout=torch.jagged,
1210
+ )
1211
+ attention_mask_full_unpadded = response_tokens.get(
1212
+ "attention_mask",
1213
+ as_nested_tensor=True,
1214
+ layout=torch.jagged,
1215
+ )
1216
+ (
1217
+ log_probs_full_unpadded,
1218
+ logits_full_unpadded,
1219
+ ) = self._model_forward_with_packed_sequences(
1220
+ # TODO: no padding if we don't need to
1221
+ tokens_full_unpadded,
1222
+ attention_mask_full_unpadded,
1223
+ pad=False,
1224
+ logits_only=logits_only,
1225
+ **kwargs,
1226
+ )
1227
+ tokens_full_padded = pad_sequence(
1228
+ tokens_full_unpadded.unbind(0),
1229
+ batch_first=True,
1230
+ padding_value=pad_val,
1231
+ padding_side="left",
1232
+ )
1233
+ attention_mask_full_padded = pad_sequence(
1234
+ attention_mask_full_unpadded.unbind(0),
1235
+ batch_first=True,
1236
+ padding_value=0,
1237
+ padding_side="left",
1238
+ )
1239
+ if log_probs_full_unpadded is not None:
1240
+ log_probs_full_padded = pad_sequence(
1241
+ log_probs_full_unpadded.unbind(0),
1242
+ batch_first=True,
1243
+ padding_value=0.0,
1244
+ padding_side="left",
1245
+ )
1246
+ else:
1247
+ log_probs_full_padded = None
1248
+ logits_full_padded = pad_sequence(
1249
+ logits_full_unpadded.unbind(0),
1250
+ batch_first=True,
1251
+ padding_value=0.0,
1252
+ padding_side="left",
1253
+ )
1254
+ # Build output TensorClass objects
1255
+ text_obj = Text._from_tensordict(
1256
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1257
+ )
1258
+ text_obj.prompt = None
1259
+ text_obj.response = None
1260
+ text_obj.full = None
1261
+ out.set(self.text_key, text_obj)
1262
+
1263
+ all_assistant_mask_padded = response_tokens.get(
1264
+ "assistant_masks",
1265
+ as_padded_tensor=True,
1266
+ padding_side="left",
1267
+ padding_value=0,
1268
+ )
1269
+ if all_assistant_mask_padded is not None:
1270
+ all_assistant_mask_padded = all_assistant_mask_padded.bool()
1271
+ masks_obj = Masks._from_tensordict(
1272
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1273
+ )
1274
+ if self.pad_output:
1275
+ masks_obj.all_attention_mask = attention_mask_full_padded.bool()
1276
+ if all_assistant_mask_padded is not None:
1277
+ masks_obj.all_assistant_mask = all_assistant_mask_padded
1278
+ else:
1279
+ masks_obj.all_attention_mask = _unpad_tensors(
1280
+ attention_mask_full_padded.bool(),
1281
+ attention_mask_full_padded,
1282
+ as_nested=False,
1283
+ )
1284
+ if all_assistant_mask_padded is not None:
1285
+ masks_obj.all_assistant_mask = _unpad_tensors(
1286
+ all_assistant_mask_padded,
1287
+ attention_mask_full_padded,
1288
+ as_nested=False,
1289
+ )
1290
+ masks_obj.padded = MetaData(self.pad_output)
1291
+ out.set(self.masks_key, masks_obj)
1292
+
1293
+ tokens_obj = Tokens._from_tensordict(
1294
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1295
+ )
1296
+ if self.pad_output:
1297
+ tokens_obj.full = tokens_full_padded
1298
+ else:
1299
+ input_ids_full_unpadded = _unpad_tensors(
1300
+ tokens_full_padded, attention_mask_full_padded, as_nested=False
1301
+ )
1302
+ tokens_obj.full = input_ids_full_unpadded
1303
+ tokens_obj.response = None
1304
+ tokens_obj.padded = MetaData(self.pad_output)
1305
+ out.set(self.tokens_key, tokens_obj)
1306
+
1307
+ if not logits_only:
1308
+ log_probs_obj = LogProbs._from_tensordict(
1309
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1310
+ )
1311
+ if self.pad_output:
1312
+ log_probs_obj.full = log_probs_full_padded
1313
+ else:
1314
+ log_probs_full_unpadded = _unpad_tensors(
1315
+ log_probs_full_padded, attention_mask_full_padded, as_nested=False
1316
+ )
1317
+ log_probs_obj.full = log_probs_full_unpadded
1318
+ log_probs_obj.response = None
1319
+ log_probs_obj.padded = MetaData(self.pad_output)
1320
+ out.set(self.log_probs_key, log_probs_obj)
1321
+
1322
+ # Add logits to output if we're in a get_dist call
1323
+ if self._in_get_dist_call:
1324
+ if self.pad_output:
1325
+ out.set("logits", logits_full_padded)
1326
+ else:
1327
+ logits_full_unpadded = _unpad_tensors(
1328
+ logits_full_padded, attention_mask_full_padded, as_nested=False
1329
+ )
1330
+ out.set("logits", logits_full_unpadded)
1331
+
1332
+ return out
1333
+
1334
+ def _from_transformers_generate_text(self, td, cfg, out) -> TensorDictBase:
1335
+ """Generate text from text input."""
1336
+ # Validate input
1337
+ if self.input_key not in td:
1338
+ raise ValueError(
1339
+ f"Expected '{self.input_key}' key for text input mode, "
1340
+ f"but found keys: {list(td.keys())}"
1341
+ )
1342
+
1343
+ text = td.get(self.input_key)
1344
+ if text is None:
1345
+ raise ValueError(f"Expected '{self.input_key}' key for text input mode")
1346
+ if isinstance(text, NonTensorStack):
1347
+ text = text.tolist()
1348
+ if not isinstance(text, list):
1349
+ raise ValueError(f"Expected list of text for text input, got {type(text)}")
1350
+ return self._generate_from_text(text, cfg, out)
1351
+
1352
+ def _from_transformers_logprobs_text(self, td, cfg, out, logits_only=False):
1353
+ """Compute log-probs from text input."""
1354
+ # Validate input
1355
+ if self.input_key not in td:
1356
+ raise ValueError(
1357
+ f"Expected '{self.input_key}' key for text input mode, "
1358
+ f"but found keys: {list(td.keys())}"
1359
+ )
1360
+
1361
+ text = td.get(self.input_key)
1362
+ if isinstance(text, NonTensorStack):
1363
+ text = text.tolist()
1364
+ if text is None:
1365
+ raise ValueError(f"Expected '{self.input_key}' key for text input mode")
1366
+ if not isinstance(text, list):
1367
+ raise ValueError(f"Expected list of text for text input, got {type(text)}")
1368
+ # Tokenize the text
1369
+ if self.tokenizer is None:
1370
+ raise ValueError(
1371
+ "Tokenizer is required for log-probs computation with text input"
1372
+ )
1373
+
1374
+ # Convert text to list format
1375
+ if isinstance(text, str):
1376
+ text = [text]
1377
+ elif not isinstance(text, list):
1378
+ text = text.tolist()
1379
+
1380
+ # Tokenize the text
1381
+ tokenizer_kwargs = dict(self.tokenizer_kwargs)
1382
+ with torch.device(
1383
+ self._device
1384
+ ) if self._device is not None else contextlib.nullcontext():
1385
+ tokens_in = self.tokenizer(text, **tokenizer_kwargs)
1386
+
1387
+ if cfg is not None:
1388
+ kwargs = copy(self.generate_kwargs)
1389
+ kwargs["generation_config"] = cfg
1390
+ else:
1391
+ kwargs = self.generate_kwargs
1392
+
1393
+ # We are going to map this tokens_in to a tensordict to facilitate the padding in case we need it
1394
+ tokens_in = (
1395
+ TensorDict(batch_size=len(tokens_in["input_ids"]))
1396
+ .to_lazystack(0)
1397
+ .update(dict(tokens_in))
1398
+ )
1399
+ pad_val = self.padding_value
1400
+
1401
+ if self.pad_model_input:
1402
+ tokens_full_padded = tokens_in.get(
1403
+ "input_ids",
1404
+ as_padded_tensor=True,
1405
+ padding_side="left",
1406
+ padding_value=pad_val,
1407
+ )
1408
+ attention_mask_full_padded = tokens_in.get(
1409
+ "attention_mask",
1410
+ as_padded_tensor=True,
1411
+ padding_side="left",
1412
+ padding_value=0,
1413
+ )
1414
+
1415
+ (
1416
+ log_probs_full_padded,
1417
+ logits_full_padded,
1418
+ ) = self._model_forward_with_padded_sequences(
1419
+ tokens_full_padded,
1420
+ attention_mask_full_padded,
1421
+ pad_val=pad_val,
1422
+ logits_only=logits_only,
1423
+ **kwargs,
1424
+ )
1425
+ else:
1426
+ # packed forward pass
1427
+ tokens_full_unpadded = tokens_in.get(
1428
+ "input_ids",
1429
+ as_nested_tensor=True,
1430
+ layout=torch.jagged,
1431
+ )
1432
+ attention_mask_full_unpadded = tokens_in.get(
1433
+ "attention_mask",
1434
+ as_nested_tensor=True,
1435
+ layout=torch.jagged,
1436
+ )
1437
+ (
1438
+ log_probs_full_unpadded,
1439
+ logits_full_unpadded,
1440
+ ) = self._model_forward_with_packed_sequences(
1441
+ tokens_full_unpadded, attention_mask_full_unpadded, pad=False, **kwargs
1442
+ )
1443
+ tokens_full_padded = pad_sequence(
1444
+ tokens_full_unpadded.unbind(0),
1445
+ batch_first=True,
1446
+ padding_value=pad_val,
1447
+ padding_side="left",
1448
+ )
1449
+ attention_mask_full_padded = pad_sequence(
1450
+ attention_mask_full_unpadded.unbind(0),
1451
+ batch_first=True,
1452
+ padding_value=0,
1453
+ padding_side="left",
1454
+ )
1455
+ log_probs_full_padded = pad_sequence(
1456
+ log_probs_full_unpadded.unbind(0),
1457
+ batch_first=True,
1458
+ padding_value=0.0,
1459
+ padding_side="left",
1460
+ )
1461
+ logits_full_padded = pad_sequence(
1462
+ logits_full_unpadded.unbind(0),
1463
+ batch_first=True,
1464
+ padding_value=0.0,
1465
+ padding_side="left",
1466
+ )
1467
+
1468
+ # Build output TensorClass objects
1469
+ text_obj = Text._from_tensordict(
1470
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1471
+ )
1472
+ text_obj.prompt = None
1473
+ text_obj.response = None
1474
+ text_obj.full = text
1475
+ out.set(self.text_key, text_obj)
1476
+
1477
+ tokens_obj = Tokens._from_tensordict(
1478
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1479
+ )
1480
+ if self.pad_output:
1481
+ tokens_obj.full = tokens_full_padded
1482
+ else:
1483
+ input_ids_full_unpadded = _unpad_tensors(
1484
+ tokens_full_padded, attention_mask_full_padded, as_nested=False
1485
+ )
1486
+ tokens_obj.full = input_ids_full_unpadded
1487
+ tokens_obj.response = None
1488
+ tokens_obj.padded = MetaData(self.pad_output)
1489
+ out.set(self.tokens_key, tokens_obj)
1490
+
1491
+ masks_obj = Masks._from_tensordict(
1492
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1493
+ )
1494
+ if self.pad_output:
1495
+ masks_obj.all_attention_mask = attention_mask_full_padded.bool()
1496
+ masks_obj.all_assistant_mask = td.get(("masks", "all_assistant_mask"))
1497
+ else:
1498
+ attention_mask_full_unpadded = _unpad_tensors(
1499
+ attention_mask_full_padded.bool(),
1500
+ attention_mask_full_padded,
1501
+ as_nested=False,
1502
+ )
1503
+ masks_obj.all_attention_mask = attention_mask_full_unpadded
1504
+ masks_obj.all_assistant_mask = td.get(
1505
+ ("masks", "all_assistant_mask"), as_list=True
1506
+ )
1507
+ masks_obj.padded = MetaData(self.pad_output)
1508
+ out.set(self.masks_key, masks_obj)
1509
+
1510
+ if not logits_only:
1511
+ log_probs_obj = LogProbs._from_tensordict(
1512
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1513
+ )
1514
+ if self.pad_output:
1515
+ log_probs_obj.full = log_probs_full_padded
1516
+ else:
1517
+ log_probs_full_unpadded = _unpad_tensors(
1518
+ log_probs_full_padded, attention_mask_full_padded, as_nested=False
1519
+ )
1520
+ log_probs_obj.full = log_probs_full_unpadded
1521
+ log_probs_obj.response = None
1522
+ log_probs_obj.padded = MetaData(self.pad_output)
1523
+ out.set(self.log_probs_key, log_probs_obj)
1524
+
1525
+ # Add logits to output if we're in a get_dist call
1526
+ if self._in_get_dist_call:
1527
+ if self.pad_output:
1528
+ out.set("logits", logits_full_padded)
1529
+ else:
1530
+ logits_full_unpadded = _unpad_tensors(
1531
+ logits_full_padded, attention_mask_full_padded, as_nested=False
1532
+ )
1533
+ out.set("logits", logits_full_unpadded)
1534
+
1535
+ return out
1536
+
1537
+ def _from_transformers_generate_tokens(
1538
+ self, td: TensorDictBase, cfg: dict | None, out: TensorDictBase
1539
+ ) -> TensorDictBase:
1540
+ """Generate text from tokens input."""
1541
+ # Validate input
1542
+ if self.input_key not in td:
1543
+ raise ValueError(
1544
+ f"Expected '{self.input_key}' key for tokens input mode, "
1545
+ f"but found keys: {list(td.keys())}"
1546
+ )
1547
+
1548
+ pad_val = self.tokenizer.pad_token_id
1549
+
1550
+ input_ids_prompt_padded = td.get(
1551
+ self.input_key,
1552
+ as_padded_tensor=True,
1553
+ padding_side="left",
1554
+ padding_value=pad_val,
1555
+ )
1556
+ attention_mask_prompt_padded = td.get(
1557
+ ("masks", "all_attention_mask"),
1558
+ as_padded_tensor=True,
1559
+ padding_side="left",
1560
+ padding_value=False,
1561
+ )
1562
+ if attention_mask_prompt_padded is None:
1563
+ attention_mask_prompt_padded = td.get(
1564
+ self.attention_mask_key,
1565
+ as_padded_tensor=True,
1566
+ padding_side="left",
1567
+ padding_value=False,
1568
+ )
1569
+ if attention_mask_prompt_padded is None:
1570
+ attention_mask_prompt_padded = input_ids_prompt_padded != pad_val
1571
+ return self._generate_from_tokens(
1572
+ input_ids_prompt_padded, attention_mask_prompt_padded, cfg, out
1573
+ )
1574
+
1575
+ def _generate_from_tokens(
1576
+ self,
1577
+ tokens_prompt_padded: torch.Tensor,
1578
+ attention_mask_prompt_padded: torch.Tensor,
1579
+ cfg: dict | None,
1580
+ out: TensorDictBase,
1581
+ ) -> TensorDictBase:
1582
+ if cfg is not None:
1583
+ kwargs = copy(self.generate_kwargs)
1584
+ kwargs["generation_config"] = cfg
1585
+ else:
1586
+ kwargs = self.generate_kwargs
1587
+
1588
+ tokens_out_struct = self.model.generate(
1589
+ input_ids=tokens_prompt_padded,
1590
+ attention_mask=attention_mask_prompt_padded,
1591
+ **kwargs,
1592
+ )
1593
+ tokens_full_padded = tokens_out_struct["sequences"]
1594
+ tokens_response_padded = tokens_full_padded[:, tokens_prompt_padded.shape[-1] :]
1595
+ pad_val = getattr(self.tokenizer, "pad_token_id", None)
1596
+ if pad_val is None:
1597
+ pad_val = self.padding_value
1598
+ attention_mask_reponse_padded = tokens_response_padded != pad_val
1599
+ attention_mask_full_padded = tokens_full_padded != pad_val
1600
+ tokens_response_unpadded = _unpad_tensors(
1601
+ tokens_response_padded, attention_mask_reponse_padded, as_nested=False
1602
+ )
1603
+
1604
+ if self.return_log_probs:
1605
+ # These are only for the new tokens, not for the prompt - to get that, we'd need to run the forward pass again
1606
+ logits_response_padded = tokens_out_struct["logits"]
1607
+ logits_response_padded = torch.stack(list(logits_response_padded), 1)
1608
+ (
1609
+ log_probs_response_padded,
1610
+ logits_response_padded,
1611
+ ) = self._log_probs_generate(
1612
+ tokens_response_padded,
1613
+ logits_response_padded,
1614
+ pad_val=pad_val,
1615
+ pad=False,
1616
+ )
1617
+
1618
+ response_text = self.tokenizer.batch_decode(
1619
+ tokens_response_unpadded, skip_special_tokens=False
1620
+ )
1621
+
1622
+ # Build output TensorClass objects
1623
+ text_obj = Text._from_tensordict(out.empty())
1624
+ text_obj.prompt = None # We don't have text in tokens mode
1625
+ with text_obj.view(-1) as text_obj_flat:
1626
+ text_obj_flat.response = response_text
1627
+ text_obj.full = None # we don't have text in tokens mode so no all_text either
1628
+ out.set(self.text_key, text_obj)
1629
+
1630
+ tokens_obj = Tokens._from_tensordict(out.empty())
1631
+ if not self.pad_output:
1632
+ input_ids_prompt_unpadded = _unpad_tensors(
1633
+ tokens_prompt_padded,
1634
+ attention_mask_prompt_padded,
1635
+ as_nested=False,
1636
+ )
1637
+ if self.num_samples is not None:
1638
+ # replicate tokens
1639
+ for i in range(self.num_samples):
1640
+ tokens_obj[:, i].prompt = (
1641
+ input_ids_prompt_unpadded
1642
+ if not self.pad_output
1643
+ else tokens_prompt_padded
1644
+ )
1645
+ else:
1646
+ tokens_obj.prompt = (
1647
+ input_ids_prompt_unpadded
1648
+ if not self.pad_output
1649
+ else tokens_prompt_padded
1650
+ )
1651
+ with tokens_obj.view(-1) as tokens_obj_flat:
1652
+ if self.pad_output:
1653
+ tokens_obj_flat.response = tokens_response_padded
1654
+ tokens_obj_flat.full = tokens_full_padded
1655
+ else:
1656
+ tokens_obj_flat.response = tokens_response_unpadded
1657
+ tokens_full_unpadded = _unpad_tensors(
1658
+ tokens_full_padded, attention_mask_full_padded, as_nested=False
1659
+ )
1660
+ tokens_obj_flat.full = tokens_full_unpadded
1661
+ tokens_obj.padded = MetaData(self.pad_output)
1662
+ out.set(self.tokens_key, tokens_obj)
1663
+
1664
+ masks_obj = Masks._from_tensordict(out.empty())
1665
+ if out.ndim == 2:
1666
+ attention_mask_full_padded = attention_mask_full_padded.unflatten(
1667
+ 0, (-1, self.num_samples)
1668
+ )
1669
+ if self.pad_output:
1670
+ # Get "real" attention masks
1671
+ masks_obj.all_attention_mask = attention_mask_full_padded.bool()
1672
+ else:
1673
+ # Get "real" attention masks
1674
+ # We can use select to avoid batch-size problems
1675
+ _td = torch.ones_like(
1676
+ out.select(("tokens", "full"))
1677
+ .copy()
1678
+ .rename_key_(("tokens", "full"), "all_attention_mask")
1679
+ ).bool()
1680
+ del _td["tokens"]
1681
+ masks_obj.update(_td)
1682
+ masks_obj.all_assistant_mask = None
1683
+ masks_obj.padded = MetaData(self.pad_output)
1684
+ out.set(self.masks_key, masks_obj)
1685
+
1686
+ if self.return_log_probs:
1687
+ log_probs_obj = LogProbs._from_tensordict(out.empty())
1688
+ if self.num_samples is None:
1689
+ if self.pad_output:
1690
+ log_probs_obj.response = log_probs_response_padded
1691
+ else:
1692
+ log_probs_response_unpadded = _unpad_tensors(
1693
+ log_probs_response_padded,
1694
+ attention_mask_reponse_padded,
1695
+ as_nested=False,
1696
+ )
1697
+ log_probs_obj.response = log_probs_response_unpadded
1698
+ else:
1699
+ with log_probs_obj.view(-1) as log_probs_obj_flat:
1700
+ if self.pad_output:
1701
+ log_probs_obj_flat.response = log_probs_response_padded
1702
+ else:
1703
+ log_probs_response_unpadded = _unpad_tensors(
1704
+ log_probs_response_padded,
1705
+ attention_mask_reponse_padded,
1706
+ as_nested=False,
1707
+ )
1708
+ log_probs_obj_flat.response = log_probs_response_unpadded
1709
+ log_probs_obj.padded = MetaData(self.pad_output)
1710
+ out.set(self.log_probs_key, log_probs_obj)
1711
+
1712
+ return out
1713
+
1714
+ def _from_transformers_logprobs_tokens(
1715
+ self,
1716
+ td: TensorDictBase,
1717
+ cfg: dict | None,
1718
+ out: TensorDictBase,
1719
+ logits_only=False,
1720
+ ) -> TensorDictBase:
1721
+ """Compute log-probs from tokens input."""
1722
+ # Validate input
1723
+ if self.input_key not in td:
1724
+ raise ValueError(
1725
+ f"Expected '{self.input_key}' key for tokens input mode, "
1726
+ f"but found keys: {list(td.keys(isinstance(self.input_key, tuple)))}"
1727
+ )
1728
+
1729
+ pad_val = self.tokenizer.pad_token_id
1730
+
1731
+ if cfg is not None:
1732
+ kwargs = copy(self.generate_kwargs)
1733
+ kwargs["generation_config"] = cfg
1734
+ else:
1735
+ kwargs = self.generate_kwargs
1736
+
1737
+ if self.pad_model_input:
1738
+ tokens_full_padded = td.get(
1739
+ self.input_key,
1740
+ as_padded_tensor=True,
1741
+ padding_side="left",
1742
+ padding_value=pad_val,
1743
+ )
1744
+ # Attention mask: try first the regular entry, then the key provided in the constructor, finally fallback on eager attention mask
1745
+ attention_mask_full_padded = td.get(
1746
+ ("masks", "all_attention_mask"),
1747
+ as_padded_tensor=True,
1748
+ padding_side="left",
1749
+ padding_value=False,
1750
+ )
1751
+ if attention_mask_full_padded is None:
1752
+ attention_mask_full_padded = td.get(
1753
+ self.attention_mask_key,
1754
+ as_padded_tensor=True,
1755
+ padding_side="left",
1756
+ padding_value=False,
1757
+ )
1758
+ if attention_mask_full_padded is None:
1759
+ attention_mask_full_padded = tokens_full_padded != pad_val
1760
+
1761
+ (
1762
+ log_probs_full_padded,
1763
+ logits_full_padded,
1764
+ ) = self._model_forward_with_padded_sequences(
1765
+ tokens_full_padded,
1766
+ attention_mask_full_padded,
1767
+ pad_val=pad_val,
1768
+ logits_only=logits_only,
1769
+ **kwargs,
1770
+ )
1771
+ else:
1772
+ # packed forward pass
1773
+ # unfortunately HF wants us to use padded tensors
1774
+ tokens_full_unpadded = td.get(
1775
+ self.input_key,
1776
+ as_nested_tensor=True,
1777
+ layout=torch.jagged,
1778
+ )
1779
+ if tokens_full_unpadded is None:
1780
+ raise ValueError(
1781
+ f"Expected '{self.input_key}' key for tokens input mode, but found keys: {list(td.keys())}"
1782
+ )
1783
+ # Attention mask: try first the regular entry, then the key provided in the constructor, finally fallback on eager attention mask
1784
+ attention_mask_full_unpadded = td.get(
1785
+ ("masks", "all_attention_mask"),
1786
+ as_nested_tensor=True,
1787
+ layout=torch.jagged,
1788
+ )
1789
+ if attention_mask_full_unpadded is None:
1790
+ attention_mask_full_unpadded = td.get(
1791
+ self.attention_mask_key,
1792
+ as_nested_tensor=True,
1793
+ layout=torch.jagged,
1794
+ )
1795
+ if attention_mask_full_unpadded is None:
1796
+ # does this even work?
1797
+ attention_mask_full_unpadded = tokens_full_unpadded != pad_val
1798
+
1799
+ (
1800
+ log_probs_full_unpadded,
1801
+ logits_full_unpadded,
1802
+ ) = self._model_forward_with_packed_sequences(
1803
+ # TODO: no padding if we don't need to
1804
+ tokens_full_unpadded,
1805
+ attention_mask_full_unpadded,
1806
+ pad=False,
1807
+ logits_only=logits_only,
1808
+ **kwargs,
1809
+ )
1810
+ tokens_full_padded = pad_sequence(
1811
+ tokens_full_unpadded.unbind(0),
1812
+ batch_first=True,
1813
+ padding_value=pad_val,
1814
+ padding_side="left",
1815
+ )
1816
+ attention_mask_full_padded = pad_sequence(
1817
+ attention_mask_full_unpadded.unbind(0),
1818
+ batch_first=True,
1819
+ padding_value=0,
1820
+ padding_side="left",
1821
+ )
1822
+ if log_probs_full_unpadded is not None:
1823
+ log_probs_full_padded = pad_sequence(
1824
+ log_probs_full_unpadded.unbind(0),
1825
+ batch_first=True,
1826
+ padding_value=0.0,
1827
+ padding_side="left",
1828
+ )
1829
+ else:
1830
+ log_probs_full_padded = None
1831
+ logits_full_padded = pad_sequence(
1832
+ logits_full_unpadded.unbind(0),
1833
+ batch_first=True,
1834
+ padding_value=0.0,
1835
+ padding_side="left",
1836
+ )
1837
+
1838
+ # Build output TensorClass objects
1839
+ text_obj = Text._from_tensordict(
1840
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1841
+ )
1842
+ text_obj.prompt = None
1843
+ text_obj.response = None
1844
+ text_obj.full = None
1845
+ out.set(self.text_key, text_obj)
1846
+
1847
+ tokens_obj = Tokens._from_tensordict(
1848
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1849
+ )
1850
+ if not self.pad_output:
1851
+ input_ids_full_unpadded = _unpad_tensors(
1852
+ tokens_full_padded, attention_mask_full_padded, as_nested=False
1853
+ )
1854
+ tokens_obj.full = input_ids_full_unpadded
1855
+ else:
1856
+ tokens_obj.full = tokens_full_padded
1857
+ tokens_obj.response = None
1858
+ tokens_obj.padded = MetaData(self.pad_output)
1859
+ out.set(self.tokens_key, tokens_obj)
1860
+
1861
+ masks_obj = Masks._from_tensordict(
1862
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1863
+ )
1864
+ if self.pad_output:
1865
+ masks_obj.all_attention_mask = attention_mask_full_padded.bool()
1866
+ masks_obj.all_assistant_mask = td.get(("masks", "all_assistant_mask"))
1867
+ else:
1868
+ masks_obj.all_attention_mask = _unpad_tensors(
1869
+ attention_mask_full_padded.bool(),
1870
+ attention_mask_full_padded,
1871
+ as_nested=False,
1872
+ )
1873
+ masks_obj.all_assistant_mask = td.get(
1874
+ ("masks", "all_assistant_mask"), as_list=True
1875
+ )
1876
+
1877
+ masks_obj.padded = MetaData(self.pad_output)
1878
+ out.set(self.masks_key, masks_obj)
1879
+
1880
+ if not logits_only:
1881
+ log_probs_obj = LogProbs._from_tensordict(
1882
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1883
+ )
1884
+ if self.pad_output:
1885
+ log_probs_obj.full = log_probs_full_padded
1886
+ else:
1887
+ log_probs_full_unpadded = _unpad_tensors(
1888
+ log_probs_full_padded, attention_mask_full_padded, as_nested=False
1889
+ )
1890
+ log_probs_obj.full = log_probs_full_unpadded
1891
+ log_probs_obj.response = None
1892
+ log_probs_obj.padded = MetaData(self.pad_output)
1893
+ out.set(self.log_probs_key, log_probs_obj)
1894
+
1895
+ # Add logits to output if we're in a get_dist call
1896
+ if self._in_get_dist_call:
1897
+ if self.pad_output:
1898
+ out.set("logits", logits_full_padded)
1899
+ else:
1900
+ logits_full_unpadded = _unpad_tensors(
1901
+ logits_full_padded, attention_mask_full_padded, as_nested=False
1902
+ )
1903
+ out.set("logits", logits_full_unpadded)
1904
+ return out
1905
+
1906
+ @classmethod
1907
+ def _log_probs_generate(cls, tokens, logits, pad_val=-100, pad: bool = True):
1908
+ if pad:
1909
+ tokens = pad_sequence(
1910
+ tokens,
1911
+ padding_value=pad_val,
1912
+ batch_first=True,
1913
+ padding_side="left",
1914
+ )
1915
+ logits = pad_sequence(
1916
+ logits,
1917
+ padding_value=0.0,
1918
+ batch_first=True,
1919
+ padding_side="left",
1920
+ )
1921
+
1922
+ # logits = logits.log_softmax(dim=-1)
1923
+ # log_probs = logits.gather(-1, tokens.unsqueeze(-1)).squeeze(-1)
1924
+ td = TensorDict(logits=logits, tokens=tokens).auto_batch_size_()
1925
+ with td.flatten() as tdflat:
1926
+ tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
1927
+ tdflat["logits"], tdflat["tokens"], reduce=False, ignore_index=pad_val
1928
+ )
1929
+ td["log_probs"][:, 0] = 0
1930
+ log_probs = td["log_probs"]
1931
+ return log_probs, logits
1932
+
1933
+ def _compute_log_probs_from_model_output(
1934
+ self, model_output, input_ids, attention_mask, pad_val, logits_only=False
1935
+ ):
1936
+ """Compute log-probs from model output without modifying original tensors.
1937
+
1938
+ Args:
1939
+ model_output: Output from the model containing logits
1940
+ input_ids: Original input token ids
1941
+ attention_mask: Original attention mask
1942
+ pad_val: Padding token value to ignore in loss computation
1943
+ logits_only: Whether to return only the logits.
1944
+
1945
+ Returns:
1946
+ tuple: (log_probs, shifted_logits) where log_probs are the computed log probabilities
1947
+ and shifted_logits are the logits shifted to align with tokens
1948
+ """
1949
+ logits = model_output["logits"]
1950
+
1951
+ # Create shifted versions for log-prob computation without modifying originals
1952
+ shifted_logits = logits[:, :-1, :]
1953
+ # shifted_logits = shifted_logits - shifted_logits.logsumexp(dim=-1, keepdim=True)
1954
+ shifted_logits = torch.cat(
1955
+ [torch.zeros_like(shifted_logits[:, :1]), shifted_logits], 1
1956
+ )
1957
+
1958
+ shifted_input_ids = input_ids[:, 1:]
1959
+ shifted_input_ids = torch.cat(
1960
+ [torch.zeros_like(shifted_input_ids[:, :1]), shifted_input_ids], 1
1961
+ )
1962
+
1963
+ # Check that the shape is correct
1964
+ if shifted_logits.shape[-2] != shifted_input_ids.shape[-1]:
1965
+ raise ValueError(
1966
+ f"The logits shape {shifted_logits.shape} does not match the input ids shape {shifted_input_ids.shape}"
1967
+ )
1968
+ if logits_only:
1969
+ return None, shifted_logits
1970
+
1971
+ # Compute log-probs
1972
+ td = TensorDict(
1973
+ logits=shifted_logits, tokens=shifted_input_ids
1974
+ ).auto_batch_size_()
1975
+ with td.flatten() as tdflat:
1976
+ tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
1977
+ tdflat["logits"],
1978
+ tdflat["tokens"],
1979
+ reduce=False,
1980
+ ignore_index=pad_val,
1981
+ )
1982
+ # For consistency with vllm, we set the log-probs of the first token to 0
1983
+ # However, the first element may not be the first - we want the first of the attention mask,
1984
+ # i.e, the first element that is true on the left
1985
+ attention_mask = attention_mask.bool()
1986
+ attention_mask_first_left = ~attention_mask[:, :-1] & attention_mask[:, 1:]
1987
+ attention_mask_first_left = torch.cat(
1988
+ [
1989
+ torch.zeros_like(attention_mask_first_left[..., :1]),
1990
+ attention_mask_first_left,
1991
+ ],
1992
+ -1,
1993
+ )
1994
+ attention_mask_first_left[~(attention_mask_first_left.any(-1)), 0] = True
1995
+ assert attention_mask_first_left.any(-1).all()
1996
+ attention_mask_first_left = attention_mask_first_left | ~attention_mask
1997
+ td["log_probs"][attention_mask_first_left] = 0
1998
+
1999
+ return td["log_probs"], shifted_logits
2000
+
2001
+ def get_dist(
2002
+ self,
2003
+ tensordict: TensorDictBase,
2004
+ tensordict_out: TensorDictBase | None = None,
2005
+ logits_key: NestedKey = "logits",
2006
+ mask_key: NestedKey | None = None,
2007
+ as_padded_tensor: bool | None = None,
2008
+ as_nested_tensor: bool | None = None,
2009
+ padding_value: float | None = None,
2010
+ padding_side: str = "right",
2011
+ layout: torch.layout | None = None,
2012
+ **kwargs,
2013
+ ) -> D.Distribution:
2014
+ """Get distribution from logits/log-probs with optional masking.
2015
+
2016
+ This method enables logits computation for distribution creation.
2017
+ """
2018
+ self._in_get_dist_call = True
2019
+ self.out_keys += ["logits"]
2020
+ try:
2021
+ return super().get_dist(
2022
+ tensordict,
2023
+ tensordict_out,
2024
+ logits_key,
2025
+ mask_key,
2026
+ as_padded_tensor,
2027
+ as_nested_tensor,
2028
+ padding_value,
2029
+ padding_side,
2030
+ layout,
2031
+ **kwargs,
2032
+ )
2033
+ finally:
2034
+ self._in_get_dist_call = False
2035
+ self.out_keys.remove("logits")
2036
+
2037
+ def _get_dist_with_prompt_mask(
2038
+ self,
2039
+ tensordict: TensorDictBase,
2040
+ tokens_key: NestedKey = ("tokens", "prompt"),
2041
+ logits_key: NestedKey = "logits",
2042
+ assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
2043
+ attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
2044
+ **kwargs,
2045
+ ) -> D.Distribution:
2046
+ """Get distribution masked to only include response tokens (exclude prompt).
2047
+
2048
+ This method enables logits computation for distribution creation.
2049
+
2050
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
2051
+ """
2052
+ self._in_get_dist_call = True
2053
+ self.out_keys += ["logits"]
2054
+ try:
2055
+ return super()._get_dist_with_prompt_mask(
2056
+ tensordict,
2057
+ tokens_key,
2058
+ logits_key,
2059
+ assistant_mask_key,
2060
+ attention_mask_key,
2061
+ **kwargs,
2062
+ )
2063
+ finally:
2064
+ self._in_get_dist_call = False
2065
+ self.out_keys.remove("logits")
2066
+
2067
+ def _get_dist_with_assistant_mask(
2068
+ self,
2069
+ tensordict: TensorDictBase,
2070
+ assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
2071
+ logits_key: NestedKey = "logits",
2072
+ **kwargs,
2073
+ ) -> D.Distribution:
2074
+ """Get distribution masked to only include assistant tokens.
2075
+
2076
+ This method enables logits computation for distribution creation.
2077
+
2078
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
2079
+ """
2080
+ self._in_get_dist_call = True
2081
+ self.out_keys += ["logits"]
2082
+ try:
2083
+ return super()._get_dist_with_assistant_mask(
2084
+ tensordict, assistant_mask_key, logits_key, **kwargs
2085
+ )
2086
+ finally:
2087
+ self._in_get_dist_call = False
2088
+ self.out_keys.remove("logits")
2089
+
2090
+ def _get_dist_with_attention_mask(
2091
+ self,
2092
+ tensordict: TensorDictBase,
2093
+ attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
2094
+ logits_key: NestedKey = "logits",
2095
+ **kwargs,
2096
+ ) -> D.Distribution:
2097
+ """Get distribution masked using attention mask.
2098
+
2099
+ This method enables logits computation for distribution creation.
2100
+
2101
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
2102
+ """
2103
+ self._in_get_dist_call = True
2104
+ self.out_keys += ["logits"]
2105
+ try:
2106
+ return super()._get_dist_with_attention_mask(
2107
+ tensordict, attention_mask_key, logits_key, **kwargs
2108
+ )
2109
+ finally:
2110
+ self._in_get_dist_call = False
2111
+ self.out_keys.remove("logits")
2112
+
2113
+ def _get_dist_with_custom_mask(
2114
+ self,
2115
+ tensordict: TensorDictBase,
2116
+ mask: torch.Tensor,
2117
+ logits_key: NestedKey = "logits",
2118
+ **kwargs,
2119
+ ) -> D.Distribution:
2120
+ """Get distribution with custom mask.
2121
+
2122
+ This method enables logits computation for distribution creation.
2123
+ """
2124
+ self._in_get_dist_call = True
2125
+ self.out_keys += ["logits"]
2126
+ try:
2127
+ return super()._get_dist_with_custom_mask(
2128
+ tensordict, mask, logits_key, **kwargs
2129
+ )
2130
+ finally:
2131
+ self._in_get_dist_call = False
2132
+ self.out_keys.remove("logits")
2133
+
2134
+ # Convenience methods for common LLM training scenarios
2135
+ def _get_sft_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
2136
+ """Get distribution suitable for SFT loss (response tokens only).
2137
+
2138
+ This method enables logits computation for distribution creation.
2139
+
2140
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
2141
+ """
2142
+ self._in_get_dist_call = True
2143
+ self.out_keys += ["logits"]
2144
+ try:
2145
+ return super()._get_sft_dist(tensordict, **kwargs)
2146
+ finally:
2147
+ self._in_get_dist_call = False
2148
+ self.out_keys.remove("logits")
2149
+
2150
+ def _get_rlhf_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
2151
+ """Get distribution suitable for RLHF loss (assistant tokens only).
2152
+
2153
+ This method enables logits computation for distribution creation.
2154
+
2155
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
2156
+ """
2157
+ self._in_get_dist_call = True
2158
+ self.out_keys += ["logits"]
2159
+ try:
2160
+ return super()._get_rlhf_dist(tensordict, **kwargs)
2161
+ finally:
2162
+ self._in_get_dist_call = False
2163
+ self.out_keys.remove("logits")
2164
+
2165
+ def _get_generic_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
2166
+ """Get distribution suitable for generic losses (all tokens).
2167
+
2168
+ This method enables logits computation for distribution creation.
2169
+
2170
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
2171
+ """
2172
+ self._in_get_dist_call = True
2173
+ self.out_keys += ["logits"]
2174
+ try:
2175
+ return super()._get_generic_dist(tensordict, **kwargs)
2176
+ finally:
2177
+ self._in_get_dist_call = False
2178
+ self.out_keys.remove("logits")
2179
+
2180
+ def _pack_sequences(
2181
+ self,
2182
+ input_ids: torch.nested.NestedTensor,
2183
+ attention_mask: torch.nested.NestedTensor,
2184
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
2185
+ """Pack sequences into a single tensor."""
2186
+ packed_input_ids = input_ids.values()
2187
+ lengths = input_ids.lengths()
2188
+ if lengths is None:
2189
+ offsets = input_ids.offsets()
2190
+ lengths = offsets.diff()
2191
+ offsets = offsets[1:]
2192
+ else:
2193
+ offsets = lengths.cumsum(0)
2194
+ # Create block-diagonal attention mask to prevent cross-sequence attention
2195
+ attention_mask = self._create_block_diagonal_attention_mask(lengths)
2196
+ # Create position IDs that restart for each sequence
2197
+ position_ids = self._create_packed_position_ids(
2198
+ lengths, total_length=packed_input_ids.numel()
2199
+ )
2200
+
2201
+ packing_metadata = {
2202
+ "sequence_lengths": lengths,
2203
+ "cumulative_lengths": offsets,
2204
+ "attention_mask": attention_mask,
2205
+ "position_ids": position_ids,
2206
+ }
2207
+
2208
+ return (
2209
+ packed_input_ids.unsqueeze(0),
2210
+ attention_mask.unsqueeze(0),
2211
+ packing_metadata,
2212
+ )
2213
+
2214
+ def _model_forward_with_padded_sequences(
2215
+ self,
2216
+ tokens_full_padded: torch.Tensor,
2217
+ attention_mask_full_padded: torch.Tensor,
2218
+ *,
2219
+ pad_val: float | int | torch.Tensor | None = None,
2220
+ logits_only: bool = False,
2221
+ **kwargs,
2222
+ ):
2223
+ """Forward pass with padded sequences."""
2224
+ # Error handling for empty sequences
2225
+ if tokens_full_padded.numel() == 0:
2226
+ raise ValueError(
2227
+ "Input contains empty sequences. Packing/padding requires at least one token per sequence."
2228
+ )
2229
+ # Error handling for overlong sequences
2230
+ config = getattr(self.model, "config", None)
2231
+ max_len = getattr(config, "max_position_embeddings", None)
2232
+ if max_len is not None and tokens_full_padded.shape[-1] > max_len:
2233
+ raise ValueError(
2234
+ f"Input sequence length ({tokens_full_padded.shape[-1]}) exceeds model's max_position_embeddings ({max_len}). Consider truncating or splitting your input."
2235
+ )
2236
+ tokens_out_struct = self.model(
2237
+ tokens_full_padded, attention_mask_full_padded, **kwargs
2238
+ )
2239
+ (
2240
+ log_probs_full_padded,
2241
+ logits_full_padded,
2242
+ ) = self._compute_log_probs_from_model_output(
2243
+ tokens_out_struct,
2244
+ tokens_full_padded,
2245
+ attention_mask_full_padded,
2246
+ pad_val,
2247
+ logits_only=logits_only,
2248
+ )
2249
+ return log_probs_full_padded, logits_full_padded
2250
+
2251
+ def _model_forward_with_packed_sequences(
2252
+ self,
2253
+ flat_input_ids: torch.Tensor,
2254
+ block_diag_attention_mask: torch.Tensor,
2255
+ *,
2256
+ pad: bool = True,
2257
+ logits_only: bool = False,
2258
+ **kwargs,
2259
+ ):
2260
+ """Pack sequences into a single tensor and forward them through the model.
2261
+
2262
+ Args:
2263
+ flat_input_ids (NestedTensor): NestedTensor of shape (batch_size, -1)
2264
+ block_diag_attention_mask (NestedTensor): NestedTensor of shape (batch_size, -1)
2265
+
2266
+ Returns:
2267
+ pad (bool): Whether to pad the output tensors.
2268
+ logits_only (bool): Whether to return only logits.
2269
+ kwargs (dict): Additional keyword arguments to pass to the model.
2270
+
2271
+ """
2272
+ # Error handling for empty sequences
2273
+ if flat_input_ids.numel() == 0:
2274
+ raise ValueError(
2275
+ "Input contains empty sequences. Packing requires at least one token per sequence."
2276
+ )
2277
+ # Error handling for overlong sequences
2278
+ # Note: Skipping this check for nested tensors due to symbolic representation issues
2279
+ # The model will handle sequence length limits internally
2280
+ max_len = getattr(self.model.config, "max_position_embeddings", None)
2281
+ if max_len is not None and not hasattr(flat_input_ids, "size"):
2282
+ # Only check for regular tensors, not nested tensors
2283
+ actual_size = flat_input_ids.shape[-1]
2284
+ if actual_size > max_len:
2285
+ raise ValueError(
2286
+ f"Input sequence length ({actual_size}) exceeds model's max_position_embeddings ({max_len}). Consider truncating or splitting your input."
2287
+ )
2288
+ (
2289
+ flat_input_ids,
2290
+ block_diag_attention_mask,
2291
+ packing_metadata,
2292
+ ) = self._pack_sequences(flat_input_ids, block_diag_attention_mask)
2293
+
2294
+ outputs = self.model(
2295
+ input_ids=flat_input_ids,
2296
+ attention_mask=block_diag_attention_mask.unsqueeze(0),
2297
+ position_ids=packing_metadata["position_ids"],
2298
+ use_cache=False, # Disable KV cache for packing
2299
+ **kwargs,
2300
+ )
2301
+ log_probs, logits = self._unpack_outputs(
2302
+ outputs, packing_metadata, flat_input_ids, pad=pad, logits_only=logits_only
2303
+ )
2304
+ return log_probs, logits
2305
+
2306
+ def _unpack_outputs(
2307
+ self,
2308
+ outputs,
2309
+ packing_metadata: dict[str, Any],
2310
+ flat_input_ids: torch.Tensor,
2311
+ pad: bool = True,
2312
+ logits_only: bool = False,
2313
+ ) -> tuple[torch.Tensor | None, torch.Tensor]:
2314
+ """Unpack outputs using nested tensors - zero syncs."""
2315
+ # use cross_entropy to compute log_probs
2316
+ log_probs, logits = self._compute_log_probs_from_model_output(
2317
+ outputs,
2318
+ flat_input_ids,
2319
+ torch.ones_like(flat_input_ids, dtype=torch.bool),
2320
+ -100,
2321
+ logits_only=logits_only,
2322
+ )
2323
+ # check shapes: [1, L] for log_probs, [1, L, vocab_size] for logits
2324
+ sequence_lengths = packing_metadata["sequence_lengths"]
2325
+ if logits_only:
2326
+ log_probs = None
2327
+ else:
2328
+ if log_probs.shape != logits.shape[:2]:
2329
+ raise ValueError(
2330
+ f"Log probs shape {log_probs.shape=} does not match logits shape {logits.shape[:2]=}"
2331
+ )
2332
+ if log_probs.ndim != 2:
2333
+ raise ValueError(f"Log probs shape {log_probs.shape=} is not 2D")
2334
+ if logits.ndim != 3:
2335
+ raise ValueError(f"Logits shape {logits.shape=} is not 3D")
2336
+ if log_probs.shape[1] != sequence_lengths.sum():
2337
+ raise ValueError(
2338
+ f"Log probs shape {log_probs.shape=} does not match sequence lengths {sequence_lengths.sum()=}"
2339
+ )
2340
+
2341
+ log_probs = log_probs.squeeze(0)
2342
+ nested_logprobs = torch.nested.nested_tensor_from_jagged(
2343
+ log_probs,
2344
+ lengths=sequence_lengths,
2345
+ )
2346
+
2347
+ logits = logits.squeeze(0)
2348
+ nested_logits = torch.nested.nested_tensor_from_jagged(
2349
+ logits, # Remove batch dim: (total_length, vocab_size)
2350
+ lengths=sequence_lengths,
2351
+ )
2352
+
2353
+ if logits_only:
2354
+ if pad:
2355
+ return None, nested_logits.to_padded_tensor(padding=0.0)
2356
+ return None, nested_logits
2357
+ else:
2358
+ if pad:
2359
+ return nested_logprobs.to_padded_tensor(
2360
+ padding=0.0
2361
+ ), nested_logits.to_padded_tensor(padding=0.0)
2362
+ return nested_logprobs, nested_logits
2363
+
2364
+ def _create_block_diagonal_attention_mask(
2365
+ self, sequence_lengths: torch.Tensor
2366
+ ) -> torch.Tensor:
2367
+ """Efficient creation of a block-diagonal attention mask.
2368
+
2369
+ Zero cuda syncs, no integer involved except len(tensor) - compilable.
2370
+
2371
+ Args:
2372
+ sequence_lengths: Tensor of shape (batch_size,) containing the lengths of the sequences
2373
+
2374
+ Returns:
2375
+ attention_mask: Tensor of shape (batch_size, total_length, total_length)
2376
+ where each sequence can only attend to itself.
2377
+ """
2378
+ seq_ids = torch.arange(len(sequence_lengths), device=sequence_lengths.device)
2379
+ position_to_seq_id = seq_ids.repeat_interleave(sequence_lengths)
2380
+
2381
+ attention_mask = position_to_seq_id.unsqueeze(
2382
+ 1
2383
+ ) == position_to_seq_id.unsqueeze(0)
2384
+ return attention_mask
2385
+
2386
+ def repeat_interleave_causal(self, sequence_lengths: torch.Tensor) -> torch.Tensor:
2387
+ """Same as _create_block_diagonal_attention_mask, but with causal masking."""
2388
+ total_length = sequence_lengths.sum()
2389
+
2390
+ seq_ids = torch.arange(len(sequence_lengths), device=sequence_lengths.device)
2391
+ position_to_seq_id = seq_ids.repeat_interleave(sequence_lengths)
2392
+
2393
+ positions = torch.arange(int(total_length), device=sequence_lengths.device)
2394
+
2395
+ same_sequence = position_to_seq_id.unsqueeze(1) == position_to_seq_id.unsqueeze(
2396
+ 0
2397
+ )
2398
+ causal = positions.unsqueeze(0) <= positions.unsqueeze(1)
2399
+
2400
+ attention_mask = same_sequence & causal
2401
+ return attention_mask
2402
+
2403
+ def _create_packed_position_ids(
2404
+ self, sequence_lengths: torch.Tensor, total_length: int | None = None
2405
+ ) -> torch.Tensor:
2406
+ """Create position IDs that restart from 0 for each sequence.
2407
+
2408
+ For sequences of length [3, 2], creates: [0, 1, 2, 0, 1]
2409
+
2410
+ No cuda syncs.
2411
+ """
2412
+ if total_length is None:
2413
+ total_length = int(sequence_lengths.sum().item())
2414
+
2415
+ # Create global position IDs: [0, 1, 2, 3, 4]
2416
+ global_positions = torch.arange(total_length, device=sequence_lengths.device)
2417
+
2418
+ # Create sequence start offsets repeated for each position: [0, 0, 0, 3, 3]
2419
+ offsets = torch.cat(
2420
+ [
2421
+ torch.zeros(1, device=sequence_lengths.device),
2422
+ sequence_lengths.cumsum(0)[:-1],
2423
+ ]
2424
+ )
2425
+ sequence_starts = offsets.repeat_interleave(sequence_lengths)
2426
+
2427
+ # Subtract to get local positions: [0, 1, 2, 0, 1]
2428
+ position_ids = global_positions - sequence_starts
2429
+
2430
+ return position_ids.unsqueeze(0) # (1, total_length)
2431
+
2432
+
2433
+ class RemoteTransformersWrapper:
2434
+ """A remote Ray actor wrapper for TransformersWrapper that provides a simplified interface.
2435
+
2436
+ This class wraps a TransformersWrapper instance as a Ray actor, allowing remote execution
2437
+ while providing a clean interface that doesn't require explicit `remote()` and `get()` calls.
2438
+
2439
+ Args:
2440
+ model (str): The Hugging Face Transformers model to wrap.
2441
+ Must be a string (model name or path) that will be passed to `transformers.AutoModelForCausalLM.from_pretrained`.
2442
+ Transformers models are not serializable, so only model names/paths are supported.
2443
+ max_concurrency (int, optional): Maximum number of concurrent calls to the remote actor. Defaults to 16.
2444
+ validate_model (bool, optional): Whether to validate the model. Defaults to True.
2445
+ num_gpus (int, optional): Number of GPUs to use. Defaults to 0.
2446
+ num_cpus (int, optional): Number of CPUs to use. Defaults to 0.
2447
+ **kwargs: All other arguments are passed directly to TransformersWrapper.
2448
+
2449
+ Example:
2450
+ >>> import ray
2451
+ >>> from torchrl.modules.llm.policies import RemoteTransformersWrapper
2452
+ >>>
2453
+ >>> # Initialize Ray if not already done
2454
+ >>> if not ray.is_initialized():
2455
+ ... ray.init()
2456
+ >>>
2457
+ >>> # Create remote wrapper
2458
+ >>> remote_wrapper = RemoteTransformersWrapper(
2459
+ ... model="gpt2",
2460
+ ... input_mode="history",
2461
+ ... generate=True,
2462
+ ... generate_kwargs={"max_new_tokens": 50}
2463
+ ... )
2464
+ >>>
2465
+ >>> # Use like a regular wrapper (no remote/get calls needed)
2466
+ >>> result = remote_wrapper(tensordict_input)
2467
+ >>> print(result["text"].response)
2468
+ """
2469
+
2470
+ def __init__(
2471
+ self,
2472
+ model,
2473
+ max_concurrency: int = 16,
2474
+ validate_model: bool = True,
2475
+ actor_name: str | None = None,
2476
+ num_gpus: int = 1,
2477
+ num_cpus: int = 1,
2478
+ **kwargs,
2479
+ ):
2480
+ import ray
2481
+
2482
+ # Validate model parameter - only strings are allowed for Transformers
2483
+ if not isinstance(model, str) and validate_model:
2484
+ raise ValueError(
2485
+ "For RemoteTransformersWrapper, the model parameter must be a string "
2486
+ f"(model name or path). Got type: {type(model)}. "
2487
+ "Transformers models are not serializable, so only model names/paths are supported. "
2488
+ "You can bypass this check by setting validate_model=False."
2489
+ )
2490
+
2491
+ if not ray.is_initialized():
2492
+ ray.init()
2493
+
2494
+ if actor_name is not None:
2495
+ # Check if an actor with this name already exists
2496
+ try:
2497
+ existing_actor = ray.get_actor(actor_name)
2498
+ # If we can get the actor, assume it's alive and use it
2499
+ self._remote_wrapper = existing_actor
2500
+ torchrl_logger.info(f"Using existing actor {actor_name}")
2501
+ return
2502
+ except ValueError:
2503
+ # Actor doesn't exist, create a new one
2504
+ torchrl_logger.info(f"Creating new actor {actor_name}")
2505
+
2506
+ # Create the remote actor with the unique name
2507
+ self._remote_wrapper = (
2508
+ ray.remote(TransformersWrapper)
2509
+ .options(
2510
+ max_concurrency=max_concurrency,
2511
+ name=actor_name,
2512
+ num_gpus=num_gpus,
2513
+ num_cpus=num_cpus,
2514
+ )
2515
+ .remote(model, **kwargs)
2516
+ )
2517
+
2518
+ def __call__(self, tensordict, **kwargs):
2519
+ """Forward pass that automatically handles remote execution."""
2520
+ import ray
2521
+
2522
+ return ray.get(self._remote_wrapper.forward.remote(tensordict, **kwargs))
2523
+
2524
+ def get_new_version(self, **kwargs):
2525
+ """Get a new version of the wrapper with altered parameters."""
2526
+ import ray
2527
+
2528
+ return ray.get(self._remote_wrapper.get_new_version.remote(**kwargs))
2529
+
2530
+ def get_dist(self, tensordict, **kwargs):
2531
+ """Get distribution from logits/log-probs with optional masking."""
2532
+ import ray
2533
+
2534
+ return ray.get(self._remote_wrapper.get_dist.remote(tensordict, **kwargs))
2535
+
2536
+ def get_dist_with_prompt_mask(self, tensordict, **kwargs):
2537
+ """Get distribution masked to only include response tokens (exclude prompt)."""
2538
+ import ray
2539
+
2540
+ return ray.get(
2541
+ self._remote_wrapper.get_dist_with_prompt_mask.remote(tensordict, **kwargs)
2542
+ )
2543
+
2544
+ def _get_dist_with_assistant_mask(self, tensordict, **kwargs):
2545
+ """Get distribution masked to only include assistant tokens."""
2546
+ import ray
2547
+
2548
+ return ray.get(
2549
+ self._remote_wrapper._get_dist_with_assistant_mask.remote(
2550
+ tensordict, **kwargs
2551
+ )
2552
+ )
2553
+
2554
+ def _get_dist_with_attention_mask(self, tensordict, **kwargs):
2555
+ """Get distribution masked using attention mask."""
2556
+ import ray
2557
+
2558
+ return ray.get(
2559
+ self._remote_wrapper._get_dist_with_attention_mask.remote(
2560
+ tensordict, **kwargs
2561
+ )
2562
+ )
2563
+
2564
+ def _get_dist_with_custom_mask(self, tensordict, **kwargs):
2565
+ """Get distribution with custom mask."""
2566
+ import ray
2567
+
2568
+ return ray.get(
2569
+ self._remote_wrapper._get_dist_with_custom_mask.remote(tensordict, **kwargs)
2570
+ )
2571
+
2572
+ def _get_sft_dist(self, tensordict, **kwargs):
2573
+ """Get distribution suitable for SFT loss (response tokens only)."""
2574
+ import ray
2575
+
2576
+ return ray.get(self._remote_wrapper._get_sft_dist.remote(tensordict, **kwargs))
2577
+
2578
+ def _get_rlhf_dist(self, tensordict, **kwargs):
2579
+ """Get distribution suitable for RLHF loss (assistant tokens only)."""
2580
+ import ray
2581
+
2582
+ return ray.get(self._remote_wrapper._get_rlhf_dist.remote(tensordict, **kwargs))
2583
+
2584
+ def _get_generic_dist(self, tensordict, **kwargs):
2585
+ """Get distribution suitable for generic losses (all tokens)."""
2586
+ import ray
2587
+
2588
+ return ray.get(
2589
+ self._remote_wrapper._get_generic_dist.remote(tensordict, **kwargs)
2590
+ )
2591
+
2592
+ def log_prob(self, data, **kwargs):
2593
+ """Compute log probabilities."""
2594
+ import ray
2595
+
2596
+ return ray.get(self._remote_wrapper.log_prob.remote(data, **kwargs))
2597
+
2598
+ def cleanup_batching(self):
2599
+ """Clean up batching resources."""
2600
+ import ray
2601
+
2602
+ return ray.get(self._remote_wrapper.cleanup_batching.remote())
2603
+
2604
+ def __del__(self):
2605
+ """Cleanup when the wrapper is destroyed."""
2606
+ try:
2607
+ import ray
2608
+
2609
+ if hasattr(self, "_remote_wrapper") and ray.is_initialized():
2610
+ # Clean up batching resources
2611
+ try:
2612
+ ray.get(self._remote_wrapper.cleanup_batching.remote())
2613
+ except Exception:
2614
+ pass # Ignore cleanup errors during destruction
2615
+ except Exception:
2616
+ pass # Ignore any errors during cleanup
2617
+
2618
+ def __enter__(self):
2619
+ """Context manager entry."""
2620
+ return self
2621
+
2622
+ def __exit__(self, exc_type, exc_val, exc_tb):
2623
+ """Context manager exit with cleanup."""
2624
+ self.cleanup_batching()
2625
+
2626
+ def get_batching_state(self):
2627
+ """Get the current batching state."""
2628
+ import ray
2629
+
2630
+ return ray.get(self._remote_wrapper.get_batching_state.remote())
2631
+
2632
+ @property
2633
+ def generate(self):
2634
+ """Whether text generation is enabled."""
2635
+ import ray
2636
+
2637
+ return ray.get(self._remote_wrapper.generate.remote)
2638
+
2639
+ @property
2640
+ def pad_output(self):
2641
+ """Whether output sequences are padded."""
2642
+ import ray
2643
+
2644
+ return ray.get(self._remote_wrapper.pad_output.remote)
2645
+
2646
+ @property
2647
+ def text_key(self):
2648
+ """The key for text output."""
2649
+ import ray
2650
+
2651
+ return ray.get(self._remote_wrapper.text_key.remote)
2652
+
2653
+ @property
2654
+ def tokens_key(self):
2655
+ """The key for tokens output."""
2656
+ import ray
2657
+
2658
+ return ray.get(self._remote_wrapper.tokens_key.remote)
2659
+
2660
+ @property
2661
+ def masks_key(self):
2662
+ """The key for masks output."""
2663
+ import ray
2664
+
2665
+ return ray.get(self._remote_wrapper.masks_key.remote)
2666
+
2667
+ @property
2668
+ def log_probs_key(self):
2669
+ """The key for log probabilities output."""
2670
+ import ray
2671
+
2672
+ return ray.get(self._remote_wrapper.log_probs_key.remote)
2673
+
2674
+ @property
2675
+ def in_keys(self):
2676
+ """The input keys."""
2677
+ import ray
2678
+
2679
+ return ray.get(self._remote_wrapper.in_keys.remote)
2680
+
2681
+ @property
2682
+ def out_keys(self):
2683
+ """The output keys."""
2684
+ import ray
2685
+
2686
+ return ray.get(self._remote_wrapper.out_keys.remote)
2687
+
2688
+ @property
2689
+ def inplace(self):
2690
+ """Whether in-place operations are used."""
2691
+ import ray
2692
+
2693
+ return ray.get(self._remote_wrapper.inplace.remote)
2694
+
2695
+ @property
2696
+ def device(self):
2697
+ """The device used for computation."""
2698
+ import ray
2699
+
2700
+ return ray.get(self._remote_wrapper.device.remote)
2701
+
2702
+ @property
2703
+ def layout(self):
2704
+ """The layout used for output tensors."""
2705
+ import ray
2706
+
2707
+ return ray.get(self._remote_wrapper.layout.remote)
2708
+
2709
+ @property
2710
+ def num_samples(self):
2711
+ """The number of samples to generate."""
2712
+ import ray
2713
+
2714
+ return ray.get(self._remote_wrapper.num_samples.remote)
2715
+
2716
+ @property
2717
+ def batching(self):
2718
+ """Whether batching is enabled."""
2719
+ import ray
2720
+
2721
+ return ray.get(self._remote_wrapper.batching.remote)
2722
+
2723
+ @property
2724
+ def collector(self):
2725
+ """The collector associated with the module."""
2726
+ import ray
2727
+
2728
+ return ray.get(self._remote_wrapper.collector.remote)
2729
+
2730
+ @property
2731
+ def log_prob_keys(self):
2732
+ """The keys for log probabilities."""
2733
+ import ray
2734
+
2735
+ return ray.get(self._remote_wrapper.log_prob_keys.remote)
2736
+
2737
+ @log_prob_keys.setter
2738
+ def log_prob_keys(self, value):
2739
+ """Set the keys for log probabilities."""
2740
+ import ray
2741
+
2742
+ ray.get(self._remote_wrapper.log_prob_keys.remote(value))
2743
+
2744
+ @property
2745
+ def dist_params_keys(self):
2746
+ """The keys for distribution parameters."""
2747
+ import ray
2748
+
2749
+ return ray.get(self._remote_wrapper.dist_params_keys.remote)
2750
+
2751
+ @property
2752
+ def dist_sample_keys(self):
2753
+ """The keys for distribution samples."""
2754
+ import ray
2755
+
2756
+ return ray.get(self._remote_wrapper.dist_sample_keys.remote)