torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,291 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import warnings
8
+
9
+ import hydra
10
+ import torch
11
+
12
+ torch.set_float32_matmul_precision("high")
13
+
14
+
15
+ @hydra.main(config_path="", config_name="config_atari", version_base="1.1")
16
+ def main(cfg: DictConfig): # noqa: F821
17
+
18
+ from copy import deepcopy
19
+
20
+ import torch.optim
21
+ import tqdm
22
+ from tensordict import from_module
23
+ from tensordict.nn import CudaGraphModule
24
+
25
+ from torchrl._utils import get_available_device, timeit
26
+ from torchrl.collectors import SyncDataCollector
27
+ from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
28
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
29
+ from torchrl.envs import ExplorationType, set_exploration_type
30
+ from torchrl.objectives import A2CLoss
31
+ from torchrl.objectives.value.advantages import GAE
32
+ from torchrl.record import VideoRecorder
33
+ from torchrl.record.loggers import generate_exp_name, get_logger
34
+ from utils_atari import eval_model, make_parallel_env, make_ppo_models
35
+
36
+ device = (
37
+ torch.device(cfg.loss.device) if cfg.loss.device else get_available_device()
38
+ )
39
+
40
+ # Correct for frame_skip
41
+ frame_skip = 4
42
+ total_frames = cfg.collector.total_frames // frame_skip
43
+ frames_per_batch = cfg.collector.frames_per_batch // frame_skip
44
+ mini_batch_size = cfg.loss.mini_batch_size // frame_skip
45
+ test_interval = cfg.logger.test_interval // frame_skip
46
+
47
+ # Create models (check utils_atari.py)
48
+ actor, critic, critic_head = make_ppo_models(
49
+ cfg.env.env_name, device=device, gym_backend=cfg.env.backend
50
+ )
51
+ with from_module(actor).data.to("meta").to_module(actor):
52
+ actor_eval = deepcopy(actor)
53
+ actor_eval.eval()
54
+ from_module(actor).data.to_module(actor_eval)
55
+
56
+ # Create data buffer
57
+ sampler = SamplerWithoutReplacement()
58
+ data_buffer = TensorDictReplayBuffer(
59
+ storage=LazyTensorStorage(frames_per_batch, device=device),
60
+ sampler=sampler,
61
+ batch_size=mini_batch_size,
62
+ )
63
+
64
+ # Create loss and adv modules
65
+ adv_module = GAE(
66
+ gamma=cfg.loss.gamma,
67
+ lmbda=cfg.loss.gae_lambda,
68
+ value_network=critic,
69
+ average_gae=True,
70
+ vectorized=not cfg.compile.compile,
71
+ device=device,
72
+ )
73
+ loss_module = A2CLoss(
74
+ actor_network=actor,
75
+ critic_network=critic,
76
+ loss_critic_type=cfg.loss.loss_critic_type,
77
+ entropy_coeff=cfg.loss.entropy_coeff,
78
+ critic_coeff=cfg.loss.critic_coeff,
79
+ )
80
+
81
+ # use end-of-life as done key
82
+ adv_module.set_keys(done="end-of-life", terminated="end-of-life")
83
+ loss_module.set_keys(done="end-of-life", terminated="end-of-life")
84
+
85
+ # Create optimizer
86
+ optim = torch.optim.Adam(
87
+ loss_module.parameters(),
88
+ lr=torch.tensor(cfg.optim.lr, device=device),
89
+ weight_decay=cfg.optim.weight_decay,
90
+ eps=cfg.optim.eps,
91
+ capturable=device.type == "cuda",
92
+ )
93
+
94
+ # Create logger
95
+ logger = None
96
+ if cfg.logger.backend:
97
+ exp_name = generate_exp_name("A2C", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
98
+ logger = get_logger(
99
+ cfg.logger.backend,
100
+ logger_name="a2c",
101
+ experiment_name=exp_name,
102
+ wandb_kwargs={
103
+ "config": dict(cfg),
104
+ "project": cfg.logger.project_name,
105
+ "group": cfg.logger.group_name,
106
+ },
107
+ )
108
+
109
+ # Create test environment
110
+ test_env = make_parallel_env(
111
+ cfg.env.env_name,
112
+ num_envs=1,
113
+ device=device,
114
+ gym_backend=cfg.env.backend,
115
+ is_test=True,
116
+ )
117
+ test_env.set_seed(0)
118
+ if cfg.logger.video:
119
+ test_env = test_env.insert_transform(
120
+ 0,
121
+ VideoRecorder(
122
+ logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"]
123
+ ),
124
+ )
125
+ test_env.eval()
126
+
127
+ # update function
128
+ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
129
+ # Forward pass A2C loss
130
+ loss = loss_module(batch)
131
+
132
+ loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
133
+
134
+ # Backward pass
135
+ loss_sum.backward()
136
+ gn = torch.nn.utils.clip_grad_norm_(
137
+ loss_module.parameters(), max_norm=max_grad_norm
138
+ )
139
+
140
+ # Update the networks
141
+ optim.step()
142
+ optim.zero_grad(set_to_none=True)
143
+
144
+ return (
145
+ loss.select("loss_critic", "loss_entropy", "loss_objective")
146
+ .detach()
147
+ .set("grad_norm", gn)
148
+ )
149
+
150
+ compile_mode = None
151
+ if cfg.compile.compile:
152
+ compile_mode = cfg.compile.compile_mode
153
+ if compile_mode in ("", None):
154
+ if cfg.compile.cudagraphs:
155
+ compile_mode = "default"
156
+ else:
157
+ compile_mode = "reduce-overhead"
158
+ update = torch.compile(update, mode=compile_mode)
159
+ adv_module = torch.compile(adv_module, mode=compile_mode)
160
+
161
+ if cfg.compile.cudagraphs:
162
+ warnings.warn(
163
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
164
+ category=UserWarning,
165
+ )
166
+ update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
167
+ adv_module = CudaGraphModule(adv_module)
168
+
169
+ # Create collector
170
+ collector = SyncDataCollector(
171
+ create_env_fn=make_parallel_env(
172
+ cfg.env.env_name,
173
+ num_envs=cfg.env.num_envs,
174
+ device=device,
175
+ gym_backend=cfg.env.backend,
176
+ ),
177
+ policy=actor,
178
+ frames_per_batch=frames_per_batch,
179
+ total_frames=total_frames,
180
+ device=device,
181
+ storing_device=device,
182
+ policy_device=device,
183
+ compile_policy={"mode": compile_mode} if cfg.compile.compile else False,
184
+ cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
185
+ )
186
+
187
+ # Main loop
188
+ collected_frames = 0
189
+ num_network_updates = 0
190
+ pbar = tqdm.tqdm(total=total_frames)
191
+ num_mini_batches = frames_per_batch // mini_batch_size
192
+ total_network_updates = (total_frames // frames_per_batch) * num_mini_batches
193
+ lr = cfg.optim.lr
194
+
195
+ c_iter = iter(collector)
196
+ total_iter = len(collector)
197
+ for i in range(total_iter):
198
+ timeit.printevery(1000, total_iter, erase=True)
199
+
200
+ with timeit("collecting"):
201
+ data = next(c_iter)
202
+
203
+ metrics_to_log = {}
204
+ frames_in_batch = data.numel()
205
+ collected_frames += frames_in_batch * frame_skip
206
+ pbar.update(data.numel())
207
+
208
+ # Get training rewards and lengths
209
+ episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
210
+ if len(episode_rewards) > 0:
211
+ episode_length = data["next", "step_count"][data["next", "terminated"]]
212
+ metrics_to_log.update(
213
+ {
214
+ "train/reward": episode_rewards.mean().item(),
215
+ "train/episode_length": episode_length.sum().item()
216
+ / len(episode_length),
217
+ }
218
+ )
219
+
220
+ losses = []
221
+
222
+ # Compute GAE
223
+ with torch.no_grad(), timeit("advantage"):
224
+ torch.compiler.cudagraph_mark_step_begin()
225
+ data = adv_module(data)
226
+ data_reshape = data.reshape(-1)
227
+
228
+ # Update the data buffer
229
+ with timeit("rb - emptying"):
230
+ data_buffer.empty()
231
+ with timeit("rb - extending"):
232
+ data_buffer.extend(data_reshape)
233
+
234
+ with timeit("optim"):
235
+ for batch in data_buffer:
236
+
237
+ # Linearly decrease the learning rate and clip epsilon
238
+ with timeit("optim - lr"):
239
+ alpha = 1.0
240
+ if cfg.optim.anneal_lr:
241
+ alpha = 1 - (num_network_updates / total_network_updates)
242
+ for group in optim.param_groups:
243
+ group["lr"].copy_(lr * alpha)
244
+
245
+ num_network_updates += 1
246
+
247
+ with timeit("update"):
248
+ torch.compiler.cudagraph_mark_step_begin()
249
+ loss = update(batch).clone()
250
+ losses.append(loss)
251
+
252
+ # Get training losses
253
+ losses = torch.stack(losses).float().mean()
254
+
255
+ for key, value in losses.items():
256
+ metrics_to_log.update({f"train/{key}": value.item()})
257
+ metrics_to_log.update(
258
+ {
259
+ "train/lr": lr * alpha,
260
+ }
261
+ )
262
+
263
+ # Get test rewards
264
+ with torch.no_grad(), set_exploration_type(
265
+ ExplorationType.DETERMINISTIC
266
+ ), timeit("eval"):
267
+ if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
268
+ i * frames_in_batch * frame_skip
269
+ ) // test_interval:
270
+ test_rewards = eval_model(
271
+ actor_eval, test_env, num_episodes=cfg.logger.num_test_episodes
272
+ )
273
+ metrics_to_log.update(
274
+ {
275
+ "test/reward": test_rewards.mean(),
276
+ }
277
+ )
278
+
279
+ if logger:
280
+ metrics_to_log.update(timeit.todict(prefix="time"))
281
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
282
+ for key, value in metrics_to_log.items():
283
+ logger.log_scalar(key, value, collected_frames)
284
+
285
+ collector.shutdown()
286
+ if not test_env.is_closed:
287
+ test_env.close()
288
+
289
+
290
+ if __name__ == "__main__":
291
+ main()
@@ -0,0 +1,273 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import warnings
8
+
9
+ import hydra
10
+ import torch
11
+
12
+ torch.set_float32_matmul_precision("high")
13
+
14
+
15
+ @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1")
16
+ def main(cfg: DictConfig): # noqa: F821
17
+
18
+ from copy import deepcopy
19
+
20
+ import torch.optim
21
+ import tqdm
22
+
23
+ from tensordict import from_module
24
+ from tensordict.nn import CudaGraphModule
25
+
26
+ from torchrl._utils import get_available_device, timeit
27
+ from torchrl.collectors import SyncDataCollector
28
+ from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
29
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
30
+ from torchrl.envs import ExplorationType, set_exploration_type
31
+ from torchrl.objectives import A2CLoss, group_optimizers
32
+ from torchrl.objectives.value import GAE
33
+ from torchrl.record import VideoRecorder
34
+ from torchrl.record.loggers import generate_exp_name, get_logger
35
+ from utils_mujoco import eval_model, make_env, make_ppo_models
36
+
37
+ # Define paper hyperparameters
38
+
39
+ device = (
40
+ torch.device(cfg.loss.device) if cfg.loss.device else get_available_device()
41
+ )
42
+
43
+ num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size
44
+ total_network_updates = (
45
+ cfg.collector.total_frames // cfg.collector.frames_per_batch
46
+ ) * num_mini_batches
47
+
48
+ # Create models (check utils_mujoco.py)
49
+ actor, critic = make_ppo_models(
50
+ cfg.env.env_name, device=device, compile=cfg.compile.compile
51
+ )
52
+ with from_module(actor).data.to("meta").to_module(actor):
53
+ actor_eval = deepcopy(actor)
54
+ actor_eval.eval()
55
+ from_module(actor).data.to_module(actor_eval)
56
+
57
+ # Create data buffer
58
+ sampler = SamplerWithoutReplacement()
59
+ data_buffer = TensorDictReplayBuffer(
60
+ storage=LazyTensorStorage(cfg.collector.frames_per_batch, device=device),
61
+ sampler=sampler,
62
+ batch_size=cfg.loss.mini_batch_size,
63
+ )
64
+
65
+ # Create loss and adv modules
66
+ adv_module = GAE(
67
+ gamma=cfg.loss.gamma,
68
+ lmbda=cfg.loss.gae_lambda,
69
+ value_network=critic,
70
+ average_gae=False,
71
+ vectorized=not cfg.compile.compile,
72
+ device=device,
73
+ )
74
+ loss_module = A2CLoss(
75
+ actor_network=actor,
76
+ critic_network=critic,
77
+ loss_critic_type=cfg.loss.loss_critic_type,
78
+ entropy_coeff=cfg.loss.entropy_coeff,
79
+ critic_coeff=cfg.loss.critic_coeff,
80
+ )
81
+
82
+ # Create optimizers
83
+ actor_optim = torch.optim.Adam(
84
+ actor.parameters(),
85
+ lr=torch.tensor(cfg.optim.lr, device=device),
86
+ capturable=device.type == "cuda",
87
+ )
88
+ critic_optim = torch.optim.Adam(
89
+ critic.parameters(),
90
+ lr=torch.tensor(cfg.optim.lr, device=device),
91
+ capturable=device.type == "cuda",
92
+ )
93
+ optim = group_optimizers(actor_optim, critic_optim)
94
+ del actor_optim, critic_optim
95
+
96
+ # Create logger
97
+ logger = None
98
+ if cfg.logger.backend:
99
+ exp_name = generate_exp_name("A2C", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
100
+ logger = get_logger(
101
+ cfg.logger.backend,
102
+ logger_name="a2c",
103
+ experiment_name=exp_name,
104
+ wandb_kwargs={
105
+ "config": dict(cfg),
106
+ "project": cfg.logger.project_name,
107
+ "group": cfg.logger.group_name,
108
+ },
109
+ )
110
+
111
+ # Create test environment
112
+ test_env = make_env(cfg.env.env_name, device, from_pixels=cfg.logger.video)
113
+ test_env.set_seed(0)
114
+ if cfg.logger.video:
115
+ test_env = test_env.insert_transform(
116
+ 0,
117
+ VideoRecorder(
118
+ logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"]
119
+ ),
120
+ )
121
+
122
+ def update(batch):
123
+ # Forward pass A2C loss
124
+ loss = loss_module(batch)
125
+ critic_loss = loss["loss_critic"]
126
+ actor_loss = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
127
+
128
+ # Backward pass
129
+ (actor_loss + critic_loss).backward()
130
+
131
+ # Update the networks
132
+ optim.step()
133
+
134
+ optim.zero_grad(set_to_none=True)
135
+ return loss.select("loss_critic", "loss_objective").detach() # , "loss_entropy"
136
+
137
+ compile_mode = None
138
+ if cfg.compile.compile:
139
+ compile_mode = cfg.compile.compile_mode
140
+ if compile_mode in ("", None):
141
+ if cfg.compile.cudagraphs:
142
+ compile_mode = "default"
143
+ else:
144
+ compile_mode = "reduce-overhead"
145
+
146
+ update = torch.compile(update, mode=compile_mode)
147
+ adv_module = torch.compile(adv_module, mode=compile_mode)
148
+
149
+ if cfg.compile.cudagraphs:
150
+ warnings.warn(
151
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
152
+ category=UserWarning,
153
+ )
154
+ update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=20)
155
+ adv_module = CudaGraphModule(adv_module, warmup=20)
156
+
157
+ # Create collector
158
+ collector = SyncDataCollector(
159
+ create_env_fn=make_env(cfg.env.env_name, device),
160
+ policy=actor,
161
+ frames_per_batch=cfg.collector.frames_per_batch,
162
+ total_frames=cfg.collector.total_frames,
163
+ device=device,
164
+ storing_device=device,
165
+ max_frames_per_traj=-1,
166
+ trust_policy=True,
167
+ compile_policy={"mode": compile_mode} if compile_mode is not None else False,
168
+ cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
169
+ )
170
+
171
+ test_env.eval()
172
+ lr = cfg.optim.lr
173
+
174
+ # Main loop
175
+ collected_frames = 0
176
+ num_network_updates = 0
177
+ pbar = tqdm.tqdm(total=cfg.collector.total_frames)
178
+
179
+ c_iter = iter(collector)
180
+ total_iter = len(collector)
181
+ for i in range(total_iter):
182
+ timeit.printevery(1000, total_iter, erase=True)
183
+
184
+ with timeit("collecting"):
185
+ data = next(c_iter)
186
+
187
+ metrics_to_log = {}
188
+ frames_in_batch = data.numel()
189
+ collected_frames += frames_in_batch
190
+ pbar.update(data.numel())
191
+
192
+ # Get training rewards and lengths
193
+ episode_rewards = data["next", "episode_reward"][data["next", "done"]]
194
+ if len(episode_rewards) > 0:
195
+ episode_length = data["next", "step_count"][data["next", "done"]]
196
+ metrics_to_log.update(
197
+ {
198
+ "train/reward": episode_rewards.mean().item(),
199
+ "train/episode_length": episode_length.sum().item()
200
+ / len(episode_length),
201
+ }
202
+ )
203
+
204
+ losses = []
205
+
206
+ # Compute GAE
207
+ with torch.no_grad(), timeit("advantage"):
208
+ torch.compiler.cudagraph_mark_step_begin()
209
+ data = adv_module(data)
210
+ data_reshape = data.reshape(-1)
211
+
212
+ # Update the data buffer
213
+ with timeit("emptying"):
214
+ data_buffer.empty()
215
+ with timeit("extending"):
216
+ data_buffer.extend(data_reshape)
217
+
218
+ with timeit("optim"):
219
+ for batch in data_buffer:
220
+
221
+ # Linearly decrease the learning rate and clip epsilon
222
+ with timeit("optim - lr"):
223
+ alpha = 1.0
224
+ if cfg.optim.anneal_lr:
225
+ alpha = 1 - (num_network_updates / total_network_updates)
226
+ for group in optim.param_groups:
227
+ group["lr"].copy_(lr * alpha)
228
+ num_network_updates += 1
229
+ with timeit("optim - update"):
230
+ torch.compiler.cudagraph_mark_step_begin()
231
+ loss = update(batch).clone()
232
+ losses.append(loss)
233
+
234
+ # Get training losses
235
+ losses = torch.stack(losses).float().mean()
236
+ for key, value in losses.items():
237
+ metrics_to_log.update({f"train/{key}": value.item()})
238
+ metrics_to_log.update(
239
+ {
240
+ "train/lr": alpha * cfg.optim.lr,
241
+ }
242
+ )
243
+
244
+ # Get test rewards
245
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
246
+ prev_test_frame = ((i - 1) * frames_in_batch) // cfg.logger.test_interval
247
+ cur_test_frame = (i * frames_in_batch) // cfg.logger.test_interval
248
+ final = collected_frames >= collector.total_frames
249
+ if prev_test_frame < cur_test_frame or final:
250
+ actor.eval()
251
+ test_rewards = eval_model(
252
+ actor, test_env, num_episodes=cfg.logger.num_test_episodes
253
+ )
254
+ metrics_to_log.update(
255
+ {
256
+ "test/reward": test_rewards.mean(),
257
+ }
258
+ )
259
+ actor.train()
260
+
261
+ if logger:
262
+ metrics_to_log.update(timeit.todict(prefix="time"))
263
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
264
+ for key, value in metrics_to_log.items():
265
+ logger.log_scalar(key, value, collected_frames)
266
+
267
+ collector.shutdown()
268
+ if not test_env.is_closed:
269
+ test_env.close()
270
+
271
+
272
+ if __name__ == "__main__":
273
+ main()