torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,167 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # Lamb optimizer directly copied from https://github.com/facebookresearch/online-dt
6
+ from __future__ import annotations
7
+
8
+ import math
9
+
10
+ import torch
11
+ from torch.optim import Optimizer
12
+
13
+
14
+ class Lamb(Optimizer):
15
+ """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
16
+ reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
17
+ LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
18
+ Arguments:
19
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
20
+ lr (:obj:`float`, optional): learning rate. (default: 1e-3)
21
+ betas (Tuple[float, float], optional): coefficients used for computing
22
+ running averages of gradient and its norm. (default: (0.9, 0.999))
23
+ eps (:obj:`float`, optional): term added to the denominator to improve
24
+ numerical stability. (default: 1e-8)
25
+ weight_decay (:obj:`float`, optional): weight decay (L2 penalty) (default: 0)
26
+ grad_averaging (bool, optional): whether apply (1-beta2) to grad when
27
+ calculating running averages of gradient. (default: True)
28
+ max_grad_norm (:obj:`float`, optional): value used to clip global grad norm (default: 1.0)
29
+ trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
30
+ always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
31
+ weight decay parameter (default: False)
32
+ .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
33
+ https://arxiv.org/abs/1904.00962
34
+ .. _On the Convergence of Adam and Beyond:
35
+ https://openreview.net/forum?id=ryQu7f-RZ
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ params,
41
+ lr=1e-3,
42
+ bias_correction=True,
43
+ betas=(0.9, 0.999),
44
+ eps=1e-6,
45
+ weight_decay=0.01,
46
+ grad_averaging=True,
47
+ max_grad_norm=1.0,
48
+ trust_clip=False,
49
+ always_adapt=False,
50
+ ):
51
+ defaults = {
52
+ "lr": lr,
53
+ "bias_correction": bias_correction,
54
+ "betas": betas,
55
+ "eps": eps,
56
+ "weight_decay": weight_decay,
57
+ "grad_averaging": grad_averaging,
58
+ "max_grad_norm": max_grad_norm,
59
+ "trust_clip": trust_clip,
60
+ "always_adapt": always_adapt,
61
+ }
62
+ super().__init__(params, defaults)
63
+
64
+ @torch.no_grad()
65
+ def step(self, closure=None):
66
+ """Performs a single optimization step.
67
+ Arguments:
68
+ closure (callable, optional): A closure that reevaluates the model
69
+ and returns the loss.
70
+ """
71
+ loss = None
72
+ if closure is not None:
73
+ with torch.enable_grad():
74
+ loss = closure()
75
+
76
+ device = self.param_groups[0]["params"][0].device
77
+ one_tensor = torch.tensor(
78
+ 1.0, device=device
79
+ ) # because torch.where doesn't handle scalars correctly
80
+ global_grad_norm = torch.zeros(1, device=device)
81
+ for group in self.param_groups:
82
+ for p in group["params"]:
83
+ if p.grad is None:
84
+ continue
85
+ grad = p.grad
86
+ if grad.is_sparse:
87
+ raise RuntimeError(
88
+ "Lamb does not support sparse gradients, consider SparseAdam instead."
89
+ )
90
+ global_grad_norm.add_(grad.pow(2).sum())
91
+
92
+ global_grad_norm = torch.sqrt(global_grad_norm)
93
+ # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
94
+ # scalar types properly https://github.com/pytorch/pytorch/issues/9190
95
+ max_grad_norm = torch.tensor(self.defaults["max_grad_norm"], device=device)
96
+ clip_global_grad_norm = torch.where(
97
+ global_grad_norm > max_grad_norm,
98
+ global_grad_norm / max_grad_norm,
99
+ one_tensor,
100
+ )
101
+
102
+ for group in self.param_groups:
103
+ bias_correction = 1 if group["bias_correction"] else 0
104
+ beta1, beta2 = group["betas"]
105
+ grad_averaging = 1 if group["grad_averaging"] else 0
106
+ beta3 = 1 - beta1 if grad_averaging else 1.0
107
+
108
+ # assume same step across group now to simplify things
109
+ # per parameter step can be easily support by making it tensor, or pass list into kernel
110
+ if "step" in group:
111
+ group["step"] += 1
112
+ else:
113
+ group["step"] = 1
114
+
115
+ if bias_correction:
116
+ bias_correction1 = 1 - beta1 ** group["step"]
117
+ bias_correction2 = 1 - beta2 ** group["step"]
118
+ else:
119
+ bias_correction1, bias_correction2 = 1.0, 1.0
120
+
121
+ for p in group["params"]:
122
+ if p.grad is None:
123
+ continue
124
+ grad = p.grad.div_(clip_global_grad_norm)
125
+ state = self.state[p]
126
+
127
+ # State initialization
128
+ if len(state) == 0:
129
+ # Exponential moving average of gradient valuesa
130
+ state["exp_avg"] = torch.zeros_like(p)
131
+ # Exponential moving average of squared gradient values
132
+ state["exp_avg_sq"] = torch.zeros_like(p)
133
+
134
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
135
+
136
+ # Decay the first and second moment running average coefficient
137
+ exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
138
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t
139
+
140
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
141
+ group["eps"]
142
+ )
143
+ update = (exp_avg / bias_correction1).div_(denom)
144
+
145
+ weight_decay = group["weight_decay"]
146
+ if weight_decay != 0:
147
+ update.add_(p, alpha=weight_decay)
148
+
149
+ if weight_decay != 0 or group["always_adapt"]:
150
+ # Layer-wise LR adaptation. By default, skip adaptation on parameters that are
151
+ # excluded from weight decay, unless always_adapt == True, then always enabled.
152
+ w_norm = p.norm(2.0)
153
+ g_norm = update.norm(2.0)
154
+ # FIXME nested where required since logical and/or not working in PT XLA
155
+ trust_ratio = torch.where(
156
+ w_norm > 0,
157
+ torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
158
+ one_tensor,
159
+ )
160
+ if group["trust_clip"]:
161
+ # LAMBC trust clipping, upper bound fixed at one
162
+ trust_ratio = torch.minimum(trust_ratio, one_tensor)
163
+ update.mul_(trust_ratio)
164
+
165
+ p.add_(update, alpha=-group["lr"])
166
+
167
+ return loss
@@ -0,0 +1,178 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """Online Decision Transformer Example.
6
+ This is a self-contained example of an Online Decision Transformer training script.
7
+ The helper functions are coded in the utils.py associated with this script.
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import warnings
12
+
13
+ import hydra
14
+ import numpy as np
15
+ import torch
16
+ import tqdm
17
+ from tensordict.nn import CudaGraphModule
18
+ from torchrl._utils import get_available_device, logger as torchrl_logger, timeit
19
+ from torchrl.envs.libs.gym import set_gym_backend
20
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
21
+ from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
22
+ from torchrl.record import VideoRecorder
23
+ from utils import (
24
+ dump_video,
25
+ log_metrics,
26
+ make_env,
27
+ make_logger,
28
+ make_odt_loss,
29
+ make_odt_model,
30
+ make_odt_optimizer,
31
+ make_offline_replay_buffer,
32
+ )
33
+
34
+
35
+ @hydra.main(config_path="", config_name="odt_config", version_base="1.1")
36
+ def main(cfg: DictConfig): # noqa: F821
37
+ set_gym_backend(cfg.env.backend).set()
38
+
39
+ model_device = (
40
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
41
+ )
42
+
43
+ # Set seeds
44
+ torch.manual_seed(cfg.env.seed)
45
+ np.random.seed(cfg.env.seed)
46
+
47
+ # Create logger
48
+ logger = make_logger(cfg)
49
+
50
+ # Create offline replay buffer
51
+ offline_buffer, obs_loc, obs_std = make_offline_replay_buffer(
52
+ cfg.replay_buffer, cfg.env.reward_scaling
53
+ )
54
+
55
+ # Create test environment
56
+ test_env = make_env(cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video)
57
+ if cfg.logger.video:
58
+ test_env = test_env.append_transform(
59
+ VideoRecorder(logger, tag="rendered", in_keys=["pixels"])
60
+ )
61
+
62
+ # Create policy model
63
+ policy = make_odt_model(cfg, device=model_device)
64
+
65
+ # Create loss
66
+ loss_module = make_odt_loss(cfg.loss, policy)
67
+
68
+ # Create optimizer
69
+ transformer_optim, temperature_optim, scheduler = make_odt_optimizer(
70
+ cfg.optim, loss_module
71
+ )
72
+
73
+ # Create inference policy
74
+ inference_policy = DecisionTransformerInferenceWrapper(
75
+ policy=policy,
76
+ inference_context=cfg.env.inference_context,
77
+ device=model_device,
78
+ )
79
+ inference_policy.set_tensor_keys(
80
+ observation="observation_cat",
81
+ action="action_cat",
82
+ return_to_go="return_to_go_cat",
83
+ )
84
+
85
+ def update(data):
86
+ transformer_optim.zero_grad(set_to_none=True)
87
+ temperature_optim.zero_grad(set_to_none=True)
88
+ # Compute loss
89
+ loss_vals = loss_module(data.to(model_device))
90
+ transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"]
91
+ temperature_loss = loss_vals["loss_alpha"]
92
+
93
+ (temperature_loss + transformer_loss).backward()
94
+ torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad)
95
+
96
+ transformer_optim.step()
97
+ temperature_optim.step()
98
+
99
+ return loss_vals.detach()
100
+
101
+ if cfg.compile.compile:
102
+ compile_mode = cfg.compile.compile_mode
103
+ if compile_mode in ("", None):
104
+ compile_mode = "default"
105
+ update = torch.compile(update, mode=compile_mode, dynamic=False)
106
+ if cfg.compile.cudagraphs:
107
+ warnings.warn(
108
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
109
+ category=UserWarning,
110
+ )
111
+ if cfg.optim.optimizer == "lamb":
112
+ raise ValueError(
113
+ "cudagraphs isn't compatible with the Lamb optimizer. Use optim.optimizer=Adam instead."
114
+ )
115
+ update = CudaGraphModule(update, warmup=50)
116
+
117
+ pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)
118
+
119
+ pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
120
+ clip_grad = cfg.optim.clip_grad
121
+ eval_steps = cfg.logger.eval_steps
122
+ pretrain_log_interval = cfg.logger.pretrain_log_interval
123
+ reward_scaling = cfg.env.reward_scaling
124
+
125
+ torchrl_logger.info(" ***Pretraining*** ")
126
+ # Pretraining
127
+ for i in range(pretrain_gradient_steps):
128
+ timeit.printevery(1000, pretrain_gradient_steps, erase=True)
129
+ pbar.update(1)
130
+ with timeit("sample"):
131
+ # Sample data
132
+ data = offline_buffer.sample()
133
+
134
+ with timeit("update"):
135
+ torch.compiler.cudagraph_mark_step_begin()
136
+ loss_vals = update(data.to(model_device))
137
+
138
+ scheduler.step()
139
+
140
+ # Log metrics
141
+ metrics_to_log = {
142
+ "train/loss_log_likelihood": loss_vals["loss_log_likelihood"],
143
+ "train/loss_entropy": loss_vals["loss_entropy"],
144
+ "train/loss_alpha": loss_vals["loss_alpha"],
145
+ "train/alpha": loss_vals["alpha"],
146
+ "train/entropy": loss_vals["entropy"],
147
+ }
148
+
149
+ # Evaluation
150
+ with torch.no_grad(), set_exploration_type(
151
+ ExplorationType.DETERMINISTIC
152
+ ), timeit("eval"):
153
+ inference_policy.eval()
154
+ if i % pretrain_log_interval == 0:
155
+ eval_td = test_env.rollout(
156
+ max_steps=eval_steps,
157
+ policy=inference_policy,
158
+ auto_cast_to_device=True,
159
+ break_when_any_done=False,
160
+ )
161
+ test_env.apply(dump_video)
162
+ inference_policy.train()
163
+ metrics_to_log["eval/reward"] = (
164
+ eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
165
+ )
166
+
167
+ if logger is not None:
168
+ metrics_to_log.update(timeit.todict(prefix="time"))
169
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
170
+ log_metrics(logger, metrics_to_log, i)
171
+
172
+ pbar.close()
173
+ if not test_env.is_closed:
174
+ test_env.close()
175
+
176
+
177
+ if __name__ == "__main__":
178
+ main()