torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,295 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from tensordict import NestedKey, TensorDictBase
5
+ from tensordict.nn import TensorDictModuleBase
6
+ from torch import nn, vmap
7
+
8
+ from torchrl._utils import logger, RL_WARNINGS
9
+ from torchrl.modules import MLP
10
+ from torchrl.objectives.value.advantages import _vmap_func
11
+
12
+ __all__ = [
13
+ "BiasModule",
14
+ "LSTMNet",
15
+ "NonSerializableBiasModule",
16
+ "call_value_nets",
17
+ ]
18
+
19
+
20
+ class BiasModule(nn.Module):
21
+ """Simple bias module to check weight synchronization correctness."""
22
+
23
+ def __init__(self, value: float = 0.0):
24
+ super().__init__()
25
+ self.bias = nn.Parameter(torch.tensor(value, dtype=torch.float))
26
+
27
+ def forward(self, x):
28
+ return x + self.bias
29
+
30
+
31
+ class NonSerializableBiasModule(BiasModule):
32
+ """Bias module that intentionally fails to serialize.
33
+
34
+ This is used in tests to simulate a policy that cannot be pickled.
35
+ """
36
+
37
+ def __getstate__(self):
38
+ # Simulate a non-serializable policy by raising on pickling
39
+ raise RuntimeError("NonSerializableBiasModule cannot be pickled")
40
+
41
+
42
+ class LSTMNet(nn.Module):
43
+ """An embedder for an LSTM preceded by an MLP.
44
+
45
+ The forward method returns the hidden states of the current state
46
+ (input hidden states) and the output, as
47
+ the environment returns the 'observation' and 'next_observation'.
48
+
49
+ Because the LSTM kernel only returns the last hidden state, hidden states
50
+ are padded with zeros such that they have the right size to be stored in a
51
+ TensorDict of size [batch x time_steps].
52
+
53
+ If a 2D tensor is provided as input, it is assumed that it is a batch of data
54
+ with only one time step. This means that we explicitly assume that users will
55
+ unsqueeze inputs of a single batch with multiple time steps.
56
+
57
+ Args:
58
+ out_features (int): number of output features.
59
+ lstm_kwargs (dict): the keyword arguments for the
60
+ :class:`~torch.nn.LSTM` layer.
61
+ mlp_kwargs (dict): the keyword arguments for the
62
+ :class:`~torchrl.modules.MLP` layer.
63
+ device (torch.device, optional): the device where the module should
64
+ be instantiated.
65
+
66
+ Keyword Args:
67
+ lstm_backend (str, optional): one of ``"torchrl"`` or ``"torch"`` that
68
+ indicates where the LSTM class is to be retrieved. The ``"torchrl"``
69
+ backend (:class:`~torchrl.modules.LSTM`) is slower but works with
70
+ :func:`~torch.vmap` and should work with :func:`~torch.compile`.
71
+ Defaults to ``"torch"``.
72
+
73
+ Examples:
74
+ >>> batch = 7
75
+ >>> time_steps = 6
76
+ >>> in_features = 4
77
+ >>> out_features = 10
78
+ >>> hidden_size = 5
79
+ >>> net = LSTMNet(
80
+ ... out_features,
81
+ ... {"input_size": hidden_size, "hidden_size": hidden_size},
82
+ ... {"out_features": hidden_size},
83
+ ... )
84
+ >>> # test single step vs multi-step
85
+ >>> x = torch.randn(batch, time_steps, in_features) # >3 dims = multi-step
86
+ >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x)
87
+ >>> x = torch.randn(batch, in_features) # 2 dims = single step
88
+ >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x)
89
+
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ out_features: int,
95
+ lstm_kwargs,
96
+ mlp_kwargs,
97
+ device=None,
98
+ *,
99
+ lstm_backend: str | None = None,
100
+ ) -> None:
101
+ super().__init__()
102
+ lstm_kwargs.update({"batch_first": True})
103
+ self.mlp = MLP(device=device, **mlp_kwargs)
104
+ if lstm_backend is None:
105
+ lstm_backend = "torch"
106
+ self.lstm_backend = lstm_backend
107
+ if self.lstm_backend == "torch":
108
+ LSTM = nn.LSTM
109
+ else:
110
+ from torchrl.modules.tensordict_module.rnn import LSTM
111
+ self.lstm = LSTM(device=device, **lstm_kwargs)
112
+ self.linear = nn.LazyLinear(out_features, device=device)
113
+
114
+ def _lstm(
115
+ self,
116
+ input: torch.Tensor,
117
+ hidden0_in: torch.Tensor | None = None,
118
+ hidden1_in: torch.Tensor | None = None,
119
+ ):
120
+ squeeze0 = False
121
+ squeeze1 = False
122
+ if input.ndimension() == 1:
123
+ squeeze0 = True
124
+ input = input.unsqueeze(0).contiguous()
125
+
126
+ if input.ndimension() == 2:
127
+ squeeze1 = True
128
+ input = input.unsqueeze(1).contiguous()
129
+ batch, steps = input.shape[:2]
130
+
131
+ if hidden1_in is None and hidden0_in is None:
132
+ shape = (batch, steps) if not squeeze1 else (batch,)
133
+ hidden0_in, hidden1_in = (
134
+ torch.zeros(
135
+ *shape,
136
+ self.lstm.num_layers,
137
+ self.lstm.hidden_size,
138
+ device=input.device,
139
+ dtype=input.dtype,
140
+ )
141
+ for _ in range(2)
142
+ )
143
+ elif hidden1_in is None or hidden0_in is None:
144
+ raise RuntimeError(
145
+ f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}"
146
+ )
147
+ elif squeeze0:
148
+ hidden0_in = hidden0_in.unsqueeze(0)
149
+ hidden1_in = hidden1_in.unsqueeze(0)
150
+
151
+ # we only need the first hidden state
152
+ if not squeeze1:
153
+ _hidden0_in = hidden0_in[:, 0]
154
+ _hidden1_in = hidden1_in[:, 0]
155
+ else:
156
+ _hidden0_in = hidden0_in
157
+ _hidden1_in = hidden1_in
158
+ hidden = (
159
+ _hidden0_in.transpose(-3, -2).contiguous(),
160
+ _hidden1_in.transpose(-3, -2).contiguous(),
161
+ )
162
+
163
+ y0, hidden = self.lstm(input, hidden)
164
+ # dim 0 in hidden is num_layers, but that will conflict with tensordict
165
+ hidden = tuple(_h.transpose(0, 1) for _h in hidden)
166
+ y = self.linear(y0)
167
+
168
+ out = [y, hidden0_in, hidden1_in, *hidden]
169
+ if squeeze1:
170
+ # squeezes time
171
+ out[0] = out[0].squeeze(1)
172
+ if not squeeze1:
173
+ # we pad the hidden states with zero to make tensordict happy
174
+ for i in range(3, 5):
175
+ out[i] = torch.stack(
176
+ [torch.zeros_like(out[i]) for _ in range(input.shape[1] - 1)]
177
+ + [out[i]],
178
+ 1,
179
+ )
180
+ if squeeze0:
181
+ out = [_out.squeeze(0) for _out in out]
182
+ return tuple(out)
183
+
184
+ def forward(
185
+ self,
186
+ input: torch.Tensor,
187
+ hidden0_in: torch.Tensor | None = None,
188
+ hidden1_in: torch.Tensor | None = None,
189
+ ):
190
+ input = self.mlp(input)
191
+ return self._lstm(input, hidden0_in, hidden1_in)
192
+
193
+
194
+ def call_value_nets(
195
+ value_net: TensorDictModuleBase,
196
+ data: TensorDictBase,
197
+ params: TensorDictBase,
198
+ next_params: TensorDictBase,
199
+ single_call: bool,
200
+ value_key: NestedKey,
201
+ detach_next: bool,
202
+ vmap_randomness: str = "error",
203
+ ):
204
+ """Call value networks to compute values at t and t+1.
205
+
206
+ This is a testing utility for computing value estimates in advantage
207
+ calculations.
208
+
209
+ Args:
210
+ value_net: The value network module.
211
+ data: Input tensordict with observations.
212
+ params: Parameters for the value network at time t.
213
+ next_params: Parameters for the value network at time t+1.
214
+ single_call: Whether to use a single forward pass for both t and t+1.
215
+ value_key: The key where values are stored.
216
+ detach_next: Whether to detach the next value from the computation graph.
217
+ vmap_randomness: Randomness mode for vmap.
218
+
219
+ Returns:
220
+ Tuple of (value, value_next).
221
+ """
222
+ in_keys = value_net.in_keys
223
+ if single_call:
224
+ for i, name in enumerate(data.names):
225
+ if name == "time":
226
+ ndim = i + 1
227
+ break
228
+ else:
229
+ ndim = None
230
+ if ndim is not None:
231
+ # get data at t and last of t+1
232
+ idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),)
233
+ idx = (slice(None),) * (ndim - 1) + (slice(None, -1),)
234
+ idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),)
235
+ data_in = torch.cat(
236
+ [
237
+ data.select(*in_keys, value_key, strict=False),
238
+ data.get("next").select(*in_keys, value_key, strict=False)[idx0],
239
+ ],
240
+ ndim - 1,
241
+ )
242
+ else:
243
+ if RL_WARNINGS:
244
+ logger.warning(
245
+ "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. "
246
+ "This warning can be turned off by setting the environment variable RL_WARNINGS to False."
247
+ )
248
+ ndim = data.ndim
249
+ idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),)
250
+ idx_ = (slice(None),) * (ndim - 1) + (slice(data.shape[ndim - 1], None),)
251
+ data_in = torch.cat(
252
+ [
253
+ data.select(*in_keys, value_key, strict=False),
254
+ data.get("next").select(*in_keys, value_key, strict=False),
255
+ ],
256
+ ndim - 1,
257
+ )
258
+
259
+ # next_params should be None or be identical to params
260
+ if next_params is not None and next_params is not params:
261
+ raise ValueError(
262
+ "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed."
263
+ )
264
+ if params is not None:
265
+ with params.to_module(value_net):
266
+ value_est = value_net(data_in).get(value_key)
267
+ else:
268
+ value_est = value_net(data_in).get(value_key)
269
+ value, value_ = value_est[idx], value_est[idx_]
270
+ else:
271
+ data_in = torch.stack(
272
+ [
273
+ data.select(*in_keys, value_key, strict=False),
274
+ data.get("next").select(*in_keys, value_key, strict=False),
275
+ ],
276
+ 0,
277
+ )
278
+ if (params is not None) ^ (next_params is not None):
279
+ raise ValueError(
280
+ "params and next_params must be either both provided or not."
281
+ )
282
+ elif params is not None:
283
+ params_stack = torch.stack([params, next_params], 0).contiguous()
284
+ data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)(
285
+ data_in, params_stack
286
+ )
287
+ else:
288
+ data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in)
289
+ value_est = data_out.get(value_key)
290
+ value, value_ = value_est[0], value_est[1]
291
+ data.set(value_key, value)
292
+ data.set(("next", value_key), value_)
293
+ if detach_next:
294
+ value_ = value_.detach()
295
+ return value, value_
@@ -0,0 +1,15 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+ from torchrl.data.utils import CloudpickleWrapper
6
+
7
+
8
+ def decorate_thread_sub_func(func, num_threads):
9
+ """Decorate a function to assert that the number of threads is correct."""
10
+
11
+ def new_func(*args, **kwargs):
12
+ assert torch.get_num_threads() == num_threads
13
+ return func(*args, **kwargs)
14
+
15
+ return CloudpickleWrapper(new_func)
@@ -0,0 +1,293 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Helper classes for Ray-based weight synchronization tests.
7
+
8
+ This module contains Ray actor classes that need to be importable by Ray workers.
9
+ These classes are used in tests but must be defined at module level in a proper
10
+ Python package (not in test files) so Ray can serialize and import them on remote workers.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import torch
15
+ from torchrl._utils import logger
16
+
17
+
18
+ class WorkerVLLMNCCL:
19
+ """Ray actor for vLLM inference worker (receiver) using NCCL collective communication."""
20
+
21
+ def __init__(
22
+ self,
23
+ scheme_config: dict,
24
+ model_name: str = "Qwen/Qwen2.5-0.5B",
25
+ trainer_actor_name: str = "Trainer",
26
+ ):
27
+ pass
28
+
29
+ # Store config for deferred initialization
30
+ self.scheme_config = scheme_config
31
+ self.model_name = model_name
32
+ self.trainer_actor_name = trainer_actor_name
33
+ self.wrapper = None
34
+ self.engine = None
35
+ self.receiver = None
36
+ self.scheme = None
37
+ self.trainer = None
38
+ self.model_metadata = None
39
+
40
+ def setup(self):
41
+ """Set up vLLM engine (deferred from __init__ to avoid blocking)."""
42
+ from torchrl.modules.llm.backends import AsyncVLLM
43
+ from torchrl.modules.llm.policies import vLLMWrapper
44
+
45
+ # Create vLLM wrapper
46
+ async_engine = AsyncVLLM.from_pretrained(
47
+ self.model_name,
48
+ num_replicas=2, # Number of engine replicas
49
+ )
50
+ self.wrapper = vLLMWrapper(async_engine, input_mode="history")
51
+ self.engine = self.wrapper.model
52
+
53
+ # Create scheme from config
54
+ from torchrl.weight_update.llm.vllm_nccl import VLLMWeightSyncScheme
55
+
56
+ self.scheme = VLLMWeightSyncScheme(**self.scheme_config)
57
+
58
+ # Create receiver (engine handles rank assignment automatically)
59
+ self.receiver = self.scheme.create_receiver(self.engine)
60
+ return "setup_complete"
61
+
62
+ def init_metadata(self):
63
+ """Initialize the receiver by fetching metadata from trainer."""
64
+ import ray
65
+
66
+ if self.receiver is None:
67
+ raise RuntimeError("Must call setup() before init()")
68
+
69
+ # Get trainer actor by name
70
+ logger.info(f"Getting trainer actor by name {self.trainer_actor_name}")
71
+ self.trainer = ray.get_actor(self.trainer_actor_name)
72
+
73
+ # Fetch model metadata from trainer
74
+ logger.info("Fetching model metadata from trainer (requires max_concurrency>1)")
75
+ self.model_metadata = ray.get(self.trainer.get_model_metadata.remote())
76
+
77
+ def init(self):
78
+ if self.model_metadata is None:
79
+ raise RuntimeError("Must call init_metadata() before init()")
80
+
81
+ # Initialize receiver with metadata
82
+ logger.info("Initializing receiver...")
83
+ self.receiver.init_all_workers_group(self.model_metadata)
84
+ self.initialized = True
85
+ logger.info("Receiver initialized")
86
+ return "initialized"
87
+
88
+ def get_engine(self):
89
+ """Get the vLLM engine reference for RPC coordination."""
90
+ if self.engine is None:
91
+ raise RuntimeError("Must call setup() first")
92
+ return self.engine
93
+
94
+ def get_sample_output(self):
95
+ """Get a sample output to verify model works."""
96
+ # Simple inference test
97
+ return "vllm_ready"
98
+
99
+ @classmethod
100
+ def as_remote(cls, *args, **kwargs):
101
+ import ray
102
+
103
+ # No GPUs needed for the actor itself - vLLM workers manage their own placement group (2 GPUs)
104
+ # AsyncVLLM service doesn't act as NCCL rank 0 when used with external trainer
105
+ return ray.remote(num_cpus=4, num_gpus=0, max_concurrency=4)(cls)
106
+
107
+
108
+ class WorkerTransformerNCCL:
109
+ """Ray actor for transformer trainer (sender) using NCCL collective communication."""
110
+
111
+ def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"):
112
+ from torchrl.weight_update.llm.vllm_nccl import (
113
+ get_model_metadata,
114
+ VLLMWeightSyncScheme,
115
+ )
116
+ from transformers import AutoModelForCausalLM
117
+
118
+ # Create transformer model
119
+ transformer = AutoModelForCausalLM.from_pretrained(
120
+ model_name,
121
+ dtype=torch.float16,
122
+ )
123
+ self.transformer = transformer.cuda()
124
+
125
+ # Create scheme from config
126
+ self.scheme = VLLMWeightSyncScheme(**scheme_config)
127
+
128
+ # Create sender
129
+ self.sender = self.scheme.create_sender()
130
+ self.sender.register_model(self.transformer)
131
+
132
+ # Extract and store model metadata
133
+ self.model_metadata = get_model_metadata(self.transformer)
134
+
135
+ def init(self, vllm_engine=None):
136
+ """Initialize sender with optional vLLM engine for RPC coordination.
137
+
138
+ Args:
139
+ vllm_engine: Optional vLLM engine reference for calling collective_rpc
140
+ """
141
+ if self.model_metadata is None:
142
+ raise RuntimeError("Must call init_metadata() before init()")
143
+
144
+ self.sender.init_all_workers_group(self.model_metadata, vllm_engine=vllm_engine)
145
+ self.initialized = True
146
+ logger.info("Trainer initialized")
147
+ return "initialized"
148
+
149
+ def get_model_metadata(self):
150
+ """Get model metadata to share with receiver."""
151
+ return self.model_metadata
152
+
153
+ def update_weights(self, modify_weights: bool = False):
154
+ """Trigger a weight update broadcast.
155
+
156
+ Args:
157
+ modify_weights: If True, modifies weights before broadcasting
158
+ for verification purposes.
159
+
160
+ Returns:
161
+ str: "updated" status message
162
+ """
163
+ # Optionally modify weights for testing
164
+ if modify_weights:
165
+ with torch.no_grad():
166
+ first_param = next(self.transformer.parameters())
167
+ first_param.add_(0.01)
168
+
169
+ # Broadcast weights to all vLLM workers
170
+ self.sender.update_weights()
171
+ return "updated"
172
+
173
+ def get_first_param_sum(self):
174
+ """Get sum of first parameter for verification."""
175
+ return next(self.transformer.parameters()).sum().item()
176
+
177
+ @classmethod
178
+ def as_remote(cls, *args, **kwargs):
179
+ import ray
180
+
181
+ return ray.remote(num_cpus=4, num_gpus=1, max_concurrency=4)(cls)
182
+
183
+
184
+ class WorkerVLLMDoubleBuffer:
185
+ """Ray actor for vLLM inference worker (receiver) using double-buffered storage."""
186
+
187
+ def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"):
188
+ # Store config for deferred initialization
189
+ self.scheme_config = scheme_config
190
+ self.model_name = model_name
191
+ self.wrapper = None
192
+ self.engine = None
193
+ self.receiver = None
194
+ self.scheme = None
195
+
196
+ def setup(self):
197
+ """Set up vLLM engine and receiver."""
198
+ from torchrl.modules.llm.backends import AsyncVLLM
199
+ from torchrl.modules.llm.policies import vLLMWrapper
200
+
201
+ # Create vLLM wrapper
202
+ async_engine = AsyncVLLM.from_pretrained(
203
+ self.model_name,
204
+ num_replicas=1, # Single replica for simplicity
205
+ )
206
+ self.wrapper = vLLMWrapper(async_engine, input_mode="history")
207
+ self.engine = self.wrapper.model
208
+
209
+ # Create scheme from config
210
+ from torchrl.weight_update.llm.vllm_double_buffer import (
211
+ VLLMDoubleBufferSyncScheme,
212
+ )
213
+
214
+ self.scheme = VLLMDoubleBufferSyncScheme(**self.scheme_config)
215
+
216
+ # Create receiver
217
+ self.receiver = self.scheme.create_receiver(self.engine)
218
+ logger.info("Receiver setup complete")
219
+ return "setup_complete"
220
+
221
+ def poll_and_apply_weights(self):
222
+ """Poll for new weights and apply them to the engine."""
223
+ if self.receiver is None:
224
+ raise RuntimeError("Must call setup() first")
225
+
226
+ success = self.receiver.poll_and_apply()
227
+ return success
228
+
229
+ def get_sample_output(self):
230
+ """Get a sample output to verify model works."""
231
+ return "vllm_ready"
232
+
233
+ @classmethod
234
+ def as_remote(cls, *args, **kwargs):
235
+ import ray
236
+
237
+ # vLLM worker needs 1 GPU
238
+ return ray.remote(num_cpus=2, num_gpus=1, max_concurrency=4)(cls)
239
+
240
+
241
+ class WorkerTransformerDoubleBuffer:
242
+ """Ray actor for transformer trainer (sender) using double-buffered storage."""
243
+
244
+ def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"):
245
+ from torchrl.weight_update.llm.vllm_double_buffer import (
246
+ VLLMDoubleBufferSyncScheme,
247
+ )
248
+ from transformers import AutoModelForCausalLM
249
+
250
+ # Create transformer model
251
+ transformer = AutoModelForCausalLM.from_pretrained(
252
+ model_name,
253
+ dtype=torch.float16,
254
+ )
255
+ self.transformer = transformer.cuda()
256
+
257
+ # Create scheme from config
258
+ self.scheme = VLLMDoubleBufferSyncScheme(**scheme_config)
259
+
260
+ # Create sender
261
+ self.sender = self.scheme.create_sender()
262
+ self.sender.register_model(self.transformer)
263
+ logger.info("Trainer setup complete")
264
+
265
+ def update_weights(self, modify_weights: bool = False):
266
+ """Trigger a weight update by writing to shared storage.
267
+
268
+ Args:
269
+ modify_weights: If True, modifies weights before writing
270
+ for verification purposes.
271
+
272
+ Returns:
273
+ str: "updated" status message
274
+ """
275
+ # Optionally modify weights for testing
276
+ if modify_weights:
277
+ with torch.no_grad():
278
+ first_param = next(self.transformer.parameters())
279
+ first_param.add_(0.01)
280
+
281
+ # Write weights to shared storage
282
+ self.sender.update_weights()
283
+ return "updated"
284
+
285
+ def get_first_param_sum(self):
286
+ """Get sum of first parameter for verification."""
287
+ return next(self.transformer.parameters()).sum().item()
288
+
289
+ @classmethod
290
+ def as_remote(cls, *args, **kwargs):
291
+ import ray
292
+
293
+ return ray.remote(num_cpus=2, num_gpus=1, max_concurrency=4)(cls)