torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,86 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from torchrl.modules.tensordict_module.common import DistributionalDQNnet
8
+
9
+ from .batchrenorm import BatchRenorm1d
10
+
11
+ from .decision_transformer import DecisionTransformer
12
+ from .exploration import (
13
+ ConsistentDropout,
14
+ ConsistentDropoutModule,
15
+ NoisyLazyLinear,
16
+ NoisyLinear,
17
+ reset_noise,
18
+ )
19
+ from .llm import GPT2RewardModel
20
+ from .model_based import (
21
+ DreamerActor,
22
+ ObsDecoder,
23
+ ObsEncoder,
24
+ RSSMPosterior,
25
+ RSSMPrior,
26
+ RSSMRollout,
27
+ )
28
+ from .models import (
29
+ Conv2dNet,
30
+ Conv3dNet,
31
+ ConvNet,
32
+ DdpgCnnActor,
33
+ DdpgCnnQNet,
34
+ DdpgMlpActor,
35
+ DdpgMlpQNet,
36
+ DTActor,
37
+ DuelingCnnDQNet,
38
+ DuelingMlpDQNet,
39
+ MLP,
40
+ OnlineDTActor,
41
+ )
42
+ from .multiagent import (
43
+ MultiAgentConvNet,
44
+ MultiAgentMLP,
45
+ MultiAgentNetBase,
46
+ QMixer,
47
+ VDNMixer,
48
+ )
49
+ from .utils import Squeeze2dLayer, SqueezeLayer
50
+
51
+ __all__ = [
52
+ "DistributionalDQNnet",
53
+ "BatchRenorm1d",
54
+ "DecisionTransformer",
55
+ "GPT2RewardModel",
56
+ "ConsistentDropout",
57
+ "ConsistentDropoutModule",
58
+ "NoisyLazyLinear",
59
+ "NoisyLinear",
60
+ "reset_noise",
61
+ "DreamerActor",
62
+ "ObsDecoder",
63
+ "ObsEncoder",
64
+ "RSSMPosterior",
65
+ "RSSMPrior",
66
+ "RSSMRollout",
67
+ "Conv2dNet",
68
+ "Conv3dNet",
69
+ "ConvNet",
70
+ "DdpgCnnActor",
71
+ "DdpgCnnQNet",
72
+ "DdpgMlpActor",
73
+ "DdpgMlpQNet",
74
+ "DTActor",
75
+ "DuelingCnnDQNet",
76
+ "DuelingMlpDQNet",
77
+ "MLP",
78
+ "OnlineDTActor",
79
+ "MultiAgentConvNet",
80
+ "MultiAgentMLP",
81
+ "MultiAgentNetBase",
82
+ "QMixer",
83
+ "VDNMixer",
84
+ "Squeeze2dLayer",
85
+ "SqueezeLayer",
86
+ ]
@@ -0,0 +1,119 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class BatchRenorm1d(nn.Module):
12
+ """BatchRenorm Module (https://arxiv.org/abs/1702.03275).
13
+
14
+ The code is adapted from https://github.com/google-research/corenet
15
+
16
+ BatchRenorm is an enhanced version of the standard BatchNorm. Unlike BatchNorm,
17
+ it utilizes running statistics to normalize batches after an initial warmup phase.
18
+ This approach reduces the impact of "outlier" batches that may occur during
19
+ extended training periods, making BatchRenorm more robust for long training runs.
20
+
21
+ During the warmup phase, BatchRenorm functions identically to a BatchNorm layer.
22
+
23
+ Args:
24
+ num_features (int): Number of features in the input tensor.
25
+
26
+ Keyword Args:
27
+ momentum (:obj:`float`, optional): Momentum factor for computing the running mean and variance.
28
+ Defaults to ``0.01``.
29
+ eps (:obj:`float`, optional): Small value added to the variance to avoid division by zero.
30
+ Defaults to ``1e-5``.
31
+ max_r (:obj:`float`, optional): Maximum value for the scaling factor r.
32
+ Defaults to ``3.0``.
33
+ max_d (:obj:`float`, optional): Maximum value for the bias factor d.
34
+ Defaults to ``5.0``.
35
+ warmup_steps (int, optional): Number of warm-up steps for the running mean and variance.
36
+ Defaults to ``10000``.
37
+ smooth (bool, optional): if ``True``, the behavior smoothly transitions from regular
38
+ batch-norm (when ``iter=0``) to batch-renorm (when ``iter=warmup_steps``).
39
+ Otherwise, the behavior will transition from batch-norm to batch-renorm when
40
+ ``iter=warmup_steps``. Defaults to ``False``.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ num_features: int,
46
+ *,
47
+ momentum: float = 0.01,
48
+ eps: float = 1e-5,
49
+ max_r: float = 3.0,
50
+ max_d: float = 5.0,
51
+ warmup_steps: int = 10000,
52
+ smooth: bool = False,
53
+ ):
54
+ super().__init__()
55
+ self.num_features = num_features
56
+ self.eps = eps
57
+ self.momentum = momentum
58
+ self.max_r = max_r
59
+ self.max_d = max_d
60
+ self.warmup_steps = warmup_steps
61
+ self.smooth = smooth
62
+
63
+ self.register_buffer(
64
+ "running_mean", torch.zeros(num_features, dtype=torch.float32)
65
+ )
66
+ self.register_buffer(
67
+ "running_var", torch.ones(num_features, dtype=torch.float32)
68
+ )
69
+ self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.int64))
70
+ self.weight = nn.Parameter(torch.ones(num_features, dtype=torch.float32))
71
+ self.bias = nn.Parameter(torch.zeros(num_features, dtype=torch.float32))
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ if not x.dim() >= 2:
75
+ raise ValueError(
76
+ f"The {type(self).__name__} expects a 2D (or more) tensor, got {x.dim()}."
77
+ )
78
+
79
+ view_dims = [1, x.shape[1]] + [1] * (x.dim() - 2)
80
+
81
+ def _v(v):
82
+ return v.view(view_dims)
83
+
84
+ running_std = (self.running_var + self.eps).sqrt_()
85
+
86
+ if self.training:
87
+ reduce_dims = [i for i in range(x.dim()) if i != 1]
88
+ b_mean = x.mean(reduce_dims)
89
+ b_var = x.var(reduce_dims, unbiased=False)
90
+ b_std = (b_var + self.eps).sqrt_()
91
+
92
+ r = torch.clamp((b_std.detach() / running_std), 1 / self.max_r, self.max_r)
93
+ d = torch.clamp(
94
+ (b_mean.detach() - self.running_mean) / running_std,
95
+ -self.max_d,
96
+ self.max_d,
97
+ )
98
+
99
+ # Compute warmup factor (0 during warmup, 1 after warmup)
100
+ if self.warmup_steps > 0:
101
+ if self.smooth:
102
+ warmup_factor = self.num_batches_tracked / self.warmup_steps
103
+ else:
104
+ warmup_factor = self.num_batches_tracked // self.warmup_steps
105
+ r = 1.0 + (r - 1.0) * warmup_factor
106
+ d = d * warmup_factor
107
+
108
+ x = (x - _v(b_mean)) / _v(b_std) * _v(r) + _v(d)
109
+
110
+ unbiased_var = b_var.detach() * x.shape[0] / (x.shape[0] - 1)
111
+ self.running_var += self.momentum * (unbiased_var - self.running_var)
112
+ self.running_mean += self.momentum * (b_mean.detach() - self.running_mean)
113
+ self.num_batches_tracked += 1
114
+ self.num_batches_tracked.clamp_max(self.warmup_steps)
115
+ else:
116
+ x = (x - _v(self.running_mean)) / _v(running_std)
117
+
118
+ x = _v(self.weight) * x + _v(self.bias)
119
+ return x
@@ -0,0 +1,179 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import dataclasses
8
+ import importlib
9
+ from contextlib import nullcontext
10
+ from dataclasses import dataclass
11
+ from typing import Any
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ _has_transformers = importlib.util.find_spec("transformers") is not None
17
+
18
+
19
+ class DecisionTransformer(nn.Module):
20
+ """Online Decision Transformer.
21
+
22
+ Desdescribed in https://arxiv.org/abs/2202.05607 .
23
+
24
+ The transformer utilizes a default config to create the GPT2 model if the user does not provide a specific config.
25
+ default_config = {
26
+ ... "n_embd": 256,
27
+ ... "n_layer": 4,
28
+ ... "n_head": 4,
29
+ ... "n_inner": 1024,
30
+ ... "activation": "relu",
31
+ ... "n_positions": 1024,
32
+ ... "resid_pdrop": 0.1,
33
+ ... "attn_pdrop": 0.1,
34
+ }
35
+
36
+ Args:
37
+ state_dim (int): dimension of the state space
38
+ action_dim (int): dimension of the action space
39
+ config (:obj:`~.DTConfig` or dict, optional): transformer architecture configuration,
40
+ used to create the GPT2Config from transformers.
41
+ Defaults to ``default_config``.
42
+
43
+
44
+ Example:
45
+ >>> config = DecisionTransformer.default_config()
46
+ >>> config.n_embd = 128
47
+ >>> print(config)
48
+ DTConfig(n_embd: 128, n_layer: 4, n_head: 4, n_inner: 1024, activation: relu, n_positions: 1024, resid_pdrop: 0.1, attn_pdrop: 0.1)
49
+ >>> # alternatively
50
+ >>> config = DecisionTransformer.DTConfig(n_embd=128)
51
+ >>> model = DecisionTransformer(state_dim=4, action_dim=2, config=config)
52
+ >>> batch_size = [3, 32]
53
+ >>> length = 10
54
+ >>> observation = torch.randn(*batch_size, length, 4)
55
+ >>> action = torch.randn(*batch_size, length, 2)
56
+ >>> return_to_go = torch.randn(*batch_size, length, 1)
57
+ >>> output = model(observation, action, return_to_go)
58
+ >>> output.shape
59
+ torch.Size([3, 32, 10, 128])
60
+
61
+ """
62
+
63
+ @dataclass
64
+ class DTConfig:
65
+ """Default configuration for DecisionTransformer."""
66
+
67
+ n_embd: Any = 256
68
+ n_layer: Any = 4
69
+ n_head: Any = 4
70
+ n_inner: Any = 1024
71
+ activation: Any = "relu"
72
+ n_positions: Any = 1024
73
+ resid_pdrop: Any = 0.1
74
+ attn_pdrop: Any = 0.1
75
+
76
+ def __repr__(self):
77
+ fields = []
78
+ for f in dataclasses.fields(self):
79
+ value = getattr(self, f.name)
80
+ fields.append(f"{f.name}: {value}")
81
+ fields = ", ".join(fields)
82
+ return f"{self.__class__.__name__}({fields})"
83
+
84
+ @classmethod
85
+ def default_config(cls):
86
+ return cls.DTConfig()
87
+
88
+ def __init__(
89
+ self,
90
+ state_dim,
91
+ action_dim,
92
+ config: dict | DTConfig = None,
93
+ device: torch.device | None = None,
94
+ ):
95
+
96
+ if not _has_transformers:
97
+ raise ImportError(
98
+ "transformers is not installed. Please install it with `pip install transformers`."
99
+ )
100
+ import transformers
101
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Model
102
+
103
+ if config is None:
104
+ config = self.default_config()
105
+ if isinstance(config, self.DTConfig):
106
+ config = dataclasses.asdict(config)
107
+ if not isinstance(config, dict):
108
+ try:
109
+ config = dict(config)
110
+ except Exception as err:
111
+ raise TypeError(
112
+ f"Config of type {type(config)} is not supported."
113
+ ) from err
114
+
115
+ super().__init__()
116
+
117
+ with torch.device(device) if device is not None else nullcontext():
118
+ gpt_config = transformers.GPT2Config(
119
+ n_embd=config["n_embd"],
120
+ n_layer=config["n_layer"],
121
+ n_head=config["n_head"],
122
+ n_inner=config["n_inner"],
123
+ activation_function=config["activation"],
124
+ n_positions=config["n_positions"],
125
+ resid_pdrop=config["resid_pdrop"],
126
+ attn_pdrop=config["attn_pdrop"],
127
+ vocab_size=1,
128
+ )
129
+ self.state_dim = state_dim
130
+ self.action_dim = action_dim
131
+ self.hidden_size = config["n_embd"]
132
+
133
+ self.transformer = GPT2Model(config=gpt_config)
134
+
135
+ self.embed_return = torch.nn.Linear(1, self.hidden_size)
136
+ self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size)
137
+ self.embed_action = torch.nn.Linear(self.action_dim, self.hidden_size)
138
+
139
+ self.embed_ln = nn.LayerNorm(self.hidden_size)
140
+
141
+ def forward(
142
+ self,
143
+ observation: torch.Tensor,
144
+ action: torch.Tensor,
145
+ return_to_go: torch.Tensor,
146
+ ):
147
+ batch_size, seq_length = observation.shape[:-2], observation.shape[-2]
148
+ batch_size_orig = batch_size
149
+ if len(batch_size) != 1:
150
+ # TODO: vmap over transformer once this is possible
151
+ observation = observation.view(-1, *observation.shape[-2:])
152
+ action = action.view(-1, *action.shape[-2:])
153
+ return_to_go = return_to_go.view(-1, *return_to_go.shape[-2:])
154
+ batch_size = torch.Size([batch_size.numel()])
155
+
156
+ # embed each modality with a different head
157
+ state_embeddings = self.embed_state(observation)
158
+ action_embeddings = self.embed_action(action)
159
+ returns_embeddings = self.embed_return(return_to_go)
160
+
161
+ # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
162
+ # which works nice in an autoregressive sense since states predict actions
163
+ stacked_inputs = torch.stack(
164
+ (returns_embeddings, state_embeddings, action_embeddings), dim=-2
165
+ ).reshape(*batch_size, 3 * seq_length, self.hidden_size)
166
+ stacked_inputs = self.embed_ln(stacked_inputs)
167
+
168
+ # we feed in the input embeddings (not word indices as in NLP) to the model
169
+ transformer_outputs = self.transformer(
170
+ inputs_embeds=stacked_inputs,
171
+ )
172
+ x = transformer_outputs["last_hidden_state"]
173
+
174
+ # reshape x so that the second dimension corresponds to the original
175
+ # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
176
+ x = x.reshape(*batch_size, seq_length, 3, self.hidden_size).transpose(-3, -2)
177
+ if batch_size_orig is batch_size:
178
+ return x[..., 1, :, :] # only state tokens
179
+ return x[..., 1, :, :].reshape(*batch_size_orig, *x.shape[-2:])