torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,337 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import time
8
+
9
+ import hydra
10
+ import torch
11
+ from tensordict.nn import TensorDictModule
12
+ from tensordict.nn.distributions import NormalParamExtractor
13
+ from torch import nn
14
+ from torch.distributions import Categorical, OneHotCategorical
15
+ from torchrl._utils import logger as torchrl_logger
16
+ from torchrl.collectors import SyncDataCollector
17
+ from torchrl.data import TensorDictReplayBuffer
18
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
19
+ from torchrl.data.replay_buffers.storages import LazyTensorStorage
20
+ from torchrl.envs import RewardSum, TransformedEnv
21
+ from torchrl.envs.libs.vmas import VmasEnv
22
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
23
+ from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
24
+ from torchrl.modules.models.multiagent import MultiAgentMLP
25
+ from torchrl.objectives import DiscreteSACLoss, SACLoss, SoftUpdate, ValueEstimators
26
+ from utils.logging import init_logging, log_evaluation, log_training
27
+ from utils.utils import DoneTransform
28
+
29
+
30
+ def rendering_callback(env, td):
31
+ env.frames.append(env.render(mode="rgb_array", agent_index_focus=None))
32
+
33
+
34
+ @hydra.main(version_base="1.1", config_path="", config_name="sac")
35
+ def train(cfg: DictConfig): # noqa: F821
36
+ # Device
37
+ cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0"
38
+ cfg.env.device = cfg.train.device
39
+
40
+ # Seeding
41
+ torch.manual_seed(cfg.seed)
42
+
43
+ # Sampling
44
+ cfg.env.vmas_envs = cfg.collector.frames_per_batch // cfg.env.max_steps
45
+ cfg.collector.total_frames = cfg.collector.frames_per_batch * cfg.collector.n_iters
46
+ cfg.buffer.memory_size = cfg.collector.frames_per_batch
47
+
48
+ # Create env and env_test
49
+ env = VmasEnv(
50
+ scenario=cfg.env.scenario_name,
51
+ num_envs=cfg.env.vmas_envs,
52
+ continuous_actions=cfg.env.continuous_actions,
53
+ max_steps=cfg.env.max_steps,
54
+ device=cfg.env.device,
55
+ seed=cfg.seed,
56
+ categorical_actions=cfg.env.categorical_actions,
57
+ # Scenario kwargs
58
+ **cfg.env.scenario,
59
+ )
60
+ env = TransformedEnv(
61
+ env,
62
+ RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]),
63
+ )
64
+
65
+ env_test = VmasEnv(
66
+ scenario=cfg.env.scenario_name,
67
+ num_envs=cfg.eval.evaluation_episodes,
68
+ continuous_actions=cfg.env.continuous_actions,
69
+ max_steps=cfg.env.max_steps,
70
+ device=cfg.env.device,
71
+ seed=cfg.seed,
72
+ # Scenario kwargs
73
+ **cfg.env.scenario,
74
+ )
75
+
76
+ # Policy
77
+ if cfg.env.continuous_actions:
78
+ actor_net = nn.Sequential(
79
+ MultiAgentMLP(
80
+ n_agent_inputs=env.full_observation_spec_unbatched[
81
+ "agents", "observation"
82
+ ].shape[-1],
83
+ n_agent_outputs=2
84
+ * env.full_action_spec_unbatched["agents", "action"].shape[-1],
85
+ n_agents=env.n_agents,
86
+ centralised=False,
87
+ share_params=cfg.model.shared_parameters,
88
+ device=cfg.train.device,
89
+ depth=2,
90
+ num_cells=256,
91
+ activation_class=nn.Tanh,
92
+ ),
93
+ NormalParamExtractor(),
94
+ )
95
+ policy_module = TensorDictModule(
96
+ actor_net,
97
+ in_keys=[("agents", "observation")],
98
+ out_keys=[("agents", "loc"), ("agents", "scale")],
99
+ )
100
+
101
+ policy = ProbabilisticActor(
102
+ module=policy_module,
103
+ spec=env.full_action_spec_unbatched,
104
+ in_keys=[("agents", "loc"), ("agents", "scale")],
105
+ out_keys=[env.action_key],
106
+ distribution_class=TanhNormal,
107
+ distribution_kwargs={
108
+ "low": env.full_action_spec_unbatched[("agents", "action")].space.low,
109
+ "high": env.full_action_spec_unbatched[("agents", "action")].space.high,
110
+ },
111
+ return_log_prob=True,
112
+ )
113
+
114
+ # Critic
115
+ module = MultiAgentMLP(
116
+ n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1]
117
+ + env.full_action_spec_unbatched["agents", "action"].shape[
118
+ -1
119
+ ], # Q critic takes action and value
120
+ n_agent_outputs=1,
121
+ n_agents=env.n_agents,
122
+ centralised=cfg.model.centralised_critic,
123
+ share_params=cfg.model.shared_parameters,
124
+ device=cfg.train.device,
125
+ depth=2,
126
+ num_cells=256,
127
+ activation_class=nn.Tanh,
128
+ )
129
+ value_module = ValueOperator(
130
+ module=module,
131
+ in_keys=[("agents", "observation"), env.action_key],
132
+ out_keys=[("agents", "state_action_value")],
133
+ )
134
+ else:
135
+ actor_net = nn.Sequential(
136
+ MultiAgentMLP(
137
+ n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
138
+ n_agent_outputs=env.full_action_spec_unbatched[
139
+ "agents", "action"
140
+ ].space.n,
141
+ n_agents=env.n_agents,
142
+ centralised=False,
143
+ share_params=cfg.model.shared_parameters,
144
+ device=cfg.train.device,
145
+ depth=2,
146
+ num_cells=256,
147
+ activation_class=nn.Tanh,
148
+ ),
149
+ )
150
+ policy_module = TensorDictModule(
151
+ actor_net,
152
+ in_keys=[("agents", "observation")],
153
+ out_keys=[("agents", "logits")],
154
+ )
155
+ policy = ProbabilisticActor(
156
+ module=policy_module,
157
+ spec=env.full_action_spec_unbatched["agents", "action"],
158
+ in_keys=[("agents", "logits")],
159
+ out_keys=[env.action_key],
160
+ distribution_class=OneHotCategorical
161
+ if not cfg.env.categorical_actions
162
+ else Categorical,
163
+ return_log_prob=True,
164
+ )
165
+
166
+ # Critic
167
+ module = MultiAgentMLP(
168
+ n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
169
+ n_agent_outputs=env.full_action_spec_unbatched["agents", "action"].space.n,
170
+ n_agents=env.n_agents,
171
+ centralised=cfg.model.centralised_critic,
172
+ share_params=cfg.model.shared_parameters,
173
+ device=cfg.train.device,
174
+ depth=2,
175
+ num_cells=256,
176
+ activation_class=nn.Tanh,
177
+ )
178
+ value_module = ValueOperator(
179
+ module=module,
180
+ in_keys=[("agents", "observation")],
181
+ out_keys=[("agents", "action_value")],
182
+ )
183
+
184
+ collector = SyncDataCollector(
185
+ env,
186
+ policy,
187
+ device=cfg.env.device,
188
+ storing_device=cfg.train.device,
189
+ frames_per_batch=cfg.collector.frames_per_batch,
190
+ total_frames=cfg.collector.total_frames,
191
+ postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys),
192
+ )
193
+
194
+ replay_buffer = TensorDictReplayBuffer(
195
+ storage=LazyTensorStorage(cfg.buffer.memory_size, device=cfg.train.device),
196
+ sampler=SamplerWithoutReplacement(),
197
+ batch_size=cfg.train.minibatch_size,
198
+ )
199
+
200
+ if cfg.env.continuous_actions:
201
+ loss_module = SACLoss(
202
+ actor_network=policy,
203
+ qvalue_network=value_module,
204
+ delay_qvalue=True,
205
+ action_spec=env.full_action_spec_unbatched,
206
+ )
207
+ loss_module.set_keys(
208
+ state_action_value=("agents", "state_action_value"),
209
+ action=env.action_key,
210
+ reward=env.reward_key,
211
+ done=("agents", "done"),
212
+ terminated=("agents", "terminated"),
213
+ )
214
+ else:
215
+ loss_module = DiscreteSACLoss(
216
+ actor_network=policy,
217
+ qvalue_network=value_module,
218
+ delay_qvalue=True,
219
+ num_actions=env.full_action_spec_unbatched["agents", "action"].space.n,
220
+ action_space=env.full_action_spec_unbatched,
221
+ )
222
+ loss_module.set_keys(
223
+ action_value=("agents", "action_value"),
224
+ action=env.action_key,
225
+ reward=env.reward_key,
226
+ done=("agents", "done"),
227
+ terminated=("agents", "terminated"),
228
+ )
229
+
230
+ loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma)
231
+ target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau)
232
+
233
+ optim = torch.optim.Adam(loss_module.parameters(), cfg.train.lr)
234
+
235
+ # Logging
236
+ if cfg.logger.backend:
237
+ model_name = (
238
+ ("Het" if not cfg.model.shared_parameters else "")
239
+ + ("MA" if cfg.model.centralised_critic else "I")
240
+ + "SAC"
241
+ )
242
+ logger = init_logging(cfg, model_name)
243
+
244
+ total_time = 0
245
+ total_frames = 0
246
+ sampling_start = time.time()
247
+ for i, tensordict_data in enumerate(collector):
248
+ torchrl_logger.info(f"\nIteration {i}")
249
+
250
+ sampling_time = time.time() - sampling_start
251
+
252
+ current_frames = tensordict_data.numel()
253
+ total_frames += current_frames
254
+ data_view = tensordict_data.reshape(-1)
255
+ replay_buffer.extend(data_view)
256
+
257
+ training_tds = []
258
+ training_start = time.time()
259
+ for _ in range(cfg.train.num_epochs):
260
+ for _ in range(cfg.collector.frames_per_batch // cfg.train.minibatch_size):
261
+ subdata = replay_buffer.sample()
262
+ loss_vals = loss_module(subdata)
263
+ training_tds.append(loss_vals.detach())
264
+
265
+ loss_value = (
266
+ loss_vals["loss_actor"]
267
+ + loss_vals["loss_alpha"]
268
+ + loss_vals["loss_qvalue"]
269
+ )
270
+
271
+ loss_value.backward()
272
+
273
+ total_norm = torch.nn.utils.clip_grad_norm_(
274
+ loss_module.parameters(), cfg.train.max_grad_norm
275
+ )
276
+ training_tds[-1].set("grad_norm", total_norm.mean())
277
+
278
+ optim.step()
279
+ optim.zero_grad()
280
+ target_net_updater.step()
281
+
282
+ collector.update_policy_weights_()
283
+
284
+ training_time = time.time() - training_start
285
+
286
+ iteration_time = sampling_time + training_time
287
+ total_time += iteration_time
288
+ training_tds = torch.stack(training_tds)
289
+
290
+ # More logs
291
+ if cfg.logger.backend:
292
+ log_training(
293
+ logger,
294
+ training_tds,
295
+ tensordict_data,
296
+ sampling_time,
297
+ training_time,
298
+ total_time,
299
+ i,
300
+ current_frames,
301
+ total_frames,
302
+ step=i,
303
+ )
304
+
305
+ if (
306
+ cfg.eval.evaluation_episodes > 0
307
+ and i % cfg.eval.evaluation_interval == 0
308
+ and cfg.logger.backend
309
+ ):
310
+ evaluation_start = time.time()
311
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
312
+ env_test.frames = []
313
+ rollouts = env_test.rollout(
314
+ max_steps=cfg.env.max_steps,
315
+ policy=policy,
316
+ callback=rendering_callback,
317
+ auto_cast_to_device=True,
318
+ break_when_any_done=False,
319
+ # We are running vectorized evaluation we do not want it to stop when just one env is done
320
+ )
321
+
322
+ evaluation_time = time.time() - evaluation_start
323
+
324
+ log_evaluation(logger, rollouts, env_test, evaluation_time, step=i)
325
+
326
+ if cfg.logger.backend == "wandb":
327
+ logger.experiment.log({}, commit=True)
328
+ sampling_start = time.time()
329
+ collector.shutdown()
330
+ if not env.is_closed:
331
+ env.close()
332
+ if not env_test.is_closed:
333
+ env_test.close()
334
+
335
+
336
+ if __name__ == "__main__":
337
+ train()
@@ -0,0 +1,4 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
@@ -0,0 +1,151 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import os
8
+
9
+ import numpy as np
10
+ import torch
11
+ from tensordict import TensorDictBase
12
+ from torchrl.envs.libs.vmas import VmasEnv
13
+ from torchrl.record.loggers import generate_exp_name, get_logger, Logger
14
+ from torchrl.record.loggers.wandb import WandbLogger
15
+
16
+
17
+ def init_logging(cfg, model_name: str):
18
+ logger = get_logger(
19
+ logger_type=cfg.logger.backend,
20
+ logger_name=os.getcwd(),
21
+ experiment_name=generate_exp_name(cfg.env.scenario_name, model_name),
22
+ wandb_kwargs={
23
+ "group": cfg.logger.group_name or model_name,
24
+ "project": cfg.logger.project_name
25
+ or f"torchrl_example_{cfg.env.scenario_name}",
26
+ },
27
+ )
28
+ logger.log_hparams(cfg)
29
+ return logger
30
+
31
+
32
+ def log_training(
33
+ logger: Logger,
34
+ training_td: TensorDictBase,
35
+ sampling_td: TensorDictBase,
36
+ sampling_time: float,
37
+ training_time: float,
38
+ total_time: float,
39
+ iteration: int,
40
+ current_frames: int,
41
+ total_frames: int,
42
+ step: int,
43
+ ):
44
+ if ("next", "agents", "reward") not in sampling_td.keys(True, True):
45
+ sampling_td.set(
46
+ ("next", "agents", "reward"),
47
+ sampling_td.get(("next", "reward"))
48
+ .expand(sampling_td.get("agents").shape)
49
+ .unsqueeze(-1),
50
+ )
51
+ if ("next", "agents", "episode_reward") not in sampling_td.keys(True, True):
52
+ sampling_td.set(
53
+ ("next", "agents", "episode_reward"),
54
+ sampling_td.get(("next", "episode_reward"))
55
+ .expand(sampling_td.get("agents").shape)
56
+ .unsqueeze(-1),
57
+ )
58
+
59
+ metrics_to_log = {
60
+ f"train/learner/{key}": value.mean().item()
61
+ for key, value in training_td.items()
62
+ }
63
+
64
+ if "info" in sampling_td.get("agents").keys():
65
+ metrics_to_log.update(
66
+ {
67
+ f"train/info/{key}": value.mean().item()
68
+ for key, value in sampling_td.get(("agents", "info")).items()
69
+ }
70
+ )
71
+
72
+ reward = sampling_td.get(("next", "agents", "reward")).mean(-2) # Mean over agents
73
+ done = sampling_td.get(("next", "done"))
74
+ if done.ndim > reward.ndim:
75
+ done = done[..., 0, :] # Remove expanded agent dim
76
+ episode_reward = sampling_td.get(("next", "agents", "episode_reward")).mean(-2)[
77
+ done
78
+ ]
79
+ metrics_to_log.update(
80
+ {
81
+ "train/reward/reward_min": reward.min().item(),
82
+ "train/reward/reward_mean": reward.mean().item(),
83
+ "train/reward/reward_max": reward.max().item(),
84
+ "train/reward/episode_reward_min": episode_reward.min().item(),
85
+ "train/reward/episode_reward_mean": episode_reward.mean().item(),
86
+ "train/reward/episode_reward_max": episode_reward.max().item(),
87
+ "train/sampling_time": sampling_time,
88
+ "train/training_time": training_time,
89
+ "train/iteration_time": training_time + sampling_time,
90
+ "train/total_time": total_time,
91
+ "train/training_iteration": iteration,
92
+ "train/current_frames": current_frames,
93
+ "train/total_frames": total_frames,
94
+ }
95
+ )
96
+ if isinstance(logger, WandbLogger):
97
+ logger.experiment.log(metrics_to_log, commit=False)
98
+ else:
99
+ for key, value in metrics_to_log.items():
100
+ logger.log_scalar(key.replace("/", "_"), value, step=step)
101
+
102
+ return metrics_to_log
103
+
104
+
105
+ def log_evaluation(
106
+ logger: WandbLogger,
107
+ rollouts: TensorDictBase,
108
+ env_test: VmasEnv,
109
+ evaluation_time: float,
110
+ step: int,
111
+ ):
112
+ rollouts = list(rollouts.unbind(0))
113
+ for k, r in enumerate(rollouts):
114
+ next_done = r.get(("next", "done")).sum(
115
+ tuple(range(r.batch_dims, r.get(("next", "done")).ndim)),
116
+ dtype=torch.bool,
117
+ )
118
+ done_index = next_done.nonzero(as_tuple=True)[0][
119
+ 0
120
+ ] # First done index for this traj
121
+ rollouts[k] = r[: done_index + 1]
122
+
123
+ rewards = [td.get(("next", "agents", "reward")).sum(0).mean() for td in rollouts]
124
+ metrics_to_log = {
125
+ "eval/episode_reward_min": min(rewards),
126
+ "eval/episode_reward_max": max(rewards),
127
+ "eval/episode_reward_mean": sum(rewards) / len(rollouts),
128
+ "eval/episode_len_mean": sum([td.batch_size[0] for td in rollouts])
129
+ / len(rollouts),
130
+ "eval/evaluation_time": evaluation_time,
131
+ }
132
+
133
+ vid = torch.tensor(
134
+ np.transpose(env_test.frames[: rollouts[0].batch_size[0]], (0, 3, 1, 2)),
135
+ dtype=torch.uint8,
136
+ ).unsqueeze(0)
137
+
138
+ if isinstance(logger, WandbLogger):
139
+ import wandb
140
+
141
+ logger.experiment.log(metrics_to_log, commit=False)
142
+ logger.experiment.log(
143
+ {
144
+ "eval/video": wandb.Video(vid, fps=1 / env_test.world.dt, format="mp4"),
145
+ },
146
+ commit=False,
147
+ )
148
+ else:
149
+ for key, value in metrics_to_log.items():
150
+ logger.log_scalar(key.replace("/", "_"), value, step=step)
151
+ logger.log_video("eval_video", vid, step=step)
@@ -0,0 +1,43 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from tensordict import unravel_key
8
+ from torchrl.envs import Transform
9
+
10
+
11
+ def swap_last(source, dest):
12
+ source = unravel_key(source)
13
+ dest = unravel_key(dest)
14
+ if isinstance(source, str):
15
+ if isinstance(dest, str):
16
+ return dest
17
+ return dest[-1]
18
+ if isinstance(dest, str):
19
+ return source[:-1] + (dest,)
20
+ return source[:-1] + (dest[-1],)
21
+
22
+
23
+ class DoneTransform(Transform):
24
+ """Expands the 'done' entries (incl. terminated) to match the reward shape.
25
+
26
+ Can be appended to a replay buffer or a collector.
27
+ """
28
+
29
+ def __init__(self, reward_key, done_keys):
30
+ super().__init__()
31
+ self.reward_key = reward_key
32
+ self.done_keys = done_keys
33
+
34
+ def forward(self, tensordict):
35
+ for done_key in self.done_keys:
36
+ new_name = swap_last(self.reward_key, done_key)
37
+ tensordict.set(
38
+ ("next", new_name),
39
+ tensordict.get(("next", done_key))
40
+ .unsqueeze(-1)
41
+ .expand(tensordict.get(("next", self.reward_key)).shape),
42
+ )
43
+ return tensordict
@@ -0,0 +1,29 @@
1
+ ## Reproducing Proximal Policy Optimization (PPO) Algorithm Results
2
+
3
+ This repository contains scripts that enable training agents using the Proximal Policy Optimization (PPO) Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Schulman et al. (2017) to implement the PPO algorithm but introduce the improvement of computing the Generalised Advantage Estimator (GAE) at every epoch.
4
+
5
+
6
+ ## Examples Structure
7
+
8
+ Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files:
9
+
10
+ 1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. ppo_atari.py).
11
+
12
+ 2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils_atari.py).
13
+
14
+ 3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. config_atari.yaml).
15
+
16
+
17
+ ## Running the Examples
18
+
19
+ You can execute the PPO algorithm on Atari environments by running the following command:
20
+
21
+ ```bash
22
+ python ppo_atari.py
23
+ ```
24
+
25
+ You can execute the PPO algorithm on MuJoCo environments by running the following command:
26
+
27
+ ```bash
28
+ python ppo_mujoco.py
29
+ ```