torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,198 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """CQL Example.
6
+
7
+ This is a self-contained example of an offline CQL training script.
8
+
9
+ The helper functions are coded in the utils.py associated with this script.
10
+
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import warnings
15
+
16
+ import hydra
17
+ import numpy as np
18
+ import torch
19
+ import tqdm
20
+ from tensordict.nn import CudaGraphModule
21
+ from torchrl._utils import get_available_device, timeit
22
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
23
+ from torchrl.objectives import group_optimizers
24
+ from torchrl.record.loggers import generate_exp_name, get_logger
25
+ from utils import (
26
+ dump_video,
27
+ log_metrics,
28
+ make_continuous_cql_optimizer,
29
+ make_continuous_loss,
30
+ make_cql_model,
31
+ make_environment,
32
+ make_offline_replay_buffer,
33
+ )
34
+
35
+ torch.set_float32_matmul_precision("high")
36
+
37
+
38
+ @hydra.main(config_path="", config_name="offline_config", version_base="1.1")
39
+ def main(cfg: DictConfig): # noqa: F821
40
+ # Create logger
41
+ exp_name = generate_exp_name("CQL-offline", cfg.logger.exp_name)
42
+ logger = None
43
+ if cfg.logger.backend:
44
+ logger = get_logger(
45
+ logger_type=cfg.logger.backend,
46
+ logger_name="cql_logging",
47
+ experiment_name=exp_name,
48
+ wandb_kwargs={
49
+ "mode": cfg.logger.mode,
50
+ "config": dict(cfg),
51
+ "project": cfg.logger.project_name,
52
+ "group": cfg.logger.group_name,
53
+ },
54
+ )
55
+ # Set seeds
56
+ torch.manual_seed(cfg.env.seed)
57
+ np.random.seed(cfg.env.seed)
58
+ device = (
59
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
60
+ )
61
+
62
+ # Create replay buffer
63
+ replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
64
+
65
+ # Create env
66
+ train_env, eval_env = make_environment(
67
+ cfg, train_num_envs=1, eval_num_envs=cfg.logger.eval_envs, logger=logger
68
+ )
69
+
70
+ # Create agent
71
+ model = make_cql_model(cfg, train_env, eval_env, device)
72
+ del train_env
73
+ if hasattr(eval_env, "start"):
74
+ # To set the number of threads to the definitive value
75
+ eval_env.start()
76
+
77
+ # Create loss
78
+ loss_module, target_net_updater = make_continuous_loss(
79
+ cfg.loss, model, device=device
80
+ )
81
+
82
+ # Create Optimizer
83
+ (
84
+ policy_optim,
85
+ critic_optim,
86
+ alpha_optim,
87
+ alpha_prime_optim,
88
+ ) = make_continuous_cql_optimizer(cfg, loss_module)
89
+
90
+ # Group optimizers
91
+ optimizer = group_optimizers(
92
+ policy_optim, critic_optim, alpha_optim, alpha_prime_optim
93
+ )
94
+
95
+ def update(data, policy_eval_start, iteration):
96
+ loss_vals = loss_module(data.to(device))
97
+
98
+ # official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks
99
+ actor_loss = torch.where(
100
+ iteration >= policy_eval_start,
101
+ loss_vals["loss_actor"],
102
+ loss_vals["loss_actor_bc"],
103
+ )
104
+ q_loss = loss_vals["loss_qvalue"]
105
+ cql_loss = loss_vals["loss_cql"]
106
+
107
+ q_loss = q_loss + cql_loss
108
+ loss_vals["q_loss"] = q_loss
109
+
110
+ # update model
111
+ alpha_loss = loss_vals["loss_alpha"]
112
+ alpha_prime_loss = loss_vals["loss_alpha_prime"]
113
+ if alpha_prime_loss is None:
114
+ alpha_prime_loss = 0
115
+
116
+ loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss
117
+
118
+ loss.backward()
119
+ optimizer.step()
120
+ optimizer.zero_grad(set_to_none=True)
121
+
122
+ # update qnet_target params
123
+ target_net_updater.step()
124
+
125
+ return loss.detach(), loss_vals.detach()
126
+
127
+ compile_mode = None
128
+ if cfg.compile.compile:
129
+ if cfg.compile.compile_mode not in (None, ""):
130
+ compile_mode = cfg.compile.compile_mode
131
+ elif cfg.compile.cudagraphs:
132
+ compile_mode = "default"
133
+ else:
134
+ compile_mode = "reduce-overhead"
135
+ update = torch.compile(update, mode=compile_mode)
136
+ if cfg.compile.cudagraphs:
137
+ warnings.warn(
138
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
139
+ category=UserWarning,
140
+ )
141
+ update = CudaGraphModule(update, warmup=50)
142
+
143
+ pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
144
+
145
+ gradient_steps = cfg.optim.gradient_steps
146
+ policy_eval_start = cfg.optim.policy_eval_start
147
+ evaluation_interval = cfg.logger.eval_iter
148
+ eval_steps = cfg.logger.eval_steps
149
+
150
+ # Training loop
151
+ policy_eval_start = torch.tensor(policy_eval_start, device=device)
152
+ for i in range(gradient_steps):
153
+ timeit.printevery(1000, gradient_steps, erase=True)
154
+ pbar.update(1)
155
+ # sample data
156
+ with timeit("sample"):
157
+ data = replay_buffer.sample()
158
+
159
+ with timeit("update"):
160
+ # compute loss
161
+ torch.compiler.cudagraph_mark_step_begin()
162
+ i_device = torch.tensor(i, device=device)
163
+ loss, loss_vals = update(
164
+ data.to(device), policy_eval_start=policy_eval_start, iteration=i_device
165
+ )
166
+
167
+ # log metrics
168
+ metrics_to_log = {
169
+ "loss": loss.cpu(),
170
+ **loss_vals.cpu(),
171
+ }
172
+
173
+ # evaluation
174
+ with timeit("log/eval"):
175
+ if i % evaluation_interval == 0:
176
+ with set_exploration_type(
177
+ ExplorationType.DETERMINISTIC
178
+ ), torch.no_grad():
179
+ eval_td = eval_env.rollout(
180
+ max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
181
+ )
182
+ eval_env.apply(dump_video)
183
+ eval_reward = eval_td["next", "reward"].sum(1).mean().item()
184
+ metrics_to_log["evaluation_reward"] = eval_reward
185
+
186
+ with timeit("log"):
187
+ metrics_to_log.update(timeit.todict(prefix="time"))
188
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
189
+
190
+ log_metrics(logger, metrics_to_log, i)
191
+
192
+ pbar.close()
193
+ if not eval_env.is_closed:
194
+ eval_env.close()
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()
@@ -0,0 +1,249 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """CQL Example.
6
+
7
+ This is a self-contained example of an online CQL training script.
8
+
9
+ It works across Gym and MuJoCo over a variety of tasks.
10
+
11
+ The helper functions are coded in the utils.py associated with this script.
12
+
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import warnings
17
+
18
+ import hydra
19
+ import numpy as np
20
+ import torch
21
+ import tqdm
22
+ from tensordict import TensorDict
23
+ from tensordict.nn import CudaGraphModule
24
+ from torchrl._utils import get_available_device, timeit
25
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
26
+ from torchrl.objectives import group_optimizers
27
+ from torchrl.record.loggers import generate_exp_name, get_logger
28
+ from utils import (
29
+ dump_video,
30
+ log_metrics,
31
+ make_collector,
32
+ make_continuous_cql_optimizer,
33
+ make_continuous_loss,
34
+ make_cql_model,
35
+ make_environment,
36
+ make_replay_buffer,
37
+ )
38
+
39
+ torch.set_float32_matmul_precision("high")
40
+
41
+
42
+ @hydra.main(version_base="1.1", config_path="", config_name="online_config")
43
+ def main(cfg: DictConfig): # noqa: F821
44
+ # Create logger
45
+ exp_name = generate_exp_name("CQL-online", cfg.logger.exp_name)
46
+ logger = None
47
+ if cfg.logger.backend:
48
+ logger = get_logger(
49
+ logger_type=cfg.logger.backend,
50
+ logger_name="cql_logging",
51
+ experiment_name=exp_name,
52
+ wandb_kwargs={
53
+ "mode": cfg.logger.mode,
54
+ "config": dict(cfg),
55
+ "project": cfg.logger.project_name,
56
+ "group": cfg.logger.group_name,
57
+ },
58
+ )
59
+
60
+ # Set seeds
61
+ torch.manual_seed(cfg.env.seed)
62
+ np.random.seed(cfg.env.seed)
63
+ device = (
64
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
65
+ )
66
+
67
+ # Create env
68
+ train_env, eval_env = make_environment(
69
+ cfg,
70
+ cfg.env.train_num_envs,
71
+ cfg.env.eval_num_envs,
72
+ logger=logger,
73
+ )
74
+
75
+ # Create replay buffer
76
+ replay_buffer = make_replay_buffer(
77
+ batch_size=cfg.optim.batch_size,
78
+ prb=cfg.replay_buffer.prb,
79
+ buffer_size=cfg.replay_buffer.size,
80
+ device="cpu",
81
+ )
82
+
83
+ # create agent
84
+ model = make_cql_model(cfg, train_env, eval_env, device)
85
+
86
+ compile_mode = None
87
+ if cfg.compile.compile:
88
+ if cfg.compile.compile_mode not in (None, ""):
89
+ compile_mode = cfg.compile.compile_mode
90
+ elif cfg.compile.cudagraphs:
91
+ compile_mode = "default"
92
+ else:
93
+ compile_mode = "reduce-overhead"
94
+
95
+ # Create collector
96
+ collector = make_collector(
97
+ cfg,
98
+ train_env,
99
+ actor_model_explore=model[0],
100
+ compile=cfg.compile.compile,
101
+ compile_mode=compile_mode,
102
+ cudagraph=cfg.compile.cudagraphs,
103
+ )
104
+
105
+ # Create loss
106
+ loss_module, target_net_updater = make_continuous_loss(
107
+ cfg.loss, model, device=device
108
+ )
109
+
110
+ # Create optimizer
111
+ (
112
+ policy_optim,
113
+ critic_optim,
114
+ alpha_optim,
115
+ alpha_prime_optim,
116
+ ) = make_continuous_cql_optimizer(cfg, loss_module)
117
+ optimizer = group_optimizers(
118
+ policy_optim, critic_optim, alpha_optim, alpha_prime_optim
119
+ )
120
+
121
+ def update(sampled_tensordict):
122
+
123
+ loss_td = loss_module(sampled_tensordict)
124
+
125
+ actor_loss = loss_td["loss_actor"]
126
+ q_loss = loss_td["loss_qvalue"]
127
+ cql_loss = loss_td["loss_cql"]
128
+ q_loss = q_loss + cql_loss
129
+ alpha_loss = loss_td["loss_alpha"]
130
+ alpha_prime_loss = loss_td["loss_alpha_prime"]
131
+
132
+ total_loss = alpha_loss + actor_loss + alpha_prime_loss + q_loss
133
+ total_loss.backward()
134
+ optimizer.step()
135
+ optimizer.zero_grad(set_to_none=True)
136
+
137
+ # update qnet_target params
138
+ target_net_updater.step()
139
+
140
+ return loss_td.detach()
141
+
142
+ if compile_mode:
143
+ update = torch.compile(update, mode=compile_mode)
144
+ if cfg.compile.cudagraphs:
145
+ warnings.warn(
146
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
147
+ category=UserWarning,
148
+ )
149
+ update = CudaGraphModule(update, warmup=50)
150
+
151
+ # Main loop
152
+ collected_frames = 0
153
+ pbar = tqdm.tqdm(total=cfg.collector.total_frames)
154
+
155
+ init_random_frames = cfg.collector.init_random_frames
156
+ num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
157
+ prb = cfg.replay_buffer.prb
158
+ frames_per_batch = cfg.collector.frames_per_batch
159
+ evaluation_interval = cfg.logger.log_interval
160
+ eval_rollout_steps = cfg.logger.eval_steps
161
+
162
+ c_iter = iter(collector)
163
+ total_iter = len(collector)
164
+ for i in range(total_iter):
165
+ timeit.printevery(1000, total_iter, erase=True)
166
+ with timeit("collecting"):
167
+ tensordict = next(c_iter)
168
+ pbar.update(tensordict.numel())
169
+ # update weights of the inference policy
170
+ collector.update_policy_weights_()
171
+
172
+ with timeit("rb - extend"):
173
+ tensordict = tensordict.view(-1)
174
+ current_frames = tensordict.numel()
175
+ # add to replay buffer
176
+ replay_buffer.extend(tensordict)
177
+ collected_frames += current_frames
178
+
179
+ if collected_frames >= init_random_frames:
180
+ log_loss_td = TensorDict(batch_size=[num_updates], device=device)
181
+ for j in range(num_updates):
182
+ pbar.set_description(f"optim iter {j}")
183
+ with timeit("rb - sample"):
184
+ # sample from replay buffer
185
+ sampled_tensordict = replay_buffer.sample().to(device)
186
+
187
+ with timeit("update"):
188
+ torch.compiler.cudagraph_mark_step_begin()
189
+ loss_td = update(sampled_tensordict)
190
+ log_loss_td[j] = loss_td.detach()
191
+ # update priority
192
+ if prb:
193
+ with timeit("rb - update priority"):
194
+ replay_buffer.update_priority(sampled_tensordict)
195
+
196
+ episode_rewards = tensordict["next", "episode_reward"][
197
+ tensordict["next", "done"]
198
+ ]
199
+ # Logging
200
+ metrics_to_log = {}
201
+ if len(episode_rewards) > 0:
202
+ episode_length = tensordict["next", "step_count"][
203
+ tensordict["next", "done"]
204
+ ]
205
+ metrics_to_log["train/reward"] = episode_rewards.mean().item()
206
+ metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
207
+ episode_length
208
+ )
209
+ if collected_frames >= init_random_frames:
210
+ metrics_to_log["train/loss_actor"] = log_loss_td.get("loss_actor").mean()
211
+ metrics_to_log["train/loss_qvalue"] = log_loss_td.get("loss_qvalue").mean()
212
+ metrics_to_log["train/loss_alpha"] = log_loss_td.get("loss_alpha").mean()
213
+ metrics_to_log["train/loss_alpha_prime"] = log_loss_td.get(
214
+ "loss_alpha_prime"
215
+ ).mean()
216
+ metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean()
217
+
218
+ # Evaluation
219
+ with timeit("eval"):
220
+ prev_test_frame = ((i - 1) * frames_per_batch) // evaluation_interval
221
+ cur_test_frame = (i * frames_per_batch) // evaluation_interval
222
+ final = current_frames >= collector.total_frames
223
+ if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
224
+ with set_exploration_type(
225
+ ExplorationType.DETERMINISTIC
226
+ ), torch.no_grad():
227
+ eval_rollout = eval_env.rollout(
228
+ eval_rollout_steps,
229
+ model[0],
230
+ auto_cast_to_device=True,
231
+ break_when_any_done=True,
232
+ )
233
+ eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
234
+ eval_env.apply(dump_video)
235
+ metrics_to_log["eval/reward"] = eval_reward
236
+
237
+ metrics_to_log.update(timeit.todict(prefix="time"))
238
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
239
+ log_metrics(logger, metrics_to_log, collected_frames)
240
+
241
+ collector.shutdown()
242
+ if not eval_env.is_closed:
243
+ eval_env.close()
244
+ if not train_env.is_closed:
245
+ train_env.close()
246
+
247
+
248
+ if __name__ == "__main__":
249
+ main()
@@ -0,0 +1,180 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """CQL Example.
7
+
8
+ This is a self-contained example of a discrete offline CQL training script.
9
+
10
+ The helper functions are coded in the utils.py associated with this script.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import warnings
15
+
16
+ import hydra
17
+ import numpy as np
18
+ import torch
19
+ import tqdm
20
+ from tensordict.nn import CudaGraphModule
21
+ from torchrl._utils import get_available_device, timeit
22
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
23
+ from torchrl.record.loggers import generate_exp_name, get_logger
24
+ from utils import (
25
+ dump_video,
26
+ log_metrics,
27
+ make_discrete_cql_optimizer,
28
+ make_discrete_loss,
29
+ make_discretecql_model,
30
+ make_environment,
31
+ make_offline_discrete_replay_buffer,
32
+ )
33
+
34
+ torch.set_float32_matmul_precision("high")
35
+
36
+
37
+ @hydra.main(version_base="1.1", config_path="", config_name="discrete_offline_config")
38
+ def main(cfg): # noqa: F821
39
+ device = (
40
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
41
+ )
42
+
43
+ # Create logger
44
+ exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name)
45
+ logger = None
46
+ if cfg.logger.backend:
47
+ logger = get_logger(
48
+ logger_type=cfg.logger.backend,
49
+ logger_name="discretecql_logging",
50
+ experiment_name=exp_name,
51
+ wandb_kwargs={
52
+ "mode": cfg.logger.mode,
53
+ "config": dict(cfg),
54
+ "project": cfg.logger.project_name,
55
+ "group": cfg.logger.group_name,
56
+ },
57
+ )
58
+
59
+ # Set seeds
60
+ torch.manual_seed(cfg.env.seed)
61
+ np.random.seed(cfg.env.seed)
62
+ if cfg.env.seed is not None:
63
+ warnings.warn(
64
+ "The seed in the environment config is deprecated. "
65
+ "Please set the seed in the optim config instead."
66
+ )
67
+
68
+ # Create replay buffer
69
+ replay_buffer = make_offline_discrete_replay_buffer(cfg.replay_buffer)
70
+
71
+ # Create env
72
+ train_env, eval_env = make_environment(
73
+ cfg, train_num_envs=1, eval_num_envs=cfg.logger.eval_envs, logger=logger
74
+ )
75
+
76
+ # Create agent
77
+ model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device)
78
+
79
+ del train_env
80
+
81
+ # Create loss
82
+ loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device)
83
+
84
+ # Create optimizers
85
+ optimizer = make_discrete_cql_optimizer(cfg, loss_module) # optimizer for CQL loss
86
+
87
+ def update(data):
88
+
89
+ # Compute loss components
90
+ loss_vals = loss_module(data)
91
+
92
+ q_loss = loss_vals["loss_qvalue"]
93
+ cql_loss = loss_vals["loss_cql"]
94
+
95
+ # Total loss = Q-learning loss + CQL regularization
96
+ loss = q_loss + cql_loss
97
+
98
+ loss.backward()
99
+ optimizer.step()
100
+ optimizer.zero_grad(set_to_none=True)
101
+
102
+ # Soft update of target Q-network
103
+ target_net_updater.step()
104
+
105
+ # Detach to avoid keeping computation graph in logging
106
+ return loss.detach(), loss_vals.detach()
107
+
108
+ compile_mode = None
109
+ if cfg.compile.compile:
110
+ if cfg.compile.compile_mode not in (None, ""):
111
+ compile_mode = cfg.compile.compile_mode
112
+ elif cfg.compile.cudagraphs:
113
+ compile_mode = "default"
114
+ else:
115
+ compile_mode = "reduce-overhead"
116
+ update = torch.compile(update, mode=compile_mode)
117
+ if cfg.compile.cudagraphs:
118
+ warnings.warn(
119
+ "CudaGraphModule es experimental y puede llevar a resultados incorrectos silenciosamente. Úsalo con precaución.",
120
+ category=UserWarning,
121
+ )
122
+ update = CudaGraphModule(update, warmup=50)
123
+
124
+ pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
125
+
126
+ gradient_steps = cfg.optim.gradient_steps
127
+ policy_eval_start = cfg.optim.policy_eval_start
128
+ evaluation_interval = cfg.logger.eval_iter
129
+ eval_steps = cfg.logger.eval_steps
130
+
131
+ # Training loop
132
+ policy_eval_start = torch.tensor(policy_eval_start, device=device)
133
+ for i in range(gradient_steps):
134
+ timeit.printevery(1000, gradient_steps, erase=True)
135
+ pbar.update(1)
136
+ # sample data
137
+ with timeit("sample"):
138
+ data = replay_buffer.sample()
139
+
140
+ with timeit("update"):
141
+ torch.compiler.cudagraph_mark_step_begin()
142
+ loss, loss_vals = update(data.to(device))
143
+
144
+ # log metrics
145
+ metrics_to_log = {
146
+ "loss": loss.cpu(),
147
+ **loss_vals.cpu(),
148
+ }
149
+
150
+ # evaluation
151
+ with timeit("log/eval"):
152
+ if i % evaluation_interval == 0:
153
+ with set_exploration_type(
154
+ ExplorationType.DETERMINISTIC
155
+ ), torch.no_grad():
156
+ eval_td = eval_env.rollout(
157
+ max_steps=eval_steps,
158
+ policy=explore_policy,
159
+ auto_cast_to_device=True,
160
+ )
161
+ eval_env.apply(dump_video)
162
+
163
+ # eval_td: matrix of shape: [num_episodes, max_steps, ...]
164
+ eval_reward = (
165
+ eval_td["next", "reward"].sum(1).mean().item()
166
+ ) # mean computed over the sum of rewards for each episode
167
+ metrics_to_log["evaluation_reward"] = eval_reward
168
+
169
+ with timeit("log"):
170
+ metrics_to_log.update(timeit.todict(prefix="time"))
171
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
172
+ log_metrics(logger, metrics_to_log, i)
173
+
174
+ pbar.close()
175
+ if not eval_env.is_closed:
176
+ eval_env.close()
177
+
178
+
179
+ if __name__ == "__main__":
180
+ main()