torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,770 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import time
8
+ from typing import Any, Literal
9
+
10
+ import torch
11
+ from omegaconf import DictConfig
12
+
13
+ from tensordict import TensorDict
14
+ from torch import device as torch_device, dtype as torch_dtype
15
+
16
+ from torchrl._utils import logger as torchrl_logger
17
+ from torchrl.envs.llm import RetrieveLogProb
18
+ from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
19
+ from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
20
+ from torchrl.weight_update.llm import VLLMWeightSyncScheme
21
+ from transformers.models.auto.modeling_auto import AutoModelForCausalLM
22
+ from transformers.tokenization_utils import PreTrainedTokenizer
23
+
24
+ try:
25
+ import ray
26
+ except ImportError:
27
+ ray = None
28
+
29
+
30
+ def get_tokenizer(cfg: DictConfig) -> PreTrainedTokenizer:
31
+ from transformers import AutoTokenizer
32
+
33
+ model_name = cfg.model.name
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
+ # tokenizer.eos_token = "<|im_end|>"
36
+ if tokenizer.pad_token == tokenizer.eos_token:
37
+ tokenizer.pad_token = "PAD"
38
+ tokenizer.padding_side = "left"
39
+ return tokenizer
40
+
41
+
42
+ def make_env(cfg: DictConfig, devices: list[int] | None = None):
43
+ """Create the environment with proper device allocation.
44
+
45
+ Args:
46
+ cfg: The configuration object
47
+ devices: The devices to use for the reference model
48
+
49
+ Returns:
50
+ The configured environment
51
+ """
52
+ # Create reference model with proper device allocation
53
+ # For the collector actor, we want inference_model devices first, then ref_model devices
54
+ train_tokenizer = get_tokenizer(cfg)
55
+
56
+ # Get device information
57
+ num_inf_devices = cfg.inference_model.num_devices
58
+ num_ref_devices = cfg.ref_model.num_devices
59
+ num_inf_devices + num_ref_devices
60
+
61
+ # Create a new config with adjusted device assignments
62
+ ref_cfg = DictConfig(dict(cfg))
63
+ ref_model = get_ref_model(ref_cfg, train_tokenizer, devices=devices)
64
+
65
+ # Setup environment
66
+ if cfg.env.dataset == "gsm8k":
67
+ from torchrl.envs.llm import GSM8KEnv
68
+
69
+ env = GSM8KEnv(
70
+ repeats=cfg.env.repeats,
71
+ tokenizer=train_tokenizer,
72
+ num_envs=cfg.env.num_envs,
73
+ device=torch.device("cpu"),
74
+ )
75
+ else: # ifeval
76
+ env = IFEvalEnv(
77
+ repeats=cfg.env.repeats,
78
+ tokenizer=train_tokenizer,
79
+ num_envs=cfg.env.num_envs,
80
+ device=torch.device("cpu"),
81
+ )
82
+
83
+ # Pass device directly to RetrieveLogProb - Since, for Ray, the local device is always 0
84
+ # we can just use 0 here.
85
+ device = torch.device("cuda:0")
86
+ env = env.append_transform(
87
+ RetrieveLogProb(
88
+ model=ref_model,
89
+ assistant_only=True,
90
+ tokenizer_kwargs={"chat_template_name": "qwen"},
91
+ device=device,
92
+ log_probs_full_key=("ref_log_probs", "full"),
93
+ )
94
+ )
95
+ return env
96
+
97
+
98
+ def get_train_model(
99
+ cfg: DictConfig,
100
+ devices: list[int] | None = None,
101
+ chat_template_name: str | None = None,
102
+ ) -> tuple[TransformersWrapper, PreTrainedTokenizer]:
103
+ """Creates and configures the training model with LoRA adapters.
104
+
105
+ This function initializes the main training model with LoRA adapters and other
106
+ training-specific configurations like gradient checkpointing. The model is wrapped
107
+ in a TransformersWrapper for policy training.
108
+
109
+ Args:
110
+ cfg (DictConfig): The hydra configuration object containing model and training settings.
111
+ Expected to have train_model section with LoRA, quantization, and other
112
+ training-specific parameters.
113
+ devices (list[int] | None, optional): The devices to use for the training model. Defaults to `None`.
114
+ chat_template_name (str | None, optional): The name of the chat template to use. Defaults to `None`.
115
+
116
+ Returns:
117
+ tuple[TransformersWrapper, PreTrainedTokenizer]:
118
+ - policy_training: The wrapped training model
119
+ - train_tokenizer: The tokenizer for the model
120
+
121
+ Raises:
122
+ RuntimeError: If CUDA is not available or if device allocation fails
123
+ """
124
+ torchrl_logger.info("Creating train model")
125
+
126
+ # Set model dtype explicitly
127
+ model_dtype = getattr(torch, cfg.train_model.torch_dtype)
128
+
129
+ # Get configured devices or default to [0]
130
+ train_devices = devices if devices is not None else [0]
131
+
132
+ # Create max_memory dict - set 0 memory for GPUs we don't want to use
133
+ max_memory = {}
134
+ for i in range(torch.cuda.device_count()):
135
+ if i in train_devices:
136
+ max_memory[i] = "24GiB" # Allow max memory for devices we want to use
137
+ else:
138
+ max_memory[i] = "0GiB" # No memory for other devices
139
+ max_memory["cpu"] = "24GiB" # Allow CPU memory as fallback
140
+
141
+ # Let HF handle distribution with max_memory
142
+ device_map = "balanced" if len(train_devices) > 1 else f"cuda:{train_devices[0]}"
143
+ train_model, train_tokenizer = get_hf_model(
144
+ cfg.model.name,
145
+ device_map=device_map,
146
+ max_memory=max_memory,
147
+ lora=cfg.train_model.lora.enabled,
148
+ lora_r=cfg.train_model.lora.r,
149
+ lora_alpha=cfg.train_model.lora.alpha,
150
+ lora_dropout=cfg.train_model.lora.dropout,
151
+ gradient_checkpointing=cfg.train_model.gradient_checkpointing,
152
+ quantize=cfg.train_model.quantization.enabled,
153
+ torch_dtype=model_dtype,
154
+ attn_implementation=cfg.train_model.attn_implementation,
155
+ compile=cfg.model.compile,
156
+ eval_mode=cfg.train_model.eval,
157
+ )
158
+
159
+ # Force all model parameters to the same dtype
160
+ for param in train_model.parameters():
161
+ param.data = param.data.to(model_dtype)
162
+
163
+ policy_training = TransformersWrapper(
164
+ train_model,
165
+ tokenizer=train_tokenizer,
166
+ input_mode="history",
167
+ generate=False,
168
+ return_log_probs=True,
169
+ pad_output=False,
170
+ device=torch.device("cuda:0"),
171
+ )
172
+ # Ensure model stays in eval mode after wrapping
173
+ policy_training.model.eval()
174
+ policy_training.model.train(False)
175
+ return policy_training, train_tokenizer
176
+
177
+
178
+ def get_inference_model(
179
+ cfg: DictConfig,
180
+ devices: list[int] | None = None,
181
+ make_ray_worker: bool = True,
182
+ tokenizer: PreTrainedTokenizer | None = None,
183
+ ) -> vLLMWrapper:
184
+ """Creates the vLLM-based inference model for fast generation.
185
+
186
+ This function initializes a vLLM model server for efficient inference and wraps
187
+ it in a vLLMWrapper for policy inference. vLLM provides optimized generation
188
+ with better throughput than standard HuggingFace generation.
189
+
190
+ Args:
191
+ cfg (DictConfig): The hydra configuration object containing model settings.
192
+ Expected to have inference_model section with vLLM-specific parameters
193
+ like gpu_memory_utilization and generation settings.
194
+ devices (list[int], optional): The devices to use for the inference model. Default: `None`.
195
+ make_ray_worker (bool, optional): Whether to make a ray worker. Default: `True`.
196
+ tokenizer (PreTrainedTokenizer | None, optional): The tokenizer to use. Default: None
197
+
198
+ Returns:
199
+ vLLMWrapper: The wrapped vLLM model ready for inference.
200
+
201
+ Raises:
202
+ AssertionError: If the vLLM server or model initialization fails
203
+ """
204
+ from torchrl.modules.llm.backends import make_vllm_worker
205
+
206
+ num_devices = cfg.inference_model.num_devices
207
+ if num_devices is None:
208
+ vllm_devices = devices if devices is not None else [1]
209
+ else:
210
+ vllm_devices = None
211
+ torchrl_logger.info(
212
+ f"Creating inference model with num_devices={num_devices}, devices={vllm_devices}"
213
+ )
214
+
215
+ model_name = cfg.model.name
216
+
217
+ if tokenizer is None:
218
+ tokenizer = get_tokenizer(cfg)
219
+
220
+ # vLLM handles device mapping internally
221
+ inference_server = make_vllm_worker(
222
+ model_name=model_name,
223
+ gpu_memory_utilization=cfg.inference_model.gpu_memory_utilization,
224
+ num_devices=num_devices,
225
+ devices=list(vllm_devices)
226
+ if vllm_devices is not None
227
+ else None, # Convert to list for type compatibility
228
+ make_ray_worker=make_ray_worker,
229
+ enforce_eager=cfg.inference_model.enforce_eager,
230
+ )
231
+ assert inference_server is not None
232
+ policy = vLLMWrapper(
233
+ inference_server,
234
+ input_mode="history",
235
+ chat_template_name="qwen",
236
+ return_log_probs=True,
237
+ tokenizer=tokenizer,
238
+ pad_output=False,
239
+ generate_kwargs={
240
+ "max_tokens": cfg.inference_model.max_tokens,
241
+ "include_stop_str_in_output": cfg.inference_model.include_stop_str_in_output,
242
+ "temperature": cfg.inference_model.temperature,
243
+ },
244
+ )
245
+ assert policy.model is not None
246
+ return policy
247
+
248
+
249
+ def get_ref_model(
250
+ cfg: DictConfig, tokenizer: PreTrainedTokenizer, devices: list[int] | None = None
251
+ ) -> TransformersWrapper:
252
+ """Creates the reference model for KL penalty computation.
253
+
254
+ This function initializes a frozen copy of the base model to serve as the
255
+ reference model for KL divergence computation. The reference model is typically
256
+ quantized and does not require gradient computation.
257
+
258
+ Args:
259
+ cfg (DictConfig): The hydra configuration object containing model settings.
260
+ Expected to have ref_model section with quantization and attention settings.
261
+ tokenizer (PreTrainedTokenizer): The tokenizer to use with the reference model.
262
+
263
+ Returns:
264
+ TransformersWrapper: The wrapped reference model in eval mode with detached weights.
265
+ """
266
+ torchrl_logger.info("Creating ref model")
267
+
268
+ # Get configured devices or default to [2]
269
+ if cfg.ref_model.num_devices is None:
270
+ ref_devices = devices if devices is not None else [2]
271
+ else:
272
+ ref_devices = list(range(cfg.ref_model.num_devices))
273
+
274
+ # Create max_memory dict - set 0 memory for GPUs we don't want to use
275
+ max_memory = {}
276
+ for i in range(torch.cuda.device_count()):
277
+ if i in ref_devices:
278
+ max_memory[i] = "24GiB" # Allow max memory for devices we want to use
279
+ else:
280
+ max_memory[i] = "0GiB" # No memory for other devices
281
+ max_memory["cpu"] = "24GiB" # Allow CPU memory as fallback
282
+
283
+ # Let HF handle distribution with max_memory
284
+ device_map = "balanced" if len(ref_devices) > 1 else f"cuda:{ref_devices[0]}"
285
+ model_name = cfg.model.name
286
+
287
+ ref_model = get_hf_model(
288
+ model_name,
289
+ device_map=device_map,
290
+ max_memory=max_memory,
291
+ torch_dtype=getattr(torch, cfg.ref_model.torch_dtype),
292
+ quantize=cfg.ref_model.quantization.enabled,
293
+ gradient_checkpointing=cfg.ref_model.gradient_checkpointing,
294
+ attn_implementation=cfg.ref_model.attn_implementation,
295
+ lora=False, # Reference model doesn't need LoRA
296
+ requires_grad=False,
297
+ eval_mode=True,
298
+ lora_dropout=0.0,
299
+ )[0]
300
+ # Detach weights
301
+ TensorDict.from_module(ref_model).data.to_module(ref_model)
302
+ ref_model = TransformersWrapper(
303
+ ref_model,
304
+ tokenizer=tokenizer,
305
+ input_mode="history",
306
+ generate=False,
307
+ return_log_probs=True,
308
+ pad_output=False,
309
+ device=torch.device("cuda:0"),
310
+ chat_template_name="qwen",
311
+ )
312
+ return ref_model
313
+
314
+
315
+ def get_hf_model(
316
+ model_name: str,
317
+ torch_dtype: torch_dtype = torch.float32,
318
+ lora_r: int = 8,
319
+ lora_alpha: int = 16,
320
+ lora_dropout: float = 0.1,
321
+ quantize: bool = False,
322
+ fsdp: str = "",
323
+ fsdp_config: Any = None,
324
+ gradient_checkpointing: bool = True,
325
+ device_map: str
326
+ | dict[str, int | str | torch_device]
327
+ | int
328
+ | torch_device
329
+ | None = None,
330
+ lora: bool = True,
331
+ attn_implementation: Literal["flash_attention_2", "flex_attention", "sdpa"]
332
+ | None = "flex_attention",
333
+ requires_grad: bool = True,
334
+ compile: bool = False,
335
+ max_memory: dict[str, str] | None = None,
336
+ eval_mode: bool = False,
337
+ ) -> tuple[AutoModelForCausalLM, PreTrainedTokenizer]:
338
+ """Creates and configures a HuggingFace model with optional optimizations.
339
+
340
+ Args:
341
+ model_name (str): HuggingFace model identifier (e.g., "Qwen/Qwen2.5-3B")
342
+ torch_dtype (torch.dtype, optional): Model precision. Default: torch.float32
343
+ lora_r (int, optional): LoRA rank - controls capacity of adaptations. Default: 8
344
+ lora_alpha (int, optional): LoRA alpha - scales the adaptations. Default: 16
345
+ lora_dropout (float, optional): Dropout probability for LoRA layers. Default: 0.1
346
+ quantize (bool, optional): Whether to enable 4-bit quantization. Default: False
347
+ fsdp (str, optional): Fully Sharded Data Parallel configuration. Default: ""
348
+ fsdp_config (Any, optional): Additional FSDP configurations. Default: None
349
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Default: True
350
+ device_map (str | dict | int | torch.device | None, optional): Device placement strategy. Default: None
351
+ lora (bool, optional): Whether to apply LoRA adapters. Default: True
352
+ attn_implementation (Literal["flash_attention_2", "flex_attention", "sdpa"] | None, optional):
353
+ Attention implementation to use. Default: "flex_attention"
354
+ requires_grad (bool, optional): Whether to enable gradient computation. Default: True
355
+ compile (bool, optional): Whether to enable model compilation. Default: False
356
+ max_memory (dict[str, str], optional): Memory configuration for distributed training. Default: {}
357
+ eval_mode (bool, optional): Whether to use the model in eval mode. Default: False
358
+
359
+ Returns:
360
+ tuple[AutoModelForCausalLM, PreTrainedTokenizer]:
361
+ - model: The configured HuggingFace model
362
+ - tokenizer: The associated tokenizer
363
+
364
+ Raises:
365
+ ImportError: If required dependencies are not installed
366
+ RuntimeError: If model initialization fails
367
+ """
368
+ from transformers import AutoModelForCausalLM, AutoTokenizer
369
+
370
+ if max_memory is None:
371
+ max_memory = {}
372
+
373
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
374
+ # tokenizer.eos_token = "<|im_end|>"
375
+ if tokenizer.pad_token == tokenizer.eos_token:
376
+ tokenizer.pad_token = "PAD"
377
+ tokenizer.padding_side = "left"
378
+
379
+ # Configure model settings for mixed precision
380
+ # Store original dtype to restore it later
381
+ original_dtype = torch.get_default_dtype()
382
+ torch.set_default_dtype(torch_dtype)
383
+
384
+ model_configs = {
385
+ "torch_dtype": torch_dtype,
386
+ "device_map": device_map if device_map is not None else "auto",
387
+ "max_memory": max_memory,
388
+ }
389
+ if torch.cuda.is_available() and attn_implementation:
390
+ torchrl_logger.info(f"{attn_implementation} init")
391
+ model_configs["attn_implementation"] = attn_implementation
392
+
393
+ try:
394
+ # Configure training settings based on FSDP usage
395
+ if fsdp != "" and fsdp_config is not None:
396
+ torchrl_logger.info("Configurations for FSDP")
397
+ else:
398
+ pass
399
+
400
+ # Enable Quantization
401
+ if quantize:
402
+ try:
403
+ from transformers.utils.quantization_config import BitsAndBytesConfig
404
+ except ImportError:
405
+ raise ImportError(
406
+ "Please install transformers with bitsandbytes support"
407
+ )
408
+
409
+ bnb_config = BitsAndBytesConfig(
410
+ load_in_4bit=True,
411
+ bnb_4bit_use_double_quant=True,
412
+ bnb_4bit_quant_type="nf4",
413
+ bnb_4bit_compute_dtype=torch_dtype,
414
+ )
415
+ model_configs["quantization_config"] = bnb_config
416
+
417
+ model = AutoModelForCausalLM.from_pretrained(
418
+ model_name,
419
+ trust_remote_code=True,
420
+ use_cache=not gradient_checkpointing,
421
+ cache_dir="/tmp/.cache",
422
+ **model_configs,
423
+ )
424
+
425
+ # Configure gradient checkpointing based on FSDP usage
426
+ if fsdp == "" and fsdp_config is None:
427
+ if gradient_checkpointing:
428
+ torchrl_logger.info("gradient_checkpointing enabled")
429
+ model.gradient_checkpointing_enable()
430
+ else:
431
+ if gradient_checkpointing:
432
+ torchrl_logger.info("gradient_checkpointing enabled")
433
+ model.gradient_checkpointing_enable(
434
+ gradient_checkpointing_kwargs={"use_reentrant": False}
435
+ )
436
+
437
+ if lora:
438
+ try:
439
+ from peft import get_peft_model, LoraConfig
440
+ except ImportError:
441
+ raise ImportError("Please install peft: pip install peft")
442
+
443
+ # Create LoRA config with explicit dtype setting
444
+ lora_config = LoraConfig(
445
+ r=lora_r,
446
+ lora_alpha=lora_alpha,
447
+ target_modules="all-linear",
448
+ lora_dropout=lora_dropout, # Standard dropout for regularization
449
+ bias="none",
450
+ task_type="CAUSAL_LM",
451
+ inference_mode=not eval_mode, # CRITICAL: Must be False for training
452
+ init_lora_weights=True, # Good practice
453
+ )
454
+
455
+ # Initialize LoRA model
456
+ model = get_peft_model(
457
+ model,
458
+ lora_config,
459
+ autocast_adapter_dtype=False, # Prevent automatic casting of adapter layers
460
+ )
461
+
462
+ # Force LoRA layers to correct dtype and eval mode
463
+ for n, p in model.named_parameters():
464
+ if "lora_" in n: # Only convert LoRA parameters
465
+ p.data = p.data.to(torch_dtype)
466
+ if eval_mode:
467
+ model.eval() # Ensure model is in eval mode
468
+ else:
469
+ model.train(True)
470
+ if requires_grad:
471
+ model.requires_grad_(True)
472
+ else:
473
+ model.requires_grad_(False)
474
+ return model, tokenizer
475
+
476
+ finally:
477
+ # Restore original dtype
478
+ torch.set_default_dtype(original_dtype)
479
+
480
+
481
+ def make_weight_sync_scheme(
482
+ master_address=None,
483
+ master_port=None,
484
+ vllm_tp_size=1,
485
+ ) -> VLLMWeightSyncScheme:
486
+ """Creates a vLLM weight synchronization scheme using NCCL collectives.
487
+
488
+ This function creates a weight sync scheme that uses NCCL for high-performance
489
+ GPU-to-GPU weight transfers from the training model to vLLM inference workers.
490
+
491
+ Args:
492
+ master_address (Optional[str]): Address of the master node for distributed init.
493
+ Defaults to "localhost".
494
+ master_port (Optional[int]): Port of the master node for distributed init.
495
+ If None, will auto-assign.
496
+ vllm_tp_size (int): vLLM tensor parallel size (gpus_per_replica). Defaults to 1.
497
+
498
+ Returns:
499
+ VLLMWeightSyncScheme: A weight sync scheme configured for the vLLM engine.
500
+ """
501
+ if master_address is None:
502
+ master_address = "localhost"
503
+
504
+ torchrl_logger.info(
505
+ f"Creating VLLMWeightSyncScheme with tp_size={vllm_tp_size}, "
506
+ f"master_address={master_address}, master_port={master_port}"
507
+ )
508
+
509
+ return VLLMWeightSyncScheme(
510
+ master_address=master_address,
511
+ master_port=master_port,
512
+ gpus_per_replica=vllm_tp_size,
513
+ num_replicas=1, # For expert iteration, typically 1 replica
514
+ strategy="state_dict",
515
+ )
516
+
517
+
518
+ def compute_device_allocation(cfg):
519
+ """Compute device allocation for different model components.
520
+
521
+ Args:
522
+ cfg: The configuration object
523
+
524
+ Returns:
525
+ A dictionary containing device allocations for different components
526
+ """
527
+ train_devices = cfg.train_model.num_devices
528
+ inf_devices = cfg.inference_model.num_devices
529
+ ref_devices = cfg.ref_model.num_devices
530
+
531
+ # So we need all GPUs for Ray
532
+ train_start = 0
533
+ train_end = train_devices
534
+ inference_start = 0
535
+ inference_end = inf_devices
536
+ ref_start = inference_end
537
+ ref_end = ref_start + ref_devices
538
+ ray_num_gpus = train_devices + inf_devices + ref_devices
539
+
540
+ # Create device lists
541
+ train_model_devices = list(range(train_start, train_end))
542
+ inference_model_devices = list(range(inference_start, inference_end))
543
+ ref_model_devices = list(range(ref_start, ref_end))
544
+
545
+ # Get total unique devices for CUDA_VISIBLE_DEVICES
546
+ all_devices = sorted(
547
+ set(train_model_devices + inference_model_devices + ref_model_devices)
548
+ )
549
+ cuda_visible_devices = ",".join(map(str, all_devices))
550
+
551
+ return {
552
+ "train_model_devices": train_model_devices,
553
+ "inference_model_devices": inference_model_devices,
554
+ "ref_model_devices": ref_model_devices,
555
+ "ray_num_gpus": ray_num_gpus,
556
+ "cuda_visible_devices": cuda_visible_devices,
557
+ }
558
+
559
+
560
+ def create_cosine_scheduler_with_warmup(
561
+ optimizer: torch.optim.Optimizer,
562
+ num_warmup_steps: int,
563
+ num_training_steps: int,
564
+ num_cycles: float = 0.5,
565
+ ) -> torch.optim.lr_scheduler.LRScheduler:
566
+ """Create a cosine scheduler with warmup using PyTorch's built-in schedulers.
567
+
568
+ This function creates a learning rate scheduler that:
569
+ 1. Linearly increases the learning rate from 0 to the base learning rate during warmup
570
+ 2. Follows a cosine curve from the base learning rate to 0 after warmup
571
+
572
+ Args:
573
+ optimizer: The optimizer to schedule learning rates for
574
+ num_warmup_steps: Number of warmup steps
575
+ num_training_steps: Total number of training steps
576
+ num_cycles: Number of cosine cycles (default: 0.5 for half a cycle)
577
+
578
+ Returns:
579
+ A PyTorch learning rate scheduler
580
+ """
581
+ from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
582
+
583
+ # Create warmup scheduler (linear increase from 0 to base LR)
584
+ warmup_scheduler = LinearLR(
585
+ optimizer, start_factor=0.01, end_factor=1.0, total_iters=num_warmup_steps
586
+ )
587
+
588
+ # Create cosine decay scheduler (from base LR to 0)
589
+ cosine_scheduler = CosineAnnealingLR(
590
+ optimizer, T_max=num_training_steps - num_warmup_steps, eta_min=0.0
591
+ )
592
+
593
+ # Combine warmup and cosine decay
594
+ scheduler = SequentialLR(
595
+ optimizer,
596
+ schedulers=[warmup_scheduler, cosine_scheduler],
597
+ milestones=[num_warmup_steps],
598
+ )
599
+
600
+ return scheduler
601
+
602
+
603
+ def get_wandb_run_id(wandb_logger):
604
+ """Get the wandb run ID from a WandbLogger instance.
605
+
606
+ Args:
607
+ wandb_logger: The WandbLogger instance
608
+
609
+ Returns:
610
+ str: The wandb run ID, or None if not available
611
+ """
612
+ try:
613
+ # Wait a bit for wandb to initialize
614
+
615
+ max_attempts = 10
616
+ for attempt in range(max_attempts):
617
+ if hasattr(wandb_logger, "experiment") and wandb_logger.experiment:
618
+ run_id = wandb_logger.experiment.id
619
+ if run_id:
620
+ torchrl_logger.info(f"Got wandb run ID: {run_id}")
621
+ return run_id
622
+ if attempt < max_attempts - 1:
623
+ time.sleep(0.5)
624
+ torchrl_logger.info(
625
+ f"Waiting for wandb run ID, attempt {attempt + 1}/{max_attempts}"
626
+ )
627
+
628
+ torchrl_logger.warning("Could not get wandb run ID after multiple attempts")
629
+ return None
630
+ except Exception as e:
631
+ torchrl_logger.error(f"Error getting wandb run ID: {e}")
632
+ return None
633
+
634
+
635
+ def log_training_metrics(
636
+ wandb_logger,
637
+ replay_buffer,
638
+ batch,
639
+ loss,
640
+ grad_norm,
641
+ global_step,
642
+ data_read_count,
643
+ collector,
644
+ start_time,
645
+ gradient_accumulation_steps,
646
+ history_str=None,
647
+ ):
648
+ """Log training metrics to wandb.
649
+
650
+ Args:
651
+ wandb_logger: The wandb logger instance
652
+ replay_buffer: The replay buffer containing collected data
653
+ batch: The current training batch
654
+ loss: The computed loss object
655
+ grad_norm: The gradient norm value
656
+ global_step: Current global training step
657
+ data_read_count: Total data read count
658
+ collector: The collector instance
659
+ start_time: Training start time
660
+ gradient_accumulation_steps: Number of gradient accumulation steps
661
+ history_str: Optional history string for logging
662
+ """
663
+ with torch.no_grad():
664
+ rb_content = replay_buffer[:]
665
+ batch_policy_version = batch["next", "policy_version"].view(-1).min()
666
+ batch_policy_age = collector.policy_version - batch_policy_version
667
+
668
+ metrics = {
669
+ "reward from buffer": float(
670
+ torch.cat(rb_content.get(("next", "reward"), as_list=True)).mean()
671
+ ),
672
+ "reward from batch": float(batch["next", "reward"].mean()),
673
+ "seq_length from buffer": float(
674
+ torch.tensor(
675
+ [
676
+ t.numel()
677
+ for t in rb_content.get(("tokens", "response"), as_list=True)
678
+ ],
679
+ dtype=torch.float,
680
+ ).mean()
681
+ ),
682
+ "loss_sft, from loss": float(loss.loss_sft),
683
+ "loss_kl_to_ref, from loss": float(loss.loss_kl_to_ref),
684
+ "kl_to_ref, from loss": float(loss.kl_to_ref),
685
+ "grad_norm": float(grad_norm)
686
+ if global_step % gradient_accumulation_steps == 0
687
+ else 0.0,
688
+ "write_count, from buffer": int(replay_buffer.write_count),
689
+ # how many gradient steps per write
690
+ "gradient_step_throughput (gradient step per write)": float(
691
+ global_step / replay_buffer.write_count
692
+ ),
693
+ # how many optim steps per write
694
+ "optim_step_throughput (optim step per write)": float(
695
+ (global_step // gradient_accumulation_steps) / replay_buffer.write_count
696
+ ),
697
+ "data_read_count (total)": data_read_count,
698
+ "current_policy_version (collector)": collector.policy_version,
699
+ # FIXME: Assume batch is a single trajectory
700
+ # FIXME: The addition of the transform after the env instantiation + _shuttle creation
701
+ # is messed up - we need the next data
702
+ "batch_policy_version (sampled batch)": batch_policy_version,
703
+ "batch_policy_age (sampled batch)": batch_policy_age,
704
+ "throughput (steps per second)": float(
705
+ global_step / (time.time() - start_time)
706
+ ),
707
+ }
708
+
709
+ for name, value in metrics.items():
710
+ wandb_logger.log_scalar(name, value, step=global_step)
711
+
712
+ if history_str is not None:
713
+ wandb_logger.log_str("history", history_str, step=global_step)
714
+
715
+
716
+ class RemoteDataLogger:
717
+ """A remote post-processing function that sends logging data to the main process via Ray for centralized logging."""
718
+
719
+ def __init__(self, log_queue):
720
+ """Initialize RemoteDataLogger with a Ray actor reference for logging.
721
+
722
+ Args:
723
+ log_queue: Ray queue for logging data.
724
+ """
725
+ self.log_queue = log_queue
726
+ self.last_time = None
727
+
728
+ def __call__(self, data: TensorDict):
729
+ self.log_data(data)
730
+ return data
731
+
732
+ def log_data(self, data: TensorDict):
733
+ logs = {}
734
+ if self.last_time is None:
735
+ self.last_time = time.time()
736
+ else:
737
+ t = time.time()
738
+ elapsed = t - self.last_time
739
+ logs["collector/time/elapsed"] = elapsed
740
+ self.last_time = t
741
+
742
+ # Prepare logging data
743
+ logs["collector/rewards/mean"] = float(data["next", "reward"].mean())
744
+ logs["collector/rewards/std"] = float(data["next", "reward"].std())
745
+ logs["collector/rewards/min"] = float(data["next", "reward"].min())
746
+ logs["collector/rewards/max"] = float(data["next", "reward"].max())
747
+
748
+ # Response length
749
+ lengths = []
750
+ responses = data["text", "response"]
751
+ for r in responses:
752
+ lengths.append(len(r))
753
+ lengths = torch.tensor(lengths, dtype=torch.float32)
754
+ logs["collector/response_length/mean"] = float(lengths.mean())
755
+ logs["collector/response_length/std"] = float(lengths.std())
756
+ logs["collector/response_length/min"] = float(lengths.min())
757
+ logs["collector/response_length/max"] = float(lengths.max())
758
+
759
+ policy_versions = data.get(("next", "policy_version"))
760
+ if isinstance(policy_versions, torch.Tensor):
761
+ policy_versions = policy_versions.float()
762
+ logs["collector/policy_version/mean"] = float(policy_versions.mean())
763
+ logs["collector/policy_version/min"] = float(policy_versions.min())
764
+ logs["collector/policy_version/max"] = float(policy_versions.max())
765
+
766
+ # Send to main process via Ray actor
767
+ try:
768
+ self.log_queue.put(logs)
769
+ except Exception as e:
770
+ torchrl_logger.error(f"Failed to send logs to main process: {e}")