torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,271 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """CrossQ Example.
6
+
7
+ This is a simple self-contained example of a CrossQ training script.
8
+
9
+ It supports state environments like MuJoCo.
10
+
11
+ The helper functions are coded in the utils.py associated with this script.
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import warnings
16
+
17
+ import hydra
18
+ import numpy as np
19
+ import torch
20
+ import torch.cuda
21
+ import tqdm
22
+ from tensordict import TensorDict
23
+ from tensordict.nn import CudaGraphModule
24
+ from torchrl._utils import get_available_device, timeit
25
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
26
+ from torchrl.objectives import group_optimizers
27
+ from torchrl.record.loggers import generate_exp_name, get_logger
28
+ from utils import (
29
+ log_metrics,
30
+ make_collector,
31
+ make_crossQ_agent,
32
+ make_crossQ_optimizer,
33
+ make_environment,
34
+ make_loss_module,
35
+ make_replay_buffer,
36
+ )
37
+
38
+ torch.set_float32_matmul_precision("high")
39
+
40
+
41
+ @hydra.main(version_base="1.1", config_path=".", config_name="config")
42
+ def main(cfg: DictConfig): # noqa: F821
43
+ device = (
44
+ torch.device(cfg.network.device)
45
+ if cfg.network.device
46
+ else get_available_device()
47
+ )
48
+
49
+ # Create logger
50
+ exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name)
51
+ logger = None
52
+ if cfg.logger.backend:
53
+ logger = get_logger(
54
+ logger_type=cfg.logger.backend,
55
+ logger_name="crossq_logging",
56
+ experiment_name=exp_name,
57
+ wandb_kwargs={
58
+ "mode": cfg.logger.mode,
59
+ "config": dict(cfg),
60
+ "project": cfg.logger.project_name,
61
+ "group": cfg.logger.group_name,
62
+ },
63
+ )
64
+
65
+ torch.manual_seed(cfg.env.seed)
66
+ np.random.seed(cfg.env.seed)
67
+
68
+ # Create environments
69
+ train_env, eval_env = make_environment(cfg)
70
+
71
+ # Create agent
72
+ model, exploration_policy = make_crossQ_agent(cfg, train_env, device)
73
+
74
+ # Create CrossQ loss
75
+ loss_module = make_loss_module(cfg, model, device=device)
76
+
77
+ compile_mode = None
78
+ if cfg.compile.compile:
79
+ if cfg.compile.compile_mode not in (None, ""):
80
+ compile_mode = cfg.compile.compile_mode
81
+ elif cfg.compile.cudagraphs:
82
+ compile_mode = "default"
83
+ else:
84
+ compile_mode = "reduce-overhead"
85
+
86
+ # Create off-policy collector
87
+ collector = make_collector(
88
+ cfg,
89
+ train_env,
90
+ exploration_policy.eval(),
91
+ device=device,
92
+ compile=cfg.compile.compile,
93
+ compile_mode=compile_mode,
94
+ cudagraph=cfg.compile.cudagraphs,
95
+ )
96
+
97
+ # Create replay buffer
98
+ replay_buffer = make_replay_buffer(
99
+ batch_size=cfg.optim.batch_size,
100
+ prb=cfg.replay_buffer.prb,
101
+ buffer_size=cfg.replay_buffer.size,
102
+ scratch_dir=cfg.replay_buffer.scratch_dir,
103
+ device="cpu",
104
+ )
105
+
106
+ # Create optimizers
107
+ (
108
+ optimizer_actor,
109
+ optimizer_critic,
110
+ optimizer_alpha,
111
+ ) = make_crossQ_optimizer(cfg, loss_module)
112
+ optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
113
+ del optimizer_actor, optimizer_critic, optimizer_alpha
114
+
115
+ def update_qloss(sampled_tensordict):
116
+ optimizer.zero_grad(set_to_none=True)
117
+ td_loss = {}
118
+ q_loss, value_meta = loss_module.qvalue_loss(sampled_tensordict)
119
+ sampled_tensordict.set(loss_module.tensor_keys.priority, value_meta["td_error"])
120
+ q_loss = q_loss.mean()
121
+
122
+ # Update critic
123
+ q_loss.backward()
124
+ optimizer.step()
125
+ td_loss["loss_qvalue"] = q_loss
126
+ td_loss["loss_actor"] = float("nan")
127
+ td_loss["loss_alpha"] = float("nan")
128
+ return TensorDict(td_loss, device=device).detach()
129
+
130
+ def update_all(sampled_tensordict: TensorDict):
131
+ optimizer.zero_grad(set_to_none=True)
132
+
133
+ td_loss = {}
134
+ q_loss, value_meta = loss_module.qvalue_loss(sampled_tensordict)
135
+ sampled_tensordict.set(loss_module.tensor_keys.priority, value_meta["td_error"])
136
+ q_loss = q_loss.mean()
137
+
138
+ actor_loss, metadata_actor = loss_module.actor_loss(sampled_tensordict)
139
+ actor_loss = actor_loss.mean()
140
+ alpha_loss = loss_module.alpha_loss(
141
+ log_prob=metadata_actor["log_prob"].detach()
142
+ ).mean()
143
+
144
+ # Updates
145
+ (q_loss + actor_loss + actor_loss).backward()
146
+ optimizer.step()
147
+
148
+ # Update critic
149
+ td_loss["loss_qvalue"] = q_loss
150
+ td_loss["loss_actor"] = actor_loss
151
+ td_loss["loss_alpha"] = alpha_loss
152
+
153
+ return TensorDict(td_loss, device=device).detach()
154
+
155
+ if compile_mode:
156
+ update_all = torch.compile(update_all, mode=compile_mode)
157
+ update_qloss = torch.compile(update_qloss, mode=compile_mode)
158
+ if cfg.compile.cudagraphs:
159
+ warnings.warn(
160
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
161
+ category=UserWarning,
162
+ )
163
+ update_all = CudaGraphModule(update_all, warmup=50)
164
+ update_qloss = CudaGraphModule(update_qloss, warmup=50)
165
+
166
+ def update(sampled_tensordict: TensorDict, update_actor: bool):
167
+ if update_actor:
168
+ return update_all(sampled_tensordict)
169
+ return update_qloss(sampled_tensordict)
170
+
171
+ # Main loop
172
+ collected_frames = 0
173
+ pbar = tqdm.tqdm(total=cfg.collector.total_frames)
174
+
175
+ init_random_frames = cfg.collector.init_random_frames
176
+ num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
177
+ prb = cfg.replay_buffer.prb
178
+ eval_iter = cfg.logger.eval_iter
179
+ frames_per_batch = cfg.collector.frames_per_batch
180
+ eval_rollout_steps = cfg.env.max_episode_steps
181
+
182
+ update_counter = 0
183
+ delayed_updates = cfg.optim.policy_update_delay
184
+ c_iter = iter(collector)
185
+ total_iter = len(collector)
186
+ for _ in range(total_iter):
187
+ timeit.printevery(1000, total_iter, erase=True)
188
+ with timeit("collecting"):
189
+ torch.compiler.cudagraph_mark_step_begin()
190
+ tensordict = next(c_iter)
191
+
192
+ # Update weights of the inference policy
193
+ collector.update_policy_weights_()
194
+
195
+ current_frames = tensordict.numel()
196
+ pbar.update(current_frames)
197
+ tensordict = tensordict.reshape(-1)
198
+
199
+ with timeit("rb - extend"):
200
+ # Add to replay buffer
201
+ replay_buffer.extend(tensordict)
202
+ collected_frames += current_frames
203
+
204
+ # Optimization steps
205
+ if collected_frames >= init_random_frames:
206
+ tds = []
207
+ for _ in range(num_updates):
208
+ # Update actor every delayed_updates
209
+ update_counter += 1
210
+ update_actor = update_counter % delayed_updates == 0
211
+ # Sample from replay buffer
212
+ with timeit("rb - sample"):
213
+ sampled_tensordict = replay_buffer.sample().to(device)
214
+ with timeit("update"):
215
+ torch.compiler.cudagraph_mark_step_begin()
216
+ td_loss = update(sampled_tensordict, update_actor=update_actor)
217
+ tds.append(td_loss.clone())
218
+ # Update priority
219
+ if prb:
220
+ replay_buffer.update_priority(sampled_tensordict)
221
+
222
+ tds = TensorDict.stack(tds).nanmean()
223
+ episode_end = (
224
+ tensordict["next", "done"]
225
+ if tensordict["next", "done"].any()
226
+ else tensordict["next", "truncated"]
227
+ )
228
+ episode_rewards = tensordict["next", "episode_reward"][episode_end]
229
+
230
+ metrics_to_log = {}
231
+
232
+ # Evaluation
233
+ if abs(collected_frames % eval_iter) < frames_per_batch:
234
+ with set_exploration_type(
235
+ ExplorationType.DETERMINISTIC
236
+ ), torch.no_grad(), timeit("eval"):
237
+ eval_rollout = eval_env.rollout(
238
+ eval_rollout_steps,
239
+ model[0],
240
+ auto_cast_to_device=True,
241
+ break_when_any_done=True,
242
+ )
243
+ eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
244
+ metrics_to_log["eval/reward"] = eval_reward
245
+
246
+ # Logging
247
+ if len(episode_rewards) > 0:
248
+ episode_length = tensordict["next", "step_count"][episode_end]
249
+ metrics_to_log["train/reward"] = episode_rewards.mean().item()
250
+ metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
251
+ episode_length
252
+ )
253
+ if collected_frames >= init_random_frames:
254
+ metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
255
+ metrics_to_log["train/actor_loss"] = tds["loss_actor"]
256
+ metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]
257
+
258
+ if logger is not None:
259
+ metrics_to_log.update(timeit.todict(prefix="time"))
260
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
261
+ log_metrics(logger, metrics_to_log, collected_frames)
262
+
263
+ collector.shutdown()
264
+ if not eval_env.is_closed:
265
+ eval_env.close()
266
+ if not train_env.is_closed:
267
+ train_env.close()
268
+
269
+
270
+ if __name__ == "__main__":
271
+ main()
@@ -0,0 +1,320 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ from tensordict.nn import InteractionType, TensorDictModule
9
+ from tensordict.nn.distributions import NormalParamExtractor
10
+ from torch import nn, optim
11
+ from torchrl.collectors import SyncDataCollector
12
+ from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
13
+ from torchrl.data.replay_buffers.storages import LazyMemmapStorage
14
+ from torchrl.envs import (
15
+ CatTensors,
16
+ Compose,
17
+ DMControlEnv,
18
+ DoubleToFloat,
19
+ EnvCreator,
20
+ ParallelEnv,
21
+ TransformedEnv,
22
+ )
23
+ from torchrl.envs.libs.gym import GymEnv, set_gym_backend
24
+ from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter
25
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
26
+ from torchrl.modules import MLP, ProbabilisticActor, ValueOperator
27
+ from torchrl.modules.distributions import TanhNormal
28
+
29
+ from torchrl.modules.models.batchrenorm import BatchRenorm1d
30
+ from torchrl.objectives import CrossQLoss
31
+
32
+ # ====================================================================
33
+ # Environment utils
34
+ # -----------------
35
+
36
+
37
+ def env_maker(cfg, device="cpu"):
38
+ lib = cfg.env.library
39
+ if lib in ("gym", "gymnasium"):
40
+ with set_gym_backend(lib):
41
+ return GymEnv(
42
+ cfg.env.name,
43
+ device=device,
44
+ )
45
+ elif lib == "dm_control":
46
+ env = DMControlEnv(cfg.env.name, cfg.env.task)
47
+ return TransformedEnv(
48
+ env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
49
+ )
50
+ else:
51
+ raise NotImplementedError(f"Unknown lib {lib}.")
52
+
53
+
54
+ def apply_env_transforms(env, max_episode_steps=1000):
55
+ transformed_env = TransformedEnv(
56
+ env,
57
+ Compose(
58
+ InitTracker(),
59
+ StepCounter(max_episode_steps),
60
+ DoubleToFloat(),
61
+ RewardSum(),
62
+ ),
63
+ )
64
+ return transformed_env
65
+
66
+
67
+ def make_environment(cfg):
68
+ """Make environments for training and evaluation."""
69
+ parallel_env = ParallelEnv(
70
+ cfg.collector.env_per_collector,
71
+ EnvCreator(lambda cfg=cfg: env_maker(cfg)),
72
+ serial_for_single=True,
73
+ )
74
+ parallel_env.set_seed(cfg.env.seed)
75
+
76
+ train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps)
77
+
78
+ eval_env = TransformedEnv(
79
+ ParallelEnv(
80
+ cfg.collector.env_per_collector,
81
+ EnvCreator(lambda cfg=cfg: env_maker(cfg)),
82
+ serial_for_single=True,
83
+ ),
84
+ train_env.transform.clone(),
85
+ )
86
+ return train_env, eval_env
87
+
88
+
89
+ # ====================================================================
90
+ # Collector and replay buffer
91
+ # ---------------------------
92
+
93
+
94
+ def make_collector(
95
+ cfg,
96
+ train_env,
97
+ actor_model_explore,
98
+ device,
99
+ compile=False,
100
+ compile_mode=None,
101
+ cudagraph=False,
102
+ ):
103
+ """Make collector."""
104
+ collector = SyncDataCollector(
105
+ train_env,
106
+ actor_model_explore,
107
+ init_random_frames=cfg.collector.init_random_frames,
108
+ frames_per_batch=cfg.collector.frames_per_batch,
109
+ total_frames=cfg.collector.total_frames,
110
+ device=device,
111
+ compile_policy={"mode": compile_mode} if compile else False,
112
+ cudagraph_policy=cudagraph,
113
+ )
114
+ collector.set_seed(cfg.env.seed)
115
+ return collector
116
+
117
+
118
+ def make_replay_buffer(
119
+ batch_size,
120
+ prb=False,
121
+ buffer_size=1000000,
122
+ scratch_dir=None,
123
+ device="cpu",
124
+ prefetch=3,
125
+ ):
126
+ if prb:
127
+ replay_buffer = TensorDictPrioritizedReplayBuffer(
128
+ alpha=0.7,
129
+ beta=0.5,
130
+ pin_memory=False,
131
+ prefetch=prefetch,
132
+ storage=LazyMemmapStorage(
133
+ buffer_size,
134
+ scratch_dir=scratch_dir,
135
+ ),
136
+ batch_size=batch_size,
137
+ )
138
+ else:
139
+ replay_buffer = TensorDictReplayBuffer(
140
+ pin_memory=False,
141
+ prefetch=prefetch,
142
+ storage=LazyMemmapStorage(
143
+ buffer_size,
144
+ scratch_dir=scratch_dir,
145
+ ),
146
+ batch_size=batch_size,
147
+ )
148
+ replay_buffer.append_transform(lambda x: x.to(device, non_blocking=True))
149
+ return replay_buffer
150
+
151
+
152
+ # ====================================================================
153
+ # Model
154
+ # -----
155
+
156
+
157
+ def make_crossQ_agent(cfg, train_env, device):
158
+ """Make CrossQ agent."""
159
+ # Define Actor Network
160
+ in_keys = ["observation"]
161
+ action_spec = train_env.action_spec_unbatched
162
+ actor_net_kwargs = {
163
+ "num_cells": cfg.network.actor_hidden_sizes,
164
+ "out_features": 2 * action_spec.shape[-1],
165
+ "activation_class": get_activation(cfg.network.actor_activation),
166
+ "norm_class": BatchRenorm1d,
167
+ "norm_kwargs": {
168
+ "momentum": cfg.network.batch_norm_momentum,
169
+ "num_features": cfg.network.actor_hidden_sizes[-1],
170
+ "warmup_steps": cfg.network.warmup_steps,
171
+ },
172
+ }
173
+
174
+ actor_net = MLP(**actor_net_kwargs)
175
+
176
+ dist_class = TanhNormal
177
+ dist_kwargs = {
178
+ "low": torch.as_tensor(action_spec.space.low, device=device),
179
+ "high": torch.as_tensor(action_spec.space.high, device=device),
180
+ "tanh_loc": False,
181
+ "safe_tanh": not cfg.compile.compile,
182
+ }
183
+
184
+ actor_extractor = NormalParamExtractor(
185
+ scale_mapping=f"biased_softplus_{cfg.network.default_policy_scale}",
186
+ scale_lb=cfg.network.scale_lb,
187
+ )
188
+ actor_net = nn.Sequential(actor_net, actor_extractor)
189
+
190
+ in_keys_actor = in_keys
191
+ actor_module = TensorDictModule(
192
+ actor_net,
193
+ in_keys=in_keys_actor,
194
+ out_keys=[
195
+ "loc",
196
+ "scale",
197
+ ],
198
+ )
199
+ actor = ProbabilisticActor(
200
+ spec=action_spec,
201
+ in_keys=["loc", "scale"],
202
+ module=actor_module,
203
+ distribution_class=dist_class,
204
+ distribution_kwargs=dist_kwargs,
205
+ default_interaction_type=InteractionType.RANDOM,
206
+ return_log_prob=False,
207
+ )
208
+
209
+ # Define Critic Network
210
+ qvalue_net_kwargs = {
211
+ "num_cells": cfg.network.critic_hidden_sizes,
212
+ "out_features": 1,
213
+ "activation_class": get_activation(cfg.network.critic_activation),
214
+ "norm_class": BatchRenorm1d,
215
+ "norm_kwargs": {
216
+ "momentum": cfg.network.batch_norm_momentum,
217
+ "num_features": cfg.network.critic_hidden_sizes[-1],
218
+ "warmup_steps": cfg.network.warmup_steps,
219
+ },
220
+ }
221
+
222
+ qvalue_net = MLP(
223
+ **qvalue_net_kwargs,
224
+ )
225
+
226
+ qvalue = ValueOperator(
227
+ in_keys=["action"] + in_keys,
228
+ module=qvalue_net,
229
+ )
230
+
231
+ model = nn.ModuleList([actor, qvalue]).to(device)
232
+
233
+ # init nets
234
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
235
+ td = train_env.fake_tensordict()
236
+ td = td.to(device)
237
+ for net in model:
238
+ net.eval()
239
+ net(td)
240
+ net.train()
241
+ del td
242
+
243
+ return model, model[0]
244
+
245
+
246
+ # ====================================================================
247
+ # CrossQ Loss
248
+ # ---------
249
+
250
+
251
+ def make_loss_module(cfg, model, device: torch.device | None = None):
252
+ """Make loss module and target network updater."""
253
+ # Create CrossQ loss
254
+ loss_module = CrossQLoss(
255
+ actor_network=model[0],
256
+ qvalue_network=model[1],
257
+ num_qvalue_nets=2,
258
+ loss_function=cfg.optim.loss_function,
259
+ alpha_init=cfg.optim.alpha_init,
260
+ )
261
+ loss_module.make_value_estimator(gamma=cfg.optim.gamma, device=device)
262
+
263
+ return loss_module
264
+
265
+
266
+ def split_critic_params(critic_params):
267
+ critic1_params = []
268
+ critic2_params = []
269
+
270
+ for param in critic_params:
271
+ data1, data2 = param.data.chunk(2, dim=0)
272
+ critic1_params.append(nn.Parameter(data1))
273
+ critic2_params.append(nn.Parameter(data2))
274
+ return critic1_params, critic2_params
275
+
276
+
277
+ def make_crossQ_optimizer(cfg, loss_module):
278
+ critic_params = list(loss_module.qvalue_network_params.flatten_keys().values())
279
+ actor_params = list(loss_module.actor_network_params.flatten_keys().values())
280
+
281
+ optimizer_actor = optim.Adam(
282
+ actor_params,
283
+ lr=cfg.optim.lr,
284
+ weight_decay=cfg.optim.weight_decay,
285
+ eps=cfg.optim.adam_eps,
286
+ betas=(cfg.optim.beta1, cfg.optim.beta2),
287
+ )
288
+ optimizer_critic = optim.Adam(
289
+ critic_params,
290
+ lr=cfg.optim.lr,
291
+ weight_decay=cfg.optim.weight_decay,
292
+ eps=cfg.optim.adam_eps,
293
+ betas=(cfg.optim.beta1, cfg.optim.beta2),
294
+ )
295
+ optimizer_alpha = optim.Adam(
296
+ [loss_module.log_alpha],
297
+ lr=cfg.optim.lr,
298
+ )
299
+ return optimizer_actor, optimizer_critic, optimizer_alpha
300
+
301
+
302
+ # ====================================================================
303
+ # General utils
304
+ # ---------
305
+
306
+
307
+ def log_metrics(logger, metrics, step):
308
+ for metric_name, metric_value in metrics.items():
309
+ logger.log_scalar(metric_name, metric_value, step)
310
+
311
+
312
+ def get_activation(activation: str):
313
+ if activation == "relu":
314
+ return nn.ReLU
315
+ elif activation == "tanh":
316
+ return nn.Tanh
317
+ elif activation == "leaky_relu":
318
+ return nn.LeakyReLU
319
+ else:
320
+ raise NotImplementedError