torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,843 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+
9
+ import time
10
+ import warnings
11
+ from typing import Any, Literal
12
+
13
+ import torch
14
+ from omegaconf import DictConfig
15
+ from torch import device as torch_device, dtype as torch_dtype
16
+
17
+ from torchrl._utils import logger as torchrl_logger, timeit
18
+ from torchrl.envs.llm import AddThinkingPrompt, GSM8KEnv, KLRewardTransform, RetrieveKL
19
+ from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
20
+ from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
21
+ from torchrl.weight_update.llm import VLLMWeightSyncScheme
22
+ from transformers.models.auto.modeling_auto import AutoModelForCausalLM
23
+ from transformers.tokenization_utils import PreTrainedTokenizer
24
+
25
+
26
+ def check_grpo_dependencies() -> None:
27
+ """Check for required GRPO dependencies and provide helpful error messages.
28
+
29
+ This function checks for critical dependencies needed for GRPO training and
30
+ provides installation instructions for missing packages.
31
+ """
32
+ missing_packages = []
33
+ missing_optional = []
34
+
35
+ # Core required packages
36
+ required_packages = {
37
+ "datasets": "pip install datasets",
38
+ "peft": "pip install peft",
39
+ "wandb": "pip install wandb",
40
+ "vllm": "pip install vllm",
41
+ "transformers": "pip install transformers",
42
+ "accelerate": "pip install accelerate",
43
+ "ray": "pip install ray",
44
+ "tqdm": "pip install tqdm",
45
+ }
46
+
47
+ # Optional but recommended packages
48
+ optional_packages = {
49
+ "flash_attn": "pip install flash-attn",
50
+ "bitsandbytes": "pip install bitsandbytes",
51
+ "xformers": "pip install xformers",
52
+ }
53
+
54
+ # Check required packages
55
+ for package, install_cmd in required_packages.items():
56
+ try:
57
+ __import__(package)
58
+ except ImportError:
59
+ missing_packages.append((package, install_cmd))
60
+
61
+ # Check optional packages
62
+ for package, install_cmd in optional_packages.items():
63
+ try:
64
+ __import__(package)
65
+ except ImportError:
66
+ missing_optional.append((package, install_cmd))
67
+
68
+ # Report missing required packages
69
+ if missing_packages:
70
+ error_msg = (
71
+ "Missing required packages for GRPO training:\n"
72
+ + "\n".join(f" - {pkg}: {cmd}" for pkg, cmd in missing_packages)
73
+ + "\n\nYou can install all GRPO dependencies with:\n"
74
+ + " pip install torchrl[grpo]\n"
75
+ + "or install individual packages as shown above."
76
+ )
77
+ raise ImportError(error_msg)
78
+
79
+ # Report missing optional packages as warnings
80
+ if missing_optional:
81
+ warning_msg = (
82
+ "Missing optional packages that may improve GRPO performance:\n"
83
+ + "\n".join(f" - {pkg}: {cmd}" for pkg, cmd in missing_optional)
84
+ + "\n\nThese packages are optional but recommended for optimal performance."
85
+ )
86
+ warnings.warn(warning_msg, UserWarning, stacklevel=2)
87
+
88
+ torchrl_logger.info("✓ All required GRPO dependencies are available")
89
+
90
+
91
+ def get_tokenizer(cfg: DictConfig) -> PreTrainedTokenizer:
92
+ from transformers import AutoTokenizer
93
+
94
+ model_name = cfg.model.name
95
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
96
+ # tokenizer.eos_token = "<|im_end|>"
97
+ if tokenizer.pad_token == tokenizer.eos_token:
98
+ tokenizer.pad_token = "PAD"
99
+ tokenizer.padding_side = "left"
100
+ return tokenizer
101
+
102
+
103
+ def get_train_model(
104
+ cfg: DictConfig,
105
+ devices: list[int] | None = None,
106
+ ) -> tuple[TransformersWrapper, PreTrainedTokenizer]:
107
+ """Creates and configures the training model with LoRA adapters.
108
+
109
+ This function initializes the main training model with LoRA adapters and other
110
+ training-specific configurations like gradient checkpointing. The model is wrapped
111
+ in a TransformersWrapper for policy training.
112
+
113
+ Args:
114
+ cfg (DictConfig): The hydra configuration object containing model and training settings.
115
+ Expected to have train_model section with LoRA, quantization, and other
116
+ training-specific parameters.
117
+
118
+ Returns:
119
+ tuple[TransformersWrapper, PreTrainedTokenizer]:
120
+ - policy_training: The wrapped training model
121
+ - train_tokenizer: The tokenizer for the model
122
+
123
+ Raises:
124
+ RuntimeError: If CUDA is not available or if device allocation fails
125
+ """
126
+ torchrl_logger.info("Creating train model")
127
+
128
+ # Set model dtype explicitly
129
+ model_dtype = getattr(torch, cfg.train_model.torch_dtype)
130
+
131
+ # Get configured devices or default to [0]
132
+ train_devices = devices if devices is not None else [0]
133
+
134
+ # Create max_memory dict - set 0 memory for GPUs we don't want to use
135
+ max_memory = {}
136
+ for i in range(torch.cuda.device_count()):
137
+ if i in train_devices:
138
+ max_memory[i] = "24GiB" # Allow max memory for devices we want to use
139
+ else:
140
+ max_memory[i] = "0GiB" # No memory for other devices
141
+ max_memory["cpu"] = "24GiB" # Allow CPU memory as fallback
142
+
143
+ # Let HF handle distribution with max_memory
144
+ device_map = "balanced" if len(train_devices) > 1 else f"cuda:{train_devices[0]}"
145
+ train_model, train_tokenizer = get_hf_model(
146
+ cfg.model.name,
147
+ device_map=device_map,
148
+ max_memory=max_memory,
149
+ lora=cfg.train_model.lora.enabled,
150
+ lora_r=cfg.train_model.lora.r,
151
+ lora_alpha=cfg.train_model.lora.alpha,
152
+ lora_dropout=cfg.train_model.lora.dropout,
153
+ gradient_checkpointing=cfg.train_model.gradient_checkpointing,
154
+ quantize=cfg.train_model.quantization.enabled,
155
+ torch_dtype=model_dtype,
156
+ attn_implementation=cfg.train_model.attn_implementation,
157
+ compile=cfg.model.compile,
158
+ )
159
+
160
+ # Force all model parameters to the same dtype
161
+ for param in train_model.parameters():
162
+ param.data = param.data.to(model_dtype)
163
+
164
+ policy_training = TransformersWrapper(
165
+ train_model,
166
+ tokenizer=train_tokenizer,
167
+ input_mode="tokens" if not cfg.env.reasoning else "history",
168
+ generate=False,
169
+ return_log_probs=True,
170
+ pad_output=False,
171
+ device=torch.device("cuda:0"),
172
+ # Enable packing when cfg.train.packing=True by disabling padding
173
+ pad_model_input=not cfg.train.packing,
174
+ )
175
+ # Ensure model stays in eval mode after wrapping
176
+ policy_training.model.eval()
177
+ policy_training.model.train(False)
178
+ return policy_training, train_tokenizer
179
+
180
+
181
+ def get_inference_model(
182
+ cfg: DictConfig,
183
+ devices: list[int] | None = None,
184
+ make_ray_worker: bool = True,
185
+ tokenizer: PreTrainedTokenizer | None = None,
186
+ ) -> vLLMWrapper:
187
+ """Creates the vLLM-based inference model for fast generation.
188
+
189
+ This function initializes a vLLM model server for efficient inference and wraps
190
+ it in a vLLMWrapper for policy inference. vLLM provides optimized generation
191
+ with better throughput than standard HuggingFace generation.
192
+
193
+ Args:
194
+ cfg (DictConfig): The hydra configuration object containing model settings.
195
+ Expected to have inference_model section with vLLM-specific parameters
196
+ like gpu_memory_utilization and generation settings.
197
+ devices (list[int], optional): The devices to use for the inference model. Default: `None`.
198
+ make_ray_worker (bool, optional): Whether to make a ray worker. Default: `True`.
199
+ tokenizer (PreTrainedTokenizer, optional): The tokenizer to use with the inference model. Default: `None`.
200
+
201
+ Returns:
202
+ vLLMWrapper: The wrapped vLLM model ready for inference.
203
+
204
+ Raises:
205
+ AssertionError: If the vLLM server or model initialization fails
206
+ """
207
+ from torchrl.modules.llm.backends.vllm import AsyncVLLM
208
+
209
+ num_devices = cfg.inference_model.num_devices
210
+ if num_devices is None:
211
+ vllm_devices = devices if devices is not None else [1]
212
+ num_devices = len(vllm_devices)
213
+ else:
214
+ vllm_devices = None
215
+ torchrl_logger.info(
216
+ f"Creating AsyncVLLM inference model with num_devices={num_devices}, devices={vllm_devices}"
217
+ )
218
+
219
+ model_name = cfg.model.name
220
+
221
+ # Use AsyncVLLM for better performance and async processing
222
+ verbose = getattr(cfg.inference_model, "verbose", True)
223
+ compile_model = getattr(
224
+ cfg.inference_model, "compile", False
225
+ ) # Disabled by default for GRPO
226
+
227
+ # Build parameters dict for AsyncVLLM with all config options
228
+ inference_params = {
229
+ "model_name": model_name,
230
+ "num_devices": 1,
231
+ "num_replicas": num_devices,
232
+ "gpu_memory_utilization": cfg.inference_model.gpu_memory_utilization,
233
+ "enforce_eager": cfg.inference_model.enforce_eager,
234
+ "verbose": verbose,
235
+ "compile": compile_model,
236
+ }
237
+
238
+ # CRITICAL FIX: Configure attention implementation to prevent Flash Attention errors
239
+ # vLLM doesn't accept attn_implementation directly through AsyncEngineArgs
240
+ # Instead, we set the VLLM_ATTENTION_BACKEND environment variable
241
+ if hasattr(cfg.inference_model, "attn_implementation"):
242
+ import os
243
+
244
+ attn_impl = cfg.inference_model.attn_implementation
245
+
246
+ # Map common attention implementations to vLLM backend names
247
+ attn_backend_map = {
248
+ "flash_attention_2": "FLASH_ATTN",
249
+ "flash_attn": "FLASH_ATTN",
250
+ "sdpa": "TORCH_SDPA",
251
+ "torch_sdpa": "TORCH_SDPA",
252
+ "xformers": "XFORMERS",
253
+ }
254
+
255
+ vllm_backend = attn_backend_map.get(attn_impl, attn_impl.upper())
256
+ os.environ["VLLM_ATTENTION_BACKEND"] = vllm_backend
257
+
258
+ torchrl_logger.info(
259
+ f"Setting VLLM_ATTENTION_BACKEND={vllm_backend} (from config: {attn_impl})"
260
+ )
261
+
262
+ # Handle FP32 output configuration
263
+ if hasattr(cfg.inference_model, "enable_fp32_output"):
264
+ enable_fp32 = cfg.inference_model.enable_fp32_output
265
+ if enable_fp32:
266
+ os.environ["VLLM_ENABLE_FP32_OUTPUT"] = "1"
267
+ torchrl_logger.info(
268
+ "Enabled FP32 output for vLLM (VLLM_ENABLE_FP32_OUTPUT=1). "
269
+ "This will use FP32 for the final output layer if the model supports it."
270
+ )
271
+ # Add to inference params so it gets passed to AsyncVLLM
272
+ inference_params["enable_fp32_output"] = enable_fp32
273
+
274
+ # Add other common vLLM parameters from config if present
275
+ optional_vllm_params = [
276
+ "max_model_len",
277
+ "dtype",
278
+ "trust_remote_code",
279
+ "seed",
280
+ "swap_space",
281
+ "cpu_offload_gb",
282
+ "enable_prefix_caching",
283
+ "tensor_parallel_size",
284
+ "pipeline_parallel_size",
285
+ ]
286
+
287
+ for param in optional_vllm_params:
288
+ if hasattr(cfg.inference_model, param):
289
+ value = getattr(cfg.inference_model, param)
290
+ if value is not None:
291
+ inference_params[param] = value
292
+
293
+ # Handle torch_dtype specifically (convert string to torch dtype)
294
+ if hasattr(cfg.inference_model, "torch_dtype"):
295
+ dtype_str = cfg.inference_model.torch_dtype
296
+ if dtype_str is not None:
297
+ if isinstance(dtype_str, str):
298
+ inference_params["dtype"] = getattr(torch, dtype_str)
299
+ else:
300
+ inference_params["dtype"] = dtype_str
301
+
302
+ inference_server = AsyncVLLM.from_pretrained(**inference_params)
303
+ assert inference_server is not None
304
+ if tokenizer is None:
305
+ from transformers import AutoTokenizer
306
+
307
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
308
+ if tokenizer.pad_token == tokenizer.eos_token:
309
+ tokenizer.pad_token = "PAD"
310
+ tokenizer.padding_side = "left"
311
+ policy = vLLMWrapper(
312
+ inference_server,
313
+ input_mode="history",
314
+ chat_template_name="qwen",
315
+ return_log_probs=not cfg.env.reasoning,
316
+ tokenizer=tokenizer,
317
+ pad_output=False,
318
+ generate_kwargs={
319
+ "max_tokens": cfg.inference_model.max_tokens,
320
+ "include_stop_str_in_output": cfg.inference_model.include_stop_str_in_output,
321
+ "temperature": cfg.inference_model.temperature,
322
+ "top_p": cfg.inference_model.top_p,
323
+ },
324
+ )
325
+ assert policy.model is not None
326
+ return policy
327
+
328
+
329
+ def get_ref_model(
330
+ cfg: DictConfig,
331
+ tokenizer: PreTrainedTokenizer,
332
+ devices: list[int] | None = None,
333
+ ) -> TransformersWrapper:
334
+ """Creates the reference model for KL penalty computation.
335
+
336
+ This function initializes a frozen copy of the base model to serve as the
337
+ reference model for KL divergence computation. The reference model is typically
338
+ quantized and does not require gradient computation.
339
+
340
+ Args:
341
+ cfg (DictConfig): The hydra configuration object containing model settings.
342
+ Expected to have ref_model section with quantization and attention settings.
343
+ tokenizer (PreTrainedTokenizer): The tokenizer to use with the reference model.
344
+
345
+ Returns:
346
+ TransformersWrapper: The wrapped reference model in eval mode with detached weights.
347
+ """
348
+ from tensordict import TensorDict
349
+
350
+ torchrl_logger.info("Creating ref model")
351
+
352
+ # Get configured devices or default to [2]
353
+ if cfg.ref_model.num_devices is None:
354
+ ref_devices = devices if devices is not None else [2]
355
+ else:
356
+ ref_devices = list(range(cfg.ref_model.num_devices))
357
+
358
+ # Create max_memory dict - set 0 memory for GPUs we don't want to use
359
+ max_memory = {}
360
+ for i in range(torch.cuda.device_count()):
361
+ if i in ref_devices:
362
+ max_memory[i] = "24GiB" # Allow max memory for devices we want to use
363
+ else:
364
+ max_memory[i] = "0GiB" # No memory for other devices
365
+ max_memory["cpu"] = "24GiB" # Allow CPU memory as fallback
366
+
367
+ # Let HF handle distribution with max_memory
368
+ device_map = "balanced" if len(ref_devices) > 1 else f"cuda:{ref_devices[0]}"
369
+ model_name = cfg.model.name
370
+
371
+ ref_model = get_hf_model(
372
+ model_name,
373
+ device_map=device_map,
374
+ max_memory=max_memory,
375
+ torch_dtype=getattr(torch, cfg.ref_model.torch_dtype),
376
+ quantize=cfg.ref_model.quantization.enabled,
377
+ gradient_checkpointing=cfg.ref_model.gradient_checkpointing,
378
+ attn_implementation=cfg.ref_model.attn_implementation,
379
+ lora=False, # Reference model doesn't need LoRA
380
+ requires_grad=False,
381
+ )[0].eval()
382
+ # Detach weights
383
+ TensorDict.from_module(ref_model).data.to_module(ref_model)
384
+ ref_model = TransformersWrapper(
385
+ ref_model,
386
+ input_mode="tokens" if not cfg.env.reasoning else "history",
387
+ tokenizer=tokenizer,
388
+ generate=False,
389
+ return_log_probs=True,
390
+ pad_output=False,
391
+ device=torch.device("cuda:0"),
392
+ )
393
+ return ref_model
394
+
395
+
396
+ def get_hf_model(
397
+ model_name: str,
398
+ torch_dtype: torch_dtype = torch.float32,
399
+ lora_r: int = 8,
400
+ lora_alpha: int = 16,
401
+ lora_dropout: float = 0.1,
402
+ quantize: bool = False,
403
+ fsdp: str = "",
404
+ fsdp_config: Any = None,
405
+ gradient_checkpointing: bool = True,
406
+ device_map: str
407
+ | dict[str, int | str | torch_device]
408
+ | int
409
+ | torch_device
410
+ | None = None,
411
+ lora: bool = True,
412
+ attn_implementation: Literal["flash_attention_2", "flex_attention", "sdpa"]
413
+ | None = "flex_attention",
414
+ requires_grad: bool = True,
415
+ compile: bool = False,
416
+ max_memory: dict[str, str] | None = None,
417
+ ) -> tuple[AutoModelForCausalLM, PreTrainedTokenizer]:
418
+ """Creates and configures a HuggingFace model with optional optimizations.
419
+
420
+ Args:
421
+ model_name (str): HuggingFace model identifier (e.g., "Qwen/Qwen2.5-3B")
422
+ torch_dtype (torch.dtype, optional): Model precision. Default: torch.float32
423
+ lora_r (int, optional): LoRA rank - controls capacity of adaptations. Default: 8
424
+ lora_alpha (int, optional): LoRA alpha - scales the adaptations. Default: 16
425
+ lora_dropout (float, optional): Dropout probability for LoRA layers. Default: 0.1
426
+ quantize (bool, optional): Whether to enable 4-bit quantization. Default: False
427
+ fsdp (str, optional): Fully Sharded Data Parallel configuration. Default: ""
428
+ fsdp_config (Any, optional): Additional FSDP configurations. Default: None
429
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Default: True
430
+ device_map (str | dict | int | torch.device | None, optional): Device placement strategy. Default: None
431
+ lora (bool, optional): Whether to apply LoRA adapters. Default: True
432
+ attn_implementation (Literal["flash_attention_2", "flex_attention", "sdpa"] | None, optional):
433
+ Attention implementation to use. Default: "flex_attention"
434
+ requires_grad (bool, optional): Whether to enable gradient computation. Default: True
435
+ compile (bool, optional): Whether to enable model compilation. Default: False
436
+ max_memory (dict[str, str], optional): Memory configuration for distributed training. Default: {}
437
+
438
+ Returns:
439
+ tuple[AutoModelForCausalLM, PreTrainedTokenizer]:
440
+ - model: The configured HuggingFace model
441
+ - tokenizer: The associated tokenizer
442
+
443
+ Raises:
444
+ ImportError: If required dependencies are not installed
445
+ RuntimeError: If model initialization fails
446
+ """
447
+ from transformers import AutoModelForCausalLM, AutoTokenizer
448
+
449
+ if max_memory is None:
450
+ max_memory = {}
451
+
452
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
453
+ # tokenizer.eos_token = "<|im_end|>"
454
+ if tokenizer.pad_token == tokenizer.eos_token:
455
+ tokenizer.pad_token = "PAD"
456
+ tokenizer.padding_side = "left"
457
+
458
+ # Configure model settings for mixed precision
459
+ # Store original dtype to restore it later
460
+ original_dtype = torch.get_default_dtype()
461
+ torch.set_default_dtype(torch_dtype)
462
+
463
+ model_configs = {
464
+ "torch_dtype": torch_dtype,
465
+ "device_map": device_map if device_map is not None else "auto",
466
+ "max_memory": max_memory,
467
+ }
468
+ if torch.cuda.is_available() and attn_implementation:
469
+ torchrl_logger.info(f"{attn_implementation} init")
470
+ model_configs["attn_implementation"] = attn_implementation
471
+
472
+ try:
473
+ # Configure training settings based on FSDP usage
474
+ if fsdp != "" and fsdp_config is not None:
475
+ torchrl_logger.info("Configurations for FSDP")
476
+ bnb_config_params = {"bnb_4bit_quant_storage": torch_dtype}
477
+ else:
478
+ bnb_config_params = {}
479
+
480
+ # Enable Quantization
481
+ if quantize:
482
+ try:
483
+ from transformers.utils.quantization_config import BitsAndBytesConfig
484
+ except ImportError:
485
+ raise ImportError(
486
+ "Please install transformers with bitsandbytes support"
487
+ )
488
+
489
+ bnb_config = BitsAndBytesConfig(
490
+ load_in_4bit=True,
491
+ bnb_4bit_use_double_quant=True,
492
+ bnb_4bit_quant_type="nf4",
493
+ bnb_4bit_compute_dtype=torch_dtype,
494
+ **bnb_config_params,
495
+ )
496
+ model_configs["quantization_config"] = bnb_config
497
+
498
+ model = AutoModelForCausalLM.from_pretrained(
499
+ model_name,
500
+ trust_remote_code=True,
501
+ use_cache=not gradient_checkpointing,
502
+ cache_dir="/tmp/.cache",
503
+ **model_configs,
504
+ )
505
+
506
+ # Configure gradient checkpointing based on FSDP usage
507
+ if fsdp == "" and fsdp_config is None:
508
+ if gradient_checkpointing:
509
+ torchrl_logger.info("gradient_checkpointing enabled")
510
+ model.gradient_checkpointing_enable()
511
+ else:
512
+ if gradient_checkpointing:
513
+ torchrl_logger.info("gradient_checkpointing enabled")
514
+ model.gradient_checkpointing_enable(
515
+ gradient_checkpointing_kwargs={"use_reentrant": False}
516
+ )
517
+
518
+ if lora:
519
+ try:
520
+ from peft import get_peft_model, LoraConfig
521
+ except ImportError:
522
+ raise ImportError("Please install peft: pip install peft")
523
+
524
+ # Create LoRA config with explicit dtype setting
525
+ lora_config = LoraConfig(
526
+ r=lora_r,
527
+ lora_alpha=lora_alpha,
528
+ target_modules="all-linear",
529
+ lora_dropout=0.0, # Disable dropout for RL training
530
+ bias="none",
531
+ task_type="CAUSAL_LM",
532
+ inference_mode=True, # Force inference mode for consistent behavior
533
+ init_lora_weights=True, # This ensures weights are initialized
534
+ )
535
+
536
+ # Initialize LoRA model
537
+ model = get_peft_model(
538
+ model,
539
+ lora_config,
540
+ autocast_adapter_dtype=False, # Prevent automatic casting of adapter layers
541
+ )
542
+
543
+ # Force LoRA layers to correct dtype and eval mode
544
+ for n, p in model.named_parameters():
545
+ if "lora_" in n: # Only convert LoRA parameters
546
+ p.data = p.data.to(torch_dtype)
547
+
548
+ model.eval() # Ensure model is in eval mode
549
+ if requires_grad:
550
+ model.requires_grad_(True)
551
+
552
+ return model, tokenizer
553
+
554
+ finally:
555
+ # Restore original dtype
556
+ torch.set_default_dtype(original_dtype)
557
+
558
+
559
+ def make_weight_sync_scheme(
560
+ vllm_engine,
561
+ ) -> VLLMWeightSyncScheme:
562
+ """Creates a vLLM weight synchronization scheme using NCCL collectives.
563
+
564
+ This function creates a weight sync scheme that uses NCCL for high-performance
565
+ GPU-to-GPU weight transfers from the training model to vLLM inference workers.
566
+
567
+ Args:
568
+ vllm_engine: A vLLM engine implementing the RLvLLMEngine interface
569
+ (like RayLLMWorker, LocalLLMWrapper, or AsyncVLLM).
570
+ This is typically obtained from the inference policy's model attribute.
571
+
572
+ Returns:
573
+ VLLMWeightSyncScheme: A weight sync scheme configured for the vLLM engine.
574
+ """
575
+ # Get configuration from the vLLM engine
576
+ tp_size = vllm_engine.get_tp_size()
577
+ num_replicas = getattr(vllm_engine, "num_replicas", 1)
578
+ master_address = vllm_engine.get_master_address()
579
+ master_port = vllm_engine.get_master_port()
580
+
581
+ torchrl_logger.info(
582
+ f"Creating VLLMWeightSyncScheme with tp_size={tp_size}, "
583
+ f"num_replicas={num_replicas}, master_address={master_address}, "
584
+ f"master_port={master_port}"
585
+ )
586
+
587
+ return VLLMWeightSyncScheme(
588
+ master_address=master_address,
589
+ master_port=master_port,
590
+ gpus_per_replica=tp_size,
591
+ num_replicas=num_replicas,
592
+ strategy="state_dict",
593
+ )
594
+
595
+
596
+ def compute_device_allocation(cfg):
597
+ """Compute device allocations and Ray GPU config.
598
+
599
+ Args:
600
+ cfg: The configuration object
601
+
602
+ Returns:
603
+ dict: Updated device configuration containing:
604
+ - train_model_devices: list of devices for training
605
+ - inference_model_devices: list of devices for inference
606
+ - ray_num_gpus: number of GPUs to tell Ray about
607
+ - cuda_visible_devices: string for CUDA_VISIBLE_DEVICES
608
+ """
609
+ train_devices = cfg.train_model.num_devices
610
+ inf_devices = cfg.inference_model.num_devices
611
+
612
+ train_start = 0
613
+ train_end = train_devices
614
+ inference_start = 0
615
+ inference_end = inf_devices
616
+
617
+ ref_devices = cfg.ref_model.num_devices if cfg.train.use_kl_to_ref else 0
618
+ ray_num_gpus = train_devices + inf_devices + ref_devices
619
+
620
+ train_model_devices = list(range(train_start, train_end))
621
+ inference_model_devices = list(range(inference_start, inference_end))
622
+
623
+ all_devices = sorted(set(train_model_devices + inference_model_devices))
624
+ if cfg.train.use_kl_to_ref:
625
+ ref_device_start = max(all_devices) + 1 if all_devices else 0
626
+ ref_devices_list = list(range(ref_device_start, ref_device_start + ref_devices))
627
+ all_devices.extend(ref_devices_list)
628
+ cuda_visible_devices = ",".join(map(str, all_devices))
629
+
630
+ return {
631
+ "train_model_devices": train_model_devices,
632
+ "inference_model_devices": inference_model_devices,
633
+ "ray_num_gpus": ray_num_gpus,
634
+ "cuda_visible_devices": cuda_visible_devices,
635
+ }
636
+
637
+
638
+ def make_env(cfg: DictConfig, single_env: bool = False):
639
+ """Create the environment.
640
+
641
+ Args:
642
+ cfg: The configuration object
643
+
644
+ Returns:
645
+ The configured environment
646
+ """
647
+ train_tokenizer = get_tokenizer(cfg)
648
+
649
+ # Setup environment
650
+ max_steps = cfg.env.max_steps if cfg.env.reasoning else 1
651
+ if cfg.env.dataset == "gsm8k":
652
+ # Reward scale is 0.0 to 100
653
+ reward_threshold = 20
654
+ env = GSM8KEnv(
655
+ repeats=cfg.env.repeats,
656
+ tokenizer=train_tokenizer,
657
+ num_envs=cfg.env.num_envs if not single_env else 1,
658
+ max_steps=max_steps,
659
+ device=torch.device("cpu"),
660
+ ray_backend=True,
661
+ )
662
+ elif cfg.env.dataset == "ifeval": # ifeval
663
+ # Reward scale is 0.0 to 2.2
664
+ reward_threshold = 1.0
665
+ env = IFEvalEnv(
666
+ repeats=cfg.env.repeats,
667
+ tokenizer=train_tokenizer,
668
+ num_envs=cfg.env.num_envs if not single_env else 1,
669
+ max_steps=max_steps,
670
+ device=torch.device("cpu"),
671
+ ray_backend=True,
672
+ )
673
+ else:
674
+ raise NotImplementedError(f"Dataset {cfg.env.dataset} not implemented")
675
+
676
+ if cfg.env.reasoning:
677
+ env = env.append_transform(
678
+ AddThinkingPrompt(
679
+ cond=lambda td, reward_threshol=reward_threshold, max_steps=max_steps: td[
680
+ "reward"
681
+ ]
682
+ <= reward_threshold
683
+ and td["step_count"] < max_steps,
684
+ role="user",
685
+ edit_last_turn=False,
686
+ zero_reward=False,
687
+ undo_done=True,
688
+ random_prompt=True,
689
+ ),
690
+ )
691
+ return env
692
+
693
+
694
+ def make_ref_model_factory(cfg: DictConfig) -> functools.partial | None:
695
+ """Create a factory for the reference model if KL to ref is enabled.
696
+
697
+ Args:
698
+ cfg: The configuration object
699
+
700
+ Returns:
701
+ A partial function that creates the reference model, or None if KL to ref is disabled
702
+ """
703
+ if not cfg.train.use_kl_to_ref:
704
+ return None
705
+
706
+ train_tokenizer = get_tokenizer(cfg)
707
+ ref_cfg = DictConfig(dict(cfg))
708
+ ref_model_factory = functools.partial(
709
+ get_ref_model,
710
+ ref_cfg,
711
+ train_tokenizer,
712
+ devices=[0],
713
+ )
714
+ return ref_model_factory
715
+
716
+
717
+ def add_kl_transforms_to_replay_buffer(replay_buffer, cfg: DictConfig):
718
+ """Add KL transforms to replay buffer.
719
+
720
+ Args:
721
+ replay_buffer: The replay buffer to add transforms to
722
+ cfg: The configuration object
723
+ """
724
+ if not cfg.train.use_kl_to_ref:
725
+ return
726
+
727
+ ref_model_factory = make_ref_model_factory(cfg)
728
+ if ref_model_factory is None:
729
+ return
730
+
731
+ if cfg.env.reasoning:
732
+ kl_transform = RetrieveKL(
733
+ ref_model_factory=ref_model_factory,
734
+ add_to_reward=not cfg.train.kl_coef_in_loss,
735
+ coeff=cfg.train.kl_to_ref_coeff,
736
+ use_ray_service=True,
737
+ )
738
+ else:
739
+ kl_transform = KLRewardTransform(
740
+ ref_model_factory=ref_model_factory,
741
+ coef=cfg.train.kl_to_ref_coeff,
742
+ add_to_reward=not cfg.train.kl_coef_in_loss,
743
+ device=torch.device("cuda:0"),
744
+ use_ray_service=True,
745
+ )
746
+
747
+ replay_buffer.append_transform(kl_transform, invert=True)
748
+
749
+
750
+ @timeit("Logging metrics")
751
+ def log_training_metrics(
752
+ wandb_logger,
753
+ replay_buffer,
754
+ batch,
755
+ loss,
756
+ grad_norm,
757
+ global_step,
758
+ data_read_count,
759
+ collector,
760
+ start_time,
761
+ gradient_accumulation_steps,
762
+ history_str=None,
763
+ use_kl_to_ref=True,
764
+ ):
765
+ """Log training metrics to wandb.
766
+
767
+ Args:
768
+ wandb_logger: The wandb logger instance
769
+ replay_buffer: The replay buffer containing collected data
770
+ batch: The current training batch
771
+ loss: The computed loss object
772
+ grad_norm: The gradient norm value
773
+ global_step: Current global training step
774
+ data_read_count: Total data read count
775
+ collector: The collector instance
776
+ start_time: Training start time
777
+ gradient_accumulation_steps: Number of gradient accumulation steps
778
+ history_str: Optional history string for logging
779
+ """
780
+ with torch.no_grad():
781
+ rb_content = replay_buffer[:]
782
+ step_count = rb_content.get(("next", "step_count")).view(-1).float().mean()
783
+ batch_policy_version = batch["next", "policy_version"].view(-1).min()
784
+ batch_policy_age = collector.policy_version - batch_policy_version
785
+
786
+ metrics = {
787
+ "step_count from buffer": float(step_count),
788
+ "reward from buffer": float(
789
+ torch.cat(rb_content.get(("next", "reward"), as_list=True)).mean()
790
+ ),
791
+ "seq_length from buffer": float(
792
+ torch.tensor(
793
+ [
794
+ t.numel()
795
+ for t in rb_content.get(("tokens", "response"), as_list=True)
796
+ ],
797
+ dtype=torch.float,
798
+ ).mean()
799
+ ),
800
+ "ESS, from loss": float(loss.ESS),
801
+ "loss_objective, from loss": float(loss.loss_objective),
802
+ "clip_fraction, from loss": float(loss.clip_fraction),
803
+ "kl_approx (train to inference), from loss": float(loss.kl_approx),
804
+ "kl_to_inference (train to inference - differentiable), from loss": float(
805
+ loss.kl_to_inference.mean()
806
+ ),
807
+ "loss_kl_to_inference, from loss": float(loss.loss_kl_to_inference.mean()),
808
+ "entropy loss, from loss": float(loss.loss_entropy.mean()),
809
+ "grad_norm": float(grad_norm)
810
+ if global_step % gradient_accumulation_steps == 0
811
+ else 0.0,
812
+ "write_count, from buffer": int(replay_buffer.write_count),
813
+ # how many gradient steps per write
814
+ "gradient_step_throughput (gradient step per write)": float(
815
+ global_step / replay_buffer.write_count
816
+ ),
817
+ # how many optim steps per write
818
+ "optim_step_throughput (optim step per write)": float(
819
+ (global_step // gradient_accumulation_steps) / replay_buffer.write_count
820
+ ),
821
+ "data_read_count (total)": data_read_count,
822
+ "current_policy_version (collector)": collector.policy_version,
823
+ # FIXME: Assume batch is a single trajectory
824
+ # FIXME: The addition of the transform after the env instantiation + _shuttle creation
825
+ # is messed up - we need the next data
826
+ "batch_policy_version (sampled batch)": batch_policy_version,
827
+ "batch_policy_age (sampled batch)": batch_policy_age,
828
+ "throughput (steps per second)": float(
829
+ global_step / (time.time() - start_time)
830
+ ),
831
+ }
832
+ if use_kl_to_ref:
833
+ metrics["kl_penalty (inference to ref) from buffer"] = float(
834
+ torch.cat(rb_content.get(("next", "kl_penalty"), as_list=True)).mean()
835
+ )
836
+ metrics["kl_to_ref, from loss"] = float(loss.kl_to_ref.mean())
837
+ metrics["loss_kl_to_ref, from loss"] = float(loss.loss_kl_to_ref.mean())
838
+
839
+ for name, value in metrics.items():
840
+ wandb_logger.log_scalar(name, value, step=global_step)
841
+
842
+ if history_str is not None:
843
+ wandb_logger.log_str("history", history_str, step=global_step)