torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,482 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import contextlib
8
+ import warnings
9
+
10
+ from dataclasses import dataclass
11
+ from typing import Literal, TYPE_CHECKING
12
+
13
+ import torch
14
+ from tensordict import NestedKey, TensorClass, TensorDictBase
15
+ from tensordict.nn import TensorDictModule
16
+ from tensordict.utils import _zip_strict
17
+ from torchrl.data import History
18
+ from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper
19
+ from torchrl.objectives.common import LossModule
20
+
21
+ if TYPE_CHECKING:
22
+ import transformers
23
+
24
+
25
+ def sft_loss(summed_log_probs: torch.Tensor, reduction: str) -> torch.Tensor:
26
+ """Compute the SFT loss."""
27
+ if reduction == "mean":
28
+ loss = -summed_log_probs.mean()
29
+ elif reduction == "sum":
30
+ loss = -summed_log_probs.sum()
31
+ elif reduction == "none":
32
+ loss = -summed_log_probs
33
+ else:
34
+ raise ValueError(f"Invalid reduction: {reduction}.")
35
+ return loss
36
+
37
+
38
+ def minor_sft_loss(
39
+ log_probs: torch.Tensor,
40
+ ref_log_probs: torch.Tensor,
41
+ beta: float,
42
+ reduction: str,
43
+ ) -> torch.Tensor:
44
+ """Compute the MinorSFT loss.
45
+
46
+ This loss is inspired by DPO and is designed to be less aggressive than standard SFT.
47
+ It computes ``-log_sigmoid(beta * (log_probs - ref_log_probs))``.
48
+
49
+ Args:
50
+ log_probs (torch.Tensor): The log probabilities from the model being trained.
51
+ ref_log_probs (torch.Tensor): The log probabilities from the reference model.
52
+ beta (float): The beta parameter from DPO.
53
+ reduction (str): The reduction to apply to the loss.
54
+
55
+ Returns:
56
+ The MinorSFT loss.
57
+
58
+ References:
59
+ - Shiming Xie, Hong Chen, Fred Yu, Zeye Sun, Xiuyu Wu, 2024.
60
+ `"Minor SFT loss for LLM fine-tune to increase performance and reduce model deviation" <https://arxiv.org/abs/2408.10642>`_
61
+ """
62
+ if log_probs.shape != ref_log_probs.shape:
63
+ raise ValueError(
64
+ f"Current log probabilities and reference log probabilities have different shapes: {log_probs.shape=} vs {ref_log_probs.shape=}."
65
+ )
66
+ loss = -torch.nn.functional.logsigmoid(beta * (log_probs - ref_log_probs))
67
+ if reduction == "mean":
68
+ return loss.mean()
69
+ if reduction == "sum":
70
+ return loss.sum()
71
+ if reduction == "none":
72
+ return loss
73
+ raise ValueError(f"Invalid reduction: {reduction}")
74
+
75
+
76
+ class SFTLossOutput(TensorClass["nocast"]):
77
+ """SFT Loss Output.
78
+
79
+ Attributes:
80
+ loss_sft (torch.Tensor): The loss for the SFT objective.
81
+ loss_kl_to_ref (torch.Tensor | None): The loss for the KL divergence to the reference model.
82
+ kl_to_ref (torch.Tensor | None): The KL divergence to the reference model.
83
+
84
+ .. note::
85
+ The loss components are kept separate to allow for logging and visualization.
86
+ Before backpropagation, the loss components are to be summed together. Since non-loss components are not differentiable
87
+ when the loss is constructed via :class:`~torchrl.objectives.llm.sft.SFTLoss`, summing
88
+ the :class:`~torchrl.objectives.llm.sft.SFTLossOutput` directly is a proper way of obtaining the total loss.
89
+
90
+ >>> loss_fn = SFTLoss(...)
91
+ >>> loss_output = loss_fn(td)
92
+ >>> loss = loss_output.loss_sft + loss_output.loss_kl_to_ref
93
+ >>> loss.backward()
94
+ >>> # or equivalently
95
+ >>> loss = loss_fn(td)
96
+ >>> loss.sum(reduce=True).backward()
97
+ """
98
+
99
+ loss_sft: torch.Tensor
100
+ loss_kl_to_ref: torch.Tensor | None = None
101
+ kl_to_ref: torch.Tensor | None = None
102
+
103
+
104
+ class SFTLoss(LossModule):
105
+ r"""Supervised fine-tuning loss.
106
+
107
+ Args:
108
+ actor_network (TensorDictModule): the actor network. Usually a :class:`~torchrl.modules.llm.TransformersWrapper` instance,
109
+ with `return_log_prob=True` and `from_text=True`.
110
+ tokenizer (`Tokenizer`): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `actor_network`.
111
+ tokenizer_kwargs (dict, optional): keyword arguments to pass to the tokenizer during :meth:`~torchrl.data.llm.chat.History.apply_chat_template`.
112
+ This can be used to override arguments such as the `chat_template` or `chat_template_name`.
113
+ reduction (Literal["mean", "sum", "none"], optional): the reduction to apply to the loss. Defaults to `"mean"`.
114
+ normalize_by_seq_length (bool, optional): whether to normalize the loss by the sequence length. Defaults to `True`.
115
+ kl_to_ref_coeff (float | None, optional): coefficient for KL divergence to reference model. Defaults to `None`.
116
+ loss_function (Literal["sft", "minor_sft"], optional): The loss function to use. Defaults to `"sft"`.
117
+ beta (float, optional): The beta parameter for MinorSFT loss. This is only used when `loss_function` is `"minor_sft"`.
118
+ Higher values of beta make the loss more aggressive (pushes the model to generate responses further from the reference model):
119
+
120
+ .. math::
121
+ \text{loss} = -\log\sigma(\beta \cdot (\text{log_probs} - \text{ref_log_probs}))
122
+
123
+ Defaults to `0.1`.
124
+ device (torch.device | None, optional): the device to use for the loss, when tokenizing the input. Defaults to `None`.
125
+
126
+ .. note::
127
+ The input tensordict is expected to contain the following keys by default:
128
+ - ``("next", "history")``: The chat history
129
+ - ``("next", "ref_log_probs")`` (optional): Reference model log probabilities, required if kl_to_ref_coeff is set
130
+
131
+ These keys can be customized using the ``set_keys()`` method.
132
+
133
+ .. seealso:: :class:`~torchrl.envs.llm.transforms.RetrieveLogProb` for the KL divergence computation.
134
+
135
+ References:
136
+ - Shiming Xie, Hong Chen, Fred Yu, Zeye Sun, Xiuyu Wu, 2024.
137
+ `"Minor SFT loss for LLM fine-tune to increase performance and reduce model deviation" <https://arxiv.org/abs/2408.10642>`_
138
+
139
+ Examples:
140
+ >>> from torchrl.data.llm.chat import History, _CHAT_TEMPLATES
141
+ >>> from torchrl.modules.llm import TransformersWrapper
142
+ >>> from torchrl.objectives.llm.sft import SFTLoss
143
+ >>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
144
+ >>> from tensordict import TensorDict, lazy_stack
145
+ >>> import torch
146
+ >>>
147
+ >>> # Create chat data
148
+ >>> chats = [
149
+ ... [
150
+ ... {"role": "system", "content": "You are a helpful assistant."},
151
+ ... {"role": "user", "content": "Hello, how are you?"},
152
+ ... {"role": "assistant", "content": "I'm doing well, thank you!"},
153
+ ... ],
154
+ ... [
155
+ ... {"role": "system", "content": "You are a helpful assistant."},
156
+ ... {"role": "user", "content": "What's the weather like?"},
157
+ ... {"role": "assistant", "content": "I can't check the weather for you."},
158
+ ... ],
159
+ ... ]
160
+ >>> history = History.from_chats(chats)
161
+ >>>
162
+ >>> # Setup tokenizer and model
163
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
164
+ >>> tokenizer.pad_token = tokenizer.eos_token
165
+ >>> tokenizer.chat_template = _CHAT_TEMPLATES["chatml_format"]
166
+ >>> model = OPTForCausalLM(OPTConfig()).eval()
167
+ >>>
168
+ >>> # Create training and reference policies
169
+ >>> policy_train = TransformersWrapper(
170
+ ... model,
171
+ ... tokenizer=tokenizer,
172
+ ... generate=False,
173
+ ... from_text=True,
174
+ ... chat_template_name="qwen",
175
+ ... )
176
+ >>> policy_ref = TransformersWrapper(
177
+ ... model,
178
+ ... tokenizer=tokenizer,
179
+ ... generate=False,
180
+ ... from_text=True,
181
+ ... return_log_probs=True,
182
+ ... chat_template_name="qwen",
183
+ ... )
184
+ >>>
185
+ >>> # Create the RetrieveLogProb transform
186
+ >>> transform = RetrieveLogProb(
187
+ ... policy_ref,
188
+ ... assistant_only=True,
189
+ ... tokenizer_kwargs={"chat_template_name": "qwen"},
190
+ ... tokenizer=tokenizer,
191
+ ... )
192
+ >>>
193
+ >>> # Prepare data
194
+ >>> text = history[:, :-1].apply_chat_template(
195
+ ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=True
196
+ ... )
197
+ >>> text_response = history.apply_chat_template(
198
+ ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False
199
+ ... )
200
+ >>> text_response = [
201
+ ... txt[len(txt_start):] for txt, txt_start in zip(text_response, text)
202
+ ... ]
203
+ >>> td = TensorDict(
204
+ ... text=text,
205
+ ... text_response=text_response,
206
+ ... history=history,
207
+ ... next=TensorDict(
208
+ ... reward=torch.randn(2, 1),
209
+ ... done=torch.zeros(2, dtype=torch.bool),
210
+ ... history=history,
211
+ ... ),
212
+ ... batch_size=(2,),
213
+ ... )
214
+ >>> data = lazy_stack(list(td.unbind(0)))
215
+ >>>
216
+ >>> # Apply the transform to get reference log probabilities
217
+ >>> data = transform(data)
218
+ >>> assert "ref_log_probs" in data["next"].keys()
219
+ >>>
220
+ >>> # Use with SFTLoss for KL regularization
221
+ >>> loss = SFTLoss(
222
+ ... actor_network=policy_train,
223
+ ... tokenizer=tokenizer,
224
+ ... reduction="mean",
225
+ ... normalize_by_seq_length=True,
226
+ ... kl_to_ref_coeff=0.1,
227
+ ... tokenizer_kwargs={"chat_template_name": "qwen"},
228
+ ... loss_function="sft",
229
+ ... )
230
+ >>> loss_vals = loss(data)
231
+ >>> print(f"SFT Loss: {loss_vals.loss_sft.item():.4f}")
232
+ >>> print(f"KL to Reference Loss: {loss_vals.loss_kl_to_ref.item():.4f}")
233
+
234
+ """
235
+
236
+ @dataclass
237
+ class _AcceptedKeys:
238
+ """Maintains default values for all configurable tensordict keys.
239
+
240
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
241
+ default values.
242
+
243
+ Attributes:
244
+ history (NestedKey): The input tensordict key where the chat history is expected.
245
+ Defaults to ``("next", "history")``.
246
+ ref_log_prob (NestedKey): The input tensordict key where the reference model log probabilities are expected.
247
+ Only used when kl_to_ref_coeff is set. Defaults to ``("next", "ref_log_probs")``.
248
+ log_probs (NestedKey): The output tensordict key where the model's log probabilities will be written.
249
+ Defaults to ``"log_probs"``.
250
+ """
251
+
252
+ history: NestedKey = ("history", "full")
253
+ ref_log_prob: NestedKey = ("next", "ref_log_probs", "full")
254
+ log_probs: NestedKey = ("log_probs", "full")
255
+
256
+ default_keys = _AcceptedKeys
257
+ tensor_keys: _AcceptedKeys
258
+
259
+ def __init__(
260
+ self,
261
+ actor_network: TensorDictModule | TransformersWrapper,
262
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
263
+ tokenizer_kwargs: dict | None = None,
264
+ reduction: Literal["mean", "sum", "none"] = "mean",
265
+ normalize_by_seq_length: bool = True,
266
+ kl_to_ref_coeff: float | None = None,
267
+ loss_function: Literal["sft", "minor_sft"] = "sft",
268
+ beta: float = 0.1,
269
+ device: torch.device | None = None,
270
+ ):
271
+ super().__init__()
272
+ self.in_keys = []
273
+ self.actor_network = actor_network
274
+ if tokenizer is None:
275
+ tokenizer = actor_network.tokenizer
276
+ self.tokenizer = tokenizer
277
+ if tokenizer_kwargs is None:
278
+ tokenizer_kwargs = {}
279
+ if tokenizer is None:
280
+ raise ValueError("Tokenizer must be provided.")
281
+ tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
282
+ tokenizer_kwargs.setdefault("tokenize", True)
283
+ tokenizer_kwargs.setdefault("return_tensors", "pt")
284
+ tokenizer_kwargs.setdefault("padding", False)
285
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
286
+ self.tokenizer_kwargs = tokenizer_kwargs
287
+ self.reduction = reduction
288
+ self.normalize_by_seq_length = normalize_by_seq_length
289
+ self.kl_to_ref_coeff = kl_to_ref_coeff
290
+ self.loss_function = loss_function
291
+ if self.loss_function == "minor_sft" and kl_to_ref_coeff:
292
+ warnings.warn(
293
+ "kl_to_ref_coeff should not be set when using minor_sft loss, as KL regularization is implicit. Setting kl_to_ref_coeff to 0.0."
294
+ )
295
+ self.kl_to_ref_coeff = 0.0
296
+ self.beta = beta
297
+ self._set_in_keys()
298
+ self.device = device
299
+
300
+ def _set_in_keys(self) -> None:
301
+ """Sets the input keys for the loss module."""
302
+ in_keys = [self.tensor_keys.history]
303
+ if self.kl_to_ref_coeff is not None or self.loss_function == "minor_sft":
304
+ in_keys.append(self.tensor_keys.ref_log_prob)
305
+ self.in_keys = in_keys
306
+ self.out_keys = [] # Loss modules typically don't have out_keys
307
+
308
+ def _kl_to_ref(
309
+ self,
310
+ cur_log_prob: list[torch.Tensor],
311
+ ref_log_prob: list[torch.Tensor],
312
+ ) -> tuple[torch.Tensor, torch.Tensor]:
313
+ """Compute KL divergence to reference model.
314
+
315
+ Args:
316
+ cur_log_prob (List[torch.Tensor]): Log probabilities from current model. Must have shape [T] where T is the number of tokens in the assistant response.
317
+ ref_log_prob (List[torch.Tensor]): Log probabilities from reference model. Must have shape [T] where T is the number of tokens in the assistant response.
318
+
319
+ Returns:
320
+ tuple[torch.Tensor, torch.Tensor]: (KL loss term, KL penalty for logging)
321
+ """
322
+ # Apply mask
323
+ ref_log_prob = torch.cat(ref_log_prob)
324
+ cur_log_prob = torch.cat(cur_log_prob)
325
+ # ref_log_prob = ref_log_prob[mask]
326
+ # cur_log_prob = cur_log_prob[mask].squeeze()
327
+ if cur_log_prob.shape != ref_log_prob.shape:
328
+ raise ValueError(
329
+ f"Current log probabilities and reference log probabilities have different shapes: {cur_log_prob.shape=} vs {ref_log_prob.shape=}."
330
+ )
331
+ # Compute KL using same approximation as GRPO
332
+ diff = ref_log_prob - cur_log_prob
333
+
334
+ kl_penalty = (diff.expm1() - diff).mean()
335
+ return self.kl_to_ref_coeff * kl_penalty, kl_penalty
336
+
337
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
338
+ # Gather history
339
+ history: History = tensordict[self.tensor_keys.history]
340
+
341
+ # Try to get mask from td
342
+ token_struct = None
343
+ assistant_masks = tensordict.get(("masks", "all_assistant_mask"), as_list=True)
344
+ attention_mask = tensordict.get(("masks", "all_attention_mask"), as_list=True)
345
+ if assistant_masks is None:
346
+ # Apply tokenizer to history and gather mask
347
+ with torch.device(
348
+ self.device
349
+ ) if self.device is not None else contextlib.nullcontext():
350
+ token_struct = history.apply_chat_template(
351
+ tokenizer=self.tokenizer, **self.tokenizer_kwargs
352
+ )
353
+ if "assistant_masks" not in token_struct:
354
+ raise ValueError(
355
+ f"Assistant masks are not present in the token structure: {token_struct=}."
356
+ )
357
+ assistant_masks = token_struct.get(
358
+ "assistant_masks",
359
+ as_list=True,
360
+ )
361
+ attention_mask = token_struct.get("attention_mask", as_list=True)
362
+ assistant_masks = [mask.bool() for mask in assistant_masks]
363
+ attention_mask = [mask.bool() for mask in attention_mask]
364
+ assistant_masks = [
365
+ mask & a_mask for mask, a_mask in zip(assistant_masks, attention_mask)
366
+ ]
367
+
368
+ if not any(mask.any(-1).all() for mask in assistant_masks):
369
+ raise ValueError("Some inputs have no valid assistant masks.")
370
+
371
+ input_loss = tensordict.select(self.tensor_keys.history)
372
+
373
+ with torch.device(
374
+ self.device
375
+ ) if self.device is not None else contextlib.nullcontext():
376
+ output_loss = self.actor_network(input_loss)
377
+
378
+ # get log-probs
379
+ log_probs = output_loss.get(
380
+ self.tensor_keys.log_probs,
381
+ as_list=True,
382
+ )
383
+
384
+ # apply mask
385
+ if not all(
386
+ mask.shape == lp.shape
387
+ for mask, lp in _zip_strict(assistant_masks, log_probs)
388
+ ):
389
+ if token_struct is not None:
390
+ suffix = f"Tokens from current template: {[inp.shape for inp in token_struct.get('input_ids', as_padded_tensor=True)]}"
391
+ else:
392
+ suffix = ""
393
+ raise ValueError(
394
+ f"Assistant masks and log_probs have different shapes: {[mask.shape for mask in assistant_masks]} vs "
395
+ f"{[lp.shape for lp in log_probs]}. {suffix}"
396
+ )
397
+
398
+ log_probs_masked = [
399
+ lp.masked_fill(~mask, 0.0)
400
+ for lp, mask in _zip_strict(log_probs, assistant_masks)
401
+ ]
402
+
403
+ # Sum log probs, optionally normalize by sequence length
404
+ summed_log_probs = torch.stack(
405
+ [lp.sum(tensordict.ndim - 1) for lp in log_probs_masked]
406
+ )
407
+ seq_lengths = torch.stack(
408
+ [mask.sum(tensordict.ndim - 1) for mask in assistant_masks]
409
+ )
410
+ if self.normalize_by_seq_length:
411
+ # Compute sequence lengths for normalization (number of assistant tokens)
412
+ summed_log_probs = summed_log_probs / seq_lengths.clamp(min=1)
413
+
414
+ # Compute main loss
415
+ if self.loss_function == "sft":
416
+ loss = sft_loss(summed_log_probs, self.reduction)
417
+ # Add KL divergence loss if reference model is provided
418
+ if self.kl_to_ref_coeff is not None:
419
+ ref_log_probs = tensordict.get(
420
+ self.tensor_keys.ref_log_prob,
421
+ default=None,
422
+ as_list=True,
423
+ )
424
+ if ref_log_probs is None:
425
+ raise ValueError(
426
+ f"Reference log probs not found in tensordict at key {self.tensor_keys.ref_log_prob} but kl_to_ref_coeff was set. "
427
+ f"Existing keys in tensordict: {set(tensordict.keys(include_nested=True, leaves_only=True))}"
428
+ )
429
+
430
+ log_probs_masked = [
431
+ lp.masked_fill(~mask, 0.0)
432
+ for lp, mask in _zip_strict(log_probs, assistant_masks)
433
+ ]
434
+
435
+ loss_kl, kl_penalty = self._kl_to_ref(
436
+ log_probs_masked,
437
+ ref_log_probs,
438
+ )
439
+ output = SFTLossOutput(
440
+ loss_sft=loss,
441
+ loss_kl_to_ref=loss_kl,
442
+ kl_to_ref=kl_penalty.detach(),
443
+ )
444
+ else:
445
+ output = SFTLossOutput(loss_sft=loss)
446
+ elif self.loss_function == "minor_sft":
447
+ ref_log_probs = tensordict.get(self.tensor_keys.ref_log_prob, as_list=True)
448
+ if ref_log_probs is None:
449
+ raise ValueError(
450
+ f"Reference log probs not found at {self.tensor_keys.ref_log_prob=} in tensordict with keys {tensordict.keys(True, True)} but loss_function is 'minor_sft'"
451
+ )
452
+
453
+ # we need to re-sum ref_log_probs as they are not summed per-sequence
454
+ summed_ref_log_probs = torch.stack([lp.sum() for lp in ref_log_probs]).to(
455
+ summed_log_probs.device
456
+ )
457
+ if self.normalize_by_seq_length:
458
+ summed_ref_log_probs = summed_ref_log_probs / seq_lengths.clamp(min=1)
459
+ loss = minor_sft_loss(
460
+ summed_log_probs, summed_ref_log_probs, self.beta, self.reduction
461
+ )
462
+ if self.kl_to_ref_coeff is not None:
463
+ with torch.no_grad():
464
+ log_probs_masked = [
465
+ lp.masked_fill(~mask, 0.0)
466
+ for lp, mask in _zip_strict(log_probs, assistant_masks)
467
+ ]
468
+ loss_kl, kl_penalty = self._kl_to_ref(
469
+ log_probs_masked,
470
+ ref_log_probs,
471
+ )
472
+ output = SFTLossOutput(
473
+ loss_sft=loss,
474
+ loss_kl_to_ref=loss_kl,
475
+ kl_to_ref=kl_penalty.detach(),
476
+ )
477
+ else:
478
+ output = SFTLossOutput(loss_sft=loss)
479
+ else:
480
+ raise ValueError(f"Invalid loss function: {self.loss_function}")
481
+
482
+ return output
@@ -0,0 +1,8 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .qmixer import QMixerLoss
7
+
8
+ __all__ = ["QMixerLoss"]