torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,327 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """GAIL Example.
6
+
7
+ This is a self-contained example of an offline GAIL training script.
8
+
9
+ The helper functions for gail are coded in the gail_utils.py and helper functions for ppo in ppo_utils.
10
+
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import warnings
15
+
16
+ import hydra
17
+ import numpy as np
18
+ import torch
19
+ import tqdm
20
+ from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer
21
+ from ppo_utils import eval_model, make_env, make_ppo_models
22
+ from tensordict.nn import CudaGraphModule
23
+ from torchrl._utils import compile_with_warmup, get_available_device, timeit
24
+ from torchrl.collectors import SyncDataCollector
25
+ from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
26
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
27
+ from torchrl.envs import set_gym_backend
28
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
29
+ from torchrl.objectives import ClipPPOLoss, GAILLoss, group_optimizers
30
+ from torchrl.objectives.value.advantages import GAE
31
+ from torchrl.record import VideoRecorder
32
+ from torchrl.record.loggers import generate_exp_name, get_logger
33
+
34
+ torch.set_float32_matmul_precision("high")
35
+
36
+
37
+ @hydra.main(config_path="", config_name="config")
38
+ def main(cfg: DictConfig): # noqa: F821
39
+ set_gym_backend(cfg.env.backend).set()
40
+
41
+ device = (
42
+ torch.device(cfg.gail.device) if cfg.gail.device else get_available_device()
43
+ )
44
+ num_mini_batches = (
45
+ cfg.ppo.collector.frames_per_batch // cfg.ppo.loss.mini_batch_size
46
+ )
47
+ total_network_updates = (
48
+ (cfg.ppo.collector.total_frames // cfg.ppo.collector.frames_per_batch)
49
+ * cfg.ppo.loss.ppo_epochs
50
+ * num_mini_batches
51
+ )
52
+
53
+ # Create logger
54
+ exp_name = generate_exp_name("Gail", cfg.logger.exp_name)
55
+ logger = None
56
+ if cfg.logger.backend:
57
+ logger = get_logger(
58
+ logger_type=cfg.logger.backend,
59
+ logger_name="gail_logging",
60
+ experiment_name=exp_name,
61
+ wandb_kwargs={
62
+ "mode": cfg.logger.mode,
63
+ "config": dict(cfg),
64
+ "project": cfg.logger.project_name,
65
+ "group": cfg.logger.group_name,
66
+ },
67
+ )
68
+
69
+ # Set seeds
70
+ torch.manual_seed(cfg.env.seed)
71
+ np.random.seed(cfg.env.seed)
72
+
73
+ # Create models (check utils_mujoco.py)
74
+ actor, critic = make_ppo_models(
75
+ cfg.env.env_name, compile=cfg.compile.compile, device=device
76
+ )
77
+
78
+ # Create data buffer
79
+ data_buffer = TensorDictReplayBuffer(
80
+ storage=LazyTensorStorage(
81
+ cfg.ppo.collector.frames_per_batch,
82
+ device=device,
83
+ compilable=cfg.compile.compile,
84
+ ),
85
+ sampler=SamplerWithoutReplacement(),
86
+ batch_size=cfg.ppo.loss.mini_batch_size,
87
+ compilable=cfg.compile.compile,
88
+ )
89
+
90
+ # Create loss and adv modules
91
+ adv_module = GAE(
92
+ gamma=cfg.ppo.loss.gamma,
93
+ lmbda=cfg.ppo.loss.gae_lambda,
94
+ value_network=critic,
95
+ average_gae=False,
96
+ device=device,
97
+ )
98
+
99
+ loss_module = ClipPPOLoss(
100
+ actor_network=actor,
101
+ critic_network=critic,
102
+ clip_epsilon=cfg.ppo.loss.clip_epsilon,
103
+ loss_critic_type=cfg.ppo.loss.loss_critic_type,
104
+ entropy_coeff=cfg.ppo.loss.entropy_coeff,
105
+ critic_coeff=cfg.ppo.loss.critic_coeff,
106
+ normalize_advantage=True,
107
+ )
108
+
109
+ # Create optimizers
110
+ actor_optim = torch.optim.Adam(
111
+ actor.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5
112
+ )
113
+ critic_optim = torch.optim.Adam(
114
+ critic.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5
115
+ )
116
+ optim = group_optimizers(actor_optim, critic_optim)
117
+ del actor_optim, critic_optim
118
+
119
+ compile_mode = None
120
+ if cfg.compile.compile:
121
+ compile_mode = cfg.compile.compile_mode
122
+ if compile_mode in ("", None):
123
+ if cfg.compile.cudagraphs:
124
+ compile_mode = "default"
125
+ else:
126
+ compile_mode = "reduce-overhead"
127
+
128
+ # Create collector
129
+ collector = SyncDataCollector(
130
+ create_env_fn=make_env(cfg.env.env_name, device),
131
+ policy=actor,
132
+ frames_per_batch=cfg.ppo.collector.frames_per_batch,
133
+ total_frames=cfg.ppo.collector.total_frames,
134
+ device=device,
135
+ max_frames_per_traj=-1,
136
+ compile_policy={"mode": compile_mode} if compile_mode is not None else False,
137
+ cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
138
+ )
139
+
140
+ # Create replay buffer
141
+ replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
142
+
143
+ # Create Discriminator
144
+ discriminator = make_gail_discriminator(cfg, collector.env, device)
145
+
146
+ # Create loss
147
+ discriminator_loss = GAILLoss(
148
+ discriminator,
149
+ use_grad_penalty=cfg.gail.use_grad_penalty,
150
+ gp_lambda=cfg.gail.gp_lambda,
151
+ )
152
+
153
+ # Create optimizer
154
+ discriminator_optim = torch.optim.Adam(
155
+ params=discriminator.parameters(), lr=cfg.gail.lr
156
+ )
157
+
158
+ # Create test environment
159
+ logger_video = cfg.logger.video
160
+ test_env = make_env(cfg.env.env_name, device, from_pixels=logger_video)
161
+ if logger_video:
162
+ test_env = test_env.append_transform(
163
+ VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
164
+ )
165
+ test_env.eval()
166
+ num_network_updates = torch.zeros((), dtype=torch.int64, device=device)
167
+
168
+ def update(data, expert_data, num_network_updates=num_network_updates):
169
+ # Add collector data to expert data
170
+ expert_data.set(
171
+ discriminator_loss.tensor_keys.collector_action,
172
+ data["action"][: expert_data.batch_size[0]],
173
+ )
174
+ expert_data.set(
175
+ discriminator_loss.tensor_keys.collector_observation,
176
+ data["observation"][: expert_data.batch_size[0]],
177
+ )
178
+ d_loss = discriminator_loss(expert_data)
179
+
180
+ # Backward pass
181
+ d_loss.get("loss").backward()
182
+ discriminator_optim.step()
183
+ discriminator_optim.zero_grad(set_to_none=True)
184
+
185
+ # Compute discriminator reward
186
+ with torch.no_grad():
187
+ data = discriminator(data)
188
+ d_rewards = -torch.log(1 - data["d_logits"] + 1e-8)
189
+
190
+ # Set discriminator rewards to tensordict
191
+ data.set(("next", "reward"), d_rewards)
192
+
193
+ # Update PPO
194
+ for _ in range(cfg_loss_ppo_epochs):
195
+ # Compute GAE
196
+ with torch.no_grad():
197
+ data = adv_module(data)
198
+ data_reshape = data.reshape(-1)
199
+
200
+ # Update the data buffer
201
+ data_buffer.empty()
202
+ data_buffer.extend(data_reshape)
203
+
204
+ for batch in data_buffer:
205
+ optim.zero_grad(set_to_none=True)
206
+
207
+ # Linearly decrease the learning rate and clip epsilon
208
+ alpha = torch.ones((), device=device)
209
+ if cfg_optim_anneal_lr:
210
+ alpha = 1 - (num_network_updates / total_network_updates)
211
+ for group in optim.param_groups:
212
+ group["lr"] = cfg_optim_lr * alpha
213
+ if cfg_loss_anneal_clip_eps:
214
+ loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
215
+ num_network_updates += 1
216
+
217
+ # Forward pass PPO loss
218
+ loss = loss_module(batch)
219
+ critic_loss = loss["loss_critic"]
220
+ actor_loss = loss["loss_objective"] + loss["loss_entropy"]
221
+
222
+ # Backward pass
223
+ (actor_loss + critic_loss).backward()
224
+
225
+ # Update the networks
226
+ optim.step()
227
+ return {"dloss": d_loss, "alpha": alpha}
228
+
229
+ if cfg.compile.compile:
230
+ update = compile_with_warmup(update, warmup=2, mode=compile_mode)
231
+ if cfg.compile.cudagraphs:
232
+ warnings.warn(
233
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
234
+ category=UserWarning,
235
+ )
236
+ update = CudaGraphModule(update, warmup=50)
237
+
238
+ # Training loop
239
+ collected_frames = 0
240
+ pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)
241
+
242
+ # extract cfg variables
243
+ cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
244
+ cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
245
+ cfg_optim_lr = cfg.ppo.optim.lr
246
+ cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
247
+ cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
248
+ cfg_logger_test_interval = cfg.logger.test_interval
249
+ cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
250
+
251
+ total_iter = len(collector)
252
+ collector_iter = iter(collector)
253
+ for i in range(total_iter):
254
+
255
+ timeit.printevery(1000, total_iter, erase=True)
256
+
257
+ with timeit("collection"):
258
+ data = next(collector_iter)
259
+
260
+ metrics_to_log = {}
261
+ frames_in_batch = data.numel()
262
+ collected_frames += frames_in_batch
263
+ pbar.update(data.numel())
264
+
265
+ with timeit("rb - sample expert"):
266
+ # Get expert data
267
+ expert_data = replay_buffer.sample()
268
+ expert_data = expert_data.to(device)
269
+
270
+ with timeit("update"):
271
+ torch.compiler.cudagraph_mark_step_begin()
272
+ metadata = update(data, expert_data)
273
+ d_loss = metadata["dloss"]
274
+ alpha = metadata["alpha"]
275
+
276
+ # Get training rewards and episode lengths
277
+ episode_rewards = data["next", "episode_reward"][data["next", "done"]]
278
+ if len(episode_rewards) > 0:
279
+ episode_length = data["next", "step_count"][data["next", "done"]]
280
+
281
+ metrics_to_log.update(
282
+ {
283
+ "train/reward": episode_rewards.mean().item(),
284
+ "train/episode_length": episode_length.sum().item()
285
+ / len(episode_length),
286
+ }
287
+ )
288
+
289
+ metrics_to_log.update(
290
+ {
291
+ "train/discriminator_loss": d_loss["loss"],
292
+ "train/lr": alpha * cfg_optim_lr,
293
+ "train/clip_epsilon": (
294
+ alpha * cfg_loss_clip_epsilon
295
+ if cfg_loss_anneal_clip_eps
296
+ else cfg_loss_clip_epsilon
297
+ ),
298
+ }
299
+ )
300
+
301
+ # evaluation
302
+ with torch.no_grad(), set_exploration_type(
303
+ ExplorationType.DETERMINISTIC
304
+ ), timeit("eval"):
305
+ if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < (
306
+ i * frames_in_batch
307
+ ) // cfg_logger_test_interval:
308
+ actor.eval()
309
+ test_rewards = eval_model(
310
+ actor, test_env, num_episodes=cfg_logger_num_test_episodes
311
+ )
312
+ metrics_to_log.update(
313
+ {
314
+ "eval/reward": test_rewards.mean(),
315
+ }
316
+ )
317
+ actor.train()
318
+ if logger is not None:
319
+ metrics_to_log.update(timeit.todict(prefix="time"))
320
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
321
+ log_metrics(logger, metrics_to_log, i)
322
+
323
+ pbar.close()
324
+
325
+
326
+ if __name__ == "__main__":
327
+ main()
@@ -0,0 +1,68 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch.nn as nn
8
+ import torch.optim
9
+ from torchrl.data.datasets.d4rl import D4RLExperienceReplay
10
+ from torchrl.data.replay_buffers import SamplerWithoutReplacement
11
+ from torchrl.envs import DoubleToFloat
12
+ from torchrl.modules import SafeModule
13
+
14
+
15
+ # ====================================================================
16
+ # Offline Replay buffer
17
+ # ---------------------------
18
+
19
+
20
+ def make_offline_replay_buffer(rb_cfg):
21
+ data = D4RLExperienceReplay(
22
+ dataset_id=rb_cfg.dataset,
23
+ split_trajs=False,
24
+ batch_size=rb_cfg.batch_size,
25
+ sampler=SamplerWithoutReplacement(drop_last=False),
26
+ prefetch=4,
27
+ direct_download=True,
28
+ )
29
+
30
+ data.append_transform(DoubleToFloat())
31
+
32
+ return data
33
+
34
+
35
+ def make_gail_discriminator(cfg, train_env, device="cpu"):
36
+ """Make GAIL discriminator."""
37
+
38
+ state_dim = train_env.observation_spec["observation"].shape[0]
39
+ action_dim = train_env.action_spec.shape[0]
40
+
41
+ hidden_dim = cfg.gail.hidden_dim
42
+
43
+ # Define Discriminator Network
44
+ class Discriminator(nn.Module):
45
+ def __init__(self, state_dim, action_dim):
46
+ super().__init__()
47
+ self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
48
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
49
+ self.fc3 = nn.Linear(hidden_dim, 1)
50
+
51
+ def forward(self, state, action):
52
+ x = torch.cat([state, action], dim=1)
53
+ x = torch.relu(self.fc1(x))
54
+ x = torch.relu(self.fc2(x))
55
+ return torch.sigmoid(self.fc3(x))
56
+
57
+ d_module = SafeModule(
58
+ module=Discriminator(state_dim, action_dim),
59
+ in_keys=["observation", "action"],
60
+ out_keys=["d_logits"],
61
+ )
62
+ return d_module.to(device)
63
+
64
+
65
+ def log_metrics(logger, metrics, step):
66
+ if logger is not None:
67
+ for metric_name, metric_value in metrics.items():
68
+ logger.log_scalar(metric_name, metric_value, step)
@@ -0,0 +1,157 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch.nn
8
+ import torch.optim
9
+
10
+ from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
11
+ from torchrl.envs import (
12
+ ClipTransform,
13
+ DoubleToFloat,
14
+ ExplorationType,
15
+ RewardSum,
16
+ StepCounter,
17
+ TransformedEnv,
18
+ VecNorm,
19
+ )
20
+ from torchrl.envs.libs.gym import GymEnv
21
+ from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
22
+ from torchrl.record import VideoRecorder
23
+
24
+
25
+ # ====================================================================
26
+ # Environment utils
27
+ # --------------------------------------------------------------------
28
+
29
+
30
+ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False):
31
+ env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False)
32
+ env = TransformedEnv(env)
33
+ env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2))
34
+ env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
35
+ env.append_transform(RewardSum())
36
+ env.append_transform(StepCounter())
37
+ env.append_transform(DoubleToFloat(in_keys=["observation"]))
38
+ return env
39
+
40
+
41
+ # ====================================================================
42
+ # Model utils
43
+ # --------------------------------------------------------------------
44
+
45
+
46
+ def make_ppo_models_state(proof_environment, compile, device):
47
+
48
+ # Define input shape
49
+ input_shape = proof_environment.observation_spec["observation"].shape
50
+
51
+ # Define policy output distribution class
52
+ num_outputs = proof_environment.action_spec_unbatched.shape[-1]
53
+ distribution_class = TanhNormal
54
+ distribution_kwargs = {
55
+ "low": proof_environment.action_spec_unbatched.space.low.to(device),
56
+ "high": proof_environment.action_spec_unbatched.space.high.to(device),
57
+ "tanh_loc": False,
58
+ # "safe_tanh": not compile,
59
+ }
60
+
61
+ # Define policy architecture
62
+ policy_mlp = MLP(
63
+ in_features=input_shape[-1],
64
+ activation_class=torch.nn.Tanh,
65
+ out_features=num_outputs, # predict only loc
66
+ num_cells=[64, 64],
67
+ device=device,
68
+ )
69
+
70
+ # Initialize policy weights
71
+ for layer in policy_mlp.modules():
72
+ if isinstance(layer, torch.nn.Linear):
73
+ torch.nn.init.orthogonal_(layer.weight, 1.0)
74
+ layer.bias.data.zero_()
75
+
76
+ # Add state-independent normal scale
77
+ policy_mlp = torch.nn.Sequential(
78
+ policy_mlp,
79
+ AddStateIndependentNormalScale(
80
+ proof_environment.action_spec_unbatched.shape[-1],
81
+ scale_lb=1e-8,
82
+ device=device,
83
+ ),
84
+ )
85
+
86
+ # Add probabilistic sampling of the actions
87
+ policy_module = ProbabilisticActor(
88
+ TensorDictModule(
89
+ module=policy_mlp,
90
+ in_keys=["observation"],
91
+ out_keys=["loc", "scale"],
92
+ ),
93
+ in_keys=["loc", "scale"],
94
+ spec=proof_environment.full_action_spec_unbatched.to(device),
95
+ distribution_class=distribution_class,
96
+ distribution_kwargs=distribution_kwargs,
97
+ return_log_prob=True,
98
+ default_interaction_type=ExplorationType.RANDOM,
99
+ )
100
+
101
+ # Define value architecture
102
+ value_mlp = MLP(
103
+ in_features=input_shape[-1],
104
+ activation_class=torch.nn.Tanh,
105
+ out_features=1,
106
+ num_cells=[64, 64],
107
+ device=device,
108
+ )
109
+
110
+ # Initialize value weights
111
+ for layer in value_mlp.modules():
112
+ if isinstance(layer, torch.nn.Linear):
113
+ torch.nn.init.orthogonal_(layer.weight, 0.01)
114
+ layer.bias.data.zero_()
115
+
116
+ # Define value module
117
+ value_module = ValueOperator(
118
+ value_mlp,
119
+ in_keys=["observation"],
120
+ )
121
+
122
+ return policy_module, value_module
123
+
124
+
125
+ def make_ppo_models(env_name, compile, device):
126
+ proof_environment = make_env(env_name, device=device)
127
+ actor, critic = make_ppo_models_state(
128
+ proof_environment, compile=compile, device=device
129
+ )
130
+ return actor, critic
131
+
132
+
133
+ # ====================================================================
134
+ # Evaluation utils
135
+ # --------------------------------------------------------------------
136
+
137
+
138
+ def dump_video(module):
139
+ if isinstance(module, VideoRecorder):
140
+ module.dump()
141
+
142
+
143
+ def eval_model(actor, test_env, num_episodes=3):
144
+ test_rewards = []
145
+ for _ in range(num_episodes):
146
+ td_test = test_env.rollout(
147
+ policy=actor,
148
+ auto_reset=True,
149
+ auto_cast_to_device=True,
150
+ break_when_any_done=True,
151
+ max_steps=10_000_000,
152
+ )
153
+ reward = td_test["next", "episode_reward"][td_test["next", "done"]]
154
+ test_rewards.append(reward.cpu())
155
+ test_env.apply(dump_video)
156
+ del td_test
157
+ return torch.cat(test_rewards, 0).mean()