torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,381 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+
9
+ import torch
10
+ from tensordict.nn import InteractionType, TensorDictModule
11
+ from tensordict.nn.distributions import NormalParamExtractor
12
+ from torch import nn, optim
13
+ from torchrl.collectors import aSyncDataCollector, SyncDataCollector
14
+ from torchrl.data import (
15
+ LazyMemmapStorage,
16
+ LazyTensorStorage,
17
+ TensorDictPrioritizedReplayBuffer,
18
+ TensorDictReplayBuffer,
19
+ )
20
+ from torchrl.envs import (
21
+ CatTensors,
22
+ Compose,
23
+ DMControlEnv,
24
+ DoubleToFloat,
25
+ EnvCreator,
26
+ ParallelEnv,
27
+ TransformedEnv,
28
+ )
29
+ from torchrl.envs.libs.gym import GymEnv, set_gym_backend
30
+ from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter
31
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
32
+ from torchrl.modules import MLP, ProbabilisticActor, ValueOperator
33
+ from torchrl.modules.distributions import TanhNormal
34
+ from torchrl.objectives import SoftUpdate
35
+ from torchrl.objectives.sac import SACLoss
36
+ from torchrl.record import VideoRecorder
37
+
38
+ # ====================================================================
39
+ # Environment utils
40
+ # -----------------
41
+
42
+
43
+ def env_maker(cfg, device="cpu", from_pixels=False):
44
+ lib = cfg.env.library
45
+ if lib in ("gym", "gymnasium"):
46
+ with set_gym_backend(lib):
47
+ return GymEnv(
48
+ cfg.env.name,
49
+ device=device,
50
+ from_pixels=from_pixels,
51
+ pixels_only=False,
52
+ )
53
+ elif lib == "dm_control":
54
+ env = DMControlEnv(
55
+ cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
56
+ )
57
+ return TransformedEnv(
58
+ env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
59
+ )
60
+ else:
61
+ raise NotImplementedError(f"Unknown lib {lib}.")
62
+
63
+
64
+ def apply_env_transforms(env, max_episode_steps=1000):
65
+ transformed_env = TransformedEnv(
66
+ env,
67
+ Compose(
68
+ InitTracker(),
69
+ StepCounter(max_episode_steps),
70
+ DoubleToFloat(),
71
+ RewardSum(),
72
+ ),
73
+ )
74
+ return transformed_env
75
+
76
+
77
+ def make_environment(cfg, logger=None):
78
+ """Make environments for training and evaluation."""
79
+ partial = functools.partial(env_maker, cfg=cfg)
80
+ parallel_env = ParallelEnv(
81
+ cfg.collector.env_per_collector,
82
+ EnvCreator(partial),
83
+ serial_for_single=True,
84
+ )
85
+ parallel_env.set_seed(cfg.env.seed)
86
+
87
+ train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps)
88
+
89
+ partial = functools.partial(env_maker, cfg=cfg, from_pixels=cfg.logger.video)
90
+ trsf_clone = train_env.transform.clone()
91
+ if cfg.logger.video:
92
+ trsf_clone.insert(
93
+ 0, VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
94
+ )
95
+ eval_env = TransformedEnv(
96
+ ParallelEnv(
97
+ cfg.collector.env_per_collector,
98
+ EnvCreator(partial),
99
+ serial_for_single=True,
100
+ ),
101
+ trsf_clone,
102
+ )
103
+ return train_env, eval_env
104
+
105
+
106
+ def make_train_environment(cfg):
107
+ """Make environments for training and evaluation."""
108
+ partial = functools.partial(env_maker, cfg=cfg)
109
+ parallel_env = ParallelEnv(
110
+ cfg.collector.env_per_collector,
111
+ EnvCreator(partial),
112
+ serial_for_single=True,
113
+ )
114
+ parallel_env.set_seed(cfg.env.seed)
115
+
116
+ train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps)
117
+
118
+ return train_env
119
+
120
+
121
+ # ====================================================================
122
+ # Collector and replay buffer
123
+ # ---------------------------
124
+
125
+
126
+ def make_collector(cfg, train_env, actor_model_explore, compile_mode):
127
+ """Make collector."""
128
+ device = cfg.collector.device
129
+ if device in ("", None):
130
+ if torch.cuda.is_available():
131
+ device = torch.device("cuda:0")
132
+ else:
133
+ device = torch.device("cpu")
134
+ collector = SyncDataCollector(
135
+ train_env,
136
+ actor_model_explore,
137
+ init_random_frames=cfg.collector.init_random_frames,
138
+ frames_per_batch=cfg.collector.frames_per_batch,
139
+ total_frames=cfg.collector.total_frames,
140
+ device=device,
141
+ compile_policy={"mode": compile_mode} if compile_mode else False,
142
+ cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
143
+ )
144
+ collector.set_seed(cfg.env.seed)
145
+ return collector
146
+
147
+
148
+ def flatten(td):
149
+ return td.reshape(-1)
150
+
151
+
152
+ def make_collector_async(
153
+ cfg, train_env_make, actor_model_explore, compile_mode, replay_buffer
154
+ ):
155
+ """Make async collector."""
156
+ device = cfg.collector.device
157
+ if device in ("", None):
158
+ if torch.cuda.is_available():
159
+ if torch.cuda.device_count() < 2:
160
+ raise RuntimeError("Requires >= 2 GPUs")
161
+ device = torch.device("cuda:1")
162
+ else:
163
+ device = torch.device("cpu")
164
+
165
+ collector = aSyncDataCollector(
166
+ train_env_make,
167
+ actor_model_explore,
168
+ init_random_frames=0, # Currently not supported, but accounted for in script: cfg.collector.init_random_frames,
169
+ frames_per_batch=cfg.collector.frames_per_batch,
170
+ total_frames=cfg.collector.total_frames,
171
+ device=device,
172
+ env_device=torch.device("cpu"),
173
+ compile_policy={"mode": compile_mode, "warmup": 5} if compile_mode else False,
174
+ cudagraph_policy={"warmup": 20} if cfg.compile.cudagraphs else False,
175
+ replay_buffer=replay_buffer,
176
+ extend_buffer=True,
177
+ postproc=flatten,
178
+ no_cuda_sync=True,
179
+ )
180
+ collector.set_seed(cfg.env.seed)
181
+ collector.start()
182
+ return collector
183
+
184
+
185
+ def make_replay_buffer(
186
+ batch_size,
187
+ prb=False,
188
+ buffer_size=1000000,
189
+ scratch_dir=None,
190
+ device="cpu",
191
+ prefetch=3,
192
+ shared: bool = False,
193
+ ):
194
+ storage_cls = (
195
+ functools.partial(LazyTensorStorage, device=device)
196
+ if not scratch_dir
197
+ else functools.partial(LazyMemmapStorage, device="cpu", scratch_dir=scratch_dir)
198
+ )
199
+ if prb:
200
+ replay_buffer = TensorDictPrioritizedReplayBuffer(
201
+ alpha=0.7,
202
+ beta=0.5,
203
+ pin_memory=False,
204
+ prefetch=prefetch,
205
+ storage=storage_cls(
206
+ buffer_size,
207
+ ),
208
+ batch_size=batch_size,
209
+ shared=shared,
210
+ )
211
+ else:
212
+ replay_buffer = TensorDictReplayBuffer(
213
+ pin_memory=False,
214
+ prefetch=prefetch,
215
+ storage=storage_cls(
216
+ buffer_size,
217
+ ),
218
+ batch_size=batch_size,
219
+ shared=shared,
220
+ )
221
+ if scratch_dir:
222
+ replay_buffer.append_transform(lambda td: td.to(device))
223
+ return replay_buffer
224
+
225
+
226
+ # ====================================================================
227
+ # Model
228
+ # -----
229
+
230
+
231
+ def make_sac_agent(cfg, train_env, eval_env, device):
232
+ """Make SAC agent."""
233
+ # Define Actor Network
234
+ in_keys = ["observation"]
235
+ action_spec = train_env.action_spec_unbatched.to(device)
236
+
237
+ actor_net = MLP(
238
+ num_cells=cfg.network.hidden_sizes,
239
+ out_features=2 * action_spec.shape[-1],
240
+ activation_class=get_activation(cfg),
241
+ device=device,
242
+ )
243
+
244
+ dist_class = TanhNormal
245
+ dist_kwargs = {
246
+ "low": action_spec.space.low,
247
+ "high": action_spec.space.high,
248
+ "tanh_loc": False,
249
+ }
250
+
251
+ actor_extractor = NormalParamExtractor(
252
+ scale_mapping=f"biased_softplus_{cfg.network.default_policy_scale}",
253
+ scale_lb=cfg.network.scale_lb,
254
+ ).to(device)
255
+ actor_net = nn.Sequential(actor_net, actor_extractor)
256
+
257
+ in_keys_actor = in_keys
258
+ actor_module = TensorDictModule(
259
+ actor_net,
260
+ in_keys=in_keys_actor,
261
+ out_keys=[
262
+ "loc",
263
+ "scale",
264
+ ],
265
+ )
266
+ actor = ProbabilisticActor(
267
+ spec=action_spec,
268
+ in_keys=["loc", "scale"],
269
+ module=actor_module,
270
+ distribution_class=dist_class,
271
+ distribution_kwargs=dist_kwargs,
272
+ default_interaction_type=InteractionType.RANDOM,
273
+ return_log_prob=False,
274
+ )
275
+
276
+ # Define Critic Network
277
+ qvalue_net = MLP(
278
+ num_cells=cfg.network.hidden_sizes,
279
+ out_features=1,
280
+ activation_class=get_activation(cfg),
281
+ device=device,
282
+ )
283
+
284
+ qvalue = ValueOperator(
285
+ in_keys=["action"] + in_keys,
286
+ module=qvalue_net,
287
+ )
288
+
289
+ model = nn.ModuleList([actor, qvalue])
290
+
291
+ # init nets
292
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
293
+ td = eval_env.fake_tensordict()
294
+ td = td.to(device)
295
+ for net in model:
296
+ net(td)
297
+ return model, model[0]
298
+
299
+
300
+ # ====================================================================
301
+ # SAC Loss
302
+ # ---------
303
+
304
+
305
+ def make_loss_module(cfg, model):
306
+ """Make loss module and target network updater."""
307
+ # Create SAC loss
308
+ loss_module = SACLoss(
309
+ actor_network=model[0],
310
+ qvalue_network=model[1],
311
+ num_qvalue_nets=2,
312
+ loss_function=cfg.optim.loss_function,
313
+ delay_actor=False,
314
+ delay_qvalue=True,
315
+ alpha_init=cfg.optim.alpha_init,
316
+ )
317
+ loss_module.make_value_estimator(gamma=cfg.optim.gamma)
318
+
319
+ # Define Target Network Updater
320
+ target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak)
321
+ return loss_module, target_net_updater
322
+
323
+
324
+ def split_critic_params(critic_params):
325
+ critic1_params = []
326
+ critic2_params = []
327
+
328
+ for param in critic_params:
329
+ data1, data2 = param.data.chunk(2, dim=0)
330
+ critic1_params.append(nn.Parameter(data1))
331
+ critic2_params.append(nn.Parameter(data2))
332
+ return critic1_params, critic2_params
333
+
334
+
335
+ def make_sac_optimizer(cfg, loss_module):
336
+ critic_params = list(loss_module.qvalue_network_params.flatten_keys().values())
337
+ actor_params = list(loss_module.actor_network_params.flatten_keys().values())
338
+
339
+ optimizer_actor = optim.Adam(
340
+ actor_params,
341
+ lr=cfg.optim.lr,
342
+ weight_decay=cfg.optim.weight_decay,
343
+ eps=cfg.optim.adam_eps,
344
+ )
345
+ optimizer_critic = optim.Adam(
346
+ critic_params,
347
+ lr=cfg.optim.lr,
348
+ weight_decay=cfg.optim.weight_decay,
349
+ eps=cfg.optim.adam_eps,
350
+ )
351
+ optimizer_alpha = optim.Adam(
352
+ [loss_module.log_alpha],
353
+ lr=3.0e-4,
354
+ )
355
+ return optimizer_actor, optimizer_critic, optimizer_alpha
356
+
357
+
358
+ # ====================================================================
359
+ # General utils
360
+ # ---------
361
+
362
+
363
+ def log_metrics(logger, metrics, step):
364
+ for metric_name, metric_value in metrics.items():
365
+ logger.log_scalar(metric_name, metric_value, step)
366
+
367
+
368
+ def get_activation(cfg):
369
+ if cfg.network.activation == "relu":
370
+ return nn.ReLU
371
+ elif cfg.network.activation == "tanh":
372
+ return nn.Tanh
373
+ elif cfg.network.activation == "leaky_relu":
374
+ return nn.LeakyReLU
375
+ else:
376
+ raise NotImplementedError
377
+
378
+
379
+ def dump_video(module):
380
+ if isinstance(module, VideoRecorder):
381
+ module.dump()
@@ -0,0 +1,16 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ import hydra
6
+ from torchrl.trainers.algorithms.configs import * # noqa: F401, F403
7
+
8
+
9
+ @hydra.main(config_path="config", config_name="config", version_base="1.1")
10
+ def main(cfg):
11
+ trainer = hydra.utils.instantiate(cfg.trainer)
12
+ trainer.train()
13
+
14
+
15
+ if __name__ == "__main__":
16
+ main()
@@ -0,0 +1,254 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """TD3 Example.
6
+
7
+ This is a simple self-contained example of a TD3 training script.
8
+
9
+ It supports state environments like MuJoCo.
10
+
11
+ The helper functions are coded in the utils.py associated with this script.
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import warnings
16
+
17
+ import hydra
18
+ import numpy as np
19
+ import torch
20
+ import torch.cuda
21
+ import tqdm
22
+ from tensordict.nn import CudaGraphModule
23
+ from torchrl._utils import compile_with_warmup, get_available_device, timeit
24
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
25
+ from torchrl.record.loggers import generate_exp_name, get_logger
26
+ from utils import (
27
+ dump_video,
28
+ log_metrics,
29
+ make_collector,
30
+ make_environment,
31
+ make_loss_module,
32
+ make_optimizer,
33
+ make_replay_buffer,
34
+ make_td3_agent,
35
+ )
36
+
37
+ torch.set_float32_matmul_precision("high")
38
+
39
+
40
+ @hydra.main(version_base="1.1", config_path="", config_name="config")
41
+ def main(cfg: DictConfig): # noqa: F821
42
+ device = (
43
+ torch.device(cfg.network.device)
44
+ if cfg.network.device
45
+ else get_available_device()
46
+ )
47
+
48
+ # Create logger
49
+ exp_name = generate_exp_name("TD3", cfg.logger.exp_name)
50
+ logger = None
51
+ if cfg.logger.backend:
52
+ logger = get_logger(
53
+ logger_type=cfg.logger.backend,
54
+ logger_name="td3_logging",
55
+ experiment_name=exp_name,
56
+ wandb_kwargs={
57
+ "mode": cfg.logger.mode,
58
+ "config": dict(cfg),
59
+ "project": cfg.logger.project_name,
60
+ "group": cfg.logger.group_name,
61
+ },
62
+ )
63
+
64
+ # Set seeds
65
+ torch.manual_seed(cfg.env.seed)
66
+ np.random.seed(cfg.env.seed)
67
+
68
+ # Create environments
69
+ train_env, eval_env = make_environment(cfg, logger=logger, device=device)
70
+
71
+ # Create agent
72
+ model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device)
73
+
74
+ # Create TD3 loss
75
+ loss_module, target_net_updater = make_loss_module(cfg, model)
76
+
77
+ compile_mode = None
78
+ if cfg.compile.compile:
79
+ compile_mode = cfg.compile.compile_mode
80
+ if compile_mode in ("", None):
81
+ if cfg.compile.cudagraphs:
82
+ compile_mode = "default"
83
+ else:
84
+ compile_mode = "reduce-overhead"
85
+
86
+ # Create off-policy collector
87
+ collector = make_collector(
88
+ cfg,
89
+ train_env,
90
+ exploration_policy,
91
+ compile_mode=compile_mode,
92
+ device=device,
93
+ )
94
+
95
+ # Create replay buffer
96
+ replay_buffer = make_replay_buffer(
97
+ batch_size=cfg.optim.batch_size,
98
+ prb=cfg.replay_buffer.prb,
99
+ buffer_size=cfg.replay_buffer.size,
100
+ scratch_dir=cfg.replay_buffer.scratch_dir,
101
+ device=device,
102
+ compile=bool(compile_mode),
103
+ )
104
+
105
+ # Create optimizers
106
+ optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)
107
+
108
+ prb = cfg.replay_buffer.prb
109
+
110
+ def update(sampled_tensordict, update_actor, prb=prb):
111
+
112
+ # Compute loss
113
+ q_loss, *_ = loss_module.value_loss(sampled_tensordict)
114
+
115
+ # Update critic
116
+ q_loss.backward()
117
+ optimizer_critic.step()
118
+ optimizer_critic.zero_grad(set_to_none=True)
119
+
120
+ # Update actor
121
+ if update_actor:
122
+ actor_loss, *_ = loss_module.actor_loss(sampled_tensordict)
123
+
124
+ actor_loss.backward()
125
+ optimizer_actor.step()
126
+ optimizer_actor.zero_grad(set_to_none=True)
127
+
128
+ # Update target params
129
+ target_net_updater.step()
130
+ else:
131
+ actor_loss = q_loss.new_zeros(())
132
+
133
+ return q_loss.detach(), actor_loss.detach()
134
+
135
+ if cfg.compile.compile:
136
+ update = compile_with_warmup(update, mode=compile_mode, warmup=1)
137
+
138
+ if cfg.compile.cudagraphs:
139
+ warnings.warn(
140
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
141
+ category=UserWarning,
142
+ )
143
+ update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
144
+
145
+ # Main loop
146
+ collected_frames = 0
147
+ pbar = tqdm.tqdm(total=cfg.collector.total_frames)
148
+
149
+ init_random_frames = cfg.collector.init_random_frames
150
+ num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
151
+ delayed_updates = cfg.optim.policy_update_delay
152
+ eval_rollout_steps = cfg.env.max_episode_steps
153
+ eval_iter = cfg.logger.eval_iter
154
+ frames_per_batch = cfg.collector.frames_per_batch
155
+ update_counter = 0
156
+
157
+ collector_iter = iter(collector)
158
+ total_iter = len(collector)
159
+
160
+ for _ in range(total_iter):
161
+ timeit.printevery(num_prints=1000, total_count=total_iter, erase=True)
162
+
163
+ with timeit("collect"):
164
+ tensordict = next(collector_iter)
165
+
166
+ # Update weights of the inference policy
167
+ collector.update_policy_weights_()
168
+
169
+ current_frames = tensordict.numel()
170
+ pbar.update(current_frames)
171
+
172
+ with timeit("rb - extend"):
173
+ # Add to replay buffer
174
+ tensordict = tensordict.reshape(-1)
175
+ replay_buffer.extend(tensordict)
176
+
177
+ collected_frames += current_frames
178
+
179
+ with timeit("train"):
180
+ # Optimization steps
181
+ if collected_frames >= init_random_frames:
182
+ (
183
+ actor_losses,
184
+ q_losses,
185
+ ) = ([], [])
186
+ for _ in range(num_updates):
187
+ # Update actor every delayed_updates
188
+ update_counter += 1
189
+ update_actor = update_counter % delayed_updates == 0
190
+
191
+ with timeit("rb - sample"):
192
+ sampled_tensordict = replay_buffer.sample()
193
+ with timeit("update"):
194
+ torch.compiler.cudagraph_mark_step_begin()
195
+ q_loss, actor_loss = update(sampled_tensordict, update_actor)
196
+
197
+ # Update priority
198
+ if prb:
199
+ with timeit("rb - priority"):
200
+ replay_buffer.update_priority(sampled_tensordict)
201
+
202
+ q_losses.append(q_loss.clone())
203
+ if update_actor:
204
+ actor_losses.append(actor_loss.clone())
205
+
206
+ episode_end = (
207
+ tensordict["next", "done"]
208
+ if tensordict["next", "done"].any()
209
+ else tensordict["next", "truncated"]
210
+ )
211
+ episode_rewards = tensordict["next", "episode_reward"][episode_end]
212
+
213
+ # Logging
214
+ metrics_to_log = {}
215
+ if len(episode_rewards) > 0:
216
+ episode_length = tensordict["next", "step_count"][episode_end]
217
+ metrics_to_log["train/reward"] = episode_rewards.mean()
218
+ metrics_to_log["train/episode_length"] = episode_length.sum() / len(
219
+ episode_length
220
+ )
221
+
222
+ if collected_frames >= init_random_frames:
223
+ metrics_to_log["train/q_loss"] = torch.stack(q_losses).mean()
224
+ if update_actor:
225
+ metrics_to_log["train/a_loss"] = torch.stack(actor_losses).mean()
226
+
227
+ # Evaluation
228
+ if abs(collected_frames % eval_iter) < frames_per_batch:
229
+ with set_exploration_type(
230
+ ExplorationType.DETERMINISTIC
231
+ ), torch.no_grad(), timeit("eval"):
232
+ eval_rollout = eval_env.rollout(
233
+ eval_rollout_steps,
234
+ exploration_policy,
235
+ auto_cast_to_device=True,
236
+ break_when_any_done=True,
237
+ )
238
+ eval_env.apply(dump_video)
239
+ eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
240
+ metrics_to_log["eval/reward"] = eval_reward
241
+ if logger is not None:
242
+ metrics_to_log.update(timeit.todict(prefix="time"))
243
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
244
+ log_metrics(logger, metrics_to_log, collected_frames)
245
+
246
+ collector.shutdown()
247
+ if not eval_env.is_closed:
248
+ eval_env.close()
249
+ if not train_env.is_closed:
250
+ train_env.close()
251
+
252
+
253
+ if __name__ == "__main__":
254
+ main()