torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,543 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import abc
8
+ import collections
9
+ import importlib
10
+
11
+ import numpy as np
12
+ import torch
13
+ from tensordict import TensorDict
14
+ from torch import nn, Tensor
15
+ from torch.nn import functional as F
16
+
17
+ from torchrl.data.llm.prompt import PromptData
18
+
19
+ _has_transformers = importlib.util.find_spec("transformers") is not None
20
+
21
+
22
+ class KLControllerBase(abc.ABC):
23
+ """Base class for KL controllers.
24
+
25
+ Each controller must implement an update method that takes the current KL value and
26
+ the number of steps and updates the kl_coef attribute of the wrapped model,
27
+ which will multiply the KL during calculation of the reward.
28
+ """
29
+
30
+ @abc.abstractmethod
31
+ def update(self, kl_values: list[float]) -> float:
32
+ ...
33
+
34
+
35
+ class ConstantKLController(KLControllerBase):
36
+ """Constant KL Controller.
37
+
38
+ This controller maintains a fixed coefficient no matter what values it is updated
39
+ with.
40
+
41
+ Keyword Arguments:
42
+ kl_coef (:obj:`float`): The coefficient to multiply KL with when calculating the
43
+ reward.
44
+ model (nn.Module, optional): wrapped model that needs to be controlled.
45
+ Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will
46
+ be updated in-place.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ *,
52
+ kl_coef: float | None = None,
53
+ model: nn.Module | None = None,
54
+ ):
55
+ self.model = model
56
+ if model is not None and not hasattr(model, "kl_coef"):
57
+ raise AttributeError(
58
+ "Model input to ConstantKLController doesn't have attribute 'kl_coef'"
59
+ )
60
+ self.coef = kl_coef
61
+ if model is not None:
62
+ self.model.kl_coef = self.coef
63
+
64
+ def update(self, kl_values: list[float] = None) -> float:
65
+ if self.model is not None:
66
+ self.model.kl_coef = self.coef
67
+ return self.coef
68
+
69
+
70
+ class AdaptiveKLController(KLControllerBase):
71
+ """Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences".
72
+
73
+ Keyword Arguments:
74
+ init_kl_coef (:obj:`float`): The starting value of the coefficient.
75
+ target (:obj:`float`): The target KL value. When the observed KL is smaller, the
76
+ coefficient is decreased, thereby relaxing the KL penalty in the training
77
+ objective and allowing the model to stray further from the reference model.
78
+ When the observed KL is greater than the target, the KL coefficient is
79
+ increased, thereby pulling the model back towards the reference model.
80
+ horizon (int): Scaling factor to control how aggressively we update the
81
+ coefficient.
82
+ model (nn.Module, optional): wrapped model that needs to be controlled.
83
+ Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will
84
+ be updated in-place.
85
+
86
+ Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2
87
+ Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ *,
93
+ init_kl_coef: float,
94
+ target: float,
95
+ horizon: int,
96
+ model: nn.Module | None = None,
97
+ ):
98
+ self.model = model
99
+ self.coef = init_kl_coef
100
+ self.target = target
101
+ self.horizon = horizon
102
+ if model is not None:
103
+ self.model.kl_coef = self.coef
104
+
105
+ def update(self, kl_values: list[float]):
106
+ """Update ``self.coef`` adaptively.
107
+
108
+ Arguments:
109
+ kl_values (sequence of float): The current KL value between the newest policy and the initial
110
+ policy.
111
+
112
+ """
113
+ if kl_values is None:
114
+ raise ValueError(
115
+ f"The kl_values were not provided to {type(self)}. "
116
+ f"Make sure these values are provided for the scheduler to be updated "
117
+ f"accordingly. "
118
+ )
119
+ n_steps = len(kl_values)
120
+ # renormalize kls
121
+ kl_value = -torch.as_tensor(kl_values).mean() / self.coef
122
+ proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ
123
+ mult = 1 + proportional_error * n_steps / self.horizon
124
+ self.coef *= mult # βₜ₊₁
125
+ if self.model is not None:
126
+ self.model.kl_coef = self.coef
127
+ return self.coef
128
+
129
+
130
+ class RolloutFromModel:
131
+ """A class for performing rollouts with causal language models.
132
+
133
+ It is assumed that the model this class wraps takes as input tokenized text and
134
+ whose task is to predict the next word in a sentence having read the n previous
135
+ words.
136
+
137
+ Args:
138
+ model (transformers.Transformer): the model to be used. Should have a
139
+ :meth:`generate` method.
140
+ ref_model (transformers.Transformer): a frozen version of ``model``
141
+ where params are in their initial configuration. This is used to compute a
142
+ KL penalty for the reward, to stop the model from straying too far from the
143
+ reference model during training.
144
+ reward_model: (nn.Module, tensordict.nn.TensorDictModule): a model which, given
145
+ ``input_ids`` and ``attention_mask``, calculates rewards for each token and
146
+ end_scores (the reward for the final token in each sequence).
147
+ kl_coef: (:obj:`float`, optional): initial kl coefficient.
148
+ max_new_tokens (int, optional): the maximum length of the sequence.
149
+ Defaults to 50.
150
+ score_clip (:obj:`float`, optional): Scores from the reward model are clipped to the
151
+ range ``(-score_clip, score_clip)``. Defaults to 10.
152
+ kl_scheduler (KLControllerBase, optional): the KL coefficient scheduler.
153
+ num_steps (int, optional): number of steps between two optimization.
154
+
155
+ Examples:
156
+ >>> from tensordict.nn import TensorDictModule
157
+ >>> from torchrl.modules.models.llm import GPT2RewardModel
158
+ >>> from torchrl.data.llm.utils import RolloutFromModel
159
+ >>> from torchrl.data.llm.dataset import get_dataloader
160
+ >>> from torchrl.data.llm.prompt import PromptData
161
+ >>> from transformers import GPT2LMHeadModel
162
+ >>>
163
+ >>> dl = get_dataloader(
164
+ ... batch_size=4,
165
+ ... block_size=550,
166
+ ... tensorclass_type=PromptData,
167
+ ... device="cpu",
168
+ ... dataset_name="CarperAI/openai_summarize_tldr",
169
+ ... )
170
+ >>> model = GPT2LMHeadModel.from_pretrained("gpt2")
171
+ >>> # we load ref_model with random weights so it differs from model
172
+ >>> ref_model = GPT2LMHeadModel(GPT2LMHeadModel.config_class())
173
+ >>> reward_model = GPT2RewardModel(model_path="gpt2")
174
+ >>> rollout_from_model = RolloutFromModel(model, ref_model, reward_model)
175
+ >>>
176
+ >>> batch = next(dl)
177
+ >>> rollout = rollout_from_model.rollout_from_data(batch)
178
+ >>> rollout
179
+ TensorDict(
180
+ fields={
181
+ action: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False),
182
+ attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False),
183
+ input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False),
184
+ next: TensorDict(
185
+ fields={
186
+ attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False),
187
+ done: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
188
+ input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False),
189
+ reward: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False),
190
+ reward_kl: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False),
191
+ reward_raw: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
192
+ batch_size=torch.Size([4, 50]),
193
+ device=cpu,
194
+ is_shared=False),
195
+ sample_log_prob: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
196
+ batch_size=torch.Size([4, 50]),
197
+ device=cpu,
198
+ is_shared=False)
199
+
200
+ """
201
+
202
+ EOS_TOKEN_ID = 50256
203
+
204
+ def __init__(
205
+ self,
206
+ model,
207
+ ref_model,
208
+ reward_model,
209
+ kl_coef=0.1,
210
+ max_new_tokens=50,
211
+ score_clip=10.0,
212
+ kl_scheduler: KLControllerBase | None = None,
213
+ num_steps: int | None = None,
214
+ ):
215
+ if not _has_transformers:
216
+ raise ImportError(
217
+ "transformers module couldn't be found. Make sure it is installed in your "
218
+ "environment."
219
+ )
220
+ self.model = model
221
+ self.ref_model = ref_model
222
+ self.reward_model = reward_model
223
+ self.max_new_tokens = max_new_tokens
224
+ self.score_clip = score_clip
225
+ self.kl_coef = kl_coef
226
+ self.kl_scheduler = kl_scheduler
227
+ if num_steps is not None:
228
+ self._kl_queue = collections.deque(maxlen=num_steps)
229
+ else:
230
+ # we create a list. Value appended to it will be detached scalars so very cheap to store,
231
+ # even if the update is not called.
232
+ # The scheduler update will take care of erasing these values.
233
+ self._kl_queue = []
234
+
235
+ @torch.no_grad()
236
+ def rollout_from_data(self, batch):
237
+ generated, log_probs, log_ratio = self.generate(batch)
238
+ return self.create_rollout_td(batch, generated, log_probs, log_ratio)
239
+
240
+ @torch.no_grad()
241
+ def create_rollout_td(self, batch, generated, log_probs, log_ratio):
242
+ """A TensorDict wrapper for generated data.
243
+
244
+ This function takes a batch plus the generated tokens and replicates the
245
+ tensordict structure that would have been obtained from a rollout with a TorchRL
246
+ env that sampled one token each timestep.
247
+
248
+ Args:
249
+ batch (TensorDict): A batch of data containing the original prompt together with a field
250
+ "rindex" indicating the right index of the prompt.
251
+ generated (torch.Tensor): Tokenized prompt followed by generated tokens. This can be obtained
252
+ by calling the ``generate`` method.
253
+ log_probs (torch.Tensor): The log probabilities of the generated tokens. Can be obtained by
254
+ calling the ``generate`` method.
255
+ log_ratio (torch.Tensor): The log ratio of the probabilities of the generated tokens
256
+ according to the generative model and the reference model. Can be
257
+ obtained by calling the ``generate`` method.
258
+
259
+ Returns:
260
+ A :class:`~tensordict.TensorDict` with the following keys:
261
+
262
+ - ``"action"``: the sequence of actions (generated tokens)
263
+ - ``"input_ids"``: the input_ids passed to the generative model at each time
264
+ step.
265
+ - ``"attention_mask"``: the attention_masks passed to the generative model at
266
+ each time step
267
+ - ``"sample_log_prob"``: the log probability of each token during generation
268
+ - ``("next", "input_ids")``: the sequence of tokens after generation. Makes up
269
+ part of the inputs that will be used for generating the next token.
270
+ - ``("next", "attention_mask")``: updated attention_mask after token has been
271
+ generated. Passed to the generative model on the next time step
272
+ - ``("next", "terminated")``: Boolean array indicating whether we've reached a
273
+ terminal state (either because we generated EOS token or because we
274
+ reached the token limit)
275
+ - ``("next", "done")``: Boolean array indicating whether we've reached a
276
+ final state. Currently a copy of ``"terminated"``.
277
+ - ``("next", "reward")``: The reward received at each time step
278
+ - ``("next", "reward_raw")``: The raw reward from the reward model, without the
279
+ KL term. This is mainly for debugging and logging, it is not used in
280
+ training
281
+ - ``("next", "reward_kl")``: The KL term from the reward. This is mainly for
282
+ debugging and logging, it is not used in training.
283
+
284
+ """
285
+ rollout_generated = self._get_rollout_generated(generated, batch)
286
+ rollout_attention_mask = (rollout_generated != self.EOS_TOKEN_ID).bool()
287
+
288
+ done, terminated = self._get_done_status(generated, batch)
289
+ action = self._get_action(generated, batch)
290
+ end_scores, end_scores_labels = self._get_end_scores(
291
+ rollout_generated, rollout_attention_mask, batch
292
+ )
293
+
294
+ # the reward is zero except for the timestep where we reached a stopping condition
295
+ clipped_scores = torch.clip(
296
+ end_scores - end_scores_labels, -self.score_clip, self.score_clip
297
+ )
298
+ reward_raw = clipped_scores.unsqueeze(-1).unsqueeze(-1)
299
+ reward_raw = reward_raw * done
300
+ reward_kl = -self.kl_coef * log_ratio.unsqueeze(-1)
301
+ reward = reward_raw + reward_kl
302
+ td = {
303
+ "action": action,
304
+ "input_ids": rollout_generated[:, :-1].clone(),
305
+ "attention_mask": rollout_attention_mask[:, :-1].clone(),
306
+ "sample_log_prob": log_probs,
307
+ "next": {
308
+ "input_ids": rollout_generated[:, 1:].clone(),
309
+ "attention_mask": rollout_attention_mask[:, 1:].clone(),
310
+ "done": done,
311
+ "terminated": terminated,
312
+ "reward": reward,
313
+ "reward_raw": reward_raw,
314
+ "reward_kl": reward_kl,
315
+ },
316
+ }
317
+ self._kl_queue.append(reward_kl.detach().mean())
318
+ return TensorDict(
319
+ td, batch_size=done.shape[:2], device=generated.device
320
+ ).refine_names(..., "time")
321
+
322
+ def _get_rollout_generated(self, generated, batch):
323
+ # stack the individual timesteps during generation into a single tensor
324
+ rollout_generated = []
325
+ arange = torch.arange(generated.shape[1], device=generated.device)
326
+ for rindex, row in zip(batch.prompt_rindex, generated):
327
+ tokens = []
328
+ for i in range(self.max_new_tokens + 1):
329
+ tokens.append(torch.where(arange < rindex + i, row, self.EOS_TOKEN_ID))
330
+ rollout_generated.append(torch.stack(tokens))
331
+ rollout_generated = torch.stack(rollout_generated)
332
+ return rollout_generated
333
+
334
+ def _get_done_status(self, generated, batch):
335
+ # done is True when we either first sample an EOS token or reach the maximum number
336
+ # of generated tokens
337
+ done_idx = torch.minimum(
338
+ (generated != self.EOS_TOKEN_ID).sum(dim=-1) - batch.prompt_rindex,
339
+ torch.as_tensor(self.max_new_tokens) - 1,
340
+ )
341
+ truncated_idx = (
342
+ torch.as_tensor(self.max_new_tokens, device=generated.device).expand_as(
343
+ done_idx
344
+ )
345
+ - 1
346
+ )
347
+ zeros = torch.zeros(
348
+ done_idx.numel(),
349
+ self.max_new_tokens,
350
+ dtype=torch.bool,
351
+ device=generated.device,
352
+ )
353
+ truncated = zeros.scatter(-1, truncated_idx.unsqueeze(-1), 1).unsqueeze(-1)
354
+ done = zeros.scatter(-1, done_idx.unsqueeze(-1), 1).unsqueeze(-1)
355
+ terminated = (
356
+ done & ~truncated
357
+ ) # we assume that if it's not truncated, it was terminated
358
+ return truncated | terminated, terminated
359
+
360
+ def _get_action(self, generated, batch):
361
+ # the sequence of actions for each trajectory is just the generated token ids
362
+ action_idx = torch.arange(self.max_new_tokens, device=generated.device)
363
+ action_idx = action_idx + batch.prompt_rindex.unsqueeze(-1)
364
+ return generated.gather(-1, action_idx)
365
+
366
+ def _get_end_scores(self, rollout_generated, rollout_attention_mask, batch):
367
+ # calculate the reward for the finished sequence
368
+ _, end_scores = self.reward_model(
369
+ input_ids=rollout_generated[:, -1],
370
+ attention_mask=rollout_attention_mask[:, -1],
371
+ )
372
+ _, end_scores_labels = self.reward_model(
373
+ input_ids=batch.input_ids,
374
+ attention_mask=batch.attention_mask,
375
+ )
376
+ return end_scores, end_scores_labels
377
+
378
+ @classmethod
379
+ def _padded_right_to_left(cls, tensor, *, eos_token_id=None, dim=1):
380
+ if eos_token_id is None:
381
+ eos_token_id = cls.EOS_TOKEN_ID
382
+ mask = tensor != eos_token_id
383
+ out = torch.full_like(tensor, eos_token_id)
384
+ out[mask.flip(dim)] = tensor[mask]
385
+ return out
386
+
387
+ @classmethod
388
+ def _padded_left_to_right(
389
+ cls, tensor, *, sequence_length=None, eos_token_id=None, dim=1
390
+ ):
391
+ # some care must be taken here, because generated sequences may have both left
392
+ # and right padding, and also may not terminated early if all sequences in the
393
+ # batch reached EOS before reaching the token limit
394
+ if sequence_length is None:
395
+ sequence_length = tensor.size(dim)
396
+ if dim < 0:
397
+ dim = tensor.ndim + dim
398
+ if eos_token_id is None:
399
+ eos_token_id = cls.EOS_TOKEN_ID
400
+ mask = tensor != eos_token_id
401
+ # convert [0, 0, 1, 1, 0] to [0, 0, 1, 1, 1] to avoid right eos
402
+ mask = ~((~mask).to(torch.uint8).cumprod(dim).bool())
403
+ shape = list(mask.shape)
404
+ shape[dim] = sequence_length
405
+ out = torch.full(torch.Size(shape), eos_token_id, device=tensor.device)
406
+ index = (slice(None),) * dim + (slice(tensor.size(dim)),)
407
+ out[index][mask.flip(dim)] = tensor[mask]
408
+ return out
409
+
410
+ @property
411
+ def _default_conf(self):
412
+ from transformers import GenerationConfig
413
+
414
+ return GenerationConfig(
415
+ pad_token_id=self.EOS_TOKEN_ID,
416
+ max_new_tokens=self.max_new_tokens,
417
+ return_dict_in_generate=True,
418
+ output_scores=True,
419
+ do_sample=True,
420
+ )
421
+
422
+ def _get_scores(
423
+ self, scores: tuple, generated_tokens: Tensor = None, use_max=False, pad_to=None
424
+ ):
425
+ scores = torch.stack(scores, 1)
426
+ if scores.shape[1] != self.max_new_tokens:
427
+ scores = F.pad(
428
+ scores,
429
+ (0, 0, 0, self.max_new_tokens - scores.shape[1]),
430
+ value=float("-inf"),
431
+ )
432
+ scores = F.log_softmax(scores, dim=-1)
433
+ if use_max:
434
+ scores = scores.max(dim=-1).values
435
+ else:
436
+ index = generated_tokens.unsqueeze(-1)
437
+ scores = torch.gather(scores, dim=-1, index=index)
438
+ if pad_to is not None:
439
+ pad = pad_to - scores.shape[1]
440
+ return F.pad(scores, (0, pad), value=-float("inf"))
441
+ return scores
442
+
443
+ @staticmethod
444
+ def logprobs_of_labels(logits, labels):
445
+ """Log probabilities of the labels.
446
+
447
+ These are calculated from the logits. The labels (token ids) are used to index
448
+ the logits along the relevant dimension.
449
+ """
450
+ logprobs = F.log_softmax(logits, dim=-1)
451
+ logprobs_labels = torch.gather(logprobs, dim=-1, index=labels.unsqueeze(-1))
452
+ return logprobs_labels.squeeze(-1)
453
+
454
+ @torch.no_grad()
455
+ def _log_ratio(self, generated, prompt_rindex):
456
+ # get the scores and normalise for log probabilities
457
+ attention_mask = (generated != self.EOS_TOKEN_ID).bool()
458
+ logits = self.model(
459
+ input_ids=generated, attention_mask=attention_mask, return_dict=True
460
+ ).logits
461
+ logprobs = self.logprobs_of_labels(logits[:, :-1], generated[:, 1:])
462
+ ref_logits = self.ref_model(
463
+ input_ids=generated.to(self.ref_model.device),
464
+ attention_mask=attention_mask.to(self.ref_model.device),
465
+ return_dict=True,
466
+ ).logits.to(logits.device)
467
+ ref_logprobs = self.logprobs_of_labels(ref_logits[:, :-1], generated[:, 1:])
468
+ log_ratio = logprobs - ref_logprobs
469
+ log_ratio = log_ratio.masked_fill(~attention_mask[:, :-1], 0)
470
+ log_ratio = torch.stack(
471
+ [
472
+ row[rindex - 1 : rindex + self.max_new_tokens - 1]
473
+ for row, rindex in zip(log_ratio, prompt_rindex)
474
+ ],
475
+ dim=0,
476
+ )
477
+ return log_ratio
478
+
479
+ def _get_generated_tokens(self, generated, rindex):
480
+ # extracts the generated tokens from the full sequence of prompt + generated
481
+ idx = torch.arange(generated.shape[1], device=generated.device)
482
+ rindex = rindex.unsqueeze(-1)
483
+ mask = (idx >= rindex) & (idx < rindex + self.max_new_tokens)
484
+ return generated[mask].reshape(-1, self.max_new_tokens)
485
+
486
+ @torch.no_grad()
487
+ def generate(self, batch: PromptData, generation_config=None):
488
+ """Generates a sequence of tokens from a batch of data sampled from the data collector.
489
+
490
+ Args:
491
+ batch (PromptData): the data to be used. Must have ``input_ids``
492
+ and ``prompt_rindex`` fields.
493
+ generation_config (GenerationConfig, optional): the configuration for the
494
+ call to generate.
495
+
496
+ Returns:
497
+ generated (torch.Tensor): a [B x (Ti +To)] sequence of integers (tokens),
498
+ where Ti is the length of the input sequence and To is the length
499
+ of the generated sequence.
500
+ log_probs_gen: the log-probabilities of the token generated.
501
+ log_ratio: the log ratio between probabilities under the generative
502
+ model and the frozen version.
503
+
504
+ """
505
+ input_ids = batch.mask_label().input_ids
506
+
507
+ # move padding tokens to left pad
508
+ # huggingface models expect left padding for generation
509
+ input_ids = self._padded_right_to_left(input_ids)
510
+
511
+ # generate and capture scores
512
+ if generation_config is None:
513
+ generation_config = self._default_conf
514
+
515
+ attention_mask = (input_ids != self.EOS_TOKEN_ID).bool()
516
+ outputs = self.model.generate(
517
+ input_ids=input_ids,
518
+ attention_mask=attention_mask,
519
+ generation_config=generation_config,
520
+ )
521
+ samples = outputs.sequences
522
+
523
+ # we'll insert generated tokens into a tensor prepopulated with padding tokens,
524
+ # thereby moving back to right padding for reward model
525
+ generated = self._padded_left_to_right(
526
+ samples,
527
+ sequence_length=input_ids.shape[1] + self.max_new_tokens,
528
+ eos_token_id=self.EOS_TOKEN_ID,
529
+ )
530
+ generated_tokens = self._get_generated_tokens(generated, batch.prompt_rindex)
531
+ # get the scores and normalise for log probabilities
532
+ log_probs_gen = self._get_scores(outputs.scores, generated_tokens)
533
+
534
+ log_ratio = self._log_ratio(generated, batch.prompt_rindex)
535
+ return generated, log_probs_gen, log_ratio
536
+
537
+ def step_scheduler(self):
538
+ # recover true kl
539
+ self.kl_coef = self.kl_scheduler.update(self._kl_queue)
540
+ if isinstance(self._kl_queue, (list, collections.deque)):
541
+ # remove all values
542
+ while len(self._kl_queue):
543
+ self._kl_queue.remove(self._kl_queue[0])
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .hash import BinaryToDecimal, RandomProjectionHash, SipHash
7
+ from .query import HashToInt, QueryModule
8
+ from .tdstorage import TensorDictMap, TensorMap
9
+ from .tree import MCTSForest, Tree
10
+
11
+ __all__ = [
12
+ "BinaryToDecimal",
13
+ "RandomProjectionHash",
14
+ "SipHash",
15
+ "HashToInt",
16
+ "QueryModule",
17
+ "TensorDictMap",
18
+ "TensorMap",
19
+ "MCTSForest",
20
+ "Tree",
21
+ ]