torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,352 @@
1
+ # Expert Iteration: Learning from Top-K Responses
2
+
3
+ This is an implementation of Expert Iteration for language models, built on top of TorchRL.
4
+ Expert Iteration is a reinforcement learning-like method that learns from the best performing responses in a batch, rather than using all responses equally.
5
+
6
+ The idea of these scripts is extremely simple:
7
+ - Collect some trajectories with a pre-trained version of the model;
8
+ - Select the top-K best trajectories of the batch (based on their reward);
9
+ - Train the model using SFT of these trajectories;
10
+ - Update the inference model.
11
+
12
+ ## Overview
13
+
14
+ The version of Expert Iteration presented here has the following features:
15
+
16
+ - **Top-K Selection**: Only the best performing responses are used for training, improving sample efficiency
17
+ - **KL Regularization**: Maintains model quality by penalizing divergence from a reference model
18
+ - **Multi-GPU support** with efficient device management
19
+ - **Mixed precision training** for memory efficiency
20
+ - **Gradient accumulation** for larger effective batch sizes
21
+ - **Automatic checkpointing** and comprehensive logging with Weights & Biases
22
+ - **Hydra configuration system** for easy experimentation
23
+ - **Asynchronous training support** with Ray for improved throughput
24
+ - **Prioritized sampling** such that samples with higher rewards have more chances of being sampled
25
+
26
+ ## Key Differences from GRPO and other RL algorithms
27
+
28
+ ### 1. Top-K Reward Selection
29
+
30
+ Unlike other RL post-training recipes (e.g. GRPO) which uses all responses,
31
+ Expert Iteration employs a `TopKRewardSelector` transform that:
32
+
33
+ - Collects multiple responses for each prompt (controlled by `env.repeats`)
34
+ - Selects only the top-k responses based on reward (controlled by `train.topk_size`)
35
+ - Writes only the best responses to the replay buffer, improving training efficiency
36
+
37
+ ```python
38
+ # Example: For each prompt, generate 32 responses but only keep the best 4
39
+ env.repeats = 32 # Generate 32 responses per prompt
40
+ train.topk_size = 4 # Keep only the top 4 responses
41
+ ```
42
+
43
+ ### 2. KL Divergence Handling
44
+
45
+ Expert Iteration uses a different approach to KL regularization:
46
+ - **No KL in reward**: Unlike GRPO's `KLRewardTransform`, Expert Iteration doesn't add KL penalties to the reward signal
47
+ - **KL in loss function**: KL divergence is computed directly in the loss function using `SFTLoss` with `kl_to_ref_coeff`
48
+ - **Reference log probabilities**: The `RetrieveLogProb` transform extracts reference model log probabilities for KL computation
49
+
50
+ ```python
51
+ # KL is handled in the loss function, not in the reward
52
+ loss_fn = SFTLoss(
53
+ actor_network=policy_training,
54
+ kl_to_ref_coeff=cfg.train.kl_to_ref_coeff, # KL penalty coefficient
55
+ tokenizer=train_tokenizer,
56
+ tokenizer_kwargs={"chat_template_name": "qwen"},
57
+ device=train_device,
58
+ )
59
+ ```
60
+
61
+ ### 3. Reduced Weight Updates
62
+
63
+ Expert Iteration can afford fewer policy weight updates due to its selective training approach. One can freely choose longer intervals for the `update_weight_frequency` (e.g., every 100 or more optimization steps).
64
+
65
+ ## Installation
66
+
67
+ Install dependencies:
68
+ ```bash
69
+ # GSM8K deps
70
+ pip install -r sota-implementations/expert-iteration/requirements_gsm8k.txt
71
+ # IFEval deps
72
+ pip install -r sota-implementations/expert-iteration/requirements_ifeval.txt
73
+ ```
74
+
75
+ ## Hardware Requirements
76
+
77
+ - At least 3 CUDA-capable GPUs:
78
+ - Training device(s)
79
+ - vLLM inference device
80
+ - Reference model device
81
+
82
+ ### Device Management
83
+
84
+ The number of devices for each model component is specified using `num_devices`:
85
+
86
+ ```bash
87
+ train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
88
+ ```
89
+
90
+ This approach:
91
+
92
+ - Automatically handles device allocation;
93
+ - Works correctly in both sync and async modes;
94
+ - Prevents device conflicts between model components;
95
+ - Is more portable across different machine configurations.
96
+
97
+ ## Configuration
98
+
99
+ The training configuration is managed through Hydra. There are two main configuration files:
100
+ - `config/ei_gsm8k.yaml`: Default configuration for GSM8K tasks (default)
101
+ - `config/ei_ifeval.yaml`: Configuration optimized for IFEval tasks
102
+
103
+ ## Usage
104
+
105
+ ### Basic Training
106
+
107
+ There are two training modes available:
108
+
109
+ #### Synchronous Mode (Default)
110
+ ```bash
111
+ python sota-implementations/expert-iteration/expert-iteration-sync.py mode=sync train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
112
+ ```
113
+
114
+ #### Asynchronous Mode (Recommended)
115
+ ```bash
116
+ python sota-implementations/expert-iteration/expert-iteration-async.py mode=async train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
117
+ ```
118
+
119
+ The key difference between sync and async modes is how data collection and optimization are handled:
120
+
121
+ **Synchronous Mode (expert-iteration-sync.py)**:
122
+ ```python
123
+ # Three nested loops:
124
+ for data in collector: # Data collection loop
125
+ for epoch in range(epochs): # Epoch loop
126
+ for batch in replay_buffer: # Buffer consumption loop
127
+ # Optimize on batch (only top-k responses)
128
+ loss = loss_fn(batch)
129
+ loss.backward()
130
+ optimizer.step()
131
+ # Weight update
132
+ weight_updater.push_weights(policy_training)
133
+ ```
134
+
135
+ **Asynchronous Mode (expert-iteration-async.py)**:
136
+ ```python
137
+ # Start data collection in background
138
+ collector.start()
139
+
140
+ # Single optimization loop
141
+ for step in range(total_steps):
142
+ # Sample and optimize (only top-k responses)
143
+ batch = replay_buffer.sample()
144
+ loss = loss_fn(batch)
145
+ loss.backward()
146
+ optimizer.step()
147
+ # Update weights once in a while
148
+ if step % weight_update_frequency == 0:
149
+ weight_updater.push_weights(policy_training)
150
+ ```
151
+
152
+ Key differences:
153
+ 1. **Data Collection**:
154
+ - Sync: Data collection and optimization happen sequentially (unless `train.sync_iter=false`)
155
+ - Async: Data collection runs in background while optimization happens
156
+
157
+ 2. **Buffer Size**:
158
+ - Sync: Buffer size must equal the batch size returned by collector
159
+ - Async: Buffer can be larger than the batch size, allowing for more diverse sampling
160
+
161
+ 3. **Data Processing**:
162
+ - Sync: Processes the same data multiple times (epochs)
163
+ - Async: Each piece of data is processed a non-deterministic number of times
164
+
165
+ 4. **Weight updates**:
166
+ - Sync: Weights are updated before every collection of data
167
+ - Async: Weights are updated at a given interval (in gradient steps)
168
+
169
+ The async mode offers better performance by:
170
+
171
+ - Running data collection and optimization concurrently
172
+ - More efficient GPU utilization
173
+ - Reduced memory overhead
174
+ - Better throughput
175
+ - More flexible buffer management
176
+
177
+ ### Top-K Configuration
178
+
179
+ The key parameters for top-k selection are:
180
+
181
+ ```yaml
182
+ env:
183
+ repeats: 32 # Number of responses to generate per prompt
184
+ train:
185
+ topk_size: 4 # Number of best responses to keep for training
186
+ ```
187
+
188
+ **Recommendations**:
189
+
190
+ - Higher `repeats` values provide more diversity but increase computation
191
+ - `topk_size` should be 10-20% of `repeats` for good selection pressure
192
+ - Typical values: `repeats=32, topk_size=4` or `repeats=64, topk_size=8`
193
+
194
+ It is critical to have a reward function that is granular enough for `top-k` to be of any use: a binary reward will have a median value
195
+ will not provide much insight into what outputs outrank others.
196
+
197
+ ### KL Regularization
198
+
199
+ KL divergence is controlled via the `kl_to_ref_coeff` parameter:
200
+
201
+ ```yaml
202
+ train:
203
+ kl_to_ref_coeff: 1.0 # KL penalty coefficient
204
+ ```
205
+
206
+ **Recommendations**:
207
+
208
+ - Start with `kl_to_ref_coeff=1.0` and adjust based on model quality.
209
+ - Higher values keep the model closer to the reference.
210
+ - Lower values allow more exploration but risk quality degradation.
211
+ **Note**: Expert iteration is a rather simple algorithm with little convergence guarantees. Using high KL regularization coefficient and setting it to lower values progressively is advisable.
212
+
213
+ ### Run with IFEval Config
214
+
215
+ ```bash
216
+ python expert-iteration-sync.py mode=sync --config-name ei_ifeval
217
+ ```
218
+
219
+ ### Override Config Values
220
+
221
+ ```bash
222
+ # Change dataset
223
+ python expert-iteration-sync.py mode=sync env.dataset=ifeval
224
+
225
+ # Modify top-k parameters
226
+ python expert-iteration-sync.py mode=sync env.repeats=64 train.topk_size=8
227
+
228
+ # Adjust KL regularization
229
+ python expert-iteration-sync.py mode=sync train.kl_to_ref_coeff=0.5
230
+
231
+ # Change model
232
+ python expert-iteration-sync.py mode=sync model.name=meta-llama/Llama-2-7b-hf
233
+ ```
234
+
235
+ ### Hyperparameter Sweeps
236
+
237
+ ```bash
238
+ # Top-k size sweep
239
+ python expert-iteration-sync.py mode=sync --multirun train.topk_size=2,4,8
240
+
241
+ # KL coefficient sweep
242
+ python expert-iteration-sync.py mode=sync --multirun train.kl_to_ref_coeff=0.5,1.0,2.0
243
+
244
+ # Multiple parameters
245
+ python expert-iteration-sync.py mode=sync --multirun \
246
+ train.topk_size=4,8 \
247
+ train.kl_to_ref_coeff=0.5,1.0
248
+ ```
249
+
250
+ Don't forget to set the number of value of `train.total_dialog_turns` to a reasonable value!
251
+
252
+ ## Monitoring
253
+
254
+ Training progress is logged to Weights & Biases with the following metrics:
255
+
256
+ - **Reward**: Average reward of responses in the buffer
257
+ - **Sequence length**: Average length of generated responses
258
+ - **KL divergence**: KL divergence from reference model
259
+ - **Loss metrics**: SFT loss, KL loss, and total loss
260
+ - **Gradient norm**: Gradient clipping statistics
261
+ - **Throughput metrics**: Steps per second, gradient steps per write
262
+ - **Buffer statistics**: Write count, policy version tracking
263
+
264
+ ### Collector Logging
265
+
266
+ The collector is given a `RemoteDataLogger` postproc hook that passes the data to a Ray queue, consumed by the training node for logging.
267
+
268
+ This approach ensures:
269
+ - Single wandb run with all metrics (training + collector)
270
+ - No conflicts between multiple wandb loggers
271
+ - Centralized logging through the main process
272
+
273
+ The collector logs the following metrics:
274
+ - **Collector rewards**: Mean, std, min, max of rewards from collected data
275
+ - **Response lengths**: Mean, std, min, max of response lengths
276
+ - **Policy versions**: Mean, min, max of policy versions (for async mode)
277
+ - **Time elapsed**: Time between collection batches
278
+
279
+ To add new collector metrics, modify the `log_data` method in `RemoteDataLogger` in `ei_utils.py`.
280
+
281
+ ## Checkpointing
282
+
283
+ Checkpoints are saved every `train.checkpoint_frequency` steps and contain:
284
+ - Model state
285
+ - Optimizer state
286
+ - Gradient scaler state (for mixed precision)
287
+ - Full configuration
288
+
289
+ ## Debugging Out-of-memory issues
290
+
291
+ - **vLLM**: Reduce `inference_model.gpu_memory_utilization=FRACTION` or number of environments run in parallel (`env.num_envs=N`)
292
+ - **Reference model**: If the reference model computation is memory-intensive, reduce the number of environments (`env.num_envs=N`) run in parallel
293
+ - **Training**: Reduce batch size (`train.optim_batch_size`)
294
+ - **Top-k**: Reduce `env.repeats` to generate fewer responses per prompt
295
+
296
+ ## Directory Structure
297
+
298
+ ```
299
+ sota-implementations/expert-iteration/
300
+ ├── config/
301
+ │ ├── ei_gsm8k.yaml # Main configuration file
302
+ │ ├── ei_ifeval.yaml # Configuration for IFEval task
303
+ │ └── mode/
304
+ │ ├── async.yaml # Async mode settings
305
+ │ └── sync.yaml # Sync mode settings
306
+ ├── expert-iteration-sync.py # Synchronous training script
307
+ ├── expert-iteration-async.py # Asynchronous training script
308
+ ├── ei_utils.py # Utility functions
309
+ └── README.md # This file
310
+ ```
311
+
312
+ ## Output Structure
313
+
314
+ Each run creates a timestamped directory under `outputs/`:
315
+ ```
316
+ outputs/
317
+ └── YYYY-MM-DD/
318
+ └── HH-MM-SS/
319
+ ├── checkpoints/
320
+ │ └── checkpoint_*.pt
321
+ └── .hydra/
322
+ └── config.yaml
323
+ ```
324
+
325
+ For hyperparameter sweeps, outputs are stored under `multirun/`.
326
+
327
+ ## Theoretical Background
328
+
329
+ Expert Iteration is based on the principle of learning from the best examples rather than all examples. The key insights are:
330
+
331
+ 1. **Selective Learning**: By only training on high-quality responses, the model learns more efficiently
332
+ 2. **Quality over Quantity**: A smaller dataset of high-quality examples can be more effective than a larger dataset of mixed quality
333
+ 3. **Iterative Improvement**: Each iteration produces better responses, which become the training data for the next iteration
334
+
335
+ This approach is particularly effective for language model training where:
336
+
337
+ - Response quality varies significantly
338
+ - High-quality responses are rare but valuable
339
+ - The model can learn to imitate good responses more effectively than avoid bad ones
340
+
341
+ In theory, one could use Exp. It. with samples gathered from other LLMs or expert datasets, although convergence will be harder to control due to
342
+ the inability to use the KL regularization factor.
343
+
344
+ ## Comparison with Other Methods
345
+
346
+ | Method | Training Data | KL Handling | Update Frequency |
347
+ |--------|---------------|-------------|------------------|
348
+ | **Expert Iteration** | Top-k responses | In loss function | Reduced (can be less frequent) |
349
+ | **GRPO** | All responses | In reward / loss | Standard |
350
+ | **DPO** | Preference pairs | Implicit in loss | Standard |
351
+
352
+ Expert Iteration's key advantage is its sample efficiency - by focusing on the best responses, it can achieve better performance with fewer training examples and less frequent policy updates.