torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1199 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import argparse
6
+
7
+ import pytest
8
+ import torch
9
+ from packaging import version
10
+
11
+ from tensordict import TensorDict
12
+ from tensordict.nn import (
13
+ composite_lp_aggregate,
14
+ InteractionType,
15
+ NormalParamExtractor,
16
+ ProbabilisticTensorDictModule as ProbMod,
17
+ ProbabilisticTensorDictSequential as ProbSeq,
18
+ TensorDictModule as Mod,
19
+ TensorDictSequential as Seq,
20
+ )
21
+ from torch.nn import functional as F
22
+ from torchrl.data.tensor_specs import Bounded, Unbounded
23
+ from torchrl.modules import MLP, QValueActor, TanhNormal
24
+ from torchrl.objectives import (
25
+ A2CLoss,
26
+ ClipPPOLoss,
27
+ CQLLoss,
28
+ DDPGLoss,
29
+ DQNLoss,
30
+ IQLLoss,
31
+ REDQLoss,
32
+ ReinforceLoss,
33
+ SACLoss,
34
+ TD3Loss,
35
+ )
36
+ from torchrl.objectives.deprecated import REDQLoss_deprecated
37
+ from torchrl.objectives.value import GAE
38
+ from torchrl.objectives.value.functional import (
39
+ generalized_advantage_estimate,
40
+ td0_return_estimate,
41
+ td1_return_estimate,
42
+ td_lambda_return_estimate,
43
+ vec_generalized_advantage_estimate,
44
+ vec_td1_return_estimate,
45
+ vec_td_lambda_return_estimate,
46
+ )
47
+
48
+ TORCH_VERSION = torch.__version__
49
+ FULLGRAPH = version.parse(".".join(TORCH_VERSION.split(".")[:3])) >= version.parse(
50
+ "2.5.0"
51
+ ) # Anything from 2.5, incl. nightlies, allows for fullgraph
52
+
53
+
54
+ # @pytest.fixture(scope="module", autouse=True)
55
+ # def set_default_device():
56
+ # cur_device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
57
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
+ # torch.set_default_device(device)
59
+ # yield
60
+ # torch.set_default_device(cur_device)
61
+
62
+
63
+ class setup_value_fn:
64
+ def __init__(self, has_lmbda, has_state_value):
65
+ self.has_lmbda = has_lmbda
66
+ self.has_state_value = has_state_value
67
+
68
+ def __call__(
69
+ self,
70
+ b=300,
71
+ t=500,
72
+ d=1,
73
+ gamma=0.95,
74
+ lmbda=0.95,
75
+ ):
76
+ torch.manual_seed(0)
77
+ device = "cuda:0" if torch.cuda.device_count() else "cpu"
78
+ values = torch.randn(b, t, d, device=device)
79
+ next_values = torch.randn(b, t, d, device=device)
80
+ reward = torch.randn(b, t, d, device=device)
81
+ done = torch.zeros(b, t, d, dtype=torch.bool, device=device).bernoulli_(0.1)
82
+ kwargs = {
83
+ "gamma": gamma,
84
+ "next_state_value": next_values,
85
+ "reward": reward,
86
+ "done": done,
87
+ }
88
+ if self.has_lmbda:
89
+ kwargs["lmbda"] = lmbda
90
+
91
+ if self.has_state_value:
92
+ kwargs["state_value"] = values
93
+
94
+ return ((), kwargs)
95
+
96
+
97
+ @pytest.mark.parametrize(
98
+ "val_fn,has_lmbda,has_state_value",
99
+ [
100
+ [generalized_advantage_estimate, True, True],
101
+ [vec_generalized_advantage_estimate, True, True],
102
+ [td0_return_estimate, False, False],
103
+ [td1_return_estimate, False, False],
104
+ [vec_td1_return_estimate, False, False],
105
+ [td_lambda_return_estimate, True, False],
106
+ [vec_td_lambda_return_estimate, True, False],
107
+ ],
108
+ )
109
+ def test_values(benchmark, val_fn, has_lmbda, has_state_value):
110
+ benchmark.pedantic(
111
+ val_fn,
112
+ setup=setup_value_fn(
113
+ has_lmbda=has_lmbda,
114
+ has_state_value=has_state_value,
115
+ ),
116
+ iterations=1,
117
+ rounds=50,
118
+ )
119
+
120
+
121
+ @pytest.mark.parametrize(
122
+ "gae_fn,gamma_tensor,batches,timesteps",
123
+ [
124
+ [generalized_advantage_estimate, False, 1, 512],
125
+ [vec_generalized_advantage_estimate, True, 1, 512],
126
+ [vec_generalized_advantage_estimate, False, 1, 512],
127
+ [vec_generalized_advantage_estimate, True, 32, 512],
128
+ [vec_generalized_advantage_estimate, False, 32, 512],
129
+ ],
130
+ )
131
+ def test_gae_speed(benchmark, gae_fn, gamma_tensor, batches, timesteps):
132
+ size = (batches, timesteps, 1)
133
+
134
+ torch.manual_seed(0)
135
+ device = "cuda:0" if torch.cuda.device_count() else "cpu"
136
+ values = torch.randn(*size, device=device)
137
+ next_values = torch.randn(*size, device=device)
138
+ reward = torch.randn(*size, device=device)
139
+ done = torch.zeros(*size, dtype=torch.bool, device=device).bernoulli_(0.1)
140
+
141
+ gamma = 0.99
142
+ if gamma_tensor:
143
+ gamma = torch.full(size, gamma, device=device)
144
+ lmbda = 0.95
145
+
146
+ benchmark(
147
+ gae_fn,
148
+ gamma=gamma,
149
+ lmbda=lmbda,
150
+ state_value=values,
151
+ next_state_value=next_values,
152
+ reward=reward,
153
+ done=done,
154
+ )
155
+
156
+
157
+ def _maybe_compile(fn, compile, td, fullgraph=FULLGRAPH, warmup=3):
158
+ if compile:
159
+ if isinstance(compile, str):
160
+ fn = torch.compile(fn, mode=compile, fullgraph=fullgraph)
161
+ else:
162
+ fn = torch.compile(fn, fullgraph=fullgraph)
163
+
164
+ for _ in range(warmup):
165
+ fn(td)
166
+
167
+ return fn
168
+
169
+
170
+ @pytest.mark.parametrize("backward", [None, "backward"])
171
+ @pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
172
+ def test_dqn_speed(
173
+ benchmark, backward, compile, n_obs=8, n_act=4, depth=3, ncells=128, batch=128
174
+ ):
175
+ if compile == "reduce-overhead" and backward is not None:
176
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
177
+ if compile:
178
+ torch._dynamo.reset_code_caches()
179
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
180
+ net = MLP(
181
+ in_features=n_obs,
182
+ out_features=n_act,
183
+ depth=depth,
184
+ num_cells=ncells,
185
+ device=device,
186
+ )
187
+ action_space = "one-hot"
188
+ mod = QValueActor(net, in_keys=["obs"], action_space=action_space)
189
+ loss = DQNLoss(value_network=mod, action_space=action_space)
190
+ td = TensorDict(
191
+ {
192
+ "obs": torch.randn(batch, n_obs),
193
+ "action": F.one_hot(torch.randint(n_act, (batch,))),
194
+ "next": {
195
+ "obs": torch.randn(batch, n_obs),
196
+ "done": torch.zeros(batch, 1, dtype=torch.bool),
197
+ "reward": torch.randn(batch, 1),
198
+ },
199
+ },
200
+ [batch],
201
+ device=device,
202
+ )
203
+ loss(td)
204
+
205
+ loss = _maybe_compile(loss, compile, td)
206
+
207
+ if backward:
208
+
209
+ def loss_and_bw(td):
210
+ losses = loss(td)
211
+ sum(
212
+ [val for key, val in losses.items() if key.startswith("loss")]
213
+ ).backward()
214
+
215
+ benchmark.pedantic(
216
+ loss_and_bw,
217
+ args=(td,),
218
+ setup=loss.zero_grad,
219
+ iterations=1,
220
+ warmup_rounds=5,
221
+ rounds=50,
222
+ )
223
+ else:
224
+ benchmark(loss, td)
225
+
226
+
227
+ @pytest.mark.parametrize("backward", [None, "backward"])
228
+ @pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
229
+ def test_ddpg_speed(
230
+ benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
231
+ ):
232
+ if compile == "reduce-overhead" and backward is not None:
233
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
234
+ if compile:
235
+ torch._dynamo.reset_code_caches()
236
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
237
+ common = MLP(
238
+ num_cells=ncells,
239
+ in_features=n_obs,
240
+ depth=3,
241
+ out_features=n_hidden,
242
+ device=device,
243
+ )
244
+ actor = MLP(
245
+ num_cells=ncells,
246
+ in_features=n_hidden,
247
+ depth=2,
248
+ out_features=n_act,
249
+ device=device,
250
+ )
251
+ value = MLP(
252
+ in_features=n_hidden + n_act,
253
+ num_cells=ncells,
254
+ depth=2,
255
+ out_features=1,
256
+ device=device,
257
+ )
258
+ batch = [batch]
259
+ td = TensorDict(
260
+ {
261
+ "obs": torch.randn(*batch, n_obs),
262
+ "action": torch.randn(*batch, n_act),
263
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
264
+ "next": {
265
+ "obs": torch.randn(*batch, n_obs),
266
+ "reward": torch.randn(*batch, 1),
267
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
268
+ },
269
+ },
270
+ batch,
271
+ device=device,
272
+ )
273
+ common = Mod(common, in_keys=["obs"], out_keys=["hidden"])
274
+ actor_head = Mod(actor, in_keys=["hidden"], out_keys=["action"])
275
+ actor = Seq(common, actor_head)
276
+ value = Mod(value, in_keys=["hidden", "action"], out_keys=["state_action_value"])
277
+ value(actor(td))
278
+
279
+ loss = DDPGLoss(actor, value)
280
+
281
+ loss(td)
282
+
283
+ loss = _maybe_compile(loss, compile, td)
284
+
285
+ if backward:
286
+
287
+ def loss_and_bw(td):
288
+ losses = loss(td)
289
+ sum(
290
+ [val for key, val in losses.items() if key.startswith("loss")]
291
+ ).backward()
292
+
293
+ benchmark.pedantic(
294
+ loss_and_bw,
295
+ args=(td,),
296
+ setup=loss.zero_grad,
297
+ iterations=1,
298
+ warmup_rounds=5,
299
+ rounds=50,
300
+ )
301
+ else:
302
+ benchmark(loss, td)
303
+
304
+
305
+ @pytest.mark.parametrize("backward", [None, "backward"])
306
+ @pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
307
+ def test_sac_speed(
308
+ benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
309
+ ):
310
+ if compile == "reduce-overhead" and backward is not None:
311
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
312
+ if compile:
313
+ torch._dynamo.reset_code_caches()
314
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
315
+ common = MLP(
316
+ num_cells=ncells,
317
+ in_features=n_obs,
318
+ depth=3,
319
+ out_features=n_hidden,
320
+ device=device,
321
+ )
322
+ actor_net = MLP(
323
+ num_cells=ncells,
324
+ in_features=n_hidden,
325
+ depth=2,
326
+ out_features=2 * n_act,
327
+ device=device,
328
+ )
329
+ value = MLP(
330
+ in_features=n_hidden + n_act,
331
+ num_cells=ncells,
332
+ depth=2,
333
+ out_features=1,
334
+ device=device,
335
+ )
336
+ batch = [batch]
337
+ td = TensorDict(
338
+ {
339
+ "obs": torch.randn(*batch, n_obs),
340
+ "action": torch.randn(*batch, n_act),
341
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
342
+ "next": {
343
+ "obs": torch.randn(*batch, n_obs),
344
+ "reward": torch.randn(*batch, 1),
345
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
346
+ },
347
+ },
348
+ batch,
349
+ device=device,
350
+ )
351
+ common = Mod(common, in_keys=["obs"], out_keys=["hidden"])
352
+ actor = ProbSeq(
353
+ common,
354
+ Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
355
+ Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
356
+ ProbMod(
357
+ in_keys=["loc", "scale"],
358
+ out_keys=["action"],
359
+ distribution_class=TanhNormal,
360
+ distribution_kwargs={"safe_tanh": False},
361
+ ),
362
+ )
363
+ value_head = Mod(
364
+ value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
365
+ )
366
+ value = Seq(common, value_head)
367
+ value(actor(td.clone()))
368
+
369
+ loss = SACLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))
370
+
371
+ loss(td)
372
+
373
+ loss = _maybe_compile(loss, compile, td)
374
+
375
+ if backward:
376
+
377
+ def loss_and_bw(td):
378
+ losses = loss(td)
379
+ sum(
380
+ [val for key, val in losses.items() if key.startswith("loss")]
381
+ ).backward()
382
+
383
+ benchmark.pedantic(
384
+ loss_and_bw,
385
+ args=(td,),
386
+ setup=loss.zero_grad,
387
+ iterations=1,
388
+ warmup_rounds=5,
389
+ rounds=50,
390
+ )
391
+ else:
392
+ benchmark(loss, td)
393
+
394
+
395
+ # FIXME: fix this
396
+ @pytest.mark.skipif(torch.cuda.is_available(), reason="Currently fails on GPU")
397
+ @pytest.mark.parametrize("backward", [None, "backward"])
398
+ @pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
399
+ def test_redq_speed(
400
+ benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
401
+ ):
402
+ if compile == "reduce-overhead" and backward is not None:
403
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
404
+ if compile:
405
+ torch._dynamo.reset_code_caches()
406
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
407
+ common = MLP(
408
+ num_cells=ncells,
409
+ in_features=n_obs,
410
+ depth=3,
411
+ out_features=n_hidden,
412
+ device=device,
413
+ )
414
+ actor_net = MLP(
415
+ num_cells=ncells,
416
+ in_features=n_hidden,
417
+ depth=2,
418
+ out_features=2 * n_act,
419
+ device=device,
420
+ )
421
+ value = MLP(
422
+ in_features=n_hidden + n_act,
423
+ num_cells=ncells,
424
+ depth=2,
425
+ out_features=1,
426
+ device=device,
427
+ )
428
+ batch = [batch]
429
+ td = TensorDict(
430
+ {
431
+ "obs": torch.randn(*batch, n_obs),
432
+ "action": torch.randn(*batch, n_act),
433
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
434
+ "next": {
435
+ "obs": torch.randn(*batch, n_obs),
436
+ "reward": torch.randn(*batch, 1),
437
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
438
+ },
439
+ },
440
+ batch,
441
+ device=device,
442
+ )
443
+ common = Mod(common, in_keys=["obs"], out_keys=["hidden"])
444
+ actor = ProbSeq(
445
+ common,
446
+ Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
447
+ Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
448
+ ProbMod(
449
+ in_keys=["loc", "scale"],
450
+ out_keys=["action"],
451
+ distribution_class=TanhNormal,
452
+ return_log_prob=True,
453
+ distribution_kwargs={"safe_tanh": False},
454
+ ),
455
+ )
456
+ value_head = Mod(
457
+ value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
458
+ )
459
+ value = Seq(common, value_head)
460
+ value(actor(td.copy()))
461
+
462
+ loss = REDQLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))
463
+
464
+ loss(td)
465
+ loss = _maybe_compile(loss, compile, td)
466
+
467
+ if backward:
468
+
469
+ def loss_and_bw(td):
470
+ losses = loss(td)
471
+ totalloss = sum(
472
+ [val for key, val in losses.items() if key.startswith("loss")]
473
+ )
474
+ totalloss.backward()
475
+
476
+ loss_and_bw(td)
477
+
478
+ benchmark.pedantic(
479
+ loss_and_bw,
480
+ args=(td,),
481
+ setup=loss.zero_grad,
482
+ iterations=1,
483
+ warmup_rounds=5,
484
+ rounds=50,
485
+ )
486
+ else:
487
+ benchmark(loss, td)
488
+
489
+
490
+ @pytest.mark.parametrize("backward", [None, "backward"])
491
+ @pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
492
+ def test_redq_deprec_speed(
493
+ benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
494
+ ):
495
+ if compile == "reduce-overhead" and backward is not None:
496
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
497
+ if compile:
498
+ torch._dynamo.reset_code_caches()
499
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
500
+ common = MLP(
501
+ num_cells=ncells,
502
+ in_features=n_obs,
503
+ depth=3,
504
+ out_features=n_hidden,
505
+ device=device,
506
+ )
507
+ actor_net = MLP(
508
+ num_cells=ncells,
509
+ in_features=n_hidden,
510
+ depth=2,
511
+ out_features=2 * n_act,
512
+ device=device,
513
+ )
514
+ value = MLP(
515
+ in_features=n_hidden + n_act,
516
+ num_cells=ncells,
517
+ depth=2,
518
+ out_features=1,
519
+ device=device,
520
+ )
521
+ batch = [batch]
522
+ td = TensorDict(
523
+ {
524
+ "obs": torch.randn(*batch, n_obs),
525
+ "action": torch.randn(*batch, n_act),
526
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
527
+ "next": {
528
+ "obs": torch.randn(*batch, n_obs),
529
+ "reward": torch.randn(*batch, 1),
530
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
531
+ },
532
+ },
533
+ batch,
534
+ device=device,
535
+ )
536
+ common = Mod(common, in_keys=["obs"], out_keys=["hidden"])
537
+ actor = ProbSeq(
538
+ common,
539
+ Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
540
+ Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
541
+ ProbMod(
542
+ in_keys=["loc", "scale"],
543
+ out_keys=["action"],
544
+ distribution_class=TanhNormal,
545
+ return_log_prob=True,
546
+ distribution_kwargs={"safe_tanh": False},
547
+ ),
548
+ )
549
+ value_head = Mod(
550
+ value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
551
+ )
552
+ value = Seq(common, value_head)
553
+ value(actor(td.copy()))
554
+
555
+ loss = REDQLoss_deprecated(actor, value, action_spec=Unbounded(shape=(n_act,)))
556
+
557
+ loss(td)
558
+
559
+ loss = _maybe_compile(loss, compile, td)
560
+
561
+ if backward:
562
+
563
+ def loss_and_bw(td):
564
+ losses = loss(td)
565
+ sum(
566
+ [val for key, val in losses.items() if key.startswith("loss")]
567
+ ).backward()
568
+
569
+ benchmark.pedantic(
570
+ loss_and_bw,
571
+ args=(td,),
572
+ setup=loss.zero_grad,
573
+ iterations=1,
574
+ warmup_rounds=5,
575
+ rounds=50,
576
+ )
577
+ else:
578
+ benchmark(loss, td)
579
+
580
+
581
+ @pytest.mark.parametrize("backward", [None, "backward"])
582
+ @pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
583
+ def test_td3_speed(
584
+ benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
585
+ ):
586
+ if compile == "reduce-overhead" and backward is not None:
587
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
588
+ if compile:
589
+ torch._dynamo.reset_code_caches()
590
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
591
+ common = MLP(
592
+ num_cells=ncells,
593
+ in_features=n_obs,
594
+ depth=3,
595
+ out_features=n_hidden,
596
+ device=device,
597
+ )
598
+ actor_net = MLP(
599
+ num_cells=ncells,
600
+ in_features=n_hidden,
601
+ depth=2,
602
+ out_features=2 * n_act,
603
+ device=device,
604
+ )
605
+ value = MLP(
606
+ in_features=n_hidden + n_act,
607
+ num_cells=ncells,
608
+ depth=2,
609
+ out_features=1,
610
+ device=device,
611
+ )
612
+ batch = [batch]
613
+ td = TensorDict(
614
+ {
615
+ "obs": torch.randn(*batch, n_obs),
616
+ "action": torch.randn(*batch, n_act),
617
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
618
+ "next": {
619
+ "obs": torch.randn(*batch, n_obs),
620
+ "reward": torch.randn(*batch, 1),
621
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
622
+ },
623
+ },
624
+ batch,
625
+ device=device,
626
+ )
627
+ common = Mod(common, in_keys=["obs"], out_keys=["hidden"])
628
+ actor = ProbSeq(
629
+ common,
630
+ Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
631
+ Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
632
+ ProbMod(
633
+ in_keys=["loc", "scale"],
634
+ out_keys=["action"],
635
+ distribution_class=TanhNormal,
636
+ distribution_kwargs={"safe_tanh": False},
637
+ return_log_prob=True,
638
+ default_interaction_type=InteractionType.DETERMINISTIC,
639
+ ),
640
+ )
641
+ value_head = Mod(
642
+ value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
643
+ )
644
+ value = Seq(common, value_head)
645
+ value(actor(td.clone()))
646
+
647
+ loss = TD3Loss(
648
+ actor,
649
+ value,
650
+ action_spec=Bounded(shape=(n_act,), low=-1, high=1),
651
+ )
652
+
653
+ loss(td)
654
+
655
+ loss = _maybe_compile(loss, compile, td)
656
+
657
+ if backward:
658
+
659
+ def loss_and_bw(td):
660
+ losses = loss(td)
661
+ sum(
662
+ [val for key, val in losses.items() if key.startswith("loss")]
663
+ ).backward()
664
+
665
+ benchmark.pedantic(
666
+ loss_and_bw,
667
+ args=(td,),
668
+ setup=loss.zero_grad,
669
+ iterations=1,
670
+ warmup_rounds=5,
671
+ rounds=50,
672
+ )
673
+ else:
674
+ benchmark.pedantic(loss, args=(td,), rounds=100, iterations=10)
675
+
676
+
677
+ @pytest.mark.parametrize("backward", [None, "backward"])
678
+ @pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
679
+ def test_cql_speed(
680
+ benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
681
+ ):
682
+ if compile == "reduce-overhead" and backward is not None:
683
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
684
+ if compile:
685
+ torch._dynamo.reset_code_caches()
686
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
687
+ common = MLP(
688
+ num_cells=ncells,
689
+ in_features=n_obs,
690
+ depth=3,
691
+ out_features=n_hidden,
692
+ device=device,
693
+ )
694
+ actor_net = MLP(
695
+ num_cells=ncells,
696
+ in_features=n_hidden,
697
+ depth=2,
698
+ out_features=2 * n_act,
699
+ device=device,
700
+ )
701
+ value = MLP(
702
+ in_features=n_hidden + n_act,
703
+ num_cells=ncells,
704
+ depth=2,
705
+ out_features=1,
706
+ device=device,
707
+ )
708
+ batch = [batch]
709
+ td = TensorDict(
710
+ {
711
+ "obs": torch.randn(*batch, n_obs),
712
+ "action": torch.randn(*batch, n_act),
713
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
714
+ "next": {
715
+ "obs": torch.randn(*batch, n_obs),
716
+ "reward": torch.randn(*batch, 1),
717
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
718
+ },
719
+ },
720
+ batch,
721
+ device=device,
722
+ )
723
+ common = Mod(common, in_keys=["obs"], out_keys=["hidden"])
724
+ actor = ProbSeq(
725
+ common,
726
+ Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
727
+ Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
728
+ ProbMod(
729
+ in_keys=["loc", "scale"],
730
+ out_keys=["action"],
731
+ distribution_class=TanhNormal,
732
+ distribution_kwargs={"safe_tanh": False},
733
+ ),
734
+ )
735
+ value_head = Mod(
736
+ value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
737
+ )
738
+ value = Seq(common, value_head)
739
+ value(actor(td.copy()))
740
+
741
+ loss = CQLLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))
742
+
743
+ loss(td)
744
+
745
+ loss = _maybe_compile(loss, compile, td)
746
+
747
+ if backward:
748
+
749
+ def loss_and_bw(td):
750
+ losses = loss(td)
751
+ sum(
752
+ [val for key, val in losses.items() if key.startswith("loss")]
753
+ ).backward()
754
+
755
+ benchmark.pedantic(
756
+ loss_and_bw,
757
+ args=(td,),
758
+ setup=loss.zero_grad,
759
+ iterations=1,
760
+ warmup_rounds=5,
761
+ rounds=50,
762
+ )
763
+ else:
764
+ benchmark(loss, td)
765
+
766
+
767
+ @pytest.mark.parametrize("backward", [None, "backward"])
768
+ @pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
769
+ def test_a2c_speed(
770
+ benchmark,
771
+ backward,
772
+ compile,
773
+ n_obs=8,
774
+ n_act=4,
775
+ n_hidden=64,
776
+ ncells=128,
777
+ batch=128,
778
+ T=10,
779
+ ):
780
+ if compile == "reduce-overhead" and backward is not None:
781
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
782
+ if compile:
783
+ torch._dynamo.reset_code_caches()
784
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
785
+ common_net = MLP(
786
+ num_cells=ncells,
787
+ in_features=n_obs,
788
+ depth=3,
789
+ out_features=n_hidden,
790
+ device=device,
791
+ )
792
+ actor_net = MLP(
793
+ num_cells=ncells,
794
+ in_features=n_hidden,
795
+ depth=2,
796
+ out_features=2 * n_act,
797
+ device=device,
798
+ )
799
+ value_net = MLP(
800
+ in_features=n_hidden,
801
+ num_cells=ncells,
802
+ depth=2,
803
+ out_features=1,
804
+ device=device,
805
+ )
806
+ batch = [batch, T]
807
+ if composite_lp_aggregate():
808
+ raise RuntimeError(
809
+ "Expected composite_lp_aggregate() to return False. Use set_composite_lp_aggregate or COMPOSITE_LP_AGGREGATE env variable."
810
+ )
811
+ td = TensorDict(
812
+ {
813
+ "obs": torch.randn(*batch, n_obs),
814
+ "action": torch.randn(*batch, n_act),
815
+ "action_log_prob": torch.randn(*batch),
816
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
817
+ "next": {
818
+ "obs": torch.randn(*batch, n_obs),
819
+ "reward": torch.randn(*batch, 1),
820
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
821
+ },
822
+ },
823
+ batch,
824
+ names=[None, "time"],
825
+ device=device,
826
+ )
827
+ common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"])
828
+ actor = ProbSeq(
829
+ common,
830
+ Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
831
+ Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
832
+ ProbMod(
833
+ in_keys=["loc", "scale"],
834
+ out_keys=["action"],
835
+ distribution_class=TanhNormal,
836
+ distribution_kwargs={"safe_tanh": False},
837
+ ),
838
+ )
839
+ critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]))
840
+ actor(td.clone())
841
+ critic(td.clone())
842
+
843
+ loss = A2CLoss(actor_network=actor, critic_network=critic)
844
+ advantage = GAE(
845
+ value_network=critic, gamma=0.99, lmbda=0.95, shifted=True, device=device
846
+ )
847
+ advantage(td)
848
+ loss(td)
849
+
850
+ loss = _maybe_compile(loss, compile, td)
851
+
852
+ if backward:
853
+
854
+ def loss_and_bw(td):
855
+ losses = loss(td)
856
+ sum(
857
+ [val for key, val in losses.items() if key.startswith("loss")]
858
+ ).backward()
859
+
860
+ benchmark.pedantic(
861
+ loss_and_bw,
862
+ args=(td,),
863
+ setup=loss.zero_grad,
864
+ iterations=1,
865
+ warmup_rounds=5,
866
+ rounds=50,
867
+ )
868
+ else:
869
+ benchmark(loss, td)
870
+
871
+
872
+ @pytest.mark.parametrize("backward", [None, "backward"])
873
+ @pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
874
+ def test_ppo_speed(
875
+ benchmark,
876
+ backward,
877
+ compile,
878
+ n_obs=8,
879
+ n_act=4,
880
+ n_hidden=64,
881
+ ncells=128,
882
+ batch=128,
883
+ T=10,
884
+ ):
885
+ if compile == "reduce-overhead" and backward is not None:
886
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
887
+ if compile:
888
+ torch._dynamo.reset_code_caches()
889
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
890
+ common_net = MLP(
891
+ num_cells=ncells,
892
+ in_features=n_obs,
893
+ depth=3,
894
+ out_features=n_hidden,
895
+ device=device,
896
+ )
897
+ actor_net = MLP(
898
+ num_cells=ncells,
899
+ in_features=n_hidden,
900
+ depth=2,
901
+ out_features=2 * n_act,
902
+ device=device,
903
+ )
904
+ value_net = MLP(
905
+ in_features=n_hidden,
906
+ num_cells=ncells,
907
+ depth=2,
908
+ out_features=1,
909
+ device=device,
910
+ )
911
+ batch = [batch, T]
912
+ if composite_lp_aggregate():
913
+ raise RuntimeError(
914
+ "Expected composite_lp_aggregate() to return False. Use set_composite_lp_aggregate or COMPOSITE_LP_AGGREGATE env variable."
915
+ )
916
+ td = TensorDict(
917
+ {
918
+ "obs": torch.randn(*batch, n_obs),
919
+ "action": torch.randn(*batch, n_act),
920
+ "action_log_prob": torch.randn(*batch),
921
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
922
+ "next": {
923
+ "obs": torch.randn(*batch, n_obs),
924
+ "reward": torch.randn(*batch, 1),
925
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
926
+ },
927
+ },
928
+ batch,
929
+ names=[None, "time"],
930
+ device=device,
931
+ )
932
+ common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"])
933
+ actor = ProbSeq(
934
+ common,
935
+ Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
936
+ Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
937
+ ProbMod(
938
+ in_keys=["loc", "scale"],
939
+ out_keys=["action"],
940
+ distribution_class=TanhNormal,
941
+ distribution_kwargs={"safe_tanh": False},
942
+ ),
943
+ )
944
+ critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]))
945
+ actor(td.clone())
946
+ critic(td.clone())
947
+
948
+ loss = ClipPPOLoss(actor_network=actor, critic_network=critic)
949
+ advantage = GAE(
950
+ value_network=critic, gamma=0.99, lmbda=0.95, shifted=True, device=device
951
+ )
952
+ advantage(td)
953
+ loss(td)
954
+
955
+ loss = _maybe_compile(loss, compile, td)
956
+
957
+ if backward:
958
+
959
+ def loss_and_bw(td):
960
+ losses = loss(td)
961
+ sum(
962
+ [val for key, val in losses.items() if key.startswith("loss")]
963
+ ).backward()
964
+
965
+ benchmark.pedantic(
966
+ loss_and_bw,
967
+ args=(td,),
968
+ setup=loss.zero_grad,
969
+ iterations=1,
970
+ warmup_rounds=5,
971
+ rounds=50,
972
+ )
973
+ else:
974
+ benchmark(loss, td)
975
+
976
+
977
+ @pytest.mark.parametrize("backward", [None, "backward"])
978
+ @pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
979
+ def test_reinforce_speed(
980
+ benchmark,
981
+ backward,
982
+ compile,
983
+ n_obs=8,
984
+ n_act=4,
985
+ n_hidden=64,
986
+ ncells=128,
987
+ batch=128,
988
+ T=10,
989
+ ):
990
+ if compile == "reduce-overhead" and backward is not None:
991
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
992
+ if compile:
993
+ torch._dynamo.reset_code_caches()
994
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
995
+ common_net = MLP(
996
+ num_cells=ncells,
997
+ in_features=n_obs,
998
+ depth=3,
999
+ out_features=n_hidden,
1000
+ device=device,
1001
+ )
1002
+ actor_net = MLP(
1003
+ num_cells=ncells,
1004
+ in_features=n_hidden,
1005
+ depth=2,
1006
+ out_features=2 * n_act,
1007
+ device=device,
1008
+ )
1009
+ value_net = MLP(
1010
+ in_features=n_hidden,
1011
+ num_cells=ncells,
1012
+ depth=2,
1013
+ out_features=1,
1014
+ device=device,
1015
+ )
1016
+ batch = [batch, T]
1017
+ if composite_lp_aggregate():
1018
+ raise RuntimeError(
1019
+ "Expected composite_lp_aggregate() to return False. Use set_composite_lp_aggregate or COMPOSITE_LP_AGGREGATE env variable."
1020
+ )
1021
+ td = TensorDict(
1022
+ {
1023
+ "obs": torch.randn(*batch, n_obs),
1024
+ "action": torch.randn(*batch, n_act),
1025
+ "action_log_prob": torch.randn(*batch),
1026
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
1027
+ "next": {
1028
+ "obs": torch.randn(*batch, n_obs),
1029
+ "reward": torch.randn(*batch, 1),
1030
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
1031
+ },
1032
+ },
1033
+ batch,
1034
+ names=[None, "time"],
1035
+ device=device,
1036
+ )
1037
+ common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"])
1038
+ actor = ProbSeq(
1039
+ common,
1040
+ Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
1041
+ Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
1042
+ ProbMod(
1043
+ in_keys=["loc", "scale"],
1044
+ out_keys=["action"],
1045
+ distribution_class=TanhNormal,
1046
+ distribution_kwargs={"safe_tanh": False},
1047
+ ),
1048
+ )
1049
+ critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]))
1050
+ actor(td.clone())
1051
+ critic(td.clone())
1052
+
1053
+ loss = ReinforceLoss(actor_network=actor, critic_network=critic)
1054
+ advantage = GAE(
1055
+ value_network=critic, gamma=0.99, lmbda=0.95, shifted=True, device=device
1056
+ )
1057
+ advantage(td)
1058
+ loss(td)
1059
+
1060
+ loss = _maybe_compile(loss, compile, td)
1061
+
1062
+ if backward:
1063
+
1064
+ def loss_and_bw(td):
1065
+ losses = loss(td)
1066
+ sum(
1067
+ [val for key, val in losses.items() if key.startswith("loss")]
1068
+ ).backward()
1069
+
1070
+ benchmark.pedantic(
1071
+ loss_and_bw,
1072
+ args=(td,),
1073
+ setup=loss.zero_grad,
1074
+ iterations=1,
1075
+ warmup_rounds=5,
1076
+ rounds=50,
1077
+ )
1078
+ else:
1079
+ benchmark(loss, td)
1080
+
1081
+
1082
+ @pytest.mark.parametrize("backward", [None, "backward"])
1083
+ @pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
1084
+ def test_iql_speed(
1085
+ benchmark,
1086
+ backward,
1087
+ compile,
1088
+ n_obs=8,
1089
+ n_act=4,
1090
+ n_hidden=64,
1091
+ ncells=128,
1092
+ batch=128,
1093
+ T=10,
1094
+ ):
1095
+ if compile == "reduce-overhead" and backward is not None:
1096
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
1097
+ if compile:
1098
+ torch._dynamo.reset_code_caches()
1099
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1100
+ common_net = MLP(
1101
+ num_cells=ncells,
1102
+ in_features=n_obs,
1103
+ depth=3,
1104
+ out_features=n_hidden,
1105
+ device=device,
1106
+ )
1107
+ actor_net = MLP(
1108
+ num_cells=ncells,
1109
+ in_features=n_hidden,
1110
+ depth=2,
1111
+ out_features=2 * n_act,
1112
+ device=device,
1113
+ )
1114
+ value_net = MLP(
1115
+ in_features=n_hidden,
1116
+ num_cells=ncells,
1117
+ depth=2,
1118
+ out_features=1,
1119
+ device=device,
1120
+ )
1121
+ qvalue_net = MLP(
1122
+ in_features=n_hidden + n_act,
1123
+ num_cells=ncells,
1124
+ depth=2,
1125
+ out_features=1,
1126
+ device=device,
1127
+ )
1128
+ batch = [batch, T]
1129
+ if composite_lp_aggregate():
1130
+ raise RuntimeError(
1131
+ "Expected composite_lp_aggregate() to return False. Use set_composite_lp_aggregate or COMPOSITE_LP_AGGREGATE env variable."
1132
+ )
1133
+ td = TensorDict(
1134
+ {
1135
+ "obs": torch.randn(*batch, n_obs),
1136
+ "action": torch.randn(*batch, n_act),
1137
+ "action_log_prob": torch.randn(*batch),
1138
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
1139
+ "next": {
1140
+ "obs": torch.randn(*batch, n_obs),
1141
+ "reward": torch.randn(*batch, 1),
1142
+ "done": torch.zeros(*batch, 1, dtype=torch.bool),
1143
+ },
1144
+ },
1145
+ batch,
1146
+ names=[None, "time"],
1147
+ device=device,
1148
+ )
1149
+ common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"])
1150
+ actor = ProbSeq(
1151
+ common,
1152
+ Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
1153
+ Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
1154
+ ProbMod(
1155
+ in_keys=["loc", "scale"],
1156
+ out_keys=["action"],
1157
+ distribution_class=TanhNormal,
1158
+ distribution_kwargs={"safe_tanh": False},
1159
+ ),
1160
+ )
1161
+ value = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]))
1162
+ qvalue = Seq(
1163
+ common,
1164
+ Mod(qvalue_net, in_keys=["hidden", "action"], out_keys=["state_action_value"]),
1165
+ )
1166
+ qvalue(actor(td.clone()))
1167
+ value(td.clone())
1168
+
1169
+ loss = IQLLoss(actor_network=actor, value_network=value, qvalue_network=qvalue)
1170
+ loss(td)
1171
+
1172
+ loss = _maybe_compile(loss, compile, td)
1173
+
1174
+ if backward:
1175
+
1176
+ def loss_and_bw(td):
1177
+ losses = loss(td)
1178
+ sum(
1179
+ [val for key, val in losses.items() if key.startswith("loss")]
1180
+ ).backward()
1181
+
1182
+ benchmark.pedantic(
1183
+ loss_and_bw,
1184
+ args=(td,),
1185
+ setup=loss.zero_grad,
1186
+ iterations=1,
1187
+ warmup_rounds=5,
1188
+ rounds=50,
1189
+ )
1190
+ else:
1191
+ benchmark(loss, td)
1192
+
1193
+
1194
+ if __name__ == "__main__":
1195
+ args, unknown = argparse.ArgumentParser().parse_known_args()
1196
+ pytest.main(
1197
+ [__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"]
1198
+ + unknown
1199
+ )