torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1378 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import dataclasses
8
+
9
+ import re
10
+ from typing import Literal, TYPE_CHECKING
11
+
12
+ import torch
13
+
14
+ from tensordict import (
15
+ lazy_stack,
16
+ LazyStackedTensorDict,
17
+ list_to_stack,
18
+ TensorClass,
19
+ TensorDict,
20
+ )
21
+ from tensordict.utils import _maybe_correct_neg_dim
22
+ from torchrl._utils import logger as torchrl_logger
23
+
24
+ if TYPE_CHECKING:
25
+ import transformers
26
+
27
+
28
+ # Global storage for custom templates and their metadata
29
+ _CHAT_TEMPLATES = {
30
+ "chatml_format": """{% for message in messages %}
31
+ {%- if message['role'] == 'assistant' %}
32
+ {% generation %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endgeneration %}
33
+ {%- else %}
34
+ {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
35
+ {%- endif %}
36
+ {% endfor %}
37
+ {%- if add_generation_prompt %}
38
+ {% generation %}{{- '<|im_start|>assistant\n' }}{% endgeneration %}
39
+ {%- endif %}
40
+ """,
41
+ "qwen": """
42
+ {%- if tools %}
43
+ {{- '<|im_start|>system\\n' }}
44
+ {%- if messages[0]['role'] == 'system' %}
45
+ {{- messages[0]['content'] }}
46
+ {%- else %}
47
+ {{- 'You are a helpful assistant.' }}
48
+ {%- endif %}
49
+ {{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}
50
+ {%- for tool in tools %}
51
+ {{- "\\n" }}
52
+ {{- tool | tojson }}
53
+ {%- endfor %}
54
+ {{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}
55
+ {%- else %}
56
+ {%- if messages[0]['role'] == 'system' %}
57
+ {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}
58
+ {%- else %}
59
+ {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}
60
+ {%- endif %}
61
+ {%- endif %}
62
+ {%- for message in messages %}
63
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
64
+ {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}
65
+ {%- elif (message.role == "assistant" and not message.tool_calls) %}
66
+ {% generation %} {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }} {% endgeneration %}
67
+ {%- elif message.role == "assistant" %}
68
+ {% generation %}{{- '<|im_start|>' + message.role }}
69
+ {%- if message.content %}
70
+ {{- '\\n' + message.content }}
71
+ {%- endif %}
72
+ {%- for tool_call in message.tool_calls %}
73
+ {%- if tool_call.function is defined %}
74
+ {%- set tool_call = tool_call.function %}
75
+ {%- endif %}
76
+ {{- '\\n<tool_call>\\n{\\\"name\\\": \\\"' }}
77
+ {{- tool_call.name }}
78
+ {{- '\\\", \\\"arguments\\\": ' }}
79
+ {{- tool_call.arguments | tojson }}
80
+ {{- '}\\n</tool_call>' }}
81
+ {%- endfor %}
82
+ {{- '<|im_end|>\\n' }}{% endgeneration %}
83
+ {%- elif message.role == "tool" %}
84
+ {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
85
+ {{- '<|im_start|>tool' }}
86
+ {%- endif %}
87
+ {{- '\\n<tool_response>\\n' }}
88
+ {%- if message.tool_responses %}
89
+ {{- message.tool_responses }}
90
+ {%- else %}
91
+ {{- message.content }}
92
+ {%- endif %}
93
+ {{- '\\n</tool_response>' }}
94
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
95
+ {{- '<|im_end|>\\n' }}
96
+ {%- endif %}
97
+ {%- endif %}
98
+ {%- endfor %}
99
+ {%- if add_generation_prompt %}
100
+ {% generation %}{{- '<|im_start|>assistant\\n' }}{% endgeneration %}
101
+ {%- endif %}
102
+ """,
103
+ "dialogpt": """{% for message in messages %}{% if message['role'] == 'assistant' %}{% generation %}{{ message['content'] }}{% endgeneration %}{{ eos_token }}{% elif message['role'] == 'user' %}{{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{% generation %}{{ ' ' }}{% endgeneration %}{% endif %}""",
104
+ "falcon": """{% for message in messages %}{% if message['role'] == 'assistant' %}{% generation %}{{ 'Assistant: ' + message['content'] }}{% endgeneration %}\n\n{% elif message['role'] == 'user' %}{{ 'User: ' + message['content'] }}\n\n{% elif message['role'] == 'system' %}{{ message['content'] }}\n\n{% endif %}{% endfor %}{% if add_generation_prompt %}{% generation %}{{ 'Assistant: ' }}{% endgeneration %}{% endif %}""",
105
+ "deepseek": """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{% generation %}{{ 'Assistant: ' + message['content'] + eos_token }}{% endgeneration %}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{% generation %}{{ 'Assistant:' }}{% endgeneration %}{% endif %}""",
106
+ "llama": """{{- bos_token }}
107
+ {%- if messages[0]['role'] == 'system' %}
108
+ {%- set system_message = messages[0]['content']|trim %}
109
+ {%- set messages = messages[1:] %}
110
+ {%- else %}
111
+ {%- set system_message = "" %}
112
+ {%- endif %}
113
+ {%- if system_message %}
114
+ {{- "<|header_start|>system<|header_end|>\n\n" }}
115
+ {{- system_message }}
116
+ {{- "<|eot|>" }}
117
+ {%- endif %}
118
+ {%- for message in messages %}
119
+ {%- if message['role'] == 'assistant' %}
120
+ {% generation %}{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
121
+ {%- if message['content'] is string %}
122
+ {{- message['content'] }}
123
+ {%- else %}
124
+ {%- for content in message['content'] %}
125
+ {%- if content['type'] == 'text' %}
126
+ {{- content['text'] | trim }}
127
+ {%- endif %}
128
+ {%- endfor %}
129
+ {%- endif %}
130
+ {{- "<|eot|>" }}{% endgeneration %}
131
+ {%- else %}
132
+ {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
133
+ {%- if message['content'] is string %}
134
+ {{- message['content'] }}
135
+ {%- else %}
136
+ {%- for content in message['content'] %}
137
+ {%- if content['type'] == 'text' %}
138
+ {{- content['text'] | trim }}
139
+ {%- endif %}
140
+ {%- endfor %}
141
+ {%- endif %}
142
+ {{- "<|eot|>" }}
143
+ {%- endif %}
144
+ {%- endfor %}
145
+ {%- if add_generation_prompt %}
146
+ {% generation %}{{- '<|header_start|>assistant<|header_end|>\n\n' }}{% endgeneration %}
147
+ {%- endif %}""",
148
+ }
149
+
150
+ # Global storage for custom template metadata
151
+ _CUSTOM_INVERSE_PARSERS = {}
152
+ _CUSTOM_MODEL_FAMILY_KEYWORDS = {}
153
+
154
+
155
+ def add_chat_template(
156
+ template_name: str,
157
+ template: str,
158
+ inverse_parser: callable | None = None,
159
+ model_family_keywords: list[str] | None = None,
160
+ ) -> None:
161
+ r"""Add a custom chat template to the global template dictionary.
162
+
163
+ This function allows you to add custom chat templates for new model families
164
+ that support assistant token masking via the `{% generation %}` keyword.
165
+
166
+ Args:
167
+ template_name (str): The name of the template (e.g., "llama", "mistral").
168
+ This name will be used in the `chat_template_name` parameter of
169
+ `History.apply_chat_template()` and `History.from_text()`.
170
+ template (str): The Jinja2 template string. Must include `{% generation %}`
171
+ blocks around assistant message content to enable token masking.
172
+ inverse_parser (callable, optional): A function that parses formatted text back
173
+ into a History object. Should have signature `(text: str) -> History`.
174
+ If None, a basic parser will be used.
175
+ model_family_keywords (list[str], optional): Keywords to detect this model family
176
+ in the auto-detection logic. For example, ["llama", "meta-llama"] for Llama models.
177
+ If provided, the template will be automatically selected for models containing
178
+ these keywords in their name.
179
+
180
+ Example:
181
+ >>> from torchrl.data.llm.chat import add_chat_template, History
182
+ >>> from transformers import AutoTokenizer
183
+ >>>
184
+ >>> # Add a custom template for Llama models
185
+ >>> llama_template = '''
186
+ ... {% for message in messages %}
187
+ ... {%- if message['role'] == 'user' %}
188
+ ... {{ '<s>[INST] ' + message['content'] + ' [/INST]' }}
189
+ ... {%- elif message['role'] == 'assistant' %}
190
+ ... {% generation %}{{ message['content'] + '</s>' }}{% endgeneration %}
191
+ ... {%- endif %}
192
+ ... {% endfor %}
193
+ ... {%- if add_generation_prompt %}
194
+ ... {% generation %}{{ ' ' }}{% endgeneration %}
195
+ ... {%- endif %}
196
+ ... '''
197
+ >>>
198
+ >>> def parse_llama_text(text: str) -> History:
199
+ ... # Custom parser for Llama format
200
+ ... import re
201
+ ... pattern = r'<s>\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?)</s>'
202
+ ... matches = re.findall(pattern, text, re.DOTALL)
203
+ ... messages = []
204
+ ... for user_content, assistant_content in matches:
205
+ ... messages.append(History(role="user", content=user_content.strip()))
206
+ ... messages.append(History(role="assistant", content=assistant_content.strip()))
207
+ ... return lazy_stack(messages)
208
+ >>>
209
+ >>> # Add the template with auto-detection
210
+ >>> add_chat_template(
211
+ ... template_name="llama",
212
+ ... template=llama_template,
213
+ ... inverse_parser=parse_llama_text,
214
+ ... model_family_keywords=["llama", "meta-llama"]
215
+ ... )
216
+ >>>
217
+ >>> # Now you can use it with auto-detection
218
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
219
+ >>> history = History.from_chats([[
220
+ ... {"role": "user", "content": "Hello"},
221
+ ... {"role": "assistant", "content": "Hi there!"}
222
+ ... ]])
223
+ >>>
224
+ >>> # Auto-detection will use the llama template
225
+ >>> result = history.apply_chat_template(
226
+ ... tokenizer=tokenizer,
227
+ ... add_generation_prompt=False,
228
+ ... return_dict=True,
229
+ ... return_assistant_tokens_mask=True,
230
+ ... )
231
+ >>>
232
+ >>> # Or use it explicitly
233
+ >>> result = history.apply_chat_template(
234
+ ... tokenizer=tokenizer,
235
+ ... chat_template_name="llama",
236
+ ... add_generation_prompt=False,
237
+ ... return_dict=True,
238
+ ... return_assistant_tokens_mask=True,
239
+ ... )
240
+
241
+ .. note:
242
+ - The template must include `{% generation %}` blocks around assistant message
243
+ content to enable assistant token masking.
244
+ - The inverse parser should handle the specific format of your template.
245
+ - Model family keywords are case-insensitive and matched against the tokenizer's
246
+ `name_or_path` attribute.
247
+ - Templates are stored globally and persist for the duration of the Python session.
248
+ """
249
+ global _CHAT_TEMPLATES, _CUSTOM_INVERSE_PARSERS, _CUSTOM_MODEL_FAMILY_KEYWORDS # noqa: F824
250
+
251
+ # Validate template contains generation blocks
252
+ if "{% generation %}" not in template:
253
+ raise ValueError(
254
+ f"Template '{template_name}' must include '{{% generation %}}' blocks "
255
+ "around assistant message content to enable token masking."
256
+ )
257
+
258
+ # Add template to dictionary
259
+ _CHAT_TEMPLATES[template_name] = template
260
+
261
+ # Store inverse parser if provided
262
+ if inverse_parser is not None:
263
+ _CUSTOM_INVERSE_PARSERS[template_name] = inverse_parser
264
+
265
+ # Store model family keywords if provided
266
+ if model_family_keywords is not None:
267
+ _CUSTOM_MODEL_FAMILY_KEYWORDS[template_name] = model_family_keywords
268
+
269
+ torchrl_logger.info(
270
+ f"Added custom chat template '{template_name}' with assistant token masking support"
271
+ )
272
+
273
+
274
+ # We need the 'shadow' flag to avoid having tensordict complaining about 'type'/'size' etc. fields
275
+ class ContentBase(TensorClass["nocast", "shadow"]):
276
+ """Base class for all message content types.
277
+
278
+ Attributes:
279
+ type (str): The type of the content.
280
+ text (str, optional): The text content.
281
+ url (str, optional): The URL content.
282
+ data (str, optional): The data content.
283
+ mime_type (str, optional): The MIME type of the content.
284
+ name (str, optional): The name of the content.
285
+ size (int, optional): The size of the content.
286
+ function_name (str, optional): The name of the function.
287
+ function_args (dict, optional): The arguments of the function.
288
+
289
+ Examples:
290
+ >>> from tensordict import lazy_stack
291
+ >>> content1 = ContentBase(type="text", text="Hello, world!")
292
+ >>> print(content1)
293
+ ContentBase(
294
+ text=NonTensorData(data=Hello, world!, batch_size=torch.Size([]), device=None),
295
+ type=NonTensorData(data=text, batch_size=torch.Size([]), device=None),
296
+ url=None,
297
+ data=None,
298
+ mime_type=None,
299
+ name=None,
300
+ size=None,
301
+ function_name=None,
302
+ function_args=None,
303
+ batch_size=torch.Size([]),
304
+ device=None,
305
+ is_shared=False)
306
+ >>> content2 = ContentBase(type="image", url="https://example.com/image.jpg")
307
+ >>> print(content2)
308
+ ContentBase(
309
+ type=NonTensorData(data=image, batch_size=torch.Size([]), device=None),
310
+ url=NonTensorData(data=https://example.com/image.jpg, batch_size=torch.Size([]), device=None),
311
+ text=None,
312
+ data=None,
313
+ mime_type=None,
314
+ name=None,
315
+ size=None,
316
+ function_name=None,
317
+ function_args=None,
318
+ batch_size=torch.Size([]),
319
+ device=None,
320
+ is_shared=False)
321
+ >>> content = lazy_stack([content1, content2])
322
+ >>> print(content)
323
+ ContentBase(
324
+ type=NonTensorStack(
325
+ ['text', 'image'],
326
+ batch_size=torch.Size([2]),
327
+ device=None),
328
+ url=None,
329
+ data=None,
330
+ mime_type=None,
331
+ name=None,
332
+ size=None,
333
+ function_name=None,
334
+ function_args=None,
335
+ text=None,
336
+ batch_size=torch.Size([2]),
337
+ device=None,
338
+ is_shared=False)
339
+ >>> # A content is typically used in a History object. Usually, its batch dimension is
340
+ >>> # one dimension greater than the History object.
341
+ >>> history = History(role="user", content=content)
342
+
343
+ """
344
+
345
+ type: Literal[
346
+ "text", "image", "audio", "video", "file", "function_call"
347
+ ] # Required: "text", "image", "audio", "video", "file", "function_call"
348
+
349
+ # Text content
350
+ text: str | None = None
351
+
352
+ # Media/file content (either URL or data)
353
+ url: str | None = None # HTTP URL to content
354
+ data: str | None = None # Base64 encoded content
355
+
356
+ # Metadata
357
+ mime_type: str | None = None # "image/jpeg", "audio/mp3", "application/pdf"
358
+ name: str | None = None # Original filename or description
359
+ size: int | None = None # File size in bytes
360
+
361
+ # Function calling (for AI agents)
362
+ function_name: str | None = None
363
+ function_args: dict | None = None
364
+
365
+
366
+ class History(TensorClass["nocast"]):
367
+ """A class representing a structured history of messages in a conversation, designed for efficient manipulation and integration with language models.
368
+
369
+ The `History` class provides a centralized API for managing conversational data, offering several advantages over
370
+ traditional list-based approaches:
371
+
372
+ - Centralized API for conversion to and from string formats, facilitating seamless integration with language models.
373
+ - Efficient methods to append, extend, and reshape history elements, enabling dynamic construction of conversation
374
+ trajectories, especially useful in reinforcement learning environments.
375
+ - Interoperability with the `transformers` API, allowing for easy tokenization and preparation of input data.
376
+ - **Assistant token masking support** across multiple model families for reinforcement learning applications.
377
+
378
+ **Recent Changes:**
379
+ - **ChatHistory Integration**: History objects are now used within :class:`~torchrl.modules.llm.policies.ChatHistory`
380
+ containers for structured conversation management in LLM environments.
381
+ - **Modular Wrapper Support**: Both vLLMWrapper and TransformersWrapper now use History objects when `input_mode="history"`
382
+ is specified, providing consistent conversation state management.
383
+ - **Environment Integration**: ChatEnv and related environments use History objects for state management and conversation tracking.
384
+
385
+ .. note:: The `"<none>"` role is used to indicate that the element is a placeholder,
386
+ for example when the tool call was not executed but a stack requires a certain number of elements
387
+ per batch to have congruent shapes. The :meth:`~torchrl.data.llm.chat.History.apply_chat_template`
388
+ method will remove the `<none>` role from the history.
389
+
390
+ **Assistant Token Masking Support:**
391
+
392
+ The class supports assistant token masking across multiple model families, allowing you to identify which tokens
393
+ in a conversation were generated by the assistant. This is crucial for reinforcement learning applications.
394
+
395
+ **Supported Model Families:**
396
+
397
+ - **Qwen family** (e.g., `Qwen/Qwen2.5-0.5B`): Custom template with full tool calling support
398
+ - **DialoGPT family** (e.g., `microsoft/DialoGPT-medium`): Custom template for conversation format
399
+ - **Falcon family** (e.g., `tiiuae/falcon-7b-instruct`): Custom template for instruction format
400
+ - **DeepSeek family** (e.g., `deepseek-ai/deepseek-coder-6.7b-base`): Custom template with native format
401
+ - **Other models** (OPT, GPT, MPT, BLOOM, Pythia, Phi, etc.): Default `chatml_format` template
402
+
403
+ **Example with Assistant Token Masking:**
404
+
405
+ .. code-block:: python
406
+
407
+ >>> from torchrl.data.llm.chat import History
408
+ >>> from torchrl.modules.llm.policies import ChatHistory
409
+ >>> from transformers import AutoTokenizer
410
+ >>>
411
+ >>> # Create a conversation history
412
+ >>> history = History.from_chats([[
413
+ ... {"role": "user", "content": "Hello"},
414
+ ... {"role": "assistant", "content": "Hi there!"},
415
+ ... {"role": "user", "content": "How are you?"},
416
+ ... {"role": "assistant", "content": "I'm doing well, thanks!"}
417
+ ... ]])
418
+ >>>
419
+ >>> # Create ChatHistory container for LLM wrapper
420
+ >>> chat_history = ChatHistory(prompt=history)
421
+ >>>
422
+ >>> # Load any supported tokenizer
423
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
424
+ >>>
425
+ >>> # Apply chat template with assistant token masking
426
+ >>> result = history.apply_chat_template(
427
+ ... tokenizer=tokenizer,
428
+ ... add_generation_prompt=False,
429
+ ... return_dict=True,
430
+ ... return_assistant_tokens_mask=True,
431
+ ... )
432
+ >>>
433
+ >>> # The result contains an assistant_masks tensor
434
+ >>> assistant_masks = result["assistant_masks"]
435
+ >>> print(f"Assistant tokens: {assistant_masks.sum().item()}")
436
+
437
+ **Integration with LLM Wrappers:**
438
+
439
+ History objects work seamlessly with the new modular wrapper design:
440
+
441
+ .. code-block:: python
442
+
443
+ >>> from torchrl.modules.llm import TransformersWrapper
444
+ >>> from torchrl.modules.llm.policies import ChatHistory
445
+ >>>
446
+ >>> # Create wrapper with history input mode
447
+ >>> wrapper = TransformersWrapper(
448
+ ... model, tokenizer=tokenizer,
449
+ ... input_mode="history",
450
+ ... generate=True,
451
+ ... return_log_probs=True
452
+ ... )
453
+ >>>
454
+ >>> # Use History with ChatHistory container
455
+ >>> history = History.from_chats([[
456
+ ... {"role": "user", "content": "Hello"},
457
+ ... {"role": "assistant", "content": "Hi there!"}
458
+ ... ]])
459
+ >>> chat_history = ChatHistory(prompt=history)
460
+ >>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,)))
461
+ >>> print(result["history"].response) # New response from LLM
462
+
463
+ Attributes:
464
+ role (str): The role of the message sender.
465
+ content (str): The content of the message.
466
+ is_complete (bool): Whether the message was properly terminated with an end token. Defaults to `True`.
467
+ tool_calls (list[dict] | None): Optional list of tool calls in the message.
468
+ tool_responses (list[str] | None): Optional list of tool responses.
469
+
470
+ Methods:
471
+ apply_chat_template: converts the `History` object to str / tokens.
472
+ append: append one element to the list of items along a given dimension.
473
+ extend: extend the list of items along a given dimension.
474
+
475
+ Examples:
476
+ >>> # With tensordict < 0.10, we need to tell the lib that lists constitute batches
477
+ >>> import tensordict
478
+ >>> tensordict.set_list_to_stack(True).set()
479
+ >>> import transformers
480
+ >>> history0 = History(
481
+ ... role='system',
482
+ ... content='''CONTENT
483
+ ... This is the setup''',
484
+ ... )
485
+ >>> history1 = History(
486
+ ... role='user',
487
+ ... content='''CONTENT
488
+ ... This is the first user prompt''',
489
+ ... )
490
+ >>> history2 = History(
491
+ ... role='assistant',
492
+ ... content='''CONTENT
493
+ ... This is the second prompt, the first for the assistant.''',
494
+ ... )
495
+ >>> history = torch.stack([history0, history1, history2])
496
+ >>> assert history.role == ['system', 'user', 'assistant']
497
+ >>> tokenizer = transformers.AutoTokenizer.from_pretrained("GPT2")
498
+ >>> # Apply a template to pass the history to an LLM. Note that the output has
499
+ >>> # an additional prompt to elict an answer from the LLM thanks to the 'add_generation_prompt' argument.
500
+ >>> parsed_string = history.apply_chat_template(tokenizer=tokenizer, add_generation_prompt=True)
501
+ >>> parsed_string
502
+ <|im_start|>system
503
+ CONTENT
504
+ This is the setup<|im_end|>
505
+
506
+ <|im_start|>user
507
+ CONTENT
508
+ This is the first user prompt<|im_end|>
509
+
510
+ <|im_start|>assistant
511
+ CONTENT
512
+ This is the second prompt, the first for the assistant.<|im_end|>
513
+
514
+ <|im_start|>assistant
515
+
516
+ .. seealso::
517
+ :class:`~torchrl.modules.llm.policies.ChatHistory`: Container for managing conversation data in LLM environments.
518
+ :class:`~torchrl.modules.llm.policies.Text`: Container for text data.
519
+ :class:`~torchrl.modules.llm.policies.Tokens`: Container for token data.
520
+ """
521
+
522
+ role: str | list[str] | list[list[str]]
523
+ content: str | ContentBase | list[str] | list[ContentBase] | list[list[str]] | list[
524
+ list[ContentBase]
525
+ ]
526
+ is_complete: bool = True
527
+ tool_calls: list[dict] | None = None
528
+ tool_responses: list[str] | None = None
529
+
530
+ def __post_init__(self):
531
+ if not list_to_stack():
532
+ raise RuntimeError(
533
+ "Please set the list_to_stack to True using tensordict.set_list_to_stack(True).set() at the beginning of your script, "
534
+ "or the LIST_TO_STACK=1 environment variable."
535
+ )
536
+
537
+ def apply_chat_template(
538
+ self,
539
+ *,
540
+ tokenizer: transformers.AutoTokenizer | transformers.AutoProcessor, # noqa
541
+ add_generation_prompt: bool = True,
542
+ chat_template: str | None = None,
543
+ chat_template_name: str | None = None,
544
+ continue_final_message: bool = False,
545
+ tokenize: bool | None = None,
546
+ padding: bool | str = False,
547
+ truncation: bool | str = False,
548
+ return_tensors: str | None = None,
549
+ return_dict: bool | None = None,
550
+ return_assistant_tokens_mask: bool = False,
551
+ **kwargs,
552
+ ) -> str | list[str] | TensorDict:
553
+ """Applies a chat template to the history.
554
+
555
+ Keyword Args:
556
+ tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor): The tokenizer to use.
557
+ add_generation_prompt (bool, optional): Whether to add a generation prompt (e.g. `"<|im_start|>assistant"`). Defaults to `True`.
558
+ chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template.
559
+ chat_template_name (str, optional): The name of the chat template to use.
560
+ Prevalent over `tokenizer.chat_template`. If `None`, the method will automatically detect the model family and use the appropriate template.
561
+ Defaults to `None`.
562
+ continue_final_message (bool, optional): Whether to continue the final message. Defaults to `False`.
563
+ tokenize (bool, optional): Whether to tokenize the output. Defaults to `False`.
564
+ padding (bool | str, optional): The padding strategy to use. Defaults to `False`.
565
+ truncation (bool | str, optional): The truncation strategy to use. Defaults to `False`.
566
+ return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
567
+ return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`.
568
+ return_assistant_tokens_mask (bool, optional): Whether to return a mask of the assistant generated tokens.
569
+ If `True`, the mask will be written to the `assistant_masks` key.
570
+ For tokens generated by the assistant, the mask will contain `1`.
571
+ For user and system tokens, the mask will contain `0`.
572
+ This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
573
+ Defaults to `False`.
574
+
575
+ .. note:: Assistant token masking is supported across multiple model families:
576
+ - **Qwen family**: Uses custom template with full tool calling support
577
+ - **DialoGPT family**: Uses custom template for conversation format
578
+ - **Falcon family**: Uses custom template for instruction format
579
+ - **DeepSeek family**: Uses custom template with native format
580
+ - **Other models**: Use the default `chatml_format` template
581
+
582
+ The method automatically detects the model family and selects the appropriate template.
583
+
584
+ **kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method.
585
+
586
+ Returns:
587
+ The formatted history.
588
+ """
589
+ if chat_template is None:
590
+ if chat_template_name is not None:
591
+ chat_template = _CHAT_TEMPLATES[chat_template_name]
592
+ chat_template_name = None
593
+ elif tokenizer is None:
594
+ raise RuntimeError(
595
+ "You must specify a tokenizer to use when chat_template is not specified."
596
+ )
597
+ else:
598
+ # Auto-detect model family and use appropriate template
599
+ model_name = getattr(tokenizer, "name_or_path", "").lower()
600
+
601
+ # First check for custom model family keywords
602
+ custom_template_found = False
603
+ for template_name, keywords in _CUSTOM_MODEL_FAMILY_KEYWORDS.items():
604
+ if any(keyword.lower() in model_name for keyword in keywords):
605
+ chat_template = _CHAT_TEMPLATES[template_name]
606
+ chat_template_name = None
607
+ custom_template_found = True
608
+ break
609
+
610
+ if not custom_template_found:
611
+ # Fall back to built-in model family detection
612
+ if "qwen" in model_name:
613
+ # We prefer our implementation of the Qwen template,
614
+ # since it accounts for the assistant's masking.
615
+ chat_template = _CHAT_TEMPLATES["qwen"]
616
+ chat_template_name = None
617
+ elif "dialogpt" in model_name or "microsoft/dialo" in model_name:
618
+ # DialoGPT family - use our custom template
619
+ chat_template = _CHAT_TEMPLATES["dialogpt"]
620
+ chat_template_name = None
621
+ elif "falcon" in model_name or "tiiuae/falcon" in model_name:
622
+ # Falcon family - use our custom template
623
+ chat_template = _CHAT_TEMPLATES["falcon"]
624
+ chat_template_name = None
625
+ elif "deepseek" in model_name:
626
+ # DeepSeek family - use our custom template with generation keyword
627
+ chat_template = _CHAT_TEMPLATES["deepseek"]
628
+ chat_template_name = None
629
+ elif "llama" in model_name:
630
+ # Llama family - use our custom template
631
+ chat_template = _CHAT_TEMPLATES["llama"]
632
+ chat_template_name = None
633
+ else:
634
+ # For other models, check if their default template supports generation
635
+ if (
636
+ hasattr(tokenizer, "chat_template")
637
+ and tokenizer.chat_template
638
+ and "{% generation %}" in tokenizer.chat_template
639
+ ):
640
+ # Use the model's own template if it supports generation
641
+ chat_template = tokenizer.chat_template
642
+ else:
643
+ # Use our default chatml_format template
644
+ chat_template = _CHAT_TEMPLATES["chatml_format"]
645
+ if chat_template is None:
646
+ chat_template = _CHAT_TEMPLATES["chatml_format"]
647
+ if tokenize is None:
648
+ if return_assistant_tokens_mask or return_tensors is not None:
649
+ tokenize = True
650
+ else:
651
+ tokenize = False
652
+ if tokenize:
653
+ if return_tensors is None:
654
+ return_tensors = "pt"
655
+ if return_dict is None and return_assistant_tokens_mask:
656
+ return_dict = True
657
+ elif return_dict is None:
658
+ return_dict = False
659
+
660
+ if self.ndim > 1:
661
+ result = [
662
+ self[i].apply_chat_template(
663
+ tokenizer=tokenizer,
664
+ add_generation_prompt=add_generation_prompt,
665
+ chat_template=chat_template,
666
+ chat_template_name=chat_template_name,
667
+ tokenize=tokenize,
668
+ padding=padding,
669
+ truncation=truncation,
670
+ return_tensors=return_tensors,
671
+ continue_final_message=continue_final_message,
672
+ return_dict=return_dict,
673
+ return_assistant_tokens_mask=return_assistant_tokens_mask,
674
+ **kwargs,
675
+ )
676
+ for i in range(self.batch_size[0])
677
+ ]
678
+ if return_dict:
679
+ return lazy_stack(result)
680
+ else:
681
+ return result
682
+ self_flat = self.view(-1)
683
+ # tolist_first=True is needed to avoid having a list of dict of dicts, but a list of dicts of lists of dicts
684
+ self_flat = self_flat.tolist(tolist_first=True)
685
+ # Remove the "<none>" role
686
+ self_flat = [item for item in self_flat if item["role"] != "<none>"]
687
+ result = tokenizer.apply_chat_template(
688
+ conversation=self_flat,
689
+ add_generation_prompt=add_generation_prompt,
690
+ chat_template=chat_template,
691
+ tokenize=tokenize,
692
+ padding=padding,
693
+ truncation=truncation,
694
+ return_tensors=return_tensors,
695
+ continue_final_message=continue_final_message,
696
+ return_dict=return_dict,
697
+ return_assistant_tokens_mask=return_assistant_tokens_mask,
698
+ **kwargs,
699
+ )
700
+ if not isinstance(result, (torch.Tensor, list, str)):
701
+ result = TensorDict.from_dict(result, auto_batch_size=True, batch_dims=1)
702
+ # If self has a batch_dims of 1, we have just the time dimension, so we need to remove the batch dim from the result
703
+ if self.batch_dims == 1:
704
+ if result.batch_size[0] != 1:
705
+ raise RuntimeError(
706
+ f"Expected a batch size of 1, got {result.batch_size[0]}."
707
+ )
708
+ result = result.squeeze(0)
709
+ return result
710
+
711
+ @classmethod
712
+ def from_text(
713
+ cls,
714
+ text: str | list[str],
715
+ chat_template_name: str | None = None,
716
+ # currently without effect
717
+ chat_template: str | None = None,
718
+ tokenizer: transformers.AutoTokenizer # noqa: F821
719
+ | transformers.AutoProcessor # noqa: F821
720
+ | None = None,
721
+ ) -> History:
722
+ r"""Inverts a chat template into a History object.
723
+
724
+ Args:
725
+ text (str | list[str]): The chat template to invert.
726
+ chat_template_name (str, optional): The name of the chat template to use.
727
+ tokenizer (transformers.AutoTokenizer | transformers.AutoProcessor, optional): The tokenizer to use.
728
+
729
+ Returns:
730
+ History: The inverted History object.
731
+
732
+ Examples:
733
+ >>> from torchrl.data.llm.history import History
734
+ >>> from transformers import AutoTokenizer
735
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
736
+ >>> text = "<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n<|im_start|>user\nWrite a python script that gives the capital of France or Germany.\n<|im_end|>\n<|im_start|>assistant\n<think>The capital of France is Paris, the capital of Germany is Berlin.</think>\n<answer><python>\n"
737
+ >>> history = History.from_text(text, tokenizer=tokenizer)
738
+ >>> print(history)
739
+ History(
740
+ content=NonTensorStack(
741
+ ['You are a helpful assistant.', 'Write a python s...,
742
+ batch_size=torch.Size([3]),
743
+ device=None),
744
+ is_complete=NonTensorStack(
745
+ [True, True, False],
746
+ batch_size=torch.Size([3]),
747
+ device=None),
748
+ role=NonTensorStack(
749
+ ['system', 'user', 'assistant'],
750
+ batch_size=torch.Size([3]),
751
+ device=None),
752
+ tool_calls=None,
753
+ tool_responses=None,
754
+ batch_size=torch.Size([3]),
755
+ device=None,
756
+ is_shared=False)
757
+ """
758
+ if chat_template_name is None:
759
+ if chat_template is not None:
760
+ # TODO: find best match given template
761
+ pass
762
+
763
+ model_name = getattr(tokenizer, "name_or_path", "").lower()
764
+ # First check for custom model family keywords
765
+ custom_template_found = False
766
+ for template_name, keywords in _CUSTOM_MODEL_FAMILY_KEYWORDS.items():
767
+ if any(keyword.lower() in model_name for keyword in keywords):
768
+ chat_template_name = template_name
769
+ custom_template_found = True
770
+ break
771
+
772
+ if not custom_template_found:
773
+ # Fall back to built-in model family detection
774
+ if "qwen" in model_name:
775
+ # We can automatically detect the template name from the tokenizer
776
+ # and use the precoded parser.
777
+ chat_template_name = "qwen"
778
+ elif "dialogpt" in model_name or "microsoft/dialo" in model_name:
779
+ chat_template_name = "dialogpt"
780
+ elif "falcon" in model_name or "tiiuae/falcon" in model_name:
781
+ chat_template_name = "falcon"
782
+ elif "deepseek" in model_name:
783
+ chat_template_name = "deepseek"
784
+ elif "llama" in model_name:
785
+ chat_template_name = "llama"
786
+ else:
787
+ chat_template_name = "chatml_format"
788
+
789
+ # Get the appropriate inverse parser function
790
+ if chat_template_name in ("chatml_format",):
791
+ func = cls._inv_chatml
792
+ elif chat_template_name in ("qwen",):
793
+ func = cls._inv_qwen
794
+ elif chat_template_name in ("dialogpt",):
795
+ func = cls._inv_dialogpt
796
+ elif chat_template_name in ("falcon",):
797
+ func = cls._inv_falcon
798
+ elif chat_template_name in ("deepseek",):
799
+ func = cls._inv_deepseek
800
+ elif chat_template_name in ("llama",):
801
+ func = cls._inv_llama
802
+ elif chat_template_name in _CUSTOM_INVERSE_PARSERS:
803
+ # Use custom inverse parser
804
+ func = _CUSTOM_INVERSE_PARSERS[chat_template_name]
805
+ else:
806
+ raise NotImplementedError(
807
+ f"chat_template_name '{chat_template_name}' is not supported. "
808
+ "Supported templates: 'chatml_format', 'qwen', 'dialogpt', 'falcon', 'deepseek'. "
809
+ "Use add_chat_template() to add custom templates."
810
+ )
811
+ if isinstance(text, list):
812
+ list_of_histories = [func(t) for t in text]
813
+ try:
814
+ return lazy_stack(list_of_histories)
815
+ except RuntimeError as e:
816
+ raise RuntimeError(
817
+ f"Failed to stack histories: {list_of_histories=}"
818
+ ) from e
819
+ return func(text)
820
+
821
+ @classmethod
822
+ def _inv_chatml(cls, text: str) -> History:
823
+ """Inverts a chatml string into a History object.
824
+
825
+ Args:
826
+ text (str): The chatml string to invert.
827
+
828
+ Returns:
829
+ History: The inverted History object.
830
+ """
831
+ import json
832
+
833
+ torchrl_logger.debug(f"Inverting chatml:\n{text}")
834
+ # Find all complete blocks (ending with im_end or endoftext)
835
+ complete_pattern = r"<\|im_start\|>(.*?)\n(.*?)<\|(im_end|endoftext)\|>"
836
+ complete_matches = re.findall(complete_pattern, text, flags=re.DOTALL)
837
+
838
+ # Find any incomplete block at the end
839
+ incomplete_pattern = r"<\|im_start\|>(.*?)\n(.*?)$"
840
+ incomplete_matches = []
841
+ if complete_matches:
842
+ # Look for incomplete block after the last complete one
843
+ last_complete = complete_matches[-1]
844
+ last_complete_text = f"<|im_start|>{last_complete[0]}\n{last_complete[1]}<|{last_complete[2]}|>"
845
+ remaining_text = text[
846
+ text.rindex(last_complete_text) + len(last_complete_text) :
847
+ ]
848
+ if remaining_text.strip():
849
+ incomplete_match = re.search(
850
+ incomplete_pattern, remaining_text, flags=re.DOTALL
851
+ )
852
+ if incomplete_match:
853
+ incomplete_matches = [
854
+ (incomplete_match.group(1), incomplete_match.group(2), None)
855
+ ]
856
+ else:
857
+ # No complete blocks, check entire text for incomplete block
858
+ incomplete_match = re.search(incomplete_pattern, text, flags=re.DOTALL)
859
+ if incomplete_match:
860
+ incomplete_matches = [
861
+ (incomplete_match.group(1), incomplete_match.group(2), None)
862
+ ]
863
+
864
+ # Combine complete and incomplete matches
865
+ matches = complete_matches + incomplete_matches
866
+
867
+ # Define tool patterns - same as Qwen for consistency
868
+ tool_call_pattern = re.compile(r"<tool_call>\n(.*?)\n</tool_call>", re.DOTALL)
869
+ tool_response_pattern = re.compile(
870
+ r"<tool_response>\n(.*?)\n</tool_response>", re.DOTALL
871
+ )
872
+
873
+ parsed_messages = []
874
+ for match in matches:
875
+ role = match[0].strip()
876
+ content = match[1].strip()
877
+ is_complete = match[2] is not None # None indicates incomplete
878
+
879
+ # Initialize message dict
880
+ message_dict = {
881
+ "role": role,
882
+ "content": content,
883
+ "is_complete": is_complete,
884
+ "tool_calls": None,
885
+ "tool_responses": None,
886
+ }
887
+
888
+ # Find tool calls within the message
889
+ tool_calls = tool_call_pattern.findall(content)
890
+ if tool_calls:
891
+ tool_calls_list = []
892
+ for tool_call in tool_calls:
893
+ try:
894
+ tool_call_dict = json.loads(tool_call)
895
+ tool_calls_list.append(tool_call_dict)
896
+ except json.JSONDecodeError:
897
+ continue
898
+ if tool_calls_list:
899
+ message_dict["tool_calls"] = tool_calls_list
900
+
901
+ # Check for tool responses
902
+ tool_responses = tool_response_pattern.findall(content)
903
+ if tool_responses:
904
+ message_dict["tool_responses"] = tool_responses
905
+
906
+ parsed_messages.append(cls(**message_dict))
907
+
908
+ if not parsed_messages:
909
+ raise RuntimeError(
910
+ f"Couldn't get a single item out of text {text}. A common cause "
911
+ f"if that special tokens should not be ommitted, did you set include_stop_str_in_output/skip_special_tokens=False?"
912
+ )
913
+
914
+ return lazy_stack(parsed_messages)
915
+
916
+ @classmethod
917
+ def _inv_qwen(cls, template):
918
+ import json
919
+
920
+ # Define regex patterns for different parts of the template
921
+ message_pattern = re.compile(
922
+ r"<\|im_start\|>(.*?)(?:<\|(im_end|endoftext)\|>|$)", re.DOTALL
923
+ )
924
+ tool_call_pattern = re.compile(r"<tool_call>\n(.*?)\n</tool_call>", re.DOTALL)
925
+ tool_response_pattern = re.compile(
926
+ r"<tool_response>\n(.*?)\n</tool_response>", re.DOTALL
927
+ )
928
+
929
+ # Find all messages and track if they end with a proper token
930
+ messages = []
931
+ is_complete_list = []
932
+ for match in message_pattern.finditer(template):
933
+ full_match = match.group(0)
934
+ messages.append(match.group(1))
935
+ # Check if the message ends with a proper token
936
+ is_complete_list.append(
937
+ full_match.endswith("<|im_end|>")
938
+ or full_match.endswith("<|endoftext|>")
939
+ )
940
+
941
+ parsed_messages = []
942
+ for message, is_complete in zip(messages, is_complete_list):
943
+ # Split the message into role and content
944
+ parts = message.split("\n", 1)
945
+ if len(parts) < 2:
946
+ continue
947
+ role, content = parts[0], parts[1]
948
+
949
+ # Initialize message dict
950
+ message_dict = {
951
+ "role": role.strip(),
952
+ "content": content.strip(),
953
+ "is_complete": is_complete,
954
+ "tool_calls": None,
955
+ "tool_responses": None,
956
+ }
957
+
958
+ # Find tool calls within the message
959
+ tool_calls = tool_call_pattern.findall(content)
960
+ if tool_calls:
961
+ tool_calls_list = []
962
+ for tool_call in tool_calls:
963
+ try:
964
+ tool_call_dict = json.loads(tool_call)
965
+ tool_calls_list.append(tool_call_dict)
966
+ except json.JSONDecodeError:
967
+ continue
968
+ if tool_calls_list:
969
+ message_dict["tool_calls"] = tool_calls_list
970
+
971
+ # Check for tool responses
972
+ tool_responses = tool_response_pattern.findall(content)
973
+ if tool_responses:
974
+ message_dict["tool_responses"] = tool_responses
975
+
976
+ parsed_messages.append(cls(**message_dict))
977
+
978
+ if not parsed_messages:
979
+ raise RuntimeError(
980
+ f"Couldn't get a single item out of text {template}. A common cause "
981
+ f"if that special tokens should not be ommitted, did you set include_stop_str_in_output/skip_special_tokens=False?"
982
+ )
983
+
984
+ return lazy_stack(parsed_messages)
985
+
986
+ @classmethod
987
+ def _inv_dialogpt(cls, text: str) -> History:
988
+ """Inverts a DialogPT string into a History object.
989
+
990
+ Args:
991
+ text (str): The DialogPT string to invert.
992
+
993
+ Returns:
994
+ History: The inverted History object.
995
+ """
996
+ torchrl_logger.debug(f"Inverting DialogPT:\n{text}")
997
+
998
+ # DialogPT format is simple: alternating user/assistant messages
999
+ # Split by lines and parse
1000
+ lines = text.strip().split("\n")
1001
+ parsed_messages = []
1002
+
1003
+ for line in lines:
1004
+ line = line.strip()
1005
+ if not line:
1006
+ continue
1007
+
1008
+ # Determine role based on content
1009
+ if line.startswith("Assistant:"):
1010
+ role = "assistant"
1011
+ content = line[len("Assistant:") :].strip()
1012
+ elif line.startswith("User:"):
1013
+ role = "user"
1014
+ content = line[len("User:") :].strip()
1015
+ else:
1016
+ # Default to user if no prefix
1017
+ role = "user"
1018
+ content = line
1019
+
1020
+ message_dict = {
1021
+ "role": role,
1022
+ "content": content,
1023
+ "is_complete": True, # DialogPT doesn't have explicit end tokens
1024
+ "tool_calls": None,
1025
+ "tool_responses": None,
1026
+ }
1027
+
1028
+ parsed_messages.append(cls(**message_dict))
1029
+
1030
+ if not parsed_messages:
1031
+ raise RuntimeError(f"Couldn't get a single item out of text {text}.")
1032
+
1033
+ return lazy_stack(parsed_messages)
1034
+
1035
+ @classmethod
1036
+ def _inv_falcon(cls, text: str) -> History:
1037
+ """Inverts a Falcon string into a History object.
1038
+
1039
+ Args:
1040
+ text (str): The Falcon string to invert.
1041
+
1042
+ Returns:
1043
+ History: The inverted History object.
1044
+ """
1045
+ torchrl_logger.debug(f"Inverting Falcon:\n{text}")
1046
+
1047
+ # Falcon format: "User: ... Assistant: ..."
1048
+ # Split by "User:" and "Assistant:" prefixes
1049
+ import re
1050
+
1051
+ # Pattern to match User: and Assistant: messages
1052
+ pattern = r"(User:|Assistant:)\s*(.*?)(?=(User:|Assistant:)|$)"
1053
+ matches = re.findall(pattern, text, re.DOTALL)
1054
+
1055
+ parsed_messages = []
1056
+ for match in matches:
1057
+ if len(match) != 2:
1058
+ continue
1059
+ prefix, content = match
1060
+ content = content.strip()
1061
+ if not content:
1062
+ continue
1063
+
1064
+ if prefix == "User:":
1065
+ role = "user"
1066
+ elif prefix == "Assistant:":
1067
+ role = "assistant"
1068
+ else:
1069
+ continue
1070
+
1071
+ message_dict = {
1072
+ "role": role,
1073
+ "content": content,
1074
+ "is_complete": True, # Falcon doesn't have explicit end tokens
1075
+ "tool_calls": None,
1076
+ "tool_responses": None,
1077
+ }
1078
+
1079
+ parsed_messages.append(cls(**message_dict))
1080
+
1081
+ if not parsed_messages:
1082
+ raise RuntimeError(f"Couldn't get a single item out of text {text}.")
1083
+
1084
+ return lazy_stack(parsed_messages)
1085
+
1086
+ @classmethod
1087
+ def _inv_deepseek(cls, text: str) -> History:
1088
+ """Inverts a DeepSeek string into a History object.
1089
+
1090
+ Args:
1091
+ text (str): The DeepSeek string to invert.
1092
+
1093
+ Returns:
1094
+ History: The inverted History object.
1095
+ """
1096
+ torchrl_logger.debug(f"Inverting DeepSeek:\n{text}")
1097
+ import re
1098
+
1099
+ # Remove leading/trailing special tokens (e.g.
1100
+ text = re.sub(r"^<[^>]+>", "", text) # Remove leading <...>
1101
+ text = re.sub(r"<[^>]+>$", "", text) # Remove trailing <...>
1102
+ # Remove any REDACTED_SPECIAL_TOKEN if present
1103
+ text = re.sub(r"REDACTED_SPECIAL_TOKEN", "", text)
1104
+ # Pattern to match User: and Assistant: messages
1105
+ pattern = r"(User:|Assistant:)\s*(.*?)(?=(User:|Assistant:)|$)"
1106
+ matches = re.findall(pattern, text, re.DOTALL)
1107
+ parsed_messages = []
1108
+ for match in matches:
1109
+ if len(match) < 2:
1110
+ continue
1111
+ prefix, content = match[0], match[1]
1112
+ content = content.strip()
1113
+ if not content:
1114
+ continue
1115
+ if prefix == "User:":
1116
+ role = "user"
1117
+ elif prefix == "Assistant:":
1118
+ role = "assistant"
1119
+ else:
1120
+ continue
1121
+ message_dict = {
1122
+ "role": role,
1123
+ "content": content,
1124
+ "is_complete": True, # DeepSeek doesn't have explicit end tokens
1125
+ "tool_calls": None,
1126
+ "tool_responses": None,
1127
+ }
1128
+ parsed_messages.append(cls(**message_dict))
1129
+ if not parsed_messages:
1130
+ raise RuntimeError(f"Couldn't get a single item out of text {text}.")
1131
+ return lazy_stack(parsed_messages)
1132
+
1133
+ @classmethod
1134
+ def _inv_llama(cls, text: str) -> History:
1135
+ import re
1136
+
1137
+ messages = []
1138
+
1139
+ # Remove BOS token if present
1140
+ if text.startswith("<|begin_of_text|>"):
1141
+ text = text[len("<|begin_of_text|>") :]
1142
+
1143
+ # Pattern to match complete message blocks: <|header_start|>role<|header_end|>\n\ncontent<|eot|>
1144
+ complete_pattern = r"<\|header_start\|>(\w+)<\|header_end\|>\n\n(.*?)<\|eot\|>"
1145
+ complete_matches = re.findall(complete_pattern, text, re.DOTALL)
1146
+
1147
+ # Pattern to match incomplete message blocks: <|header_start|>role<|header_end|>\n\ncontent (without <|eot|>)
1148
+ incomplete_pattern = r"<\|header_start\|>(\w+)<\|header_end\|>\n\n(.*?)$"
1149
+
1150
+ # Find any incomplete message at the end
1151
+ incomplete_matches = []
1152
+ if complete_matches:
1153
+ # Look for incomplete message after the last complete one
1154
+ last_complete_end = text.rfind("<|eot|>")
1155
+ if last_complete_end != -1:
1156
+ remaining_text = text[last_complete_end + len("<|eot|>") :]
1157
+ if remaining_text.strip():
1158
+ incomplete_match = re.search(
1159
+ incomplete_pattern, remaining_text, re.DOTALL
1160
+ )
1161
+ if incomplete_match:
1162
+ incomplete_matches = [
1163
+ (
1164
+ incomplete_match.group(1),
1165
+ incomplete_match.group(2),
1166
+ False,
1167
+ )
1168
+ ]
1169
+ else:
1170
+ # No complete messages, check entire text for incomplete message
1171
+ incomplete_match = re.search(incomplete_pattern, text, re.DOTALL)
1172
+ if incomplete_match:
1173
+ incomplete_matches = [
1174
+ (incomplete_match.group(1), incomplete_match.group(2), False)
1175
+ ]
1176
+
1177
+ # Process complete messages
1178
+ for role, content in complete_matches:
1179
+ if content.strip():
1180
+ messages.append(
1181
+ cls(role=role, content=content.strip(), is_complete=True)
1182
+ )
1183
+
1184
+ # Process incomplete messages
1185
+ for role, content, is_complete in incomplete_matches:
1186
+ if content.strip():
1187
+ messages.append(
1188
+ cls(role=role, content=content.strip(), is_complete=is_complete)
1189
+ )
1190
+
1191
+ if not messages:
1192
+ raise RuntimeError(f"Couldn't parse Llama format from text: {text}")
1193
+
1194
+ from tensordict import lazy_stack
1195
+
1196
+ return lazy_stack(messages)
1197
+
1198
+ def append(
1199
+ self, history: History, *, inplace: bool = True, dim: int = -1
1200
+ ) -> History:
1201
+ """Appends a new history to the current one.
1202
+
1203
+ Args:
1204
+ history (History): The new history to append.
1205
+ inplace (bool, optional): Whether to perform the operation in-place. Defaults to `True`.
1206
+ dim (int, optional): The dimension to append along. Defaults to -1.
1207
+
1208
+ Returns:
1209
+ History: The appended History object.
1210
+ """
1211
+ # TODO: we should remove the <none> role from the history before appending / extending
1212
+ # It works when keeping them, but it may lead to a lot of useless padding in between valid messages
1213
+ if not self.batch_dims:
1214
+ raise RuntimeError(
1215
+ "Cannot append an element to a batchless History. Call unsqueeze(dim=0) first on self."
1216
+ )
1217
+ if self.batch_dims != history.batch_dims + 1:
1218
+ raise RuntimeError(
1219
+ f"The new history to append must have one less dimension than self. Got self.ndim={self.ndim} and history.ndim={history.ndim}."
1220
+ )
1221
+ dim = _maybe_correct_neg_dim(dim, self.batch_size)
1222
+ if inplace:
1223
+ if (
1224
+ isinstance(self._tensordict, LazyStackedTensorDict)
1225
+ and self._tensordict.stack_dim == dim
1226
+ ):
1227
+ td = history._tensordict
1228
+ if td.device != self.device:
1229
+ if self.device is None:
1230
+ td = td.copy().clear_device_()
1231
+ else:
1232
+ td = td.to(self.device)
1233
+ self._tensordict.append(td)
1234
+ return self
1235
+ else:
1236
+ td = history._tensordict
1237
+ if td.device != self.device:
1238
+ if self.device is None:
1239
+ td = td.copy().clear_device_()
1240
+ else:
1241
+ td = td.to(self.device)
1242
+ td = lazy_stack(list(self._tensordict.unbind(dim)) + [td], dim=dim)
1243
+ self.__dict__["_tensordict"] = td
1244
+ return self
1245
+ if history.device != self.device:
1246
+ if self.device is None:
1247
+ history = history.copy().clear_device_()
1248
+ else:
1249
+ history = history.to(self.device)
1250
+ return lazy_stack(list(self.unbind(dim)) + [history], dim=dim)
1251
+
1252
+ def extend(
1253
+ self, history: History, *, inplace: bool = True, dim: int = 0
1254
+ ) -> History:
1255
+ if not self.batch_dims:
1256
+ raise RuntimeError(
1257
+ "Cannot add an element to a batchless History. Call unsqueeze(dim=0) first on self."
1258
+ )
1259
+ if self.batch_dims != history.batch_dims:
1260
+ raise RuntimeError(
1261
+ f"The new history to extend must have as many dimensions as self. Got self.ndim={self.ndim} and history.ndim={self.ndim}."
1262
+ )
1263
+ dim = _maybe_correct_neg_dim(dim, self.batch_size)
1264
+ # if self.ndim > 1 and dim >= self.ndim - 1:
1265
+ # # then we need to append each element independently
1266
+ # result = []
1267
+ # for hist, new_hist in zip(self.unbind(0), history.unbind(0)):
1268
+ # hist_c = hist.extend(new_hist, inplace=inplace, dim=dim - 1)
1269
+ # result.append(hist_c)
1270
+ # if inplace:
1271
+ # return self
1272
+ # return lazy_stack(result)
1273
+ if inplace:
1274
+ if (
1275
+ isinstance(self._tensordict, LazyStackedTensorDict)
1276
+ and self._tensordict.stack_dim == dim
1277
+ ):
1278
+ td = history._tensordict
1279
+ if td.device != self.device:
1280
+ if self.device is None:
1281
+ td = td.copy().clear_device_()
1282
+ else:
1283
+ td = td.to(self.device)
1284
+ self._tensordict.extend(td)
1285
+ return self
1286
+ else:
1287
+ td = lazy_stack(
1288
+ list(self._tensordict.unbind(dim))
1289
+ + list(history._tensordict.unbind(dim)),
1290
+ dim=dim,
1291
+ )
1292
+ if td.device != self.device:
1293
+ if self.device is None:
1294
+ td = td.copy().clear_device_()
1295
+ else:
1296
+ td = td.to(self.device)
1297
+ self.__dict__["_tensordict"] = td
1298
+ return self
1299
+ if history.device != self.device:
1300
+ if self.device is None:
1301
+ history = history.copy().clear_device_()
1302
+ else:
1303
+ history = history.to(self.device)
1304
+ return torch.stack(list(self.unbind(dim)) + list(history.unbind(dim)), dim=dim)
1305
+
1306
+ @classmethod
1307
+ def default_spec(cls, shape=(-1,)):
1308
+ """A default spec to use in transforms / envs that return History objects.
1309
+
1310
+ Args:
1311
+ shape (torch.Size, optional): The shape of the returned History spec. Defaults to `(-1)` (variable length
1312
+ along the time dimension).
1313
+
1314
+ Example:
1315
+ >>> import tensordict
1316
+ >>> from torchrl.data import History
1317
+ >>> tensordict.set_list_to_stack(True).set()
1318
+ >>>
1319
+ >>> history = History(role=["system", "user"], content=["a message", "another message"], batch_size=(2,))
1320
+ >>> spec = history.default_spec()
1321
+ >>> print(spec)
1322
+ Composite(
1323
+ role: NonTensor(
1324
+ shape=torch.Size([-1]),
1325
+ space=None,
1326
+ device=None,
1327
+ dtype=None,
1328
+ domain=None,
1329
+ example_data=foo),
1330
+ content: NonTensor(
1331
+ shape=torch.Size([-1]),
1332
+ space=None,
1333
+ device=None,
1334
+ dtype=None,
1335
+ domain=None,
1336
+ example_data=foo),
1337
+ device=None,
1338
+ shape=torch.Size([-1]))
1339
+ >>> print(spec.zero())
1340
+ History(
1341
+ content=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None),
1342
+ role=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None),
1343
+ batch_size=torch.Size([1]),
1344
+ device=None,
1345
+ is_shared=False)
1346
+
1347
+ """
1348
+ from torchrl.data import Composite, NonTensor
1349
+
1350
+ def get_default_value(field):
1351
+ if field.default is not dataclasses.MISSING:
1352
+ return field.default
1353
+ elif field.type in (str, "str"):
1354
+ return "foo"
1355
+ else:
1356
+ return None
1357
+
1358
+ defaults = {
1359
+ k: NonTensor(
1360
+ example_data=get_default_value(cls.__dataclass_fields__[k]),
1361
+ shape=shape,
1362
+ )
1363
+ for k in cls.__dataclass_fields__
1364
+ }
1365
+
1366
+ return Composite(defaults, shape=shape[:-1], data_cls=cls)
1367
+
1368
+ @classmethod
1369
+ def from_chats(cls, chats: list[list[dict]]) -> History:
1370
+ """Create a History object from a list of chats.
1371
+
1372
+ Args:
1373
+ chats (list[list[dict]]): A list of chats, where each chat is a list of dictionaries.
1374
+ """
1375
+ if isinstance(chats[0], dict):
1376
+ return lazy_stack([cls(**chat) for chat in chats])
1377
+ else:
1378
+ return lazy_stack([cls.from_chats(chat) for chat in chats])