torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,164 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """IQL Example.
6
+
7
+ This is a self-contained example of an offline IQL training script.
8
+
9
+ The helper functions are coded in the utils.py associated with this script.
10
+
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import warnings
15
+
16
+ import hydra
17
+ import numpy as np
18
+ import torch
19
+ import tqdm
20
+ from tensordict.nn import CudaGraphModule
21
+ from torchrl._utils import get_available_device, timeit
22
+ from torchrl.envs import set_gym_backend
23
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
24
+ from torchrl.objectives import group_optimizers
25
+ from torchrl.record.loggers import generate_exp_name, get_logger
26
+ from utils import (
27
+ dump_video,
28
+ log_metrics,
29
+ make_environment,
30
+ make_iql_model,
31
+ make_iql_optimizer,
32
+ make_loss,
33
+ make_offline_replay_buffer,
34
+ )
35
+
36
+ torch.set_float32_matmul_precision("high")
37
+
38
+
39
+ @hydra.main(config_path="", config_name="offline_config")
40
+ def main(cfg: DictConfig): # noqa: F821
41
+ set_gym_backend(cfg.env.backend).set()
42
+
43
+ # Create logger
44
+ exp_name = generate_exp_name("IQL-offline", cfg.logger.exp_name)
45
+ logger = None
46
+ if cfg.logger.backend:
47
+ logger = get_logger(
48
+ logger_type=cfg.logger.backend,
49
+ logger_name="iql_logging",
50
+ experiment_name=exp_name,
51
+ wandb_kwargs={
52
+ "mode": cfg.logger.mode,
53
+ "config": dict(cfg),
54
+ "project": cfg.logger.project_name,
55
+ "group": cfg.logger.group_name,
56
+ },
57
+ )
58
+
59
+ # Set seeds
60
+ torch.manual_seed(cfg.env.seed)
61
+ np.random.seed(cfg.env.seed)
62
+ device = (
63
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
64
+ )
65
+
66
+ # Create env
67
+ train_env, eval_env = make_environment(
68
+ cfg,
69
+ cfg.logger.eval_envs,
70
+ logger=logger,
71
+ )
72
+
73
+ # Create replay buffer
74
+ replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
75
+
76
+ # Create agent
77
+ model = make_iql_model(cfg, train_env, eval_env, device)
78
+
79
+ # Create loss
80
+ loss_module, target_net_updater = make_loss(cfg.loss, model, device=device)
81
+
82
+ # Create optimizer
83
+ optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
84
+ cfg.optim, loss_module
85
+ )
86
+ optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value)
87
+
88
+ def update(data):
89
+ optimizer.zero_grad(set_to_none=True)
90
+ # compute losses
91
+ loss_info = loss_module(data)
92
+ actor_loss = loss_info["loss_actor"]
93
+ value_loss = loss_info["loss_value"]
94
+ q_loss = loss_info["loss_qvalue"]
95
+
96
+ (actor_loss + value_loss + q_loss).backward()
97
+ optimizer.step()
98
+
99
+ # update qnet_target params
100
+ target_net_updater.step()
101
+ return loss_info.detach()
102
+
103
+ compile_mode = None
104
+ if cfg.compile.compile:
105
+ compile_mode = cfg.compile.compile_mode
106
+ if compile_mode in ("", None):
107
+ if cfg.compile.cudagraphs:
108
+ compile_mode = "default"
109
+ else:
110
+ compile_mode = "reduce-overhead"
111
+
112
+ if cfg.compile.compile:
113
+ update = torch.compile(update, mode=compile_mode)
114
+ if cfg.compile.cudagraphs:
115
+ warnings.warn(
116
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
117
+ category=UserWarning,
118
+ )
119
+ update = CudaGraphModule(update, warmup=50)
120
+
121
+ pbar = tqdm.tqdm(range(cfg.optim.gradient_steps))
122
+
123
+ evaluation_interval = cfg.logger.eval_iter
124
+ eval_steps = cfg.logger.eval_steps
125
+
126
+ # Training loop
127
+ for i in pbar:
128
+ timeit.printevery(1000, cfg.optim.gradient_steps, erase=True)
129
+
130
+ # sample data
131
+ with timeit("sample"):
132
+ data = replay_buffer.sample()
133
+ data = data.to(device)
134
+
135
+ with timeit("update"):
136
+ torch.compiler.cudagraph_mark_step_begin()
137
+ loss_info = update(data)
138
+
139
+ # evaluation
140
+ metrics_to_log = loss_info.to_dict()
141
+ if i % evaluation_interval == 0:
142
+ with set_exploration_type(
143
+ ExplorationType.DETERMINISTIC
144
+ ), torch.no_grad(), timeit("eval"):
145
+ eval_td = eval_env.rollout(
146
+ max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
147
+ )
148
+ eval_env.apply(dump_video)
149
+ eval_reward = eval_td["next", "reward"].sum(1).mean().item()
150
+ metrics_to_log["evaluation_reward"] = eval_reward
151
+ if logger is not None:
152
+ metrics_to_log.update(timeit.todict(prefix="time"))
153
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
154
+ log_metrics(logger, metrics_to_log, i)
155
+
156
+ pbar.close()
157
+ if not eval_env.is_closed:
158
+ eval_env.close()
159
+ if not train_env.is_closed:
160
+ train_env.close()
161
+
162
+
163
+ if __name__ == "__main__":
164
+ main()
@@ -0,0 +1,225 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """IQL Example.
6
+
7
+ This is a self-contained example of an online IQL training script.
8
+
9
+ It works across Gym and MuJoCo over a variety of tasks.
10
+
11
+ The helper functions are coded in the utils.py associated with this script.
12
+
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import warnings
17
+
18
+ import hydra
19
+ import numpy as np
20
+ import torch
21
+ import tqdm
22
+ from tensordict.nn import CudaGraphModule
23
+ from torchrl._utils import get_available_device, timeit
24
+ from torchrl.envs import set_gym_backend
25
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
26
+ from torchrl.objectives import group_optimizers
27
+ from torchrl.record.loggers import generate_exp_name, get_logger
28
+ from utils import (
29
+ dump_video,
30
+ log_metrics,
31
+ make_collector,
32
+ make_environment,
33
+ make_iql_model,
34
+ make_iql_optimizer,
35
+ make_loss,
36
+ make_replay_buffer,
37
+ )
38
+
39
+ torch.set_float32_matmul_precision("high")
40
+
41
+
42
+ @hydra.main(config_path="", config_name="online_config")
43
+ def main(cfg: DictConfig): # noqa: F821
44
+ set_gym_backend(cfg.env.backend).set()
45
+
46
+ # Create logger
47
+ exp_name = generate_exp_name("IQL-online", cfg.logger.exp_name)
48
+ logger = None
49
+ if cfg.logger.backend:
50
+ logger = get_logger(
51
+ logger_type=cfg.logger.backend,
52
+ logger_name="iql_logging",
53
+ experiment_name=exp_name,
54
+ wandb_kwargs={
55
+ "mode": cfg.logger.mode,
56
+ "config": dict(cfg),
57
+ "project": cfg.logger.project_name,
58
+ "group": cfg.logger.group_name,
59
+ },
60
+ )
61
+
62
+ # Set seeds
63
+ torch.manual_seed(cfg.env.seed)
64
+ np.random.seed(cfg.env.seed)
65
+ device = (
66
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
67
+ )
68
+
69
+ # Create environments
70
+ train_env, eval_env = make_environment(
71
+ cfg,
72
+ cfg.env.train_num_envs,
73
+ cfg.env.eval_num_envs,
74
+ logger=logger,
75
+ )
76
+
77
+ # Create replay buffer
78
+ replay_buffer = make_replay_buffer(
79
+ batch_size=cfg.optim.batch_size,
80
+ prb=cfg.replay_buffer.prb,
81
+ buffer_size=cfg.replay_buffer.size,
82
+ device="cpu",
83
+ )
84
+
85
+ # Create model
86
+ model = make_iql_model(cfg, train_env, eval_env, device)
87
+
88
+ compile_mode = None
89
+ if cfg.compile.compile:
90
+ compile_mode = cfg.compile.compile_mode
91
+ if compile_mode in ("", None):
92
+ if cfg.compile.cudagraphs:
93
+ compile_mode = "default"
94
+ else:
95
+ compile_mode = "reduce-overhead"
96
+
97
+ # Create collector
98
+ collector = make_collector(
99
+ cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode
100
+ )
101
+
102
+ # Create loss
103
+ loss_module, target_net_updater = make_loss(cfg.loss, model, device=device)
104
+
105
+ # Create optimizer
106
+ optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
107
+ cfg.optim, loss_module
108
+ )
109
+ optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value)
110
+ del optimizer_actor, optimizer_critic, optimizer_value
111
+
112
+ def update(sampled_tensordict):
113
+ optimizer.zero_grad(set_to_none=True)
114
+ # compute losses
115
+ loss_info = loss_module(sampled_tensordict)
116
+ actor_loss = loss_info["loss_actor"]
117
+ value_loss = loss_info["loss_value"]
118
+ q_loss = loss_info["loss_qvalue"]
119
+
120
+ (actor_loss + value_loss + q_loss).backward()
121
+ optimizer.step()
122
+
123
+ # update qnet_target params
124
+ target_net_updater.step()
125
+ return loss_info.detach()
126
+
127
+ if cfg.compile.compile:
128
+ update = torch.compile(update, mode=compile_mode)
129
+ if cfg.compile.cudagraphs:
130
+ warnings.warn(
131
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
132
+ category=UserWarning,
133
+ )
134
+ update = CudaGraphModule(update, warmup=50)
135
+
136
+ # Main loop
137
+ collected_frames = 0
138
+
139
+ init_random_frames = cfg.collector.init_random_frames
140
+ num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
141
+ prb = cfg.replay_buffer.prb
142
+ eval_iter = cfg.logger.eval_iter
143
+ frames_per_batch = cfg.collector.frames_per_batch
144
+ eval_rollout_steps = cfg.collector.max_frames_per_traj
145
+ collector_iter = iter(collector)
146
+ pbar = tqdm.tqdm(range(collector.total_frames))
147
+ total_iter = len(collector)
148
+ for _ in range(total_iter):
149
+ timeit.printevery(1000, total_iter, erase=True)
150
+
151
+ with timeit("collection"):
152
+ tensordict = next(collector_iter)
153
+ current_frames = tensordict.numel()
154
+ pbar.update(current_frames)
155
+ # update weights of the inference policy
156
+ collector.update_policy_weights_()
157
+
158
+ with timeit("rb - extend"):
159
+ # add to replay buffer
160
+ tensordict = tensordict.reshape(-1)
161
+ replay_buffer.extend(tensordict.cpu())
162
+ collected_frames += current_frames
163
+
164
+ # optimization steps
165
+ with timeit("training"):
166
+ if collected_frames >= init_random_frames:
167
+ for _ in range(num_updates):
168
+ with timeit("rb - sampling"):
169
+ # sample from replay buffer
170
+ sampled_tensordict = replay_buffer.sample().to(device)
171
+ with timeit("update"):
172
+ torch.compiler.cudagraph_mark_step_begin()
173
+ loss_info = update(sampled_tensordict)
174
+ # update priority
175
+ if prb:
176
+ replay_buffer.update_priority(sampled_tensordict)
177
+ episode_rewards = tensordict["next", "episode_reward"][
178
+ tensordict["next", "done"]
179
+ ]
180
+
181
+ # Logging
182
+ metrics_to_log = {}
183
+ # Evaluation
184
+ if abs(collected_frames % eval_iter) < frames_per_batch:
185
+ with set_exploration_type(
186
+ ExplorationType.DETERMINISTIC
187
+ ), torch.no_grad(), timeit("evaluating"):
188
+ eval_rollout = eval_env.rollout(
189
+ eval_rollout_steps,
190
+ model[0],
191
+ auto_cast_to_device=True,
192
+ break_when_any_done=True,
193
+ )
194
+ eval_env.apply(dump_video)
195
+ eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
196
+ metrics_to_log["eval/reward"] = eval_reward
197
+ if len(episode_rewards) > 0:
198
+ episode_length = tensordict["next", "step_count"][
199
+ tensordict["next", "done"]
200
+ ]
201
+ metrics_to_log["train/reward"] = episode_rewards.mean().item()
202
+ metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
203
+ episode_length
204
+ )
205
+ if collected_frames >= init_random_frames:
206
+ metrics_to_log["train/q_loss"] = loss_info["loss_qvalue"]
207
+ metrics_to_log["train/actor_loss"] = loss_info["loss_actor"]
208
+ metrics_to_log["train/value_loss"] = loss_info["loss_value"]
209
+ metrics_to_log["train/entropy"] = loss_info.get("entropy")
210
+
211
+ if logger is not None:
212
+ metrics_to_log.update(timeit.todict(prefix="time"))
213
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
214
+ log_metrics(logger, metrics_to_log, collected_frames)
215
+
216
+ collector.shutdown()
217
+
218
+ if not eval_env.is_closed:
219
+ eval_env.close()
220
+ if not train_env.is_closed:
221
+ train_env.close()
222
+
223
+
224
+ if __name__ == "__main__":
225
+ main()