torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,625 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+ from tensordict import TensorDict, TensorDictBase, TensorDictParams
11
+ from tensordict.nn import dispatch, TensorDictModule
12
+ from tensordict.utils import NestedKey
13
+
14
+ from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec
15
+ from torchrl.envs.utils import step_mdp
16
+ from torchrl.objectives.common import LossModule
17
+ from torchrl.objectives.utils import (
18
+ _cache_values,
19
+ _reduce,
20
+ _vmap_func,
21
+ default_value_kwargs,
22
+ distance_loss,
23
+ ValueEstimators,
24
+ )
25
+ from torchrl.objectives.value import (
26
+ TD0Estimator,
27
+ TD1Estimator,
28
+ TDLambdaEstimator,
29
+ ValueEstimatorBase,
30
+ )
31
+
32
+
33
+ class TD3BCLoss(LossModule):
34
+ r"""TD3+BC Loss Module.
35
+
36
+ Implementation of the TD3+BC loss presented in the paper `"A Minimalist Approach to
37
+ Offline Reinforcement Learning" <https://arxiv.org/pdf/2106.06860>`.
38
+
39
+ This class incorporates two loss functions, executed sequentially within the `forward` method:
40
+
41
+ 1. :meth:`~.qvalue_loss`
42
+ 2. :meth:`~.actor_loss`
43
+
44
+ Users also have the option to call these functions directly in the same order if preferred.
45
+
46
+ Args:
47
+ actor_network (TensorDictModule): the actor to be trained
48
+ qvalue_network (TensorDictModule): a single Q-value network or a list of
49
+ Q-value networks.
50
+ If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets``
51
+ times. If a list of modules is passed, their
52
+ parameters will be stacked unless they share the same identity (in which case
53
+ the original parameter will be expanded).
54
+
55
+ .. warning:: When a list of parameters if passed, it will **not** be compared against the policy parameters
56
+ and all the parameters will be considered as untied.
57
+
58
+ Keyword Args:
59
+ bounds (tuple of float, optional): the bounds of the action space.
60
+ Exclusive with ``action_spec``. Either this or ``action_spec`` must
61
+ be provided.
62
+ action_spec (TensorSpec, optional): the action spec.
63
+ Exclusive with ``bounds``. Either this or ``bounds`` must be provided.
64
+ num_qvalue_nets (int, optional): Number of Q-value networks to be
65
+ trained. Default is ``2``.
66
+ policy_noise (:obj:`float`, optional): Standard deviation for the target
67
+ policy action noise. Default is ``0.2``.
68
+ noise_clip (:obj:`float`, optional): Clipping range value for the sampled
69
+ target policy action noise. Default is ``0.5``.
70
+ alpha (:obj:`float`, optional): Weight for the behavioral cloning loss.
71
+ Defaults to ``2.5``.
72
+ priority_key (str, optional): Key where to write the priority value
73
+ for prioritized replay buffers. Default is
74
+ `"td_error"`.
75
+ loss_function (str, optional): loss function to be used for the Q-value.
76
+ Can be one of ``"smooth_l1"``, ``"l2"``,
77
+ ``"l1"``, Default is ``"smooth_l1"``.
78
+ delay_actor (bool, optional): whether to separate the target actor
79
+ networks from the actor networks used for
80
+ data collection. Default is ``True``.
81
+ delay_qvalue (bool, optional): Whether to separate the target Q value
82
+ networks from the Q value networks used
83
+ for data collection. Default is ``True``.
84
+ spec (TensorSpec, optional): the action tensor spec. If not provided
85
+ and the target entropy is ``"auto"``, it will be retrieved from
86
+ the actor.
87
+ separate_losses (bool, optional): if ``True``, shared parameters between
88
+ policy and critic will only be trained on the policy loss.
89
+ Defaults to ``False``, i.e., gradients are propagated to shared
90
+ parameters for both policy and critic losses.
91
+ reduction (str, optional): Specifies the reduction to apply to the output:
92
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
93
+ ``"mean"``: the sum of the output will be divided by the number of
94
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
95
+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
96
+ Defaults to ``False``.
97
+
98
+ Examples:
99
+ >>> import torch
100
+ >>> from torch import nn
101
+ >>> from torchrl.data import Bounded
102
+ >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
103
+ >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator
104
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
105
+ >>> from torchrl.objectives.td3_bc import TD3BCLoss
106
+ >>> from tensordict import TensorDict
107
+ >>> n_act, n_obs = 4, 3
108
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
109
+ >>> module = nn.Linear(n_obs, n_act)
110
+ >>> actor = Actor(
111
+ ... module=module,
112
+ ... spec=spec)
113
+ >>> class ValueClass(nn.Module):
114
+ ... def __init__(self):
115
+ ... super().__init__()
116
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
117
+ ... def forward(self, obs, act):
118
+ ... return self.linear(torch.cat([obs, act], -1))
119
+ >>> module = ValueClass()
120
+ >>> qvalue = ValueOperator(
121
+ ... module=module,
122
+ ... in_keys=['observation', 'action'])
123
+ >>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec)
124
+ >>> batch = [2, ]
125
+ >>> action = spec.rand(batch)
126
+ >>> data = TensorDict({
127
+ ... "observation": torch.randn(*batch, n_obs),
128
+ ... "action": action,
129
+ ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
130
+ ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
131
+ ... ("next", "reward"): torch.randn(*batch, 1),
132
+ ... ("next", "observation"): torch.randn(*batch, n_obs),
133
+ ... }, batch)
134
+ >>> loss(data)
135
+ TensorDict(
136
+ fields={
137
+ bc_loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
138
+ lmbd: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
139
+ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
140
+ loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
141
+ next_state_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
142
+ pred_value: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
143
+ state_action_value_actor: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
144
+ target_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
145
+ batch_size=torch.Size([]),
146
+ device=None,
147
+ is_shared=False)
148
+
149
+ This class is compatible with non-tensordict based modules too and can be
150
+ used without recurring to any tensordict-related primitive. In this case,
151
+ the expected keyword arguments are:
152
+ ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network
153
+ The return value is a tuple of tensors in the following order:
154
+ ``["loss_actor", "loss_qvalue", "bc_loss, "lmbd", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``.
155
+
156
+ Examples:
157
+ >>> import torch
158
+ >>> from torch import nn
159
+ >>> from torchrl.data import Bounded
160
+ >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
161
+ >>> from torchrl.objectives.td3_bc import TD3BCLoss
162
+ >>> n_act, n_obs = 4, 3
163
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
164
+ >>> module = nn.Linear(n_obs, n_act)
165
+ >>> actor = Actor(
166
+ ... module=module,
167
+ ... spec=spec)
168
+ >>> class ValueClass(nn.Module):
169
+ ... def __init__(self):
170
+ ... super().__init__()
171
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
172
+ ... def forward(self, obs, act):
173
+ ... return self.linear(torch.cat([obs, act], -1))
174
+ >>> module = ValueClass()
175
+ >>> qvalue = ValueOperator(
176
+ ... module=module,
177
+ ... in_keys=['observation', 'action'])
178
+ >>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec)
179
+ >>> _ = loss.select_out_keys("loss_actor", "loss_qvalue")
180
+ >>> batch = [2, ]
181
+ >>> action = spec.rand(batch)
182
+ >>> loss_actor, loss_qvalue = loss(
183
+ ... observation=torch.randn(*batch, n_obs),
184
+ ... action=action,
185
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
186
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
187
+ ... next_reward=torch.randn(*batch, 1),
188
+ ... next_observation=torch.randn(*batch, n_obs))
189
+ >>> loss_actor.backward()
190
+
191
+ """
192
+
193
+ @dataclass
194
+ class _AcceptedKeys:
195
+ """Maintains default values for all configurable tensordict keys.
196
+
197
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
198
+ default values.
199
+
200
+ Attributes:
201
+ action (NestedKey): The input tensordict key where the action is expected.
202
+ Defaults to ``"action"``.
203
+ state_action_value (NestedKey): The input tensordict key where the state action value is expected.
204
+ Will be used for the underlying value estimator. Defaults to ``"state_action_value"``.
205
+ priority (NestedKey): The input tensordict key where the target priority is written to.
206
+ Defaults to ``"td_error"``.
207
+ reward (NestedKey): The input tensordict key where the reward is expected.
208
+ Will be used for the underlying value estimator. Defaults to ``"reward"``.
209
+ done (NestedKey): The key in the input TensorDict that indicates
210
+ whether a trajectory is done. Will be used for the underlying value estimator.
211
+ Defaults to ``"done"``.
212
+ terminated (NestedKey): The key in the input TensorDict that indicates
213
+ whether a trajectory is terminated. Will be used for the underlying value estimator.
214
+ Defaults to ``"terminated"``.
215
+ """
216
+
217
+ action: NestedKey = "action"
218
+ state_action_value: NestedKey = "state_action_value"
219
+ priority: NestedKey = "td_error"
220
+ reward: NestedKey = "reward"
221
+ done: NestedKey = "done"
222
+ terminated: NestedKey = "terminated"
223
+ priority_weight: NestedKey = "priority_weight"
224
+
225
+ tensor_keys: _AcceptedKeys
226
+ default_keys = _AcceptedKeys
227
+ default_value_estimator = ValueEstimators.TD0
228
+ out_keys = [
229
+ "loss_actor",
230
+ "loss_qvalue",
231
+ "bc_loss",
232
+ "lmbd",
233
+ "pred_value",
234
+ "state_action_value_actor",
235
+ "next_state_value",
236
+ "target_value",
237
+ ]
238
+
239
+ actor_network: TensorDictModule
240
+ qvalue_network: TensorDictModule
241
+ actor_network_params: TensorDictParams
242
+ qvalue_network_params: TensorDictParams
243
+ target_actor_network_params: TensorDictParams
244
+ target_qvalue_network_params: TensorDictParams
245
+
246
+ def __init__(
247
+ self,
248
+ actor_network: TensorDictModule,
249
+ qvalue_network: TensorDictModule | list[TensorDictModule],
250
+ *,
251
+ action_spec: TensorSpec = None,
252
+ bounds: tuple[float] | None = None,
253
+ num_qvalue_nets: int = 2,
254
+ policy_noise: float = 0.2,
255
+ noise_clip: float = 0.5,
256
+ alpha: float = 2.5,
257
+ loss_function: str = "smooth_l1",
258
+ delay_actor: bool = True,
259
+ delay_qvalue: bool = True,
260
+ priority_key: str | None = None,
261
+ separate_losses: bool = False,
262
+ reduction: str | None = None,
263
+ deactivate_vmap: bool = False,
264
+ use_prioritized_weights: str | bool = "auto",
265
+ ) -> None:
266
+ if reduction is None:
267
+ reduction = "mean"
268
+ super().__init__()
269
+ self.use_prioritized_weights = use_prioritized_weights
270
+ self._in_keys = None
271
+ self._set_deprecated_ctor_keys(priority=priority_key)
272
+
273
+ self.delay_actor = delay_actor
274
+ self.delay_qvalue = delay_qvalue
275
+ self.deactivate_vmap = deactivate_vmap
276
+
277
+ self.convert_to_functional(
278
+ actor_network,
279
+ "actor_network",
280
+ create_target_params=self.delay_actor,
281
+ )
282
+ if separate_losses:
283
+ # we want to make sure there are no duplicates in the params: the
284
+ # params of critic must be refs to actor if they're shared
285
+ policy_params = list(actor_network.parameters())
286
+ else:
287
+ policy_params = None
288
+ self.convert_to_functional(
289
+ qvalue_network,
290
+ "qvalue_network",
291
+ num_qvalue_nets,
292
+ create_target_params=self.delay_qvalue,
293
+ compare_against=policy_params,
294
+ )
295
+
296
+ for p in self.parameters():
297
+ device = p.device
298
+ break
299
+ else:
300
+ device = None
301
+ self.num_qvalue_nets = num_qvalue_nets
302
+ self.loss_function = loss_function
303
+ self.policy_noise = policy_noise
304
+ self.noise_clip = noise_clip
305
+ self.alpha = alpha
306
+ if not ((action_spec is not None) ^ (bounds is not None)):
307
+ raise ValueError(
308
+ "One of 'bounds' and 'action_spec' must be provided, "
309
+ f"but not both or none. Got bounds={bounds} and action_spec={action_spec}."
310
+ )
311
+ elif action_spec is not None:
312
+ if isinstance(action_spec, Composite):
313
+ if (
314
+ isinstance(self.tensor_keys.action, tuple)
315
+ and len(self.tensor_keys.action) > 1
316
+ ):
317
+ action_container_shape = action_spec[
318
+ self.tensor_keys.action[:-1]
319
+ ].shape
320
+ else:
321
+ action_container_shape = action_spec.shape
322
+ action_spec = action_spec[self.tensor_keys.action][
323
+ (0,) * len(action_container_shape)
324
+ ]
325
+ if not isinstance(action_spec, Bounded):
326
+ raise ValueError(
327
+ f"action_spec is not of type Bounded but {type(action_spec)}."
328
+ )
329
+ low = action_spec.space.low
330
+ high = action_spec.space.high
331
+ else:
332
+ low, high = bounds
333
+ if not isinstance(low, torch.Tensor):
334
+ low = torch.tensor(low)
335
+ if not isinstance(high, torch.Tensor):
336
+ high = torch.tensor(high, device=low.device, dtype=low.dtype)
337
+ if (low > high).any():
338
+ raise ValueError("Got a low bound higher than a high bound.")
339
+ if device is not None:
340
+ low = low.to(device)
341
+ high = high.to(device)
342
+ self.register_buffer("max_action", high)
343
+ self.register_buffer("min_action", low)
344
+ self._make_vmap()
345
+ self.reduction = reduction
346
+
347
+ def _make_vmap(self):
348
+ self._vmap_qvalue_network00 = _vmap_func(
349
+ self.qvalue_network,
350
+ randomness=self.vmap_randomness,
351
+ pseudo_vmap=self.deactivate_vmap,
352
+ )
353
+ self._vmap_actor_network00 = _vmap_func(
354
+ self.actor_network,
355
+ randomness=self.vmap_randomness,
356
+ pseudo_vmap=self.deactivate_vmap,
357
+ )
358
+
359
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
360
+ if self._value_estimator is not None:
361
+ self._value_estimator.set_keys(
362
+ value=self._tensor_keys.state_action_value,
363
+ reward=self.tensor_keys.reward,
364
+ done=self.tensor_keys.done,
365
+ terminated=self.tensor_keys.terminated,
366
+ )
367
+ self._set_in_keys()
368
+
369
+ def _set_in_keys(self):
370
+ keys = [
371
+ self.tensor_keys.action,
372
+ ("next", self.tensor_keys.reward),
373
+ ("next", self.tensor_keys.done),
374
+ ("next", self.tensor_keys.terminated),
375
+ *self.actor_network.in_keys,
376
+ *[("next", key) for key in self.actor_network.in_keys],
377
+ *self.qvalue_network.in_keys,
378
+ ]
379
+ self._in_keys = list(set(keys))
380
+
381
+ @property
382
+ def in_keys(self):
383
+ if self._in_keys is None:
384
+ self._set_in_keys()
385
+ return self._in_keys
386
+
387
+ @in_keys.setter
388
+ def in_keys(self, values):
389
+ self._in_keys = values
390
+
391
+ @property
392
+ @_cache_values
393
+ def _cached_detach_qvalue_network_params(self):
394
+ return self.qvalue_network_params.detach()
395
+
396
+ @property
397
+ @_cache_values
398
+ def _cached_stack_actor_params(self):
399
+ return torch.stack(
400
+ [self.actor_network_params, self.target_actor_network_params], 0
401
+ )
402
+
403
+ def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]:
404
+ """Compute the actor loss.
405
+
406
+ The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates.
407
+
408
+ Args:
409
+ tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields
410
+ are required for this to be computed.
411
+ Returns: a differentiable tensor with the actor loss along with a metadata dictionary containing the detached `"bc_loss"`
412
+ used in the combined actor loss as well as the detached `"state_action_value_actor"` used to calculate the lambda
413
+ value, and the lambda value `"lmbd"` itself.
414
+ """
415
+ weights = self._maybe_get_priority_weight(tensordict)
416
+ tensordict_actor_grad = tensordict.select(
417
+ *self.actor_network.in_keys, strict=False
418
+ )
419
+ with self.actor_network_params.to_module(self.actor_network):
420
+ tensordict_actor_grad = self.actor_network(tensordict_actor_grad)
421
+ actor_loss_td = tensordict_actor_grad.select(
422
+ *self.qvalue_network.in_keys, strict=False
423
+ ).expand(
424
+ self.num_qvalue_nets, *tensordict_actor_grad.batch_size
425
+ ) # for actor loss
426
+ state_action_value_actor = (
427
+ self._vmap_qvalue_network00(
428
+ actor_loss_td,
429
+ self._cached_detach_qvalue_network_params,
430
+ )
431
+ .get(self.tensor_keys.state_action_value)
432
+ .squeeze(-1)
433
+ )
434
+
435
+ bc_loss = torch.nn.functional.mse_loss(
436
+ tensordict_actor_grad.get(self.tensor_keys.action),
437
+ tensordict.get(self.tensor_keys.action),
438
+ )
439
+ lmbd = self.alpha / state_action_value_actor[0].abs().mean().detach()
440
+
441
+ loss_actor = -lmbd * state_action_value_actor[0] + bc_loss
442
+
443
+ metadata = {
444
+ "state_action_value_actor": state_action_value_actor[0].detach(),
445
+ "bc_loss": bc_loss.detach(),
446
+ "lmbd": lmbd,
447
+ }
448
+ loss_actor = _reduce(loss_actor, reduction=self.reduction, weights=weights)
449
+ self._clear_weakrefs(
450
+ tensordict,
451
+ "actor_network_params",
452
+ "qvalue_network_params",
453
+ "target_actor_network_params",
454
+ "target_qvalue_network_params",
455
+ )
456
+ return loss_actor, metadata
457
+
458
+ def qvalue_loss(self, tensordict) -> tuple[torch.Tensor, dict]:
459
+ """Compute the q-value loss.
460
+
461
+ The q-value loss should be computed before the :meth:`~.actor_loss`.
462
+
463
+ Args:
464
+ tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields
465
+ are required for this to be computed.
466
+ Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing
467
+ the detached `"td_error"` to be used for prioritized sampling, the detached `"next_state_value"`, the detached `"pred_value"`, and the detached `"target_value"`.
468
+ """
469
+ weights = self._maybe_get_priority_weight(tensordict)
470
+ tensordict = tensordict.clone(False)
471
+
472
+ act = tensordict.get(self.tensor_keys.action)
473
+
474
+ # computing early for reprod
475
+ noise = (torch.randn_like(act) * self.policy_noise).clamp(
476
+ -self.noise_clip, self.noise_clip
477
+ )
478
+
479
+ with torch.no_grad():
480
+ next_td_actor = step_mdp(tensordict).select(
481
+ *self.actor_network.in_keys, strict=False
482
+ ) # next_observation ->
483
+ with self.target_actor_network_params.to_module(self.actor_network):
484
+ next_td_actor = self.actor_network(next_td_actor)
485
+ next_action = (next_td_actor.get(self.tensor_keys.action) + noise).clamp(
486
+ self.min_action, self.max_action
487
+ )
488
+ next_td_actor.set(
489
+ self.tensor_keys.action,
490
+ next_action,
491
+ )
492
+ next_val_td = next_td_actor.select(
493
+ *self.qvalue_network.in_keys, strict=False
494
+ ).expand(
495
+ self.num_qvalue_nets, *next_td_actor.batch_size
496
+ ) # for next value estimation
497
+ next_target_q1q2 = (
498
+ self._vmap_qvalue_network00(
499
+ next_val_td,
500
+ self.target_qvalue_network_params,
501
+ )
502
+ .get(self.tensor_keys.state_action_value)
503
+ .squeeze(-1)
504
+ )
505
+ # min over the next target qvalues
506
+ next_target_qvalue = next_target_q1q2.min(0)[0]
507
+
508
+ # set next target qvalues
509
+ tensordict.set(
510
+ ("next", self.tensor_keys.state_action_value),
511
+ next_target_qvalue.unsqueeze(-1),
512
+ )
513
+
514
+ qval_td = tensordict.select(*self.qvalue_network.in_keys, strict=False).expand(
515
+ self.num_qvalue_nets,
516
+ *tensordict.batch_size,
517
+ )
518
+ # preditcted current qvalues
519
+ current_qvalue = (
520
+ self._vmap_qvalue_network00(
521
+ qval_td,
522
+ self.qvalue_network_params,
523
+ )
524
+ .get(self.tensor_keys.state_action_value)
525
+ .squeeze(-1)
526
+ )
527
+
528
+ # compute target values for the qvalue loss (reward + gamma * next_target_qvalue * (1 - done))
529
+ target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
530
+
531
+ td_error = (current_qvalue - target_value).pow(2)
532
+ loss_qval = distance_loss(
533
+ current_qvalue,
534
+ target_value.expand_as(current_qvalue),
535
+ loss_function=self.loss_function,
536
+ ).sum(0)
537
+ metadata = {
538
+ "td_error": td_error,
539
+ "next_state_value": next_target_qvalue.detach(),
540
+ "pred_value": current_qvalue.detach(),
541
+ "target_value": target_value.detach(),
542
+ }
543
+ loss_qval = _reduce(loss_qval, reduction=self.reduction, weights=weights)
544
+ self._clear_weakrefs(
545
+ tensordict,
546
+ "actor_network_params",
547
+ "qvalue_network_params",
548
+ "target_actor_network_params",
549
+ "target_qvalue_network_params",
550
+ )
551
+ return loss_qval, metadata
552
+
553
+ @dispatch
554
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
555
+ """The forward method.
556
+
557
+ Computes successively the :meth:`~.actor_loss`, :meth:`~.qvalue_loss`, and returns
558
+ a tensordict with these values.
559
+ To see what keys are expected in the input tensordict and what keys are expected as output, check the
560
+ class's `"in_keys"` and `"out_keys"` attributes.
561
+ """
562
+ tensordict_save = tensordict
563
+ loss_actor, metadata_actor = self.actor_loss(tensordict)
564
+ loss_qval, metadata_value = self.qvalue_loss(tensordict_save)
565
+ tensordict_save.set(
566
+ self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0]
567
+ )
568
+ if not loss_qval.shape == loss_actor.shape:
569
+ raise RuntimeError(
570
+ f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}"
571
+ )
572
+ td_out = TensorDict(
573
+ source={
574
+ "loss_actor": loss_actor,
575
+ "loss_qvalue": loss_qval,
576
+ **metadata_actor,
577
+ **metadata_value,
578
+ },
579
+ )
580
+ self._clear_weakrefs(
581
+ tensordict,
582
+ td_out,
583
+ "actor_network_params",
584
+ "qvalue_network_params",
585
+ "target_actor_network_params",
586
+ "target_qvalue_network_params",
587
+ )
588
+ return td_out
589
+
590
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
591
+ if value_type is None:
592
+ value_type = self.default_value_estimator
593
+
594
+ # Handle ValueEstimatorBase instance or class
595
+ if isinstance(value_type, ValueEstimatorBase) or (
596
+ isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
597
+ ):
598
+ return LossModule.make_value_estimator(self, value_type, **hyperparams)
599
+
600
+ self.value_type = value_type
601
+ hp = dict(default_value_kwargs(value_type))
602
+ if hasattr(self, "gamma"):
603
+ hp["gamma"] = self.gamma
604
+ hp.update(hyperparams)
605
+ # we do not need a value network bc the next state value is already passed
606
+ if value_type == ValueEstimators.TD1:
607
+ self._value_estimator = TD1Estimator(value_network=None, **hp)
608
+ elif value_type == ValueEstimators.TD0:
609
+ self._value_estimator = TD0Estimator(value_network=None, **hp)
610
+ elif value_type == ValueEstimators.GAE:
611
+ raise NotImplementedError(
612
+ f"Value type {value_type} it not implemented for loss {type(self)}."
613
+ )
614
+ elif value_type == ValueEstimators.TDLambda:
615
+ self._value_estimator = TDLambdaEstimator(value_network=None, **hp)
616
+ else:
617
+ raise NotImplementedError(f"Unknown value type {value_type}")
618
+
619
+ tensor_keys = {
620
+ "value": self.tensor_keys.state_action_value,
621
+ "reward": self.tensor_keys.reward,
622
+ "done": self.tensor_keys.done,
623
+ "terminated": self.tensor_keys.terminated,
624
+ }
625
+ self._value_estimator.set_keys(**tensor_keys)