torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,267 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import time
8
+
9
+ import hydra
10
+ import torch
11
+ from tensordict.nn import TensorDictModule
12
+ from tensordict.nn.distributions import NormalParamExtractor
13
+ from torch import nn
14
+ from torchrl._utils import logger as torchrl_logger
15
+ from torchrl.collectors import SyncDataCollector
16
+ from torchrl.data import TensorDictReplayBuffer
17
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
18
+ from torchrl.data.replay_buffers.storages import LazyTensorStorage
19
+ from torchrl.envs import RewardSum, TransformedEnv
20
+ from torchrl.envs.libs.vmas import VmasEnv
21
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
22
+ from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
23
+ from torchrl.modules.models.multiagent import MultiAgentMLP
24
+ from torchrl.objectives import ClipPPOLoss, ValueEstimators
25
+ from utils.logging import init_logging, log_evaluation, log_training
26
+ from utils.utils import DoneTransform
27
+
28
+
29
+ def rendering_callback(env, td):
30
+ env.frames.append(env.render(mode="rgb_array", agent_index_focus=None))
31
+
32
+
33
+ @hydra.main(version_base="1.1", config_path="", config_name="mappo_ippo")
34
+ def train(cfg: DictConfig): # noqa: F821
35
+ # Device
36
+ cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0"
37
+ cfg.env.device = cfg.train.device
38
+
39
+ # Seeding
40
+ torch.manual_seed(cfg.seed)
41
+
42
+ # Sampling
43
+ cfg.env.vmas_envs = cfg.collector.frames_per_batch // cfg.env.max_steps
44
+ cfg.collector.total_frames = cfg.collector.frames_per_batch * cfg.collector.n_iters
45
+ cfg.buffer.memory_size = cfg.collector.frames_per_batch
46
+
47
+ # Create env and env_test
48
+ env = VmasEnv(
49
+ scenario=cfg.env.scenario_name,
50
+ num_envs=cfg.env.vmas_envs,
51
+ continuous_actions=True,
52
+ max_steps=cfg.env.max_steps,
53
+ device=cfg.env.device,
54
+ seed=cfg.seed,
55
+ # Scenario kwargs
56
+ **cfg.env.scenario,
57
+ )
58
+ env = TransformedEnv(
59
+ env,
60
+ RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]),
61
+ )
62
+
63
+ env_test = VmasEnv(
64
+ scenario=cfg.env.scenario_name,
65
+ num_envs=cfg.eval.evaluation_episodes,
66
+ continuous_actions=True,
67
+ max_steps=cfg.env.max_steps,
68
+ device=cfg.env.device,
69
+ seed=cfg.seed,
70
+ # Scenario kwargs
71
+ **cfg.env.scenario,
72
+ )
73
+
74
+ # Policy
75
+ actor_net = nn.Sequential(
76
+ MultiAgentMLP(
77
+ n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
78
+ n_agent_outputs=2
79
+ * env.full_action_spec_unbatched[env.action_key].shape[-1],
80
+ n_agents=env.n_agents,
81
+ centralised=False,
82
+ share_params=cfg.model.shared_parameters,
83
+ device=cfg.train.device,
84
+ depth=2,
85
+ num_cells=256,
86
+ activation_class=nn.Tanh,
87
+ ),
88
+ NormalParamExtractor(),
89
+ )
90
+ policy_module = TensorDictModule(
91
+ actor_net,
92
+ in_keys=[("agents", "observation")],
93
+ out_keys=[("agents", "loc"), ("agents", "scale")],
94
+ )
95
+ policy = ProbabilisticActor(
96
+ module=policy_module,
97
+ spec=env.full_action_spec_unbatched,
98
+ in_keys=[("agents", "loc"), ("agents", "scale")],
99
+ out_keys=[env.action_key],
100
+ distribution_class=TanhNormal,
101
+ distribution_kwargs={
102
+ "low": env.full_action_spec_unbatched[("agents", "action")].space.low,
103
+ "high": env.full_action_spec_unbatched[("agents", "action")].space.high,
104
+ },
105
+ return_log_prob=True,
106
+ )
107
+
108
+ # Critic
109
+ module = MultiAgentMLP(
110
+ n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
111
+ n_agent_outputs=1,
112
+ n_agents=env.n_agents,
113
+ centralised=cfg.model.centralised_critic,
114
+ share_params=cfg.model.shared_parameters,
115
+ device=cfg.train.device,
116
+ depth=2,
117
+ num_cells=256,
118
+ activation_class=nn.Tanh,
119
+ )
120
+ value_module = ValueOperator(
121
+ module=module,
122
+ in_keys=[("agents", "observation")],
123
+ )
124
+
125
+ collector = SyncDataCollector(
126
+ env,
127
+ policy,
128
+ device=cfg.env.device,
129
+ storing_device=cfg.train.device,
130
+ frames_per_batch=cfg.collector.frames_per_batch,
131
+ total_frames=cfg.collector.total_frames,
132
+ postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys),
133
+ )
134
+
135
+ replay_buffer = TensorDictReplayBuffer(
136
+ storage=LazyTensorStorage(cfg.buffer.memory_size, device=cfg.train.device),
137
+ sampler=SamplerWithoutReplacement(),
138
+ batch_size=cfg.train.minibatch_size,
139
+ )
140
+
141
+ # Loss
142
+ loss_module = ClipPPOLoss(
143
+ actor_network=policy,
144
+ critic_network=value_module,
145
+ clip_epsilon=cfg.loss.clip_epsilon,
146
+ entropy_coeff=cfg.loss.entropy_eps,
147
+ normalize_advantage=False,
148
+ )
149
+ loss_module.set_keys(
150
+ reward=env.reward_key,
151
+ action=env.action_key,
152
+ done=("agents", "done"),
153
+ terminated=("agents", "terminated"),
154
+ )
155
+ loss_module.make_value_estimator(
156
+ ValueEstimators.GAE, gamma=cfg.loss.gamma, lmbda=cfg.loss.lmbda
157
+ )
158
+ optim = torch.optim.Adam(loss_module.parameters(), cfg.train.lr)
159
+
160
+ # Logging
161
+ if cfg.logger.backend:
162
+ model_name = (
163
+ ("Het" if not cfg.model.shared_parameters else "")
164
+ + ("MA" if cfg.model.centralised_critic else "I")
165
+ + "PPO"
166
+ )
167
+ logger = init_logging(cfg, model_name)
168
+
169
+ total_time = 0
170
+ total_frames = 0
171
+ sampling_start = time.time()
172
+ for i, tensordict_data in enumerate(collector):
173
+ torchrl_logger.info(f"\nIteration {i}")
174
+
175
+ sampling_time = time.time() - sampling_start
176
+
177
+ with torch.no_grad():
178
+ loss_module.value_estimator(
179
+ tensordict_data,
180
+ params=loss_module.critic_network_params,
181
+ target_params=loss_module.target_critic_network_params,
182
+ )
183
+ current_frames = tensordict_data.numel()
184
+ total_frames += current_frames
185
+ data_view = tensordict_data.reshape(-1)
186
+ replay_buffer.extend(data_view)
187
+
188
+ training_tds = []
189
+ training_start = time.time()
190
+ for _ in range(cfg.train.num_epochs):
191
+ for _ in range(cfg.collector.frames_per_batch // cfg.train.minibatch_size):
192
+ subdata = replay_buffer.sample()
193
+ loss_vals = loss_module(subdata)
194
+ training_tds.append(loss_vals.detach())
195
+
196
+ loss_value = (
197
+ loss_vals["loss_objective"]
198
+ + loss_vals["loss_critic"]
199
+ + loss_vals["loss_entropy"]
200
+ )
201
+
202
+ loss_value.backward()
203
+
204
+ total_norm = torch.nn.utils.clip_grad_norm_(
205
+ loss_module.parameters(), cfg.train.max_grad_norm
206
+ )
207
+ training_tds[-1].set("grad_norm", total_norm.mean())
208
+
209
+ optim.step()
210
+ optim.zero_grad()
211
+
212
+ collector.update_policy_weights_()
213
+
214
+ training_time = time.time() - training_start
215
+
216
+ iteration_time = sampling_time + training_time
217
+ total_time += iteration_time
218
+ training_tds = torch.stack(training_tds)
219
+
220
+ # More logs
221
+ if cfg.logger.backend:
222
+ log_training(
223
+ logger,
224
+ training_tds,
225
+ tensordict_data,
226
+ sampling_time,
227
+ training_time,
228
+ total_time,
229
+ i,
230
+ current_frames,
231
+ total_frames,
232
+ step=i,
233
+ )
234
+
235
+ if (
236
+ cfg.eval.evaluation_episodes > 0
237
+ and i % cfg.eval.evaluation_interval == 0
238
+ and cfg.logger.backend
239
+ ):
240
+ evaluation_start = time.time()
241
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
242
+ env_test.frames = []
243
+ rollouts = env_test.rollout(
244
+ max_steps=cfg.env.max_steps,
245
+ policy=policy,
246
+ callback=rendering_callback,
247
+ auto_cast_to_device=True,
248
+ break_when_any_done=False,
249
+ # We are running vectorized evaluation we do not want it to stop when just one env is done
250
+ )
251
+
252
+ evaluation_time = time.time() - evaluation_start
253
+
254
+ log_evaluation(logger, rollouts, env_test, evaluation_time, step=i)
255
+
256
+ if cfg.logger.backend == "wandb":
257
+ logger.experiment.log({}, commit=True)
258
+ sampling_start = time.time()
259
+ collector.shutdown()
260
+ if not env.is_closed:
261
+ env.close()
262
+ if not env_test.is_closed:
263
+ env_test.close()
264
+
265
+
266
+ if __name__ == "__main__":
267
+ train()
@@ -0,0 +1,271 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import time
8
+
9
+ import hydra
10
+ import torch
11
+ from tensordict.nn import TensorDictModule, TensorDictSequential
12
+ from torch import nn
13
+ from torchrl._utils import logger as torchrl_logger
14
+ from torchrl.collectors import SyncDataCollector
15
+ from torchrl.data import TensorDictReplayBuffer
16
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
17
+ from torchrl.data.replay_buffers.storages import LazyTensorStorage
18
+ from torchrl.envs import RewardSum, TransformedEnv
19
+ from torchrl.envs.libs.vmas import VmasEnv
20
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
21
+ from torchrl.modules import EGreedyModule, QValueModule, SafeSequential
22
+ from torchrl.modules.models.multiagent import MultiAgentMLP, QMixer, VDNMixer
23
+ from torchrl.objectives import SoftUpdate, ValueEstimators
24
+ from torchrl.objectives.multiagent.qmixer import QMixerLoss
25
+ from utils.logging import init_logging, log_evaluation, log_training
26
+
27
+
28
+ def rendering_callback(env, td):
29
+ env.frames.append(env.render(mode="rgb_array", agent_index_focus=None))
30
+
31
+
32
+ @hydra.main(version_base="1.1", config_path="", config_name="qmix_vdn")
33
+ def train(cfg: DictConfig): # noqa: F821
34
+ # Device
35
+ cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0"
36
+ cfg.env.device = cfg.train.device
37
+
38
+ # Seeding
39
+ torch.manual_seed(cfg.seed)
40
+
41
+ # Sampling
42
+ cfg.env.vmas_envs = cfg.collector.frames_per_batch // cfg.env.max_steps
43
+ cfg.collector.total_frames = cfg.collector.frames_per_batch * cfg.collector.n_iters
44
+ cfg.buffer.memory_size = cfg.collector.frames_per_batch
45
+
46
+ # Create env and env_test
47
+ env = VmasEnv(
48
+ scenario=cfg.env.scenario_name,
49
+ num_envs=cfg.env.vmas_envs,
50
+ continuous_actions=False,
51
+ max_steps=cfg.env.max_steps,
52
+ device=cfg.env.device,
53
+ seed=cfg.seed,
54
+ # Scenario kwargs
55
+ **cfg.env.scenario,
56
+ )
57
+ env = TransformedEnv(
58
+ env,
59
+ RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]),
60
+ )
61
+
62
+ env_test = VmasEnv(
63
+ scenario=cfg.env.scenario_name,
64
+ num_envs=cfg.eval.evaluation_episodes,
65
+ continuous_actions=False,
66
+ max_steps=cfg.env.max_steps,
67
+ device=cfg.env.device,
68
+ seed=cfg.seed,
69
+ # Scenario kwargs
70
+ **cfg.env.scenario,
71
+ )
72
+
73
+ # Policy
74
+ net = MultiAgentMLP(
75
+ n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
76
+ n_agent_outputs=env.full_action_spec["agents", "action"].space.n,
77
+ n_agents=env.n_agents,
78
+ centralised=False,
79
+ share_params=cfg.model.shared_parameters,
80
+ device=cfg.train.device,
81
+ depth=2,
82
+ num_cells=256,
83
+ activation_class=nn.Tanh,
84
+ )
85
+ module = TensorDictModule(
86
+ net, in_keys=[("agents", "observation")], out_keys=[("agents", "action_value")]
87
+ )
88
+ value_module = QValueModule(
89
+ action_value_key=("agents", "action_value"),
90
+ out_keys=[
91
+ env.action_key,
92
+ ("agents", "action_value"),
93
+ ("agents", "chosen_action_value"),
94
+ ],
95
+ spec=env.full_action_spec_unbatched,
96
+ action_space=None,
97
+ )
98
+ qnet = SafeSequential(module, value_module)
99
+
100
+ qnet_explore = TensorDictSequential(
101
+ qnet,
102
+ EGreedyModule(
103
+ eps_init=0.3,
104
+ eps_end=0,
105
+ annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
106
+ action_key=env.action_key,
107
+ spec=env.full_action_spec_unbatched,
108
+ ),
109
+ )
110
+
111
+ if cfg.loss.mixer_type == "qmix":
112
+ mixer = TensorDictModule(
113
+ module=QMixer(
114
+ state_shape=env.observation_spec_unbatched[
115
+ "agents", "observation"
116
+ ].shape,
117
+ mixing_embed_dim=32,
118
+ n_agents=env.n_agents,
119
+ device=cfg.train.device,
120
+ ),
121
+ in_keys=[("agents", "chosen_action_value"), ("agents", "observation")],
122
+ out_keys=["chosen_action_value"],
123
+ )
124
+ elif cfg.loss.mixer_type == "vdn":
125
+ mixer = TensorDictModule(
126
+ module=VDNMixer(
127
+ n_agents=env.n_agents,
128
+ device=cfg.train.device,
129
+ ),
130
+ in_keys=[("agents", "chosen_action_value")],
131
+ out_keys=["chosen_action_value"],
132
+ )
133
+ else:
134
+ raise ValueError("Mixer type not in the example")
135
+
136
+ collector = SyncDataCollector(
137
+ env,
138
+ qnet_explore,
139
+ device=cfg.env.device,
140
+ storing_device=cfg.train.device,
141
+ frames_per_batch=cfg.collector.frames_per_batch,
142
+ total_frames=cfg.collector.total_frames,
143
+ )
144
+
145
+ replay_buffer = TensorDictReplayBuffer(
146
+ storage=LazyTensorStorage(cfg.buffer.memory_size, device=cfg.train.device),
147
+ sampler=SamplerWithoutReplacement(),
148
+ batch_size=cfg.train.minibatch_size,
149
+ )
150
+
151
+ loss_module = QMixerLoss(qnet, mixer, delay_value=True)
152
+ loss_module.set_keys(
153
+ action_value=("agents", "action_value"),
154
+ local_value=("agents", "chosen_action_value"),
155
+ global_value="chosen_action_value",
156
+ action=env.action_key,
157
+ )
158
+ loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma)
159
+ target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau)
160
+
161
+ optim = torch.optim.Adam(loss_module.parameters(), cfg.train.lr)
162
+
163
+ # Logging
164
+ if cfg.logger.backend:
165
+ model_name = (
166
+ "Het" if not cfg.model.shared_parameters else ""
167
+ ) + cfg.loss.mixer_type.upper()
168
+ logger = init_logging(cfg, model_name)
169
+
170
+ total_time = 0
171
+ total_frames = 0
172
+ sampling_start = time.time()
173
+ for i, tensordict_data in enumerate(collector):
174
+ torchrl_logger.info(f"\nIteration {i}")
175
+
176
+ sampling_time = time.time() - sampling_start
177
+
178
+ # Remove agent dimension from reward (since it is shared in QMIX/VDN)
179
+ tensordict_data.set(
180
+ ("next", "reward"), tensordict_data.get(("next", env.reward_key)).mean(-2)
181
+ )
182
+ del tensordict_data["next", env.reward_key]
183
+ tensordict_data.set(
184
+ ("next", "episode_reward"),
185
+ tensordict_data.get(("next", "agents", "episode_reward")).mean(-2),
186
+ )
187
+ del tensordict_data["next", "agents", "episode_reward"]
188
+
189
+ current_frames = tensordict_data.numel()
190
+ total_frames += current_frames
191
+ data_view = tensordict_data.reshape(-1)
192
+ replay_buffer.extend(data_view)
193
+
194
+ training_tds = []
195
+ training_start = time.time()
196
+ for _ in range(cfg.train.num_epochs):
197
+ for _ in range(cfg.collector.frames_per_batch // cfg.train.minibatch_size):
198
+ subdata = replay_buffer.sample()
199
+ loss_vals = loss_module(subdata)
200
+ training_tds.append(loss_vals.detach())
201
+
202
+ loss_value = loss_vals["loss"]
203
+
204
+ loss_value.backward()
205
+
206
+ total_norm = torch.nn.utils.clip_grad_norm_(
207
+ loss_module.parameters(), cfg.train.max_grad_norm
208
+ )
209
+ training_tds[-1].set("grad_norm", total_norm.mean())
210
+
211
+ optim.step()
212
+ optim.zero_grad()
213
+ target_net_updater.step()
214
+
215
+ qnet_explore[1].step(frames=current_frames) # Update exploration annealing
216
+ collector.update_policy_weights_()
217
+
218
+ training_time = time.time() - training_start
219
+
220
+ iteration_time = sampling_time + training_time
221
+ total_time += iteration_time
222
+ training_tds = torch.stack(training_tds)
223
+
224
+ # More logs
225
+ if cfg.logger.backend:
226
+ log_training(
227
+ logger,
228
+ training_tds,
229
+ tensordict_data,
230
+ sampling_time,
231
+ training_time,
232
+ total_time,
233
+ i,
234
+ current_frames,
235
+ total_frames,
236
+ step=i,
237
+ )
238
+
239
+ if (
240
+ cfg.eval.evaluation_episodes > 0
241
+ and i % cfg.eval.evaluation_interval == 0
242
+ and cfg.logger.backend
243
+ ):
244
+ evaluation_start = time.time()
245
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
246
+ env_test.frames = []
247
+ rollouts = env_test.rollout(
248
+ max_steps=cfg.env.max_steps,
249
+ policy=qnet,
250
+ callback=rendering_callback,
251
+ auto_cast_to_device=True,
252
+ break_when_any_done=False,
253
+ # We are running vectorized evaluation we do not want it to stop when just one env is done
254
+ )
255
+
256
+ evaluation_time = time.time() - evaluation_start
257
+
258
+ log_evaluation(logger, rollouts, env_test, evaluation_time, step=i)
259
+
260
+ if cfg.logger.backend == "wandb":
261
+ logger.experiment.log({}, commit=True)
262
+ sampling_start = time.time()
263
+ collector.shutdown()
264
+ if not env.is_closed:
265
+ env.close()
266
+ if not env_test.is_closed:
267
+ env_test.close()
268
+
269
+
270
+ if __name__ == "__main__":
271
+ train()