torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,266 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """Async SAC Example.
6
+
7
+ WARNING: This isn't a SOTA implementation but a rudimentary implementation of SAC where inference
8
+ and training are entirely decoupled. It can achieve a 20x speedup if compile and cudagraph are used.
9
+ Two GPUs are required for this script to run.
10
+ The API is currently being perfected, and contributions are welcome (as usual!) - see the TODOs in this script.
11
+
12
+ This is a simple self-contained example of a SAC training script.
13
+
14
+ It supports state environments like MuJoCo.
15
+
16
+ The helper functions are coded in the utils.py associated with this script.
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import time
21
+
22
+ import warnings
23
+ from functools import partial
24
+
25
+ import hydra
26
+ import numpy as np
27
+ import tensordict
28
+ import torch
29
+ import torch.cuda
30
+ import tqdm
31
+ from tensordict import TensorDict
32
+ from tensordict.nn import CudaGraphModule
33
+ from torchrl._utils import (
34
+ compile_with_warmup,
35
+ get_available_device,
36
+ logger as torchrl_logger,
37
+ timeit,
38
+ )
39
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
40
+ from torchrl.objectives import group_optimizers
41
+ from torchrl.record.loggers import generate_exp_name, get_logger
42
+ from utils import (
43
+ dump_video,
44
+ log_metrics,
45
+ make_collector_async,
46
+ make_environment,
47
+ make_loss_module,
48
+ make_replay_buffer,
49
+ make_sac_agent,
50
+ make_sac_optimizer,
51
+ make_train_environment,
52
+ )
53
+
54
+ torch.set_float32_matmul_precision("high")
55
+ tensordict.nn.functional_modules._exclude_td_from_pytree().set()
56
+
57
+
58
+ @hydra.main(version_base="1.1", config_path="", config_name="config-async")
59
+ def main(cfg: DictConfig): # noqa: F821
60
+ device = (
61
+ torch.device(cfg.network.device)
62
+ if cfg.network.device
63
+ else get_available_device()
64
+ )
65
+
66
+ # Create logger
67
+ exp_name = generate_exp_name("SAC", cfg.logger.exp_name)
68
+ logger = None
69
+ if cfg.logger.backend:
70
+ logger = get_logger(
71
+ logger_type=cfg.logger.backend,
72
+ logger_name="async_sac_logging",
73
+ experiment_name=exp_name,
74
+ wandb_kwargs={
75
+ "mode": cfg.logger.mode,
76
+ "config": dict(cfg),
77
+ "project": cfg.logger.project_name,
78
+ "group": cfg.logger.group_name,
79
+ },
80
+ )
81
+
82
+ torch.manual_seed(cfg.env.seed)
83
+ np.random.seed(cfg.env.seed)
84
+
85
+ # Create environments
86
+ _, eval_env = make_environment(cfg, logger=logger)
87
+
88
+ # TODO: This should be simplified. We need to create the policy on cuda:1 directly because of the bounds
89
+ # of the TanhDistribution which cannot be sent to cuda:1 within the distribution construction (ie, the
90
+ # distribution kwargs need to have access to the low / high values on the right device for compile and
91
+ # cudagraph to work).
92
+ # Create agent
93
+ dummy_train_env = make_train_environment(cfg)
94
+ model, _ = make_sac_agent(cfg, dummy_train_env, eval_env, device)
95
+ _, exploration_policy = make_sac_agent(cfg, dummy_train_env, eval_env, "cuda:1")
96
+ dummy_train_env.close(raise_if_closed=False)
97
+ del dummy_train_env
98
+ exploration_policy.load_state_dict(model[0].state_dict())
99
+
100
+ # Create SAC loss
101
+ loss_module, target_net_updater = make_loss_module(cfg, model)
102
+
103
+ compile_mode = None
104
+ if cfg.compile.compile:
105
+ compile_mode = cfg.compile.compile_mode
106
+ if compile_mode in ("", None):
107
+ if cfg.compile.cudagraphs:
108
+ compile_mode = "default"
109
+ else:
110
+ compile_mode = "reduce-overhead"
111
+ compile_mode_collector = compile_mode # "reduce-overhead"
112
+
113
+ # TODO: enabling prefetch for mp RBs would speed up sampling which is currently responsible for
114
+ # half of the compute time on the trainer side.
115
+ # Create replay buffer
116
+ replay_buffer = make_replay_buffer(
117
+ batch_size=cfg.optim.batch_size,
118
+ prb=cfg.replay_buffer.prb,
119
+ buffer_size=cfg.replay_buffer.size,
120
+ scratch_dir=cfg.replay_buffer.scratch_dir,
121
+ device=device,
122
+ shared=True,
123
+ prefetch=0,
124
+ )
125
+
126
+ # TODO: Simplify this - ideally we'd like to share the uninitialized lazy tensor storage and fetch it once
127
+ # it's initialized
128
+ replay_buffer.extend(make_train_environment(cfg).rollout(1).view(-1))
129
+ replay_buffer.empty()
130
+
131
+ # Create off-policy collector and start it
132
+ collector = make_collector_async(
133
+ cfg,
134
+ partial(make_train_environment, cfg),
135
+ exploration_policy,
136
+ compile_mode=compile_mode_collector,
137
+ replay_buffer=replay_buffer,
138
+ )
139
+
140
+ # Create optimizers
141
+ (
142
+ optimizer_actor,
143
+ optimizer_critic,
144
+ optimizer_alpha,
145
+ ) = make_sac_optimizer(cfg, loss_module)
146
+ optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
147
+ del optimizer_actor, optimizer_critic, optimizer_alpha
148
+
149
+ def update(sampled_tensordict):
150
+ # Compute loss
151
+ loss_td = loss_module(sampled_tensordict)
152
+
153
+ actor_loss = loss_td["loss_actor"]
154
+ q_loss = loss_td["loss_qvalue"]
155
+ alpha_loss = loss_td["loss_alpha"]
156
+
157
+ (actor_loss + q_loss + alpha_loss).sum().backward()
158
+ optimizer.step()
159
+
160
+ # Update qnet_target params
161
+ target_net_updater.step()
162
+
163
+ optimizer.zero_grad(set_to_none=True)
164
+ return loss_td.detach()
165
+
166
+ if cfg.compile.compile:
167
+ update = compile_with_warmup(update, mode=compile_mode, warmup=2)
168
+
169
+ cfg.compile.cudagraphs
170
+ if cfg.compile.cudagraphs:
171
+ warnings.warn(
172
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
173
+ category=UserWarning,
174
+ )
175
+ update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=10)
176
+
177
+ # Main loop
178
+ init_random_frames = cfg.collector.init_random_frames
179
+
180
+ prb = cfg.replay_buffer.prb
181
+ update_freq = cfg.collector.update_freq
182
+
183
+ eval_rollout_steps = cfg.env.max_episode_steps
184
+ log_freq = cfg.logger.log_freq
185
+
186
+ # TODO: customize this
187
+ num_updates = 1000
188
+ total_iter = 1000
189
+ pbar = tqdm.tqdm(total=total_iter * num_updates)
190
+ params = TensorDict.from_module(model[0]).data
191
+
192
+ # Wait till we have enough data to start training
193
+ while replay_buffer.write_count <= init_random_frames:
194
+ time.sleep(0.01)
195
+
196
+ losses = []
197
+ for i in range(total_iter * num_updates):
198
+ timeit.printevery(
199
+ num_prints=total_iter * num_updates // log_freq,
200
+ total_count=total_iter * num_updates,
201
+ erase=True,
202
+ )
203
+
204
+ if (i % update_freq) == 0:
205
+ # Update weights of the inference policy
206
+ torchrl_logger.info("Updating weights")
207
+ collector.update_policy_weights_(params)
208
+
209
+ pbar.update(1)
210
+
211
+ # Optimization steps
212
+ with timeit("train"):
213
+ with timeit("train - rb - sample"):
214
+ # Sample from replay buffer
215
+ sampled_tensordict = replay_buffer.sample()
216
+
217
+ with timeit("train - update"):
218
+ torch.compiler.cudagraph_mark_step_begin()
219
+ loss_td = update(sampled_tensordict).clone()
220
+ losses.append(loss_td.select("loss_actor", "loss_qvalue", "loss_alpha"))
221
+
222
+ # Update priority
223
+ if prb:
224
+ replay_buffer.update_priority(sampled_tensordict)
225
+
226
+ # Logging
227
+ if (i % log_freq) == (log_freq - 1):
228
+ torchrl_logger.info("Logging")
229
+ collected_frames = replay_buffer.write_count
230
+ metrics_to_log = {}
231
+ if collected_frames >= init_random_frames:
232
+ losses_m = torch.stack(losses).mean()
233
+ losses = []
234
+ metrics_to_log["train/q_loss"] = losses_m.get("loss_qvalue")
235
+ metrics_to_log["train/actor_loss"] = losses_m.get("loss_actor")
236
+ metrics_to_log["train/alpha_loss"] = losses_m.get("loss_alpha")
237
+ metrics_to_log["train/alpha"] = loss_td["alpha"]
238
+ metrics_to_log["train/entropy"] = loss_td["entropy"]
239
+ metrics_to_log["train/collected_frames"] = int(collected_frames)
240
+
241
+ # Evaluation
242
+ with set_exploration_type(
243
+ ExplorationType.DETERMINISTIC
244
+ ), torch.no_grad(), timeit("eval"):
245
+ eval_rollout = eval_env.rollout(
246
+ eval_rollout_steps,
247
+ model[0],
248
+ auto_cast_to_device=True,
249
+ break_when_any_done=True,
250
+ )
251
+ eval_env.apply(dump_video)
252
+ eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
253
+ metrics_to_log["eval/reward"] = eval_reward
254
+ torchrl_logger.info(f"Logs: {metrics_to_log}")
255
+ if logger is not None:
256
+ metrics_to_log.update(timeit.todict(prefix="time"))
257
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
258
+ log_metrics(logger, metrics_to_log, collected_frames)
259
+
260
+ collector.shutdown()
261
+ if not eval_env.is_closed:
262
+ eval_env.close()
263
+
264
+
265
+ if __name__ == "__main__":
266
+ main()
@@ -0,0 +1,239 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """SAC Example.
6
+
7
+ This is a simple self-contained example of a SAC training script.
8
+
9
+ It supports state environments like MuJoCo.
10
+
11
+ The helper functions are coded in the utils.py associated with this script.
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import warnings
16
+
17
+ import hydra
18
+ import numpy as np
19
+ import torch
20
+ import torch.cuda
21
+ import tqdm
22
+ from tensordict import TensorDict
23
+ from tensordict.nn import CudaGraphModule
24
+ from torchrl._utils import compile_with_warmup, get_available_device, timeit
25
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
26
+ from torchrl.objectives import group_optimizers
27
+ from torchrl.record.loggers import generate_exp_name, get_logger
28
+ from utils import (
29
+ dump_video,
30
+ log_metrics,
31
+ make_collector,
32
+ make_environment,
33
+ make_loss_module,
34
+ make_replay_buffer,
35
+ make_sac_agent,
36
+ make_sac_optimizer,
37
+ )
38
+
39
+ torch.set_float32_matmul_precision("high")
40
+
41
+
42
+ @hydra.main(version_base="1.1", config_path="", config_name="config")
43
+ def main(cfg: DictConfig): # noqa: F821
44
+ device = (
45
+ torch.device(cfg.network.device)
46
+ if cfg.network.device
47
+ else get_available_device()
48
+ )
49
+
50
+ # Create logger
51
+ exp_name = generate_exp_name("SAC", cfg.logger.exp_name)
52
+ logger = None
53
+ if cfg.logger.backend:
54
+ logger = get_logger(
55
+ logger_type=cfg.logger.backend,
56
+ logger_name="sac_logging",
57
+ experiment_name=exp_name,
58
+ wandb_kwargs={
59
+ "mode": cfg.logger.mode,
60
+ "config": dict(cfg),
61
+ "project": cfg.logger.project_name,
62
+ "group": cfg.logger.group_name,
63
+ },
64
+ )
65
+
66
+ torch.manual_seed(cfg.env.seed)
67
+ np.random.seed(cfg.env.seed)
68
+
69
+ # Create environments
70
+ train_env, eval_env = make_environment(cfg, logger=logger)
71
+
72
+ # Create agent
73
+ model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device)
74
+
75
+ # Create SAC loss
76
+ loss_module, target_net_updater = make_loss_module(cfg, model)
77
+
78
+ compile_mode = None
79
+ if cfg.compile.compile:
80
+ compile_mode = cfg.compile.compile_mode
81
+ if compile_mode in ("", None):
82
+ if cfg.compile.cudagraphs:
83
+ compile_mode = "default"
84
+ else:
85
+ compile_mode = "reduce-overhead"
86
+
87
+ # Create off-policy collector
88
+ collector = make_collector(
89
+ cfg, train_env, exploration_policy, compile_mode=compile_mode
90
+ )
91
+
92
+ # Create replay buffer
93
+ replay_buffer = make_replay_buffer(
94
+ batch_size=cfg.optim.batch_size,
95
+ prb=cfg.replay_buffer.prb,
96
+ buffer_size=cfg.replay_buffer.size,
97
+ scratch_dir=cfg.replay_buffer.scratch_dir,
98
+ device=device,
99
+ )
100
+
101
+ # Create optimizers
102
+ (
103
+ optimizer_actor,
104
+ optimizer_critic,
105
+ optimizer_alpha,
106
+ ) = make_sac_optimizer(cfg, loss_module)
107
+ optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
108
+ del optimizer_actor, optimizer_critic, optimizer_alpha
109
+
110
+ def update(sampled_tensordict):
111
+ # Compute loss
112
+ loss_td = loss_module(sampled_tensordict)
113
+
114
+ actor_loss = loss_td["loss_actor"]
115
+ q_loss = loss_td["loss_qvalue"]
116
+ alpha_loss = loss_td["loss_alpha"]
117
+
118
+ (actor_loss + q_loss + alpha_loss).sum().backward()
119
+ optimizer.step()
120
+ optimizer.zero_grad(set_to_none=True)
121
+
122
+ # Update qnet_target params
123
+ target_net_updater.step()
124
+ return loss_td.detach()
125
+
126
+ if cfg.compile.compile:
127
+ update = compile_with_warmup(update, mode=compile_mode, warmup=1)
128
+
129
+ if cfg.compile.cudagraphs:
130
+ warnings.warn(
131
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
132
+ category=UserWarning,
133
+ )
134
+ update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
135
+
136
+ # Main loop
137
+ collected_frames = 0
138
+ pbar = tqdm.tqdm(total=cfg.collector.total_frames)
139
+
140
+ init_random_frames = cfg.collector.init_random_frames
141
+ num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
142
+ prb = cfg.replay_buffer.prb
143
+ eval_iter = cfg.logger.eval_iter
144
+ frames_per_batch = cfg.collector.frames_per_batch
145
+ eval_rollout_steps = cfg.env.max_episode_steps
146
+
147
+ collector_iter = iter(collector)
148
+ total_iter = len(collector)
149
+
150
+ for i in range(total_iter):
151
+ timeit.printevery(num_prints=1000, total_count=total_iter, erase=True)
152
+
153
+ with timeit("collect"):
154
+ tensordict = next(collector_iter)
155
+
156
+ # Update weights of the inference policy
157
+ collector.update_policy_weights_()
158
+
159
+ current_frames = tensordict.numel()
160
+ pbar.update(current_frames)
161
+
162
+ with timeit("rb - extend"):
163
+ # Add to replay buffer
164
+ tensordict = tensordict.reshape(-1)
165
+ replay_buffer.extend(tensordict)
166
+
167
+ collected_frames += current_frames
168
+
169
+ # Optimization steps
170
+ with timeit("train"):
171
+ if collected_frames >= init_random_frames:
172
+ losses = TensorDict(batch_size=[num_updates])
173
+ for i in range(num_updates):
174
+ with timeit("rb - sample"):
175
+ # Sample from replay buffer
176
+ sampled_tensordict = replay_buffer.sample()
177
+
178
+ with timeit("update"):
179
+ torch.compiler.cudagraph_mark_step_begin()
180
+ loss_td = update(sampled_tensordict).clone()
181
+ losses[i] = loss_td.select(
182
+ "loss_actor", "loss_qvalue", "loss_alpha"
183
+ )
184
+
185
+ # Update priority
186
+ if prb:
187
+ replay_buffer.update_priority(sampled_tensordict)
188
+
189
+ episode_end = (
190
+ tensordict["next", "done"]
191
+ if tensordict["next", "done"].any()
192
+ else tensordict["next", "truncated"]
193
+ )
194
+ episode_rewards = tensordict["next", "episode_reward"][episode_end]
195
+
196
+ # Logging
197
+ metrics_to_log = {}
198
+ if len(episode_rewards) > 0:
199
+ episode_length = tensordict["next", "step_count"][episode_end]
200
+ metrics_to_log["train/reward"] = episode_rewards
201
+ metrics_to_log["train/episode_length"] = episode_length.sum() / len(
202
+ episode_length
203
+ )
204
+ if collected_frames >= init_random_frames:
205
+ losses = losses.mean()
206
+ metrics_to_log["train/q_loss"] = losses.get("loss_qvalue")
207
+ metrics_to_log["train/actor_loss"] = losses.get("loss_actor")
208
+ metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha")
209
+ metrics_to_log["train/alpha"] = loss_td["alpha"]
210
+ metrics_to_log["train/entropy"] = loss_td["entropy"]
211
+
212
+ # Evaluation
213
+ if abs(collected_frames % eval_iter) < frames_per_batch:
214
+ with set_exploration_type(
215
+ ExplorationType.DETERMINISTIC
216
+ ), torch.no_grad(), timeit("eval"):
217
+ eval_rollout = eval_env.rollout(
218
+ eval_rollout_steps,
219
+ model[0],
220
+ auto_cast_to_device=True,
221
+ break_when_any_done=True,
222
+ )
223
+ eval_env.apply(dump_video)
224
+ eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
225
+ metrics_to_log["eval/reward"] = eval_reward
226
+ if logger is not None:
227
+ metrics_to_log.update(timeit.todict(prefix="time"))
228
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
229
+ log_metrics(logger, metrics_to_log, collected_frames)
230
+
231
+ collector.shutdown()
232
+ if not eval_env.is_closed:
233
+ eval_env.close()
234
+ if not train_env.is_closed:
235
+ train_env.close()
236
+
237
+
238
+ if __name__ == "__main__":
239
+ main()