torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,311 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import time
9
+
10
+ from collections.abc import Iterator
11
+
12
+ import torch
13
+ from tensordict import TensorDictBase
14
+ from torchrl._utils import logger as torchrl_logger
15
+ from torchrl.collectors.weight_update import WeightUpdaterBase
16
+ from torchrl.modules.llm.backends.vllm import RLvLLMEngine
17
+
18
+ try:
19
+ pass
20
+
21
+ _has_transformers = True
22
+ except ImportError:
23
+ _has_transformers = False
24
+
25
+
26
+ class vLLMUpdaterV2(WeightUpdaterBase):
27
+ """Simplified vLLM weight updater using the RLvLLMEngine interface.
28
+
29
+ This updater works with any vLLM engine that implements the RLvLLMEngine
30
+ interface, automatically extracting configuration and handling weight updates
31
+ through the engine's own methods.
32
+
33
+ Args:
34
+ vllm_engine: A vLLM engine implementing the RLvLLMEngine interface.
35
+
36
+ .. note:: This class can be created through :class:`torchrl.collectors.llm.vLLMUpdater` with `v2=True`.
37
+
38
+ """
39
+
40
+ def __init__(self, vllm_engine: RLvLLMEngine):
41
+ # Check that vllm_engine implements the RLvLLMEngine interface
42
+ if not isinstance(vllm_engine, RLvLLMEngine):
43
+ raise TypeError(
44
+ f"vllm_engine must implement RLvLLMEngine interface, got {type(vllm_engine)}"
45
+ )
46
+
47
+ torchrl_logger.debug(f"=> in {type(self).__name__}.__init__")
48
+ self.vllm_engine = vllm_engine
49
+ self.initialized_group = None
50
+
51
+ # Extract configuration from engine
52
+ self.vllm_tp_size = vllm_engine.get_tp_size()
53
+ self.master_address = vllm_engine.get_master_address()
54
+ self.master_port = vllm_engine.get_master_port()
55
+ self.model_metadata = vllm_engine.get_model_metadata()
56
+
57
+ torchrl_logger.debug(
58
+ f"Initialized vLLMUpdaterV2 with tp_size={self.vllm_tp_size}"
59
+ )
60
+
61
+ def get_tp_size(self) -> int:
62
+ """Get the tensor parallel size."""
63
+ return self.vllm_tp_size
64
+
65
+ def init(
66
+ self, model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None
67
+ ) -> None:
68
+ """Initialize the weight updater.
69
+
70
+ Args:
71
+ model_metadata: Optional model metadata. If not provided, uses engine's metadata.
72
+ """
73
+ if model_metadata is not None:
74
+ self.model_metadata = model_metadata
75
+
76
+ # Initialize the engine's weight update group
77
+ self.vllm_engine.init_weight_update_group()
78
+ self.initialized_group = True
79
+ torchrl_logger.debug("Weight update group initialized")
80
+
81
+ def push_weights(
82
+ self, weights: Iterator[tuple[str, torch.Tensor]] | TensorDictBase
83
+ ):
84
+ """Push weights to the vLLM engine.
85
+
86
+ Args:
87
+ weights: Either an iterator of (name, tensor) pairs or a TensorDictBase
88
+ """
89
+ if isinstance(weights, TensorDictBase):
90
+ weights = iter(weights.flatten_keys(".").items())
91
+
92
+ if self.initialized_group is None:
93
+ raise RuntimeError("Weight updater not initialized. Call init() first.")
94
+
95
+ # Delegate to the engine's update_weights method
96
+ self.vllm_engine.update_weights(weights)
97
+ torchrl_logger.debug("Weight update completed")
98
+
99
+ # Call post-hooks to increment policy version
100
+ torchrl_logger.debug("Calling post-hooks...")
101
+ self._call_post_hooks()
102
+ torchrl_logger.debug("Post-hooks completed")
103
+
104
+ def push_weights_from_transformers(self, transformers_model):
105
+ """Push weights from a transformers model.
106
+
107
+ Args:
108
+ transformers_model: A transformers PreTrainedModel or TorchRL wrapper
109
+ """
110
+ if not _has_transformers:
111
+ raise ImportError("transformers not available")
112
+ t0 = time.time()
113
+ # Extract state dict from model, handling LoRA models properly
114
+ if hasattr(transformers_model, "model") and hasattr(
115
+ transformers_model.model, "state_dict"
116
+ ):
117
+ # TorchRL wrapper (e.g., TransformersWrapper)
118
+ model = transformers_model.model
119
+ # Check if it's a LoRA model
120
+ if hasattr(model, "merge_and_unload"):
121
+ state_dict = model.merge_and_unload().state_dict()
122
+ else:
123
+ state_dict = model.state_dict()
124
+ elif hasattr(transformers_model, "state_dict"):
125
+ # Direct transformers model
126
+ # Check if it's a LoRA model
127
+ if hasattr(transformers_model, "merge_and_unload"):
128
+ state_dict = transformers_model.merge_and_unload().state_dict()
129
+ else:
130
+ state_dict = transformers_model.state_dict()
131
+ else:
132
+ raise TypeError(
133
+ f"Cannot extract state_dict from {type(transformers_model)}"
134
+ )
135
+
136
+ t1 = time.time()
137
+ torchrl_logger.debug(f"Time to extract state_dict: {t1 - t0}")
138
+ # Convert to iterator for memory efficiency
139
+ weights_iter = iter(state_dict.items())
140
+ self.push_weights(weights_iter)
141
+ torchrl_logger.debug(f"Time to push weights: {time.time() - t1}")
142
+
143
+ def push_weights_from_transformers_optimized(
144
+ self, transformers_model, batch_size=50
145
+ ):
146
+ """Optimized version of push_weights_from_transformers with GPU pre-loading.
147
+
148
+ This method provides several optimizations:
149
+ 1. Pre-loads all weights to GPU before transfer
150
+ 2. Optionally batches weights for better memory management
151
+ 3. Uses non-blocking transfers when possible
152
+
153
+ Args:
154
+ transformers_model: A transformers PreTrainedModel or TorchRL wrapper
155
+ batch_size: Number of weights to transfer in each batch (0 = no batching)
156
+ """
157
+ if not _has_transformers:
158
+ raise ImportError("transformers not available")
159
+
160
+ t0 = time.time()
161
+
162
+ # Extract state dict from model, handling LoRA models properly
163
+ if hasattr(transformers_model, "model") and hasattr(
164
+ transformers_model.model, "state_dict"
165
+ ):
166
+ # TorchRL wrapper (e.g., TransformersWrapper)
167
+ model = transformers_model.model
168
+ if hasattr(model, "merge_and_unload"):
169
+ state_dict = model.merge_and_unload().state_dict()
170
+ else:
171
+ state_dict = model.state_dict()
172
+ elif hasattr(transformers_model, "state_dict"):
173
+ # Direct transformers model
174
+ if hasattr(transformers_model, "merge_and_unload"):
175
+ state_dict = transformers_model.merge_and_unload().state_dict()
176
+ else:
177
+ state_dict = transformers_model.state_dict()
178
+ else:
179
+ raise TypeError(
180
+ f"Cannot extract state_dict from {type(transformers_model)}"
181
+ )
182
+
183
+ t1 = time.time()
184
+ torchrl_logger.debug(f"Time to extract state_dict: {t1 - t0:.3f}s")
185
+
186
+ # Pre-load all weights to GPU for faster transfer
187
+ gpu_weights = {}
188
+ with torch.device("cuda:0"): # Ensure we're using the right GPU
189
+ for name, weight in state_dict.items():
190
+ if not weight.is_cuda:
191
+ gpu_weights[name] = weight.cuda(non_blocking=True)
192
+ else:
193
+ gpu_weights[name] = weight
194
+
195
+ # Synchronize to ensure all transfers are complete
196
+ torch.cuda.synchronize()
197
+ t2 = time.time()
198
+ torchrl_logger.debug(f"Time to move weights to GPU: {t2 - t1:.3f}s")
199
+
200
+ # Transfer weights (optionally in batches)
201
+ if batch_size > 0:
202
+ weight_items = list(gpu_weights.items())
203
+ for i in range(0, len(weight_items), batch_size):
204
+ batch = weight_items[i : i + batch_size]
205
+ self.push_weights(iter(batch))
206
+ torchrl_logger.debug(
207
+ f"Transferred batch {i // batch_size + 1}/{(len(weight_items) + batch_size - 1) // batch_size}"
208
+ )
209
+ else:
210
+ # Transfer all at once
211
+ self.push_weights(iter(gpu_weights.items()))
212
+
213
+ t3 = time.time()
214
+ torchrl_logger.debug(
215
+ f"Time to push weights: {t3 - t2:.3f}s, total time: {t3 - t0:.3f}s"
216
+ )
217
+
218
+ # Required WeightUpdaterBase methods
219
+ def _sync_weights_with_worker(self, *, worker_id=None, server_weights=None):
220
+ """Sync weights with worker (delegates to push_weights)."""
221
+ if server_weights is None:
222
+ raise ValueError("server_weights cannot be None")
223
+
224
+ if hasattr(server_weights, "items"):
225
+ # Dict-like object
226
+ self.push_weights(iter(server_weights.items()))
227
+ else:
228
+ # Assume it's a model with state_dict
229
+ self.push_weights_from_transformers(server_weights)
230
+
231
+ def _get_server_weights(self):
232
+ """Not used - weights must be passed directly."""
233
+ return None
234
+
235
+ def _maybe_map_weights(self, server_weights):
236
+ """Map weights to expected format."""
237
+ return server_weights # No mapping needed, handled in push_weights methods
238
+
239
+ def all_worker_ids(self):
240
+ """Return list of worker IDs."""
241
+ return [0]
242
+
243
+ def register_collector(self, collector): # noqa: F821
244
+ """Register a collector and set up policy version increment post-hook.
245
+
246
+ Args:
247
+ collector: The collector to register (BaseCollector)
248
+ """
249
+ result = super().register_collector(collector)
250
+
251
+ # Only register the increment_version post-hook once for the first collector
252
+ # This avoids N^2 complexity where each weight update calls increment_version
253
+ # on all collectors N times (once per registered collector)
254
+ if len(self.post_hooks) == 0:
255
+ torchrl_logger.debug("Registering policy version increment post-hook")
256
+ self.register_post_hook(self._increment_all_collector_versions)
257
+
258
+ return result
259
+
260
+ def _increment_all_collector_versions(self):
261
+ """Increment version for all registered collectors efficiently."""
262
+ torchrl_logger.debug(
263
+ f"Incrementing policy version for {len(self.collectors)} collectors..."
264
+ )
265
+ for i, collector in enumerate(self.collectors):
266
+ try:
267
+ collector.increment_version()
268
+ torchrl_logger.debug(
269
+ f"Incremented version for collector {i + 1}/{len(self.collectors)}"
270
+ )
271
+ except Exception as e:
272
+ torchrl_logger.warning(
273
+ f"Failed to increment version for collector {i + 1}: {e}"
274
+ )
275
+ torchrl_logger.debug("All collector versions incremented")
276
+
277
+ @classmethod
278
+ def get_model_metadata(cls, model) -> dict[str, tuple[torch.dtype, torch.Size]]:
279
+ """Get model metadata from a model.
280
+
281
+ Args:
282
+ model: A model with state_dict() method (e.g., TransformersWrapper)
283
+
284
+ Returns:
285
+ dict: Mapping of parameter names to (dtype, shape) tuples
286
+ """
287
+ if hasattr(model, "model") and hasattr(model.model, "state_dict"):
288
+ # TorchRL wrapper (e.g., TransformersWrapper)
289
+ model_obj = model.model
290
+ # Check if it's a LoRA model
291
+ if hasattr(model_obj, "merge_and_unload"):
292
+ sd = model_obj.merge_and_unload().state_dict()
293
+ else:
294
+ sd = model_obj.state_dict()
295
+ elif hasattr(model, "state_dict"):
296
+ # Direct model
297
+ # Check if it's a LoRA model
298
+ if hasattr(model, "merge_and_unload"):
299
+ sd = model.merge_and_unload().state_dict()
300
+ else:
301
+ sd = model.state_dict()
302
+ else:
303
+ raise TypeError(f"Cannot extract state_dict from {type(model)}")
304
+
305
+ return {k: (v.dtype, v.shape) for k, v in sd.items()}
306
+
307
+ # Remove the weakrefs from the updater for serialization
308
+ def __getstate__(self):
309
+ state = self.__dict__.copy()
310
+ state["_collector_wrs"] = None
311
+ return state