torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,437 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+
9
+ import torch.nn
10
+ import torch.optim
11
+ from tensordict.nn import InteractionType, TensorDictModule
12
+ from tensordict.nn.distributions import NormalParamExtractor
13
+ from torch.distributions import Categorical
14
+
15
+ from torchrl.collectors import SyncDataCollector
16
+ from torchrl.data import (
17
+ Composite,
18
+ LazyMemmapStorage,
19
+ TensorDictPrioritizedReplayBuffer,
20
+ TensorDictReplayBuffer,
21
+ )
22
+ from torchrl.data.datasets.d4rl import D4RLExperienceReplay
23
+ from torchrl.data.replay_buffers import SamplerWithoutReplacement
24
+ from torchrl.envs import (
25
+ CatTensors,
26
+ Compose,
27
+ DMControlEnv,
28
+ DoubleToFloat,
29
+ EnvCreator,
30
+ InitTracker,
31
+ ParallelEnv,
32
+ RewardSum,
33
+ TransformedEnv,
34
+ )
35
+
36
+ from torchrl.envs.libs.gym import GymEnv, set_gym_backend
37
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
38
+ from torchrl.modules import (
39
+ MLP,
40
+ ProbabilisticActor,
41
+ SafeModule,
42
+ TanhNormal,
43
+ ValueOperator,
44
+ )
45
+ from torchrl.objectives import DiscreteIQLLoss, HardUpdate, IQLLoss, SoftUpdate
46
+ from torchrl.record import VideoRecorder
47
+ from torchrl.trainers.helpers.models import ACTIVATIONS
48
+
49
+
50
+ # ====================================================================
51
+ # Environment utils
52
+ # -----------------
53
+
54
+
55
+ def env_maker(cfg, device="cpu", from_pixels=False):
56
+ lib = cfg.env.backend
57
+ if lib in ("gym", "gymnasium"):
58
+ with set_gym_backend(lib):
59
+ return GymEnv(
60
+ cfg.env.name,
61
+ device=device,
62
+ from_pixels=from_pixels,
63
+ pixels_only=False,
64
+ categorical_action_encoding=True,
65
+ )
66
+ elif lib == "dm_control":
67
+ env = DMControlEnv(
68
+ cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
69
+ )
70
+ return TransformedEnv(
71
+ env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
72
+ )
73
+ else:
74
+ raise NotImplementedError(f"Unknown lib {lib}.")
75
+
76
+
77
+ def apply_env_transforms(
78
+ env,
79
+ ):
80
+ transformed_env = TransformedEnv(
81
+ env,
82
+ Compose(
83
+ InitTracker(),
84
+ DoubleToFloat(),
85
+ RewardSum(),
86
+ ),
87
+ )
88
+ return transformed_env
89
+
90
+
91
+ def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None):
92
+ """Make environments for training and evaluation."""
93
+ maker = functools.partial(env_maker, cfg)
94
+ parallel_env = ParallelEnv(
95
+ train_num_envs,
96
+ EnvCreator(maker),
97
+ serial_for_single=True,
98
+ )
99
+ parallel_env.set_seed(cfg.env.seed)
100
+
101
+ train_env = apply_env_transforms(parallel_env)
102
+
103
+ maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video)
104
+ eval_env = TransformedEnv(
105
+ ParallelEnv(
106
+ eval_num_envs,
107
+ EnvCreator(maker),
108
+ serial_for_single=True,
109
+ ),
110
+ train_env.transform.clone(),
111
+ )
112
+ if cfg.logger.video:
113
+ eval_env.insert_transform(
114
+ 0, VideoRecorder(logger, tag="rendered", in_keys=["pixels"])
115
+ )
116
+ return train_env, eval_env
117
+
118
+
119
+ # ====================================================================
120
+ # Collector and replay buffer
121
+ # ---------------------------
122
+
123
+
124
+ def make_collector(cfg, train_env, actor_model_explore, compile_mode):
125
+ """Make collector."""
126
+ device = cfg.collector.device
127
+ if device in ("", None):
128
+ if torch.cuda.is_available():
129
+ device = torch.device("cuda:0")
130
+ else:
131
+ device = torch.device("cpu")
132
+ collector = SyncDataCollector(
133
+ train_env,
134
+ actor_model_explore,
135
+ frames_per_batch=cfg.collector.frames_per_batch,
136
+ init_random_frames=cfg.collector.init_random_frames,
137
+ max_frames_per_traj=cfg.collector.max_frames_per_traj,
138
+ total_frames=cfg.collector.total_frames,
139
+ device=device,
140
+ compile_policy={"mode": compile_mode} if compile_mode else False,
141
+ cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
142
+ )
143
+ collector.set_seed(cfg.env.seed)
144
+ return collector
145
+
146
+
147
+ def make_replay_buffer(
148
+ batch_size,
149
+ prb=False,
150
+ buffer_size=1000000,
151
+ scratch_dir=None,
152
+ device="cpu",
153
+ prefetch=3,
154
+ ):
155
+ if prb:
156
+ replay_buffer = TensorDictPrioritizedReplayBuffer(
157
+ alpha=0.7,
158
+ beta=0.5,
159
+ pin_memory=False,
160
+ prefetch=prefetch,
161
+ storage=LazyMemmapStorage(
162
+ buffer_size,
163
+ scratch_dir=scratch_dir,
164
+ device=device,
165
+ ),
166
+ batch_size=batch_size,
167
+ )
168
+ else:
169
+ replay_buffer = TensorDictReplayBuffer(
170
+ pin_memory=False,
171
+ prefetch=prefetch,
172
+ storage=LazyMemmapStorage(
173
+ buffer_size,
174
+ scratch_dir=scratch_dir,
175
+ device=device,
176
+ ),
177
+ batch_size=batch_size,
178
+ )
179
+ return replay_buffer
180
+
181
+
182
+ def make_offline_replay_buffer(rb_cfg):
183
+ data = D4RLExperienceReplay(
184
+ dataset_id=rb_cfg.dataset,
185
+ split_trajs=False,
186
+ batch_size=rb_cfg.batch_size,
187
+ # We use drop_last to avoid recompiles (and dynamic shapes)
188
+ sampler=SamplerWithoutReplacement(drop_last=True),
189
+ prefetch=4,
190
+ direct_download=True,
191
+ )
192
+
193
+ data.append_transform(DoubleToFloat())
194
+
195
+ return data
196
+
197
+
198
+ # ====================================================================
199
+ # Model
200
+ # -----
201
+ #
202
+ # We give one version of the model for learning from pixels, and one for state.
203
+ # TorchRL comes in handy at this point, as the high-level interactions with
204
+ # these models is unchanged, regardless of the modality.
205
+ #
206
+
207
+
208
+ def make_iql_model(cfg, train_env, eval_env, device="cpu"):
209
+ model_cfg = cfg.model
210
+
211
+ in_keys = ["observation"]
212
+ action_spec = train_env.action_spec_unbatched
213
+ actor_net, q_net, value_net = make_iql_modules_state(model_cfg, eval_env)
214
+
215
+ out_keys = ["loc", "scale"]
216
+
217
+ actor_module = TensorDictModule(actor_net, in_keys=in_keys, out_keys=out_keys)
218
+
219
+ # We use a ProbabilisticActor to make sure that we map the
220
+ # network output to the right space using a TanhDelta
221
+ # distribution.
222
+ actor = ProbabilisticActor(
223
+ module=actor_module,
224
+ in_keys=["loc", "scale"],
225
+ spec=action_spec,
226
+ distribution_class=TanhNormal,
227
+ distribution_kwargs={
228
+ "low": action_spec.space.low.to(device),
229
+ "high": action_spec.space.high.to(device),
230
+ "tanh_loc": False,
231
+ },
232
+ default_interaction_type=ExplorationType.RANDOM,
233
+ )
234
+
235
+ in_keys = ["observation", "action"]
236
+
237
+ out_keys = ["state_action_value"]
238
+ qvalue = ValueOperator(
239
+ in_keys=in_keys,
240
+ out_keys=out_keys,
241
+ module=q_net,
242
+ )
243
+ in_keys = ["observation"]
244
+ out_keys = ["state_value"]
245
+ value_net = ValueOperator(
246
+ in_keys=in_keys,
247
+ out_keys=out_keys,
248
+ module=value_net,
249
+ )
250
+ model = torch.nn.ModuleList([actor, qvalue, value_net]).to(device)
251
+ # init nets
252
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
253
+ td = eval_env.fake_tensordict()
254
+ td = td.to(device)
255
+ for net in model:
256
+ net(td)
257
+
258
+ return model
259
+
260
+
261
+ def make_iql_modules_state(model_cfg, proof_environment):
262
+ action_spec = proof_environment.action_spec_unbatched
263
+
264
+ actor_net_kwargs = {
265
+ "num_cells": model_cfg.hidden_sizes,
266
+ "out_features": 2 * action_spec.shape[-1],
267
+ "activation_class": ACTIVATIONS[model_cfg.activation],
268
+ }
269
+ actor_net = MLP(**actor_net_kwargs)
270
+ actor_extractor = NormalParamExtractor(
271
+ scale_mapping=f"biased_softplus_{model_cfg.default_policy_scale}",
272
+ scale_lb=model_cfg.scale_lb,
273
+ )
274
+ actor_net = torch.nn.Sequential(actor_net, actor_extractor)
275
+
276
+ qvalue_net_kwargs = {
277
+ "num_cells": model_cfg.hidden_sizes,
278
+ "out_features": 1,
279
+ "activation_class": ACTIVATIONS[model_cfg.activation],
280
+ }
281
+
282
+ q_net = MLP(**qvalue_net_kwargs)
283
+
284
+ # Define Value Network
285
+ value_net_kwargs = {
286
+ "num_cells": model_cfg.hidden_sizes,
287
+ "out_features": 1,
288
+ "activation_class": ACTIVATIONS[model_cfg.activation],
289
+ }
290
+ value_net = MLP(**value_net_kwargs)
291
+
292
+ return actor_net, q_net, value_net
293
+
294
+
295
+ def make_discrete_iql_model(cfg, train_env, eval_env, device):
296
+ """Make discrete IQL agent."""
297
+ # Define Actor Network
298
+ in_keys = ["observation"]
299
+ action_spec = train_env.action_spec_unbatched
300
+ # Define Actor Network
301
+ in_keys = ["observation"]
302
+
303
+ actor_net = MLP(
304
+ num_cells=cfg.model.hidden_sizes,
305
+ out_features=action_spec.space.n,
306
+ activation_class=ACTIVATIONS[cfg.model.activation],
307
+ device=device,
308
+ )
309
+
310
+ actor_module = SafeModule(
311
+ module=actor_net,
312
+ in_keys=in_keys,
313
+ out_keys=["logits"],
314
+ )
315
+ actor = ProbabilisticActor(
316
+ spec=Composite(action=eval_env.action_spec_unbatched).to(device),
317
+ module=actor_module,
318
+ in_keys=["logits"],
319
+ out_keys=["action"],
320
+ distribution_class=Categorical,
321
+ distribution_kwargs={},
322
+ default_interaction_type=InteractionType.RANDOM,
323
+ return_log_prob=False,
324
+ )
325
+
326
+ # Define Critic Network
327
+ qvalue_net = MLP(
328
+ num_cells=cfg.model.hidden_sizes,
329
+ out_features=action_spec.space.n,
330
+ activation_class=ACTIVATIONS[cfg.model.activation],
331
+ device=device,
332
+ )
333
+ qvalue = TensorDictModule(
334
+ in_keys=["observation"],
335
+ out_keys=["state_action_value"],
336
+ module=qvalue_net,
337
+ )
338
+
339
+ # Define Value Network
340
+ value_net = MLP(
341
+ num_cells=cfg.model.hidden_sizes,
342
+ out_features=1,
343
+ activation_class=ACTIVATIONS[cfg.model.activation],
344
+ device=device,
345
+ )
346
+ value_net = TensorDictModule(
347
+ in_keys=["observation"],
348
+ out_keys=["state_value"],
349
+ module=value_net,
350
+ )
351
+
352
+ model = torch.nn.ModuleList([actor, qvalue, value_net])
353
+ # init nets
354
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
355
+ td = eval_env.fake_tensordict()
356
+ td = td.to(device)
357
+ for net in model:
358
+ net(td)
359
+
360
+ return model
361
+
362
+
363
+ # ====================================================================
364
+ # IQL Loss
365
+ # ---------
366
+
367
+
368
+ def make_loss(loss_cfg, model, device):
369
+ loss_module = IQLLoss(
370
+ model[0],
371
+ model[1],
372
+ value_network=model[2],
373
+ loss_function=loss_cfg.loss_function,
374
+ temperature=loss_cfg.temperature,
375
+ expectile=loss_cfg.expectile,
376
+ )
377
+ loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
378
+ target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)
379
+
380
+ return loss_module, target_net_updater
381
+
382
+
383
+ def make_discrete_loss(loss_cfg, model, device):
384
+ loss_module = DiscreteIQLLoss(
385
+ model[0],
386
+ model[1],
387
+ value_network=model[2],
388
+ loss_function=loss_cfg.loss_function,
389
+ temperature=loss_cfg.temperature,
390
+ expectile=loss_cfg.expectile,
391
+ action_space="categorical",
392
+ )
393
+ loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
394
+ target_net_updater = HardUpdate(
395
+ loss_module, value_network_update_interval=loss_cfg.hard_update_interval
396
+ )
397
+
398
+ return loss_module, target_net_updater
399
+
400
+
401
+ def make_iql_optimizer(optim_cfg, loss_module):
402
+ critic_params = list(loss_module.qvalue_network_params.flatten_keys().values())
403
+ actor_params = list(loss_module.actor_network_params.flatten_keys().values())
404
+ value_params = list(loss_module.value_network_params.flatten_keys().values())
405
+
406
+ optimizer_actor = torch.optim.Adam(
407
+ actor_params,
408
+ lr=optim_cfg.lr,
409
+ weight_decay=optim_cfg.weight_decay,
410
+ )
411
+ optimizer_critic = torch.optim.Adam(
412
+ critic_params,
413
+ lr=optim_cfg.lr,
414
+ weight_decay=optim_cfg.weight_decay,
415
+ )
416
+ optimizer_value = torch.optim.Adam(
417
+ value_params,
418
+ lr=optim_cfg.lr,
419
+ weight_decay=optim_cfg.weight_decay,
420
+ )
421
+ return optimizer_actor, optimizer_critic, optimizer_value
422
+
423
+
424
+ # ====================================================================
425
+ # General utils
426
+ # ---------
427
+
428
+
429
+ def log_metrics(logger, metrics, step):
430
+ if logger is not None:
431
+ for metric_name, metric_value in metrics.items():
432
+ logger.log_scalar(metric_name, metric_value, step)
433
+
434
+
435
+ def dump_video(module):
436
+ if isinstance(module, VideoRecorder):
437
+ module.dump()
@@ -0,0 +1,74 @@
1
+ # Multi-agent examples
2
+
3
+ In this folder we provide a set of multi-agent example scripts using the [VMAS](https://github.com/proroklab/VectorizedMultiAgentSimulator) simulator.
4
+
5
+ <p align="center">
6
+ <img src="https://pytorch.s3.amazonaws.com/torchrl/github-artifacts/img/marl_vmas.png" width="600px">
7
+ </p>
8
+
9
+ <center><i>The MARL algorithms contained in the scripts of this folder run on three multi-robot tasks in VMAS.</i></center>
10
+
11
+ For more details on the experiment setup and the environments please refer to the corresponding section of the appendix in the [TorchRL paper](https://arxiv.org/abs/2306.00577).
12
+
13
+ > [!NOTE]
14
+ > If you are interested in Multi-Agent Reinforcement Learning (MARL) in TorchRL, check out [BenchMARL](https://github.com/facebookresearch/BenchMARL):
15
+ > a benchmarking library where you
16
+ > can train and compare MARL algorithms, tasks, and models using TorchRL!
17
+
18
+ ## Using the scripts
19
+
20
+ ### Install
21
+
22
+ First you need to install vmas and the dependencies of the scripts.
23
+
24
+ Install torchrl and tensordict following repo instructions.
25
+
26
+ Install vmas and dependencies:
27
+
28
+ ```bash
29
+ pip install vmas
30
+ pip install wandb "moviepy<2.0.0"
31
+ pip install hydra-core
32
+ ```
33
+
34
+ ### Run
35
+
36
+ To run the scripts just execute the corresponding python file after having modified the corresponding config
37
+ according to your needs.
38
+ The config can be found in the .yaml file with the same name.
39
+
40
+ For example:
41
+ ```bash
42
+ python mappo_ippo.py
43
+ ```
44
+
45
+ You can even change the config from the command line like:
46
+
47
+ ```bash
48
+ python mappo_ippo.py --m env.scenario_name=navigation
49
+ ```
50
+
51
+ ### Computational demand
52
+ The scripts are set up for collecting many frames, if your compute is limited, you can change the "frames_per_batch"
53
+ and "num_epochs" parameters to reduce compute requirements.
54
+
55
+ ### Script structure
56
+
57
+ The scripts are self-contained.
58
+ This means that all the code you will need to look at is contained in the script file.
59
+ No helper functions are used.
60
+
61
+ The structure of scripts follows this order:
62
+ - Configuration dictionary for the script
63
+ - Environment creation
64
+ - Modules creation
65
+ - Collector instantiation
66
+ - Replay buffer instantiation
67
+ - Loss module creation
68
+ - Training loop (with inner minibatch loops)
69
+ - Evaluation run (at the desired frequency)
70
+
71
+ Logging is done by default to wandb.
72
+ The logging backend can be changed in the config files to one of "wandb", "tensorboard", "csv", "mlflow".
73
+
74
+ All the scripts follow the same on-policy training structure so that results can be compared across different algorithms.