torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,243 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """Discrete SAC Example.
6
+
7
+ This is a simple self-contained example of a discrete SAC training script.
8
+
9
+ It supports gym state environments like CartPole.
10
+
11
+ The helper functions are coded in the utils.py associated with this script.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import warnings
17
+
18
+ import hydra
19
+ import numpy as np
20
+ import torch
21
+ import torch.cuda
22
+ import tqdm
23
+ from tensordict.nn import CudaGraphModule
24
+ from torchrl._utils import timeit
25
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
26
+ from torchrl.objectives import group_optimizers
27
+ from torchrl.record.loggers import generate_exp_name, get_logger
28
+ from utils import (
29
+ dump_video,
30
+ log_metrics,
31
+ make_collector,
32
+ make_environment,
33
+ make_loss_module,
34
+ make_optimizer,
35
+ make_replay_buffer,
36
+ make_sac_agent,
37
+ )
38
+
39
+
40
+ @hydra.main(version_base="1.1", config_path="", config_name="config")
41
+ def main(cfg: DictConfig): # noqa: F821
42
+ device = cfg.network.device
43
+ if device in ("", None):
44
+ if torch.cuda.is_available():
45
+ device = "cuda:0"
46
+ else:
47
+ device = "cpu"
48
+ device = torch.device(device)
49
+
50
+ # Create logger
51
+ exp_name = generate_exp_name("DiscreteSAC", cfg.logger.exp_name)
52
+ logger = None
53
+ if cfg.logger.backend:
54
+ logger = get_logger(
55
+ logger_type=cfg.logger.backend,
56
+ logger_name="DiscreteSAC_logging",
57
+ experiment_name=exp_name,
58
+ wandb_kwargs={
59
+ "mode": cfg.logger.mode,
60
+ "config": dict(cfg),
61
+ "project": cfg.logger.project_name,
62
+ "group": cfg.logger.group_name,
63
+ },
64
+ )
65
+
66
+ # Set seeds
67
+ torch.manual_seed(cfg.env.seed)
68
+ np.random.seed(cfg.env.seed)
69
+
70
+ # Create environments
71
+ train_env, eval_env = make_environment(cfg, logger=logger)
72
+
73
+ # Create agent
74
+ model = make_sac_agent(cfg, train_env, eval_env, device)
75
+
76
+ # Create TD3 loss
77
+ loss_module, target_net_updater = make_loss_module(cfg, model)
78
+
79
+ # Create replay buffer
80
+ replay_buffer = make_replay_buffer(
81
+ batch_size=cfg.optim.batch_size,
82
+ prb=cfg.replay_buffer.prb,
83
+ buffer_size=cfg.replay_buffer.size,
84
+ scratch_dir=cfg.replay_buffer.scratch_dir,
85
+ device="cpu",
86
+ )
87
+
88
+ # Create optimizers
89
+ optimizer_actor, optimizer_critic, optimizer_alpha = make_optimizer(
90
+ cfg, loss_module
91
+ )
92
+ optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
93
+ del optimizer_actor, optimizer_critic, optimizer_alpha
94
+
95
+ def update(sampled_tensordict):
96
+ optimizer.zero_grad(set_to_none=True)
97
+
98
+ # Compute loss
99
+ loss_out = loss_module(sampled_tensordict)
100
+
101
+ actor_loss, q_loss, alpha_loss = (
102
+ loss_out["loss_actor"],
103
+ loss_out["loss_qvalue"],
104
+ loss_out["loss_alpha"],
105
+ )
106
+
107
+ # Update critic
108
+ (q_loss + actor_loss + alpha_loss).backward()
109
+ optimizer.step()
110
+
111
+ # Update target params
112
+ target_net_updater.step()
113
+
114
+ return loss_out.detach()
115
+
116
+ compile_mode = None
117
+ if cfg.compile.compile:
118
+ compile_mode = cfg.compile.compile_mode
119
+ if compile_mode in ("", None):
120
+ if cfg.compile.cudagraphs:
121
+ compile_mode = "default"
122
+ else:
123
+ compile_mode = "reduce-overhead"
124
+ update = torch.compile(update, mode=compile_mode)
125
+ if cfg.compile.cudagraphs:
126
+ warnings.warn(
127
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
128
+ category=UserWarning,
129
+ )
130
+ update = CudaGraphModule(update, warmup=50)
131
+
132
+ # Create off-policy collector
133
+ collector = make_collector(
134
+ cfg,
135
+ train_env,
136
+ model[0],
137
+ compile=compile_mode is not None,
138
+ compile_mode=compile_mode,
139
+ cudagraphs=cfg.compile.cudagraphs,
140
+ )
141
+
142
+ # Main loop
143
+ collected_frames = 0
144
+ pbar = tqdm.tqdm(total=cfg.collector.total_frames)
145
+
146
+ init_random_frames = cfg.collector.init_random_frames
147
+ num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
148
+ prb = cfg.replay_buffer.prb
149
+ eval_rollout_steps = cfg.env.max_episode_steps
150
+ eval_iter = cfg.logger.eval_iter
151
+ frames_per_batch = cfg.collector.frames_per_batch
152
+
153
+ c_iter = iter(collector)
154
+ total_iter = len(collector)
155
+ for i in range(total_iter):
156
+ timeit.printevery(1000, total_iter, erase=True)
157
+ with timeit("collecting"):
158
+ collected_data = next(c_iter)
159
+
160
+ # Update weights of the inference policy
161
+ collector.update_policy_weights_()
162
+ current_frames = collected_data.numel()
163
+
164
+ pbar.update(current_frames)
165
+
166
+ collected_data = collected_data.reshape(-1)
167
+ with timeit("rb - extend"):
168
+ # Add to replay buffer
169
+ replay_buffer.extend(collected_data)
170
+ collected_frames += current_frames
171
+
172
+ # Optimization steps
173
+ if collected_frames >= init_random_frames:
174
+ tds = []
175
+ for _ in range(num_updates):
176
+ with timeit("rb - sample"):
177
+ # Sample from replay buffer
178
+ sampled_tensordict = replay_buffer.sample()
179
+
180
+ with timeit("update"):
181
+ torch.compiler.cudagraph_mark_step_begin()
182
+ sampled_tensordict = sampled_tensordict.to(device)
183
+ loss_out = update(sampled_tensordict).clone()
184
+
185
+ tds.append(loss_out)
186
+
187
+ # Update priority
188
+ if prb:
189
+ replay_buffer.update_priority(sampled_tensordict)
190
+ tds = torch.stack(tds).mean()
191
+
192
+ # Logging
193
+ episode_end = (
194
+ collected_data["next", "done"]
195
+ if collected_data["next", "done"].any()
196
+ else collected_data["next", "truncated"]
197
+ )
198
+ episode_rewards = collected_data["next", "episode_reward"][episode_end]
199
+
200
+ metrics_to_log = {}
201
+ if len(episode_rewards) > 0:
202
+ episode_length = collected_data["next", "step_count"][episode_end]
203
+ metrics_to_log["train/reward"] = episode_rewards.mean().item()
204
+ metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
205
+ episode_length
206
+ )
207
+
208
+ if collected_frames >= init_random_frames:
209
+ metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
210
+ metrics_to_log["train/a_loss"] = tds["loss_actor"]
211
+ metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]
212
+
213
+ # Evaluation
214
+ prev_test_frame = ((i - 1) * frames_per_batch) // eval_iter
215
+ cur_test_frame = (i * frames_per_batch) // eval_iter
216
+ final = current_frames >= collector.total_frames
217
+ if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
218
+ with set_exploration_type(
219
+ ExplorationType.DETERMINISTIC
220
+ ), torch.no_grad(), timeit("eval"):
221
+ eval_rollout = eval_env.rollout(
222
+ eval_rollout_steps,
223
+ model[0],
224
+ auto_cast_to_device=True,
225
+ break_when_any_done=True,
226
+ )
227
+ eval_env.apply(dump_video)
228
+ eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
229
+ metrics_to_log["eval/reward"] = eval_reward
230
+ if logger is not None:
231
+ metrics_to_log.update(timeit.todict(prefix="time"))
232
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
233
+ log_metrics(logger, metrics_to_log, collected_frames)
234
+
235
+ collector.shutdown()
236
+ if not eval_env.is_closed:
237
+ eval_env.close()
238
+ if not train_env.is_closed:
239
+ train_env.close()
240
+
241
+
242
+ if __name__ == "__main__":
243
+ main()
@@ -0,0 +1,324 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+ import tempfile
9
+ from contextlib import nullcontext
10
+
11
+ import torch
12
+ from tensordict.nn import InteractionType, TensorDictModule
13
+
14
+ from torch import nn, optim
15
+ from torchrl.collectors import SyncDataCollector
16
+ from torchrl.data import (
17
+ Composite,
18
+ TensorDictPrioritizedReplayBuffer,
19
+ TensorDictReplayBuffer,
20
+ )
21
+ from torchrl.data.replay_buffers.storages import LazyMemmapStorage
22
+ from torchrl.envs import (
23
+ CatTensors,
24
+ Compose,
25
+ DMControlEnv,
26
+ DoubleToFloat,
27
+ EnvCreator,
28
+ InitTracker,
29
+ ParallelEnv,
30
+ RewardSum,
31
+ StepCounter,
32
+ TransformedEnv,
33
+ )
34
+ from torchrl.envs.libs.gym import GymEnv, set_gym_backend
35
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
36
+ from torchrl.modules import MLP, SafeModule
37
+ from torchrl.modules.distributions import OneHotCategorical
38
+
39
+ from torchrl.modules.tensordict_module.actors import ProbabilisticActor
40
+ from torchrl.objectives import SoftUpdate
41
+ from torchrl.objectives.sac import DiscreteSACLoss
42
+ from torchrl.record import VideoRecorder
43
+
44
+
45
+ # ====================================================================
46
+ # Environment utils
47
+ # -----------------
48
+
49
+
50
+ def env_maker(cfg, device="cpu", from_pixels=False):
51
+ lib = cfg.env.library
52
+ if lib in ("gym", "gymnasium"):
53
+ with set_gym_backend(lib):
54
+ return GymEnv(
55
+ cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False
56
+ )
57
+ elif lib == "dm_control":
58
+ env = DMControlEnv(
59
+ cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
60
+ )
61
+ return TransformedEnv(
62
+ env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
63
+ )
64
+ else:
65
+ raise NotImplementedError(f"Unknown lib {lib}.")
66
+
67
+
68
+ def apply_env_transforms(env, max_episode_steps):
69
+ transformed_env = TransformedEnv(
70
+ env,
71
+ Compose(
72
+ StepCounter(max_steps=max_episode_steps),
73
+ InitTracker(),
74
+ DoubleToFloat(),
75
+ RewardSum(),
76
+ ),
77
+ )
78
+ return transformed_env
79
+
80
+
81
+ def make_environment(cfg, logger=None):
82
+ """Make environments for training and evaluation."""
83
+ maker = functools.partial(env_maker, cfg)
84
+ parallel_env = ParallelEnv(
85
+ cfg.collector.env_per_collector,
86
+ EnvCreator(maker),
87
+ serial_for_single=True,
88
+ )
89
+ parallel_env.set_seed(cfg.env.seed)
90
+
91
+ train_env = apply_env_transforms(
92
+ parallel_env, max_episode_steps=cfg.env.max_episode_steps
93
+ )
94
+
95
+ maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video)
96
+ eval_env = TransformedEnv(
97
+ ParallelEnv(
98
+ cfg.collector.env_per_collector,
99
+ EnvCreator(maker),
100
+ serial_for_single=True,
101
+ ),
102
+ train_env.transform.clone(),
103
+ )
104
+ if cfg.logger.video:
105
+ eval_env = eval_env.insert_transform(
106
+ 0, VideoRecorder(logger, tag="rendered", in_keys=["pixels"])
107
+ )
108
+ return train_env, eval_env
109
+
110
+
111
+ # ====================================================================
112
+ # Collector and replay buffer
113
+ # ---------------------------
114
+
115
+
116
+ def make_collector(
117
+ cfg,
118
+ train_env,
119
+ actor_model_explore,
120
+ compile=False,
121
+ compile_mode=None,
122
+ cudagraphs=False,
123
+ ):
124
+ """Make collector."""
125
+ device = cfg.collector.device
126
+ if device in ("", None):
127
+ if torch.cuda.is_available():
128
+ device = "cuda:0"
129
+ else:
130
+ device = "cpu"
131
+ device = torch.device(device)
132
+ collector = SyncDataCollector(
133
+ train_env,
134
+ actor_model_explore,
135
+ init_random_frames=cfg.collector.init_random_frames,
136
+ frames_per_batch=cfg.collector.frames_per_batch,
137
+ total_frames=cfg.collector.total_frames,
138
+ reset_at_each_iter=cfg.collector.reset_at_each_iter,
139
+ device=device,
140
+ storing_device="cpu",
141
+ compile_policy=False if not compile else {"mode": compile_mode},
142
+ cudagraph_policy={"warmup": 10} if cudagraphs else False,
143
+ )
144
+ collector.set_seed(cfg.env.seed)
145
+ return collector
146
+
147
+
148
+ def make_replay_buffer(
149
+ batch_size,
150
+ prb=False,
151
+ buffer_size=1000000,
152
+ scratch_dir=None,
153
+ device="cpu",
154
+ prefetch=3,
155
+ ):
156
+ with (
157
+ tempfile.TemporaryDirectory()
158
+ if scratch_dir is None
159
+ else nullcontext(scratch_dir)
160
+ ) as scratch_dir:
161
+ if prb:
162
+ replay_buffer = TensorDictPrioritizedReplayBuffer(
163
+ alpha=0.7,
164
+ beta=0.5,
165
+ pin_memory=False,
166
+ prefetch=prefetch,
167
+ storage=LazyMemmapStorage(
168
+ buffer_size,
169
+ scratch_dir=scratch_dir,
170
+ device=device,
171
+ ),
172
+ batch_size=batch_size,
173
+ )
174
+ else:
175
+ replay_buffer = TensorDictReplayBuffer(
176
+ pin_memory=False,
177
+ prefetch=prefetch,
178
+ storage=LazyMemmapStorage(
179
+ buffer_size,
180
+ scratch_dir=scratch_dir,
181
+ device=device,
182
+ ),
183
+ batch_size=batch_size,
184
+ )
185
+ return replay_buffer
186
+
187
+
188
+ # ====================================================================
189
+ # Model
190
+ # -----
191
+
192
+
193
+ def make_sac_agent(cfg, train_env, eval_env, device):
194
+ """Make discrete SAC agent."""
195
+ # Define Actor Network
196
+ in_keys = ["observation"]
197
+ action_spec = train_env.action_spec
198
+ if train_env.batch_size:
199
+ action_spec = action_spec[(0,) * len(train_env.batch_size)]
200
+ # Define Actor Network
201
+ in_keys = ["observation"]
202
+
203
+ actor_net_kwargs = {
204
+ "num_cells": cfg.network.hidden_sizes,
205
+ "out_features": action_spec.shape[-1],
206
+ "activation_class": get_activation(cfg),
207
+ }
208
+
209
+ actor_net = MLP(**actor_net_kwargs)
210
+
211
+ actor_module = SafeModule(
212
+ module=actor_net,
213
+ in_keys=in_keys,
214
+ out_keys=["logits"],
215
+ )
216
+ actor = ProbabilisticActor(
217
+ spec=Composite(action=eval_env.action_spec),
218
+ module=actor_module,
219
+ in_keys=["logits"],
220
+ out_keys=["action"],
221
+ distribution_class=OneHotCategorical,
222
+ distribution_kwargs={},
223
+ default_interaction_type=InteractionType.RANDOM,
224
+ return_log_prob=False,
225
+ )
226
+
227
+ # Define Critic Network
228
+ qvalue_net_kwargs = {
229
+ "num_cells": cfg.network.hidden_sizes,
230
+ "out_features": action_spec.shape[-1],
231
+ "activation_class": get_activation(cfg),
232
+ }
233
+ qvalue_net = MLP(
234
+ **qvalue_net_kwargs,
235
+ )
236
+
237
+ qvalue = TensorDictModule(
238
+ in_keys=in_keys,
239
+ out_keys=["action_value"],
240
+ module=qvalue_net,
241
+ )
242
+
243
+ model = torch.nn.ModuleList([actor, qvalue]).to(device)
244
+ # init nets
245
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
246
+ td = eval_env.reset()
247
+ td = td.to(device)
248
+ for net in model:
249
+ net(td)
250
+ del td
251
+ eval_env.close()
252
+
253
+ return model
254
+
255
+
256
+ # ====================================================================
257
+ # Discrete SAC Loss
258
+ # ---------
259
+
260
+
261
+ def make_loss_module(cfg, model):
262
+ """Make loss module and target network updater."""
263
+ # Create discrete SAC loss
264
+ loss_module = DiscreteSACLoss(
265
+ actor_network=model[0],
266
+ qvalue_network=model[1],
267
+ num_actions=model[0].spec["action"].space.n,
268
+ num_qvalue_nets=2,
269
+ loss_function=cfg.optim.loss_function,
270
+ target_entropy_weight=cfg.optim.target_entropy_weight,
271
+ delay_qvalue=True,
272
+ )
273
+ loss_module.make_value_estimator(gamma=cfg.optim.gamma)
274
+
275
+ # Define Target Network Updater
276
+ target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak)
277
+ return loss_module, target_net_updater
278
+
279
+
280
+ def make_optimizer(cfg, loss_module):
281
+ critic_params = list(loss_module.qvalue_network_params.flatten_keys().values())
282
+ actor_params = list(loss_module.actor_network_params.flatten_keys().values())
283
+
284
+ optimizer_actor = optim.Adam(
285
+ actor_params,
286
+ lr=cfg.optim.lr,
287
+ weight_decay=cfg.optim.weight_decay,
288
+ )
289
+ optimizer_critic = optim.Adam(
290
+ critic_params,
291
+ lr=cfg.optim.lr,
292
+ weight_decay=cfg.optim.weight_decay,
293
+ )
294
+ optimizer_alpha = optim.Adam(
295
+ [loss_module.log_alpha],
296
+ lr=3.0e-4,
297
+ )
298
+ return optimizer_actor, optimizer_critic, optimizer_alpha
299
+
300
+
301
+ # ====================================================================
302
+ # General utils
303
+ # ---------
304
+
305
+
306
+ def log_metrics(logger, metrics, step):
307
+ for metric_name, metric_value in metrics.items():
308
+ logger.log_scalar(metric_name, metric_value, step)
309
+
310
+
311
+ def get_activation(cfg):
312
+ if cfg.network.activation == "relu":
313
+ return nn.ReLU
314
+ elif cfg.network.activation == "tanh":
315
+ return nn.Tanh
316
+ elif cfg.network.activation == "leaky_relu":
317
+ return nn.LeakyReLU
318
+ else:
319
+ raise NotImplementedError
320
+
321
+
322
+ def dump_video(module):
323
+ if isinstance(module, VideoRecorder):
324
+ module.dump()
@@ -0,0 +1,30 @@
1
+ ## Reproducing Deep Q-Learning (DQN) Algorithm Results
2
+
3
+ This repository contains scripts that enable training agents using the Deep Q-Learning (DQN) Algorithm on CartPole and Atari environments. For Atari, We follow the original paper [Playing Atari with Deep Reinforcement Learning](https://arxiv.org/abs/1312.5602) by Mnih et al. (2013).
4
+
5
+
6
+ ## Examples Structure
7
+
8
+ Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files:
9
+
10
+ 1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. dqn_atari.py).
11
+
12
+ 2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils_atari.py).
13
+
14
+ 3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. config_atari.yaml).
15
+
16
+
17
+ ## Running the Examples
18
+
19
+ You can execute the DQN algorithm on the CartPole environment by running the following command:
20
+
21
+ ```bash
22
+ python dqn_cartpole.py
23
+
24
+ ```
25
+
26
+ You can execute the DQN algorithm on Atari environments by running the following command:
27
+
28
+ ```bash
29
+ python dqn_atari.py
30
+ ```