torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,996 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import warnings
8
+ from dataclasses import dataclass
9
+
10
+ import torch
11
+ from tensordict import TensorDict, TensorDictBase, TensorDictParams
12
+ from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
13
+ from tensordict.utils import NestedKey
14
+ from torch import Tensor
15
+
16
+ from torchrl.data.tensor_specs import TensorSpec
17
+ from torchrl.data.utils import _find_action_space
18
+ from torchrl.objectives.common import LossModule
19
+ from torchrl.objectives.utils import (
20
+ _GAMMA_LMBDA_DEPREC_ERROR,
21
+ _pseudo_vmap,
22
+ _reduce,
23
+ _vmap_func,
24
+ default_value_kwargs,
25
+ distance_loss,
26
+ ValueEstimators,
27
+ )
28
+ from torchrl.objectives.value import (
29
+ TD0Estimator,
30
+ TD1Estimator,
31
+ TDLambdaEstimator,
32
+ ValueEstimatorBase,
33
+ )
34
+
35
+
36
+ class IQLLoss(LossModule):
37
+ r"""TorchRL implementation of the IQL loss.
38
+
39
+ Presented in "Offline Reinforcement Learning with Implicit Q-Learning" https://arxiv.org/abs/2110.06169
40
+
41
+ Args:
42
+ actor_network (ProbabilisticTensorDictSequential): stochastic actor
43
+ qvalue_network (TensorDictModule): Q(s, a) parametric model
44
+ If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets``
45
+ times. If a list of modules is passed, their
46
+ parameters will be stacked unless they share the same identity (in which case
47
+ the original parameter will be expanded).
48
+
49
+ .. warning:: When a list of parameters if passed, it will **not** be compared against the policy parameters
50
+ and all the parameters will be considered as untied.
51
+
52
+ value_network (TensorDictModule, optional): V(s) parametric model.
53
+
54
+ Keyword Args:
55
+ num_qvalue_nets (integer, optional): number of Q-Value networks used.
56
+ Defaults to ``2``.
57
+ loss_function (str, optional): loss function to be used with
58
+ the value function loss. Default is `"smooth_l1"`.
59
+ temperature (:obj:`float`, optional): Inverse temperature (beta).
60
+ For smaller hyperparameter values, the objective behaves similarly to
61
+ behavioral cloning, while for larger values, it attempts to recover the
62
+ maximum of the Q-function.
63
+ expectile (:obj:`float`, optional): expectile :math:`\tau`. A larger value of :math:`\tau` is crucial
64
+ for antmaze tasks that require dynamical programming ("stichting").
65
+ priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
66
+ tensordict key where to write the priority (for prioritized replay
67
+ buffer usage). Default is `"td_error"`.
68
+ separate_losses (bool, optional): if ``True``, shared parameters between
69
+ policy and critic will only be trained on the policy loss.
70
+ Defaults to ``False``, i.e., gradients are propagated to shared
71
+ parameters for both policy and critic losses.
72
+ reduction (str, optional): Specifies the reduction to apply to the output:
73
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
74
+ ``"mean"``: the sum of the output will be divided by the number of
75
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
76
+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
77
+ Defaults to ``False``.
78
+
79
+ Examples:
80
+ >>> import torch
81
+ >>> from torch import nn
82
+ >>> from torchrl.data import Bounded
83
+ >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
84
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
85
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
86
+ >>> from torchrl.objectives.iql import IQLLoss
87
+ >>> from tensordict import TensorDict
88
+ >>> n_act, n_obs = 4, 3
89
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
90
+ >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
91
+ >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
92
+ >>> actor = ProbabilisticActor(
93
+ ... module=module,
94
+ ... in_keys=["loc", "scale"],
95
+ ... spec=spec,
96
+ ... distribution_class=TanhNormal)
97
+ >>> class QValueClass(nn.Module):
98
+ ... def __init__(self):
99
+ ... super().__init__()
100
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
101
+ ... def forward(self, obs, act):
102
+ ... return self.linear(torch.cat([obs, act], -1))
103
+ >>> qvalue = SafeModule(
104
+ ... QValueClass(),
105
+ ... in_keys=["observation", "action"],
106
+ ... out_keys=["state_action_value"],
107
+ ... )
108
+ >>> value = SafeModule(
109
+ ... nn.Linear(n_obs, 1),
110
+ ... in_keys=["observation"],
111
+ ... out_keys=["state_value"],
112
+ ... )
113
+ >>> loss = IQLLoss(actor, qvalue, value)
114
+ >>> batch = [2, ]
115
+ >>> action = spec.rand(batch)
116
+ >>> data = TensorDict({
117
+ ... "observation": torch.randn(*batch, n_obs),
118
+ ... "action": action,
119
+ ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
120
+ ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
121
+ ... ("next", "reward"): torch.randn(*batch, 1),
122
+ ... ("next", "observation"): torch.randn(*batch, n_obs),
123
+ ... }, batch)
124
+ >>> loss(data)
125
+ TensorDict(
126
+ fields={
127
+ entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
128
+ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
129
+ loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
130
+ loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
131
+ batch_size=torch.Size([]),
132
+ device=None,
133
+ is_shared=False)
134
+
135
+ This class is compatible with non-tensordict based modules too and can be
136
+ used without recurring to any tensordict-related primitive. In this case,
137
+ the expected keyword arguments are:
138
+ ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network
139
+ The return value is a tuple of tensors in the following order:
140
+ ``["loss_actor", "loss_qvalue", "loss_value", "entropy"]``.
141
+
142
+ Examples:
143
+ >>> import torch
144
+ >>> from torch import nn
145
+ >>> from torchrl.data import Bounded
146
+ >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
147
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
148
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
149
+ >>> from torchrl.objectives.iql import IQLLoss
150
+ >>> _ = torch.manual_seed(42)
151
+ >>> n_act, n_obs = 4, 3
152
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
153
+ >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
154
+ >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
155
+ >>> actor = ProbabilisticActor(
156
+ ... module=module,
157
+ ... in_keys=["loc", "scale"],
158
+ ... spec=spec,
159
+ ... distribution_class=TanhNormal)
160
+ >>> class QValueClass(nn.Module):
161
+ ... def __init__(self):
162
+ ... super().__init__()
163
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
164
+ ... def forward(self, obs, act):
165
+ ... return self.linear(torch.cat([obs, act], -1))
166
+ >>> qvalue = SafeModule(
167
+ ... QValueClass(),
168
+ ... in_keys=["observation", "action"],
169
+ ... out_keys=["state_action_value"],
170
+ ... )
171
+ >>> value = SafeModule(
172
+ ... nn.Linear(n_obs, 1),
173
+ ... in_keys=["observation"],
174
+ ... out_keys=["state_value"],
175
+ ... )
176
+ >>> loss = IQLLoss(actor, qvalue, value)
177
+ >>> batch = [2, ]
178
+ >>> action = spec.rand(batch)
179
+ >>> loss_actor, loss_qvalue, loss_value, entropy = loss(
180
+ ... observation=torch.randn(*batch, n_obs),
181
+ ... action=action,
182
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
183
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
184
+ ... next_observation=torch.zeros(*batch, n_obs),
185
+ ... next_reward=torch.randn(*batch, 1))
186
+ >>> loss_actor.backward()
187
+
188
+
189
+ The output keys can also be filtered using the :meth:`IQLLoss.select_out_keys`
190
+ method.
191
+
192
+ Examples:
193
+ >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
194
+ >>> loss_actor, loss_qvalue = loss(
195
+ ... observation=torch.randn(*batch, n_obs),
196
+ ... action=action,
197
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
198
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
199
+ ... next_observation=torch.zeros(*batch, n_obs),
200
+ ... next_reward=torch.randn(*batch, 1))
201
+ >>> loss_actor.backward()
202
+ """
203
+
204
+ @dataclass
205
+ class _AcceptedKeys:
206
+ """Maintains default values for all configurable tensordict keys.
207
+
208
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
209
+ default values
210
+
211
+ Attributes:
212
+ value (NestedKey): The input tensordict key where the state value is expected.
213
+ Will be used for the underlying value estimator. Defaults to ``"state_value"``.
214
+ action (NestedKey): The input tensordict key where the action is expected.
215
+ Defaults to ``"action"``.
216
+ log_prob (NestedKey): The input tensordict key where the log probability is expected.
217
+ Defaults to ``"_log_prob"``.
218
+ priority (NestedKey): The input tensordict key where the target priority is written to.
219
+ Defaults to ``"td_error"``.
220
+ state_action_value (NestedKey): The input tensordict key where the
221
+ state action value is expected. Will be used for the underlying
222
+ value estimator as value key. Defaults to ``"state_action_value"``.
223
+ reward (NestedKey): The input tensordict key where the reward is expected.
224
+ Will be used for the underlying value estimator. Defaults to ``"reward"``.
225
+ done (NestedKey): The key in the input TensorDict that indicates
226
+ whether a trajectory is done. Will be used for the underlying value estimator.
227
+ Defaults to ``"done"``.
228
+ terminated (NestedKey): The key in the input TensorDict that indicates
229
+ whether a trajectory is terminated. Will be used for the underlying value estimator.
230
+ Defaults to ``"terminated"``.
231
+ """
232
+
233
+ value: NestedKey = "state_value"
234
+ action: NestedKey = "action"
235
+ log_prob: NestedKey = "_log_prob"
236
+ priority: NestedKey = "td_error"
237
+ state_action_value: NestedKey = "state_action_value"
238
+ reward: NestedKey = "reward"
239
+ done: NestedKey = "done"
240
+ terminated: NestedKey = "terminated"
241
+
242
+ tensor_keys: _AcceptedKeys
243
+ default_keys = _AcceptedKeys
244
+ default_value_estimator = ValueEstimators.TD0
245
+ out_keys = [
246
+ "loss_actor",
247
+ "loss_qvalue",
248
+ "loss_value",
249
+ "entropy",
250
+ ]
251
+
252
+ actor_network: TensorDictModule
253
+ actor_network_params: TensorDictParams
254
+ target_actor_network_params: TensorDictParams
255
+ qvalue_network: TensorDictModule
256
+ qvalue_network_params: TensorDictParams
257
+ target_qvalue_network_params: TensorDictParams
258
+ value_network: TensorDictModule | None
259
+ value_network_params: TensorDictParams | None
260
+ target_value_network_params: TensorDictParams | None
261
+
262
+ def __init__(
263
+ self,
264
+ actor_network: ProbabilisticTensorDictSequential,
265
+ qvalue_network: TensorDictModule | list[TensorDictModule],
266
+ value_network: TensorDictModule | None,
267
+ *,
268
+ num_qvalue_nets: int = 2,
269
+ loss_function: str = "smooth_l1",
270
+ temperature: float = 1.0,
271
+ expectile: float = 0.5,
272
+ gamma: float | None = None,
273
+ priority_key: str | None = None,
274
+ separate_losses: bool = False,
275
+ reduction: str | None = None,
276
+ deactivate_vmap: bool = False,
277
+ ) -> None:
278
+ self._in_keys = None
279
+ self._out_keys = None
280
+ if reduction is None:
281
+ reduction = "mean"
282
+ super().__init__()
283
+ self._set_deprecated_ctor_keys(priority=priority_key)
284
+
285
+ self.deactivate_vmap = deactivate_vmap
286
+
287
+ # IQL parameter
288
+ self.temperature = temperature
289
+ self.expectile = expectile
290
+
291
+ # Actor Network
292
+ self.convert_to_functional(
293
+ actor_network,
294
+ "actor_network",
295
+ create_target_params=False,
296
+ )
297
+ if separate_losses:
298
+ # we want to make sure there are no duplicates in the params: the
299
+ # params of critic must be refs to actor if they're shared
300
+ policy_params = list(actor_network.parameters())
301
+ else:
302
+ policy_params = None
303
+ # Value Function Network
304
+ self.convert_to_functional(
305
+ value_network,
306
+ "value_network",
307
+ create_target_params=False,
308
+ compare_against=policy_params,
309
+ )
310
+
311
+ # Q Function Network
312
+ self.delay_qvalue = True
313
+ self.num_qvalue_nets = num_qvalue_nets
314
+ if separate_losses and policy_params is not None:
315
+ qvalue_policy_params = list(actor_network.parameters()) + list(
316
+ value_network.parameters()
317
+ )
318
+ else:
319
+ qvalue_policy_params = None
320
+ self.convert_to_functional(
321
+ qvalue_network,
322
+ "qvalue_network",
323
+ num_qvalue_nets,
324
+ create_target_params=True,
325
+ compare_against=qvalue_policy_params,
326
+ )
327
+
328
+ self.loss_function = loss_function
329
+ if gamma is not None:
330
+ raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
331
+ self._make_vmap()
332
+ self.reduction = reduction
333
+
334
+ def _make_vmap(self):
335
+ self._vmap_qvalue_networkN0 = _vmap_func(
336
+ self.qvalue_network,
337
+ (None, 0),
338
+ randomness=self.vmap_randomness,
339
+ pseudo_vmap=self.deactivate_vmap,
340
+ )
341
+
342
+ @property
343
+ def device(self) -> torch.device:
344
+ raise RuntimeError(
345
+ "The device attributes of the losses is deprecated since v0.3.",
346
+ )
347
+
348
+ def _set_in_keys(self):
349
+ keys = [
350
+ self.tensor_keys.action,
351
+ ("next", self.tensor_keys.reward),
352
+ ("next", self.tensor_keys.done),
353
+ ("next", self.tensor_keys.terminated),
354
+ *self.actor_network.in_keys,
355
+ *[("next", key) for key in self.actor_network.in_keys],
356
+ *self.qvalue_network.in_keys,
357
+ *self.value_network.in_keys,
358
+ ]
359
+ self._in_keys = list(set(keys))
360
+
361
+ @property
362
+ def in_keys(self):
363
+ if self._in_keys is None:
364
+ self._set_in_keys()
365
+ return self._in_keys
366
+
367
+ @in_keys.setter
368
+ def in_keys(self, values):
369
+ self._in_keys = values
370
+
371
+ @staticmethod
372
+ def loss_value_diff(diff, expectile=0.8):
373
+ """Loss function for iql expectile value difference."""
374
+ weight = torch.where(diff > 0, expectile, (1 - expectile))
375
+ return weight * (diff**2)
376
+
377
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
378
+ if self._value_estimator is not None:
379
+ self._value_estimator.set_keys(
380
+ value=self._tensor_keys.value,
381
+ reward=self.tensor_keys.reward,
382
+ done=self.tensor_keys.done,
383
+ terminated=self.tensor_keys.terminated,
384
+ )
385
+ self._set_in_keys()
386
+
387
+ @dispatch
388
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
389
+ loss_actor, metadata = self.actor_loss(tensordict)
390
+ loss_qvalue, metadata_qvalue = self.qvalue_loss(tensordict)
391
+ loss_value, metadata_value = self.value_loss(tensordict)
392
+ metadata.update(metadata_qvalue)
393
+ metadata.update(metadata_value)
394
+
395
+ if (loss_actor.shape != loss_qvalue.shape) or (
396
+ loss_value is not None and loss_actor.shape != loss_value.shape
397
+ ):
398
+ raise RuntimeError(
399
+ f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}"
400
+ )
401
+ tensordict.set(
402
+ self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
403
+ )
404
+ entropy = -tensordict.get(self.tensor_keys.log_prob).detach()
405
+ out = {
406
+ "loss_actor": loss_actor,
407
+ "loss_qvalue": loss_qvalue,
408
+ "loss_value": loss_value,
409
+ "entropy": entropy.mean(),
410
+ }
411
+ td_out = TensorDict(out)
412
+
413
+ self._clear_weakrefs(
414
+ tensordict,
415
+ td_out,
416
+ "actor_network_params",
417
+ "qvalue_network_params",
418
+ "value_network_params",
419
+ "target_actor_network_params",
420
+ "target_qvalue_network_params",
421
+ "target_value_network_params",
422
+ )
423
+ return td_out
424
+
425
+ def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
426
+ # KL loss
427
+ with self.actor_network_params.to_module(self.actor_network):
428
+ dist = self.actor_network.get_dist(tensordict)
429
+
430
+ log_prob = dist.log_prob(tensordict[self.tensor_keys.action])
431
+
432
+ # Min Q value
433
+ td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
434
+ td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
435
+ min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
436
+
437
+ if log_prob.shape != min_q.shape:
438
+ raise RuntimeError(
439
+ f"Losses shape mismatch: {log_prob.shape} and {min_q.shape}"
440
+ )
441
+ # state value
442
+ with torch.no_grad():
443
+ td_copy = tensordict.select(
444
+ *self.value_network.in_keys, strict=False
445
+ ).detach()
446
+ with self.value_network_params.to_module(self.value_network):
447
+ self.value_network(td_copy)
448
+ value = td_copy.get(self.tensor_keys.value).squeeze(
449
+ -1
450
+ ) # assert has no gradient
451
+
452
+ exp_a = torch.exp((min_q - value) * self.temperature)
453
+ exp_a = exp_a.clamp_max(100)
454
+
455
+ # write log_prob in tensordict for alpha loss
456
+ tensordict.set(self.tensor_keys.log_prob, log_prob.detach())
457
+ loss_actor = -(exp_a * log_prob)
458
+ loss_actor = _reduce(loss_actor, reduction=self.reduction)
459
+ self._clear_weakrefs(
460
+ tensordict,
461
+ "actor_network_params",
462
+ "qvalue_network_params",
463
+ "value_network_params",
464
+ "target_actor_network_params",
465
+ "target_qvalue_network_params",
466
+ "target_value_network_params",
467
+ )
468
+ return loss_actor, {}
469
+
470
+ def value_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
471
+ # Min Q value
472
+ td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
473
+ td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
474
+ min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
475
+ # state value
476
+ td_copy = tensordict.select(*self.value_network.in_keys, strict=False)
477
+ with self.value_network_params.to_module(self.value_network):
478
+ self.value_network(td_copy)
479
+ value = td_copy.get(self.tensor_keys.value).squeeze(-1)
480
+ value_loss = self.loss_value_diff(min_q - value, self.expectile)
481
+ value_loss = _reduce(value_loss, reduction=self.reduction)
482
+ self._clear_weakrefs(
483
+ tensordict,
484
+ "actor_network_params",
485
+ "qvalue_network_params",
486
+ "value_network_params",
487
+ "target_actor_network_params",
488
+ "target_qvalue_network_params",
489
+ "target_value_network_params",
490
+ )
491
+ return value_loss, {}
492
+
493
+ def qvalue_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
494
+ obs_keys = self.actor_network.in_keys
495
+ tensordict = tensordict.select(
496
+ "next", *obs_keys, self.tensor_keys.action, strict=False
497
+ )
498
+
499
+ target_value = self.value_estimator.value_estimate(
500
+ tensordict, target_params=self.target_value_network_params
501
+ ).squeeze(-1)
502
+ tensordict_expand = self._vmap_qvalue_networkN0(
503
+ tensordict.select(*self.qvalue_network.in_keys, strict=False),
504
+ self.qvalue_network_params,
505
+ )
506
+ pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze(
507
+ -1
508
+ )
509
+ td_error = (pred_val - target_value).pow(2)
510
+ loss_qval = distance_loss(
511
+ pred_val,
512
+ target_value.expand_as(pred_val),
513
+ loss_function=self.loss_function,
514
+ ).sum(0)
515
+ loss_qval = _reduce(loss_qval, reduction=self.reduction)
516
+ metadata = {"td_error": td_error.detach()}
517
+ self._clear_weakrefs(
518
+ tensordict,
519
+ "actor_network_params",
520
+ "qvalue_network_params",
521
+ "value_network_params",
522
+ "target_actor_network_params",
523
+ "target_qvalue_network_params",
524
+ "target_value_network_params",
525
+ )
526
+ return loss_qval, metadata
527
+
528
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
529
+ if value_type is None:
530
+ value_type = self.default_value_estimator
531
+
532
+ # Handle ValueEstimatorBase instance or class
533
+ if isinstance(value_type, ValueEstimatorBase) or (
534
+ isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
535
+ ):
536
+ return LossModule.make_value_estimator(self, value_type, **hyperparams)
537
+
538
+ self.value_type = value_type
539
+ value_net = self.value_network
540
+
541
+ hp = dict(default_value_kwargs(value_type))
542
+ if hasattr(self, "gamma"):
543
+ hp["gamma"] = self.gamma
544
+ hp.update(hyperparams)
545
+ if value_type is ValueEstimators.TD1:
546
+ self._value_estimator = TD1Estimator(
547
+ **hp,
548
+ value_network=value_net,
549
+ )
550
+ elif value_type is ValueEstimators.TD0:
551
+ self._value_estimator = TD0Estimator(
552
+ **hp,
553
+ value_network=value_net,
554
+ )
555
+ elif value_type is ValueEstimators.GAE:
556
+ raise NotImplementedError(
557
+ f"Value type {value_type} it not implemented for loss {type(self)}."
558
+ )
559
+ elif value_type is ValueEstimators.TDLambda:
560
+ self._value_estimator = TDLambdaEstimator(
561
+ **hp,
562
+ value_network=value_net,
563
+ )
564
+ else:
565
+ raise NotImplementedError(f"Unknown value type {value_type}")
566
+
567
+ tensor_keys = {
568
+ "value_target": "value_target",
569
+ "value": self.tensor_keys.value,
570
+ "reward": self.tensor_keys.reward,
571
+ "done": self.tensor_keys.done,
572
+ "terminated": self.tensor_keys.terminated,
573
+ }
574
+ self._value_estimator.set_keys(**tensor_keys)
575
+
576
+
577
+ class DiscreteIQLLoss(IQLLoss):
578
+ r"""TorchRL implementation of the discrete IQL loss.
579
+
580
+ Presented in "Offline Reinforcement Learning with Implicit Q-Learning" https://arxiv.org/abs/2110.06169
581
+
582
+ Args:
583
+ actor_network (ProbabilisticTensorDictSequential): stochastic actor
584
+ qvalue_network (TensorDictModule): Q(s, a) parametric model.
585
+ value_network (TensorDictModule, optional): V(s) parametric model.
586
+
587
+ Keyword Args:
588
+ action_space (str or TensorSpec): Action space. Must be one of
589
+ ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``,
590
+ or an instance of the corresponding specs (:class:`torchrl.data.OneHot`,
591
+ :class:`torchrl.data.MultiOneHot`,
592
+ :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`).
593
+ num_qvalue_nets (integer, optional): number of Q-Value networks used.
594
+ Defaults to ``2``.
595
+ loss_function (str, optional): loss function to be used with
596
+ the value function loss. Default is `"smooth_l1"`.
597
+ temperature (:obj:`float`, optional): Inverse temperature (beta).
598
+ For smaller hyperparameter values, the objective behaves similarly to
599
+ behavioral cloning, while for larger values, it attempts to recover the
600
+ maximum of the Q-function.
601
+ expectile (:obj:`float`, optional): expectile :math:`\tau`. A larger value of :math:`\tau` is crucial
602
+ for antmaze tasks that require dynamical programming ("stichting").
603
+ priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
604
+ tensordict key where to write the priority (for prioritized replay
605
+ buffer usage). Default is `"td_error"`.
606
+ separate_losses (bool, optional): if ``True``, shared parameters between
607
+ policy and critic will only be trained on the policy loss.
608
+ Defaults to ``False``, i.e., gradients are propagated to shared
609
+ parameters for both policy and critic losses.
610
+ reduction (str, optional): Specifies the reduction to apply to the output:
611
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
612
+ ``"mean"``: the sum of the output will be divided by the number of
613
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
614
+
615
+ Examples:
616
+ >>> import torch
617
+ >>> from torch import nn
618
+ >>> from torchrl.data.tensor_specs import OneHot
619
+ >>> from torchrl.modules.distributions.discrete import OneHotCategorical
620
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor
621
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
622
+ >>> from torchrl.objectives.iql import DiscreteIQLLoss
623
+ >>> from tensordict import TensorDict
624
+ >>> n_act, n_obs = 4, 3
625
+ >>> spec = OneHot(n_act)
626
+ >>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
627
+ >>> actor = ProbabilisticActor(
628
+ ... module=module,
629
+ ... in_keys=["logits"],
630
+ ... out_keys=["action"],
631
+ ... spec=spec,
632
+ ... distribution_class=OneHotCategorical)
633
+ >>> qvalue = SafeModule(
634
+ ... nn.Linear(n_obs, n_act),
635
+ ... in_keys=["observation"],
636
+ ... out_keys=["state_action_value"],
637
+ ... )
638
+ >>> value = SafeModule(
639
+ ... nn.Linear(n_obs, 1),
640
+ ... in_keys=["observation"],
641
+ ... out_keys=["state_value"],
642
+ ... )
643
+ >>> loss = DiscreteIQLLoss(actor, qvalue, value)
644
+ >>> batch = [2, ]
645
+ >>> action = spec.rand(batch).long()
646
+ >>> data = TensorDict({
647
+ ... "observation": torch.randn(*batch, n_obs),
648
+ ... "action": action,
649
+ ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
650
+ ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
651
+ ... ("next", "reward"): torch.randn(*batch, 1),
652
+ ... ("next", "observation"): torch.randn(*batch, n_obs),
653
+ ... }, batch)
654
+ >>> loss(data)
655
+ TensorDict(
656
+ fields={
657
+ entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
658
+ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
659
+ loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
660
+ loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
661
+ batch_size=torch.Size([]),
662
+ device=None,
663
+ is_shared=False)
664
+
665
+ This class is compatible with non-tensordict based modules too and can be
666
+ used without recurring to any tensordict-related primitive. In this case,
667
+ the expected keyword arguments are:
668
+ ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network
669
+ The return value is a tuple of tensors in the following order:
670
+ ``["loss_actor", "loss_qvalue", "loss_value", "entropy"]``.
671
+
672
+ Examples:
673
+ >>> import torch
674
+ >>> import torch
675
+ >>> from torch import nn
676
+ >>> from torchrl.data.tensor_specs import OneHot
677
+ >>> from torchrl.modules.distributions.discrete import OneHotCategorical
678
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor
679
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
680
+ >>> from torchrl.objectives.iql import DiscreteIQLLoss
681
+ >>> _ = torch.manual_seed(42)
682
+ >>> n_act, n_obs = 4, 3
683
+ >>> spec = OneHot(n_act)
684
+ >>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
685
+ >>> actor = ProbabilisticActor(
686
+ ... module=module,
687
+ ... in_keys=["logits"],
688
+ ... out_keys=["action"],
689
+ ... spec=spec,
690
+ ... distribution_class=OneHotCategorical)
691
+ >>> qvalue = SafeModule(
692
+ ... nn.Linear(n_obs, n_act),
693
+ ... in_keys=["observation"],
694
+ ... out_keys=["state_action_value"],
695
+ ... )
696
+ >>> value = SafeModule(
697
+ ... nn.Linear(n_obs, 1),
698
+ ... in_keys=["observation"],
699
+ ... out_keys=["state_value"],
700
+ ... )
701
+ >>> loss = DiscreteIQLLoss(actor, qvalue, value)
702
+ >>> batch = [2, ]
703
+ >>> action = spec.rand(batch).long()
704
+ >>> loss_actor, loss_qvalue, loss_value, entropy = loss(
705
+ ... observation=torch.randn(*batch, n_obs),
706
+ ... action=action,
707
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
708
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
709
+ ... next_observation=torch.zeros(*batch, n_obs),
710
+ ... next_reward=torch.randn(*batch, 1))
711
+ >>> loss_actor.backward()
712
+
713
+
714
+ The output keys can also be filtered using the :meth:`DiscreteIQLLoss.select_out_keys`
715
+ method.
716
+
717
+ Examples:
718
+ >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue', 'loss_value')
719
+ >>> loss_actor, loss_qvalue, loss_value = loss(
720
+ ... observation=torch.randn(*batch, n_obs),
721
+ ... action=action,
722
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
723
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
724
+ ... next_observation=torch.zeros(*batch, n_obs),
725
+ ... next_reward=torch.randn(*batch, 1))
726
+ >>> loss_actor.backward()
727
+ """
728
+
729
+ @dataclass
730
+ class _AcceptedKeys:
731
+ """Maintains default values for all configurable tensordict keys.
732
+
733
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
734
+ default values
735
+
736
+ Attributes:
737
+ value (NestedKey): The input tensordict key where the state value is expected.
738
+ Will be used for the underlying value estimator. Defaults to ``"state_value"``.
739
+ action (NestedKey): The input tensordict key where the action is expected.
740
+ Defaults to ``"action"``.
741
+ log_prob (NestedKey): The input tensordict key where the log probability is expected.
742
+ Defaults to ``"_log_prob"``.
743
+ priority (NestedKey): The input tensordict key where the target priority is written to.
744
+ Defaults to ``"td_error"``.
745
+ state_action_value (NestedKey): The input tensordict key where the
746
+ state action value is expected. Will be used for the underlying
747
+ value estimator as value key. Defaults to ``"state_action_value"``.
748
+ reward (NestedKey): The input tensordict key where the reward is expected.
749
+ Will be used for the underlying value estimator. Defaults to ``"reward"``.
750
+ done (NestedKey): The key in the input TensorDict that indicates
751
+ whether a trajectory is done. Will be used for the underlying value estimator.
752
+ Defaults to ``"done"``.
753
+ terminated (NestedKey): The key in the input TensorDict that indicates
754
+ whether a trajectory is terminated. Will be used for the underlying value estimator.
755
+ Defaults to ``"terminated"``.
756
+ """
757
+
758
+ value: NestedKey = "state_value"
759
+ action: NestedKey = "action"
760
+ log_prob: NestedKey = "_log_prob"
761
+ priority: NestedKey = "td_error"
762
+ state_action_value: NestedKey = "state_action_value"
763
+ reward: NestedKey = "reward"
764
+ done: NestedKey = "done"
765
+ terminated: NestedKey = "terminated"
766
+
767
+ tensor_keys: _AcceptedKeys
768
+ default_keys = _AcceptedKeys
769
+ default_value_estimator = ValueEstimators.TD0
770
+ out_keys = [
771
+ "loss_actor",
772
+ "loss_qvalue",
773
+ "loss_value",
774
+ "entropy",
775
+ ]
776
+
777
+ actor_network: TensorDictModule
778
+ actor_network_params: TensorDictParams
779
+ target_actor_network_params: TensorDictParams
780
+ qvalue_network: TensorDictModule
781
+ qvalue_network_params: TensorDictParams
782
+ target_qvalue_network_params: TensorDictParams
783
+ value_network: TensorDictModule | None
784
+ value_network_params: TensorDictParams | None
785
+ target_value_network_params: TensorDictParams | None
786
+
787
+ def __init__(
788
+ self,
789
+ actor_network: ProbabilisticTensorDictSequential,
790
+ qvalue_network: TensorDictModule,
791
+ value_network: TensorDictModule | None,
792
+ *,
793
+ action_space: str | TensorSpec = None,
794
+ num_qvalue_nets: int = 2,
795
+ loss_function: str = "smooth_l1",
796
+ temperature: float = 1.0,
797
+ expectile: float = 0.5,
798
+ gamma: float | None = None,
799
+ priority_key: str | None = None,
800
+ separate_losses: bool = False,
801
+ reduction: str | None = None,
802
+ ) -> None:
803
+ self._in_keys = None
804
+ self._out_keys = None
805
+ if reduction is None:
806
+ reduction = "mean"
807
+ if expectile >= 1.0:
808
+ raise ValueError(f"Expectile should be lower than 1.0 but is {expectile}")
809
+ super().__init__(
810
+ actor_network=actor_network,
811
+ qvalue_network=qvalue_network,
812
+ value_network=value_network,
813
+ num_qvalue_nets=num_qvalue_nets,
814
+ loss_function=loss_function,
815
+ temperature=temperature,
816
+ expectile=expectile,
817
+ gamma=gamma,
818
+ priority_key=priority_key,
819
+ separate_losses=separate_losses,
820
+ )
821
+ if action_space is None:
822
+ warnings.warn(
823
+ "action_space was not specified. DiscreteIQLLoss will default to 'one-hot'. "
824
+ "This behavior will be deprecated soon and a space will have to be passed. "
825
+ "Check the DiscreteIQLLoss documentation to see how to pass the action space. "
826
+ )
827
+ action_space = "one-hot"
828
+ self.action_space = _find_action_space(action_space)
829
+ self.reduction = reduction
830
+
831
+ def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
832
+ # KL loss
833
+ with self.actor_network_params.to_module(self.actor_network):
834
+ dist = self.actor_network.get_dist(tensordict)
835
+
836
+ log_prob = dist.log_prob(tensordict[self.tensor_keys.action])
837
+
838
+ # Min Q value
839
+ td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
840
+ td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
841
+ state_action_value = td_q.get(self.tensor_keys.state_action_value)
842
+ action = tensordict.get(self.tensor_keys.action)
843
+ if self.action_space == "categorical":
844
+ if action.ndim < (state_action_value.ndim - (td_q.ndim - tensordict.ndim)):
845
+ # unsqueeze the action if it lacks on trailing singleton dim
846
+ action = action.unsqueeze(-1)
847
+ if self.deactivate_vmap:
848
+ vmap = _pseudo_vmap
849
+ else:
850
+ vmap = torch.vmap
851
+ chosen_state_action_value = vmap(
852
+ lambda state_action_value, action: torch.gather(
853
+ state_action_value, -1, index=action
854
+ ).squeeze(-1),
855
+ (0, None),
856
+ )(state_action_value, action)
857
+ elif self.action_space == "one_hot":
858
+ action = action.to(torch.float)
859
+ chosen_state_action_value = (state_action_value * action).sum(-1)
860
+ else:
861
+ raise RuntimeError(f"Unknown action space {self.action_space}.")
862
+ min_Q, _ = torch.min(chosen_state_action_value, dim=0)
863
+ if log_prob.shape != min_Q.shape:
864
+ raise RuntimeError(
865
+ f"Losses shape mismatch: {log_prob.shape} and {min_Q.shape}"
866
+ )
867
+ with torch.no_grad():
868
+ # state value
869
+ td_copy = tensordict.select(
870
+ *self.value_network.in_keys, strict=False
871
+ ).detach()
872
+ with self.value_network_params.to_module(self.value_network):
873
+ self.value_network(td_copy)
874
+ value = td_copy.get(self.tensor_keys.value).squeeze(
875
+ -1
876
+ ) # assert has no gradient
877
+
878
+ exp_a = torch.exp((min_Q - value) * self.temperature)
879
+ exp_a = exp_a.clamp_max(100)
880
+
881
+ # write log_prob in tensordict for alpha loss
882
+ tensordict.set(self.tensor_keys.log_prob, log_prob.detach())
883
+ loss_actor = -(exp_a * log_prob)
884
+ loss_actor = _reduce(loss_actor, reduction=self.reduction)
885
+ self._clear_weakrefs(
886
+ tensordict,
887
+ "actor_network_params",
888
+ "qvalue_network_params",
889
+ "value_network_params",
890
+ "target_actor_network_params",
891
+ "target_qvalue_network_params",
892
+ "target_value_network_params",
893
+ )
894
+ return loss_actor, {}
895
+
896
+ def value_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
897
+ # Min Q value
898
+ with torch.no_grad():
899
+ # Min Q value
900
+ td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
901
+ td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
902
+ state_action_value = td_q.get(self.tensor_keys.state_action_value)
903
+ action = tensordict.get(self.tensor_keys.action)
904
+ if self.action_space == "categorical":
905
+ if action.ndim < (
906
+ state_action_value.ndim - (td_q.ndim - tensordict.ndim)
907
+ ):
908
+ # unsqueeze the action if it lacks on trailing singleton dim
909
+ action = action.unsqueeze(-1)
910
+ if self.deactivate_vmap:
911
+ vmap = _pseudo_vmap
912
+ else:
913
+ vmap = torch.vmap
914
+ chosen_state_action_value = vmap(
915
+ lambda state_action_value, action: torch.gather(
916
+ state_action_value, -1, index=action
917
+ ).squeeze(-1),
918
+ (0, None),
919
+ )(state_action_value, action)
920
+ elif self.action_space == "one_hot":
921
+ action = action.to(torch.float)
922
+ chosen_state_action_value = (state_action_value * action).sum(-1)
923
+ else:
924
+ raise RuntimeError(f"Unknown action space {self.action_space}.")
925
+ min_Q, _ = torch.min(chosen_state_action_value, dim=0)
926
+ # state value
927
+ td_copy = tensordict.select(*self.value_network.in_keys, strict=False)
928
+ with self.value_network_params.to_module(self.value_network):
929
+ self.value_network(td_copy)
930
+ value = td_copy.get(self.tensor_keys.value).squeeze(-1)
931
+ value_loss = self.loss_value_diff(min_Q - value, self.expectile)
932
+ value_loss = _reduce(value_loss, reduction=self.reduction)
933
+ self._clear_weakrefs(
934
+ tensordict,
935
+ "actor_network_params",
936
+ "qvalue_network_params",
937
+ "value_network_params",
938
+ "target_actor_network_params",
939
+ "target_qvalue_network_params",
940
+ "target_value_network_params",
941
+ )
942
+ return value_loss, {}
943
+
944
+ def qvalue_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
945
+ obs_keys = self.actor_network.in_keys
946
+ next_td = tensordict.select(
947
+ "next", *obs_keys, self.tensor_keys.action, strict=False
948
+ )
949
+ with torch.no_grad():
950
+ target_value = self.value_estimator.value_estimate(
951
+ next_td, target_params=self.target_value_network_params
952
+ ).squeeze(-1)
953
+
954
+ # predict current Q value
955
+ td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
956
+ td_q = self._vmap_qvalue_networkN0(td_q, self.qvalue_network_params)
957
+ state_action_value = td_q.get(self.tensor_keys.state_action_value)
958
+ action = tensordict.get(self.tensor_keys.action)
959
+ if self.action_space == "categorical":
960
+ if action.ndim < (state_action_value.ndim - (td_q.ndim - tensordict.ndim)):
961
+ # unsqueeze the action if it lacks on trailing singleton dim
962
+ action = action.unsqueeze(-1)
963
+ if self.deactivate_vmap:
964
+ vmap = _pseudo_vmap
965
+ else:
966
+ vmap = torch.vmap
967
+ pred_val = vmap(
968
+ lambda state_action_value, action: torch.gather(
969
+ state_action_value, -1, index=action
970
+ ).squeeze(-1),
971
+ (0, None),
972
+ )(state_action_value, action)
973
+ elif self.action_space == "one_hot":
974
+ action = action.to(torch.float)
975
+ pred_val = (state_action_value * action).sum(-1)
976
+ else:
977
+ raise RuntimeError(f"Unknown action space {self.action_space}.")
978
+
979
+ td_error = (pred_val - target_value.expand_as(pred_val)).pow(2)
980
+ loss_qval = distance_loss(
981
+ pred_val,
982
+ target_value.expand_as(pred_val),
983
+ loss_function=self.loss_function,
984
+ ).sum(0)
985
+ loss_qval = _reduce(loss_qval, reduction=self.reduction)
986
+ metadata = {"td_error": td_error.detach()}
987
+ self._clear_weakrefs(
988
+ tensordict,
989
+ "actor_network_params",
990
+ "qvalue_network_params",
991
+ "value_network_params",
992
+ "target_actor_network_params",
993
+ "target_qvalue_network_params",
994
+ "target_value_network_params",
995
+ )
996
+ return loss_qval, metadata