torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,570 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+ from tensordict import TensorDict, TensorDictBase, TensorDictParams
11
+ from tensordict.nn import dispatch, TensorDictModule
12
+ from tensordict.utils import NestedKey
13
+
14
+ from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec
15
+ from torchrl.envs.utils import step_mdp
16
+ from torchrl.objectives.common import LossModule
17
+ from torchrl.objectives.utils import (
18
+ _cache_values,
19
+ _GAMMA_LMBDA_DEPREC_ERROR,
20
+ _reduce,
21
+ _vmap_func,
22
+ default_value_kwargs,
23
+ distance_loss,
24
+ ValueEstimators,
25
+ )
26
+ from torchrl.objectives.value import (
27
+ TD0Estimator,
28
+ TD1Estimator,
29
+ TDLambdaEstimator,
30
+ ValueEstimatorBase,
31
+ )
32
+
33
+
34
+ class TD3Loss(LossModule):
35
+ """TD3 Loss module.
36
+
37
+ Args:
38
+ actor_network (TensorDictModule): the actor to be trained
39
+ qvalue_network (TensorDictModule): a single Q-value network or a list of
40
+ Q-value networks.
41
+ If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets``
42
+ times. If a list of modules is passed, their
43
+ parameters will be stacked unless they share the same identity (in which case
44
+ the original parameter will be expanded).
45
+
46
+ .. warning:: When a list of parameters if passed, it will **not** be compared against the policy parameters
47
+ and all the parameters will be considered as untied.
48
+
49
+ Keyword Args:
50
+ bounds (tuple of float, optional): the bounds of the action space.
51
+ Exclusive with action_spec. Either this or ``action_spec`` must
52
+ be provided.
53
+ action_spec (TensorSpec, optional): the action spec.
54
+ Exclusive with bounds. Either this or ``bounds`` must be provided.
55
+ num_qvalue_nets (int, optional): Number of Q-value networks to be
56
+ trained. Default is ``10``.
57
+ policy_noise (:obj:`float`, optional): Standard deviation for the target
58
+ policy action noise. Default is ``0.2``.
59
+ noise_clip (:obj:`float`, optional): Clipping range value for the sampled
60
+ target policy action noise. Default is ``0.5``.
61
+ priority_key (str, optional): Key where to write the priority value
62
+ for prioritized replay buffers. Default is
63
+ `"td_error"`.
64
+ loss_function (str, optional): loss function to be used for the Q-value.
65
+ Can be one of ``"smooth_l1"``, ``"l2"``,
66
+ ``"l1"``, Default is ``"smooth_l1"``.
67
+ delay_actor (bool, optional): whether to separate the target actor
68
+ networks from the actor networks used for
69
+ data collection. Default is ``True``.
70
+ delay_qvalue (bool, optional): Whether to separate the target Q value
71
+ networks from the Q value networks used
72
+ for data collection. Default is ``True``.
73
+ spec (TensorSpec, optional): the action tensor spec. If not provided
74
+ and the target entropy is ``"auto"``, it will be retrieved from
75
+ the actor.
76
+ separate_losses (bool, optional): if ``True``, shared parameters between
77
+ policy and critic will only be trained on the policy loss.
78
+ Defaults to ``False``, i.e., gradients are propagated to shared
79
+ parameters for both policy and critic losses.
80
+ reduction (str, optional): Specifies the reduction to apply to the output:
81
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
82
+ ``"mean"``: the sum of the output will be divided by the number of
83
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
84
+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
85
+ Defaults to ``False``.
86
+
87
+ Examples:
88
+ >>> import torch
89
+ >>> from torch import nn
90
+ >>> from torchrl.data import Bounded
91
+ >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
92
+ >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator
93
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
94
+ >>> from torchrl.objectives.td3 import TD3Loss
95
+ >>> from tensordict import TensorDict
96
+ >>> n_act, n_obs = 4, 3
97
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
98
+ >>> module = nn.Linear(n_obs, n_act)
99
+ >>> actor = Actor(
100
+ ... module=module,
101
+ ... spec=spec)
102
+ >>> class ValueClass(nn.Module):
103
+ ... def __init__(self):
104
+ ... super().__init__()
105
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
106
+ ... def forward(self, obs, act):
107
+ ... return self.linear(torch.cat([obs, act], -1))
108
+ >>> module = ValueClass()
109
+ >>> qvalue = ValueOperator(
110
+ ... module=module,
111
+ ... in_keys=['observation', 'action'])
112
+ >>> loss = TD3Loss(actor, qvalue, action_spec=actor.spec)
113
+ >>> batch = [2, ]
114
+ >>> action = spec.rand(batch)
115
+ >>> data = TensorDict({
116
+ ... "observation": torch.randn(*batch, n_obs),
117
+ ... "action": action,
118
+ ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
119
+ ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
120
+ ... ("next", "reward"): torch.randn(*batch, 1),
121
+ ... ("next", "observation"): torch.randn(*batch, n_obs),
122
+ ... }, batch)
123
+ >>> loss(data)
124
+ TensorDict(
125
+ fields={
126
+ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
127
+ loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
128
+ next_state_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
129
+ pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
130
+ state_action_value_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
131
+ target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
132
+ batch_size=torch.Size([]),
133
+ device=None,
134
+ is_shared=False)
135
+
136
+ This class is compatible with non-tensordict based modules too and can be
137
+ used without recurring to any tensordict-related primitive. In this case,
138
+ the expected keyword arguments are:
139
+ ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network
140
+ The return value is a tuple of tensors in the following order:
141
+ ``["loss_actor", "loss_qvalue", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``.
142
+
143
+ Examples:
144
+ >>> import torch
145
+ >>> from torch import nn
146
+ >>> from torchrl.data import Bounded
147
+ >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
148
+ >>> from torchrl.objectives.td3 import TD3Loss
149
+ >>> n_act, n_obs = 4, 3
150
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
151
+ >>> module = nn.Linear(n_obs, n_act)
152
+ >>> actor = Actor(
153
+ ... module=module,
154
+ ... spec=spec)
155
+ >>> class ValueClass(nn.Module):
156
+ ... def __init__(self):
157
+ ... super().__init__()
158
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
159
+ ... def forward(self, obs, act):
160
+ ... return self.linear(torch.cat([obs, act], -1))
161
+ >>> module = ValueClass()
162
+ >>> qvalue = ValueOperator(
163
+ ... module=module,
164
+ ... in_keys=['observation', 'action'])
165
+ >>> loss = TD3Loss(actor, qvalue, action_spec=actor.spec)
166
+ >>> _ = loss.select_out_keys("loss_actor", "loss_qvalue")
167
+ >>> batch = [2, ]
168
+ >>> action = spec.rand(batch)
169
+ >>> loss_actor, loss_qvalue = loss(
170
+ ... observation=torch.randn(*batch, n_obs),
171
+ ... action=action,
172
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
173
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
174
+ ... next_reward=torch.randn(*batch, 1),
175
+ ... next_observation=torch.randn(*batch, n_obs))
176
+ >>> loss_actor.backward()
177
+
178
+ """
179
+
180
+ @dataclass
181
+ class _AcceptedKeys:
182
+ """Maintains default values for all configurable tensordict keys.
183
+
184
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
185
+ default values.
186
+
187
+ Attributes:
188
+ action (NestedKey): The input tensordict key where the action is expected.
189
+ Defaults to ``"action"``.
190
+ state_action_value (NestedKey): The input tensordict key where the state action value is expected.
191
+ Will be used for the underlying value estimator. Defaults to ``"state_action_value"``.
192
+ priority (NestedKey): The input tensordict key where the target priority is written to.
193
+ Defaults to ``"td_error"``.
194
+ reward (NestedKey): The input tensordict key where the reward is expected.
195
+ Will be used for the underlying value estimator. Defaults to ``"reward"``.
196
+ done (NestedKey): The key in the input TensorDict that indicates
197
+ whether a trajectory is done. Will be used for the underlying value estimator.
198
+ Defaults to ``"done"``.
199
+ terminated (NestedKey): The key in the input TensorDict that indicates
200
+ whether a trajectory is terminated. Will be used for the underlying value estimator.
201
+ Defaults to ``"terminated"``.
202
+ """
203
+
204
+ action: NestedKey = "action"
205
+ state_action_value: NestedKey = "state_action_value"
206
+ priority: NestedKey = "td_error"
207
+ reward: NestedKey = "reward"
208
+ done: NestedKey = "done"
209
+ terminated: NestedKey = "terminated"
210
+ priority_weight: NestedKey = "priority_weight"
211
+
212
+ tensor_keys: _AcceptedKeys
213
+ default_keys = _AcceptedKeys
214
+ default_value_estimator = ValueEstimators.TD0
215
+ out_keys = [
216
+ "loss_actor",
217
+ "loss_qvalue",
218
+ "pred_value",
219
+ "state_action_value_actor",
220
+ "next_state_value",
221
+ "target_value",
222
+ ]
223
+
224
+ actor_network: TensorDictModule
225
+ qvalue_network: TensorDictModule
226
+ actor_network_params: TensorDictParams
227
+ qvalue_network_params: TensorDictParams
228
+ target_actor_network_params: TensorDictParams
229
+ target_qvalue_network_params: TensorDictParams
230
+
231
+ def __init__(
232
+ self,
233
+ actor_network: TensorDictModule,
234
+ qvalue_network: TensorDictModule | list[TensorDictModule],
235
+ *,
236
+ action_spec: TensorSpec = None,
237
+ bounds: tuple[float] | None = None,
238
+ num_qvalue_nets: int = 2,
239
+ policy_noise: float = 0.2,
240
+ noise_clip: float = 0.5,
241
+ loss_function: str = "smooth_l1",
242
+ delay_actor: bool = True,
243
+ delay_qvalue: bool = True,
244
+ gamma: float | None = None,
245
+ priority_key: str | None = None,
246
+ separate_losses: bool = False,
247
+ reduction: str | None = None,
248
+ deactivate_vmap: bool = False,
249
+ use_prioritized_weights: str | bool = "auto",
250
+ ) -> None:
251
+ if reduction is None:
252
+ reduction = "mean"
253
+ super().__init__()
254
+ self.use_prioritized_weights = use_prioritized_weights
255
+ self._in_keys = None
256
+ self._set_deprecated_ctor_keys(priority=priority_key)
257
+
258
+ self.delay_actor = delay_actor
259
+ self.delay_qvalue = delay_qvalue
260
+ self.deactivate_vmap = deactivate_vmap
261
+
262
+ self.convert_to_functional(
263
+ actor_network,
264
+ "actor_network",
265
+ create_target_params=self.delay_actor,
266
+ )
267
+ if separate_losses:
268
+ # we want to make sure there are no duplicates in the params: the
269
+ # params of critic must be refs to actor if they're shared
270
+ policy_params = list(actor_network.parameters())
271
+ else:
272
+ policy_params = None
273
+ self.convert_to_functional(
274
+ qvalue_network,
275
+ "qvalue_network",
276
+ num_qvalue_nets,
277
+ create_target_params=self.delay_qvalue,
278
+ compare_against=policy_params,
279
+ )
280
+
281
+ for p in self.parameters():
282
+ device = p.device
283
+ break
284
+ else:
285
+ device = None
286
+ self.num_qvalue_nets = num_qvalue_nets
287
+ self.loss_function = loss_function
288
+ self.policy_noise = policy_noise
289
+ self.noise_clip = noise_clip
290
+ if not ((action_spec is not None) ^ (bounds is not None)):
291
+ raise ValueError(
292
+ "One of 'bounds' and 'action_spec' must be provided, "
293
+ f"but not both or none. Got bounds={bounds} and action_spec={action_spec}."
294
+ )
295
+ elif action_spec is not None:
296
+ if isinstance(action_spec, Composite):
297
+ if (
298
+ isinstance(self.tensor_keys.action, tuple)
299
+ and len(self.tensor_keys.action) > 1
300
+ ):
301
+ action_container_shape = action_spec[
302
+ self.tensor_keys.action[:-1]
303
+ ].shape
304
+ else:
305
+ action_container_shape = action_spec.shape
306
+ action_spec = action_spec[self.tensor_keys.action][
307
+ (0,) * len(action_container_shape)
308
+ ]
309
+ if not isinstance(action_spec, Bounded):
310
+ raise ValueError(
311
+ f"action_spec is not of type Bounded but {type(action_spec)}."
312
+ )
313
+ low = action_spec.space.low
314
+ high = action_spec.space.high
315
+ else:
316
+ low, high = bounds
317
+ if not isinstance(low, torch.Tensor):
318
+ low = torch.tensor(low)
319
+ if not isinstance(high, torch.Tensor):
320
+ high = torch.tensor(high, device=low.device, dtype=low.dtype)
321
+ if (low > high).any():
322
+ raise ValueError("Got a low bound higher than a high bound.")
323
+ if device is not None:
324
+ low = low.to(device)
325
+ high = high.to(device)
326
+ self.register_buffer("max_action", high)
327
+ self.register_buffer("min_action", low)
328
+ if gamma is not None:
329
+ raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
330
+ self._make_vmap()
331
+ self.reduction = reduction
332
+
333
+ def _make_vmap(self):
334
+ self._vmap_qvalue_network00 = _vmap_func(
335
+ self.qvalue_network,
336
+ randomness=self.vmap_randomness,
337
+ pseudo_vmap=self.deactivate_vmap,
338
+ )
339
+ self._vmap_actor_network00 = _vmap_func(
340
+ self.actor_network,
341
+ randomness=self.vmap_randomness,
342
+ pseudo_vmap=self.deactivate_vmap,
343
+ )
344
+
345
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
346
+ if self._value_estimator is not None:
347
+ self._value_estimator.set_keys(
348
+ value=self._tensor_keys.state_action_value,
349
+ reward=self.tensor_keys.reward,
350
+ done=self.tensor_keys.done,
351
+ terminated=self.tensor_keys.terminated,
352
+ )
353
+ self._set_in_keys()
354
+
355
+ def _set_in_keys(self):
356
+ keys = [
357
+ self.tensor_keys.action,
358
+ ("next", self.tensor_keys.reward),
359
+ ("next", self.tensor_keys.done),
360
+ ("next", self.tensor_keys.terminated),
361
+ *self.actor_network.in_keys,
362
+ *[("next", key) for key in self.actor_network.in_keys],
363
+ *self.qvalue_network.in_keys,
364
+ ]
365
+ self._in_keys = list(set(keys))
366
+
367
+ @property
368
+ def in_keys(self):
369
+ if self._in_keys is None:
370
+ self._set_in_keys()
371
+ return self._in_keys
372
+
373
+ @in_keys.setter
374
+ def in_keys(self, values):
375
+ self._in_keys = values
376
+
377
+ @property
378
+ @_cache_values
379
+ def _cached_detach_qvalue_network_params(self):
380
+ return self.qvalue_network_params.detach()
381
+
382
+ @property
383
+ @_cache_values
384
+ def _cached_stack_actor_params(self):
385
+ return torch.stack(
386
+ [self.actor_network_params, self.target_actor_network_params], 0
387
+ )
388
+
389
+ def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]:
390
+ weights = self._maybe_get_priority_weight(tensordict)
391
+ tensordict_actor_grad = tensordict.select(
392
+ *self.actor_network.in_keys, strict=False
393
+ )
394
+ with self.actor_network_params.to_module(self.actor_network):
395
+ tensordict_actor_grad = self.actor_network(tensordict_actor_grad)
396
+ actor_loss_td = tensordict_actor_grad.select(
397
+ *self.qvalue_network.in_keys, strict=False
398
+ ).expand(
399
+ self.num_qvalue_nets, *tensordict_actor_grad.batch_size
400
+ ) # for actor loss
401
+ state_action_value_actor = (
402
+ self._vmap_qvalue_network00(
403
+ actor_loss_td,
404
+ self._cached_detach_qvalue_network_params,
405
+ )
406
+ .get(self.tensor_keys.state_action_value)
407
+ .squeeze(-1)
408
+ )
409
+ loss_actor = -(state_action_value_actor[0])
410
+ metadata = {
411
+ "state_action_value_actor": state_action_value_actor.detach(),
412
+ }
413
+ loss_actor = _reduce(loss_actor, reduction=self.reduction, weights=weights)
414
+ self._clear_weakrefs(
415
+ tensordict,
416
+ "actor_network_params",
417
+ "qvalue_network_params",
418
+ "target_actor_network_params",
419
+ "target_qvalue_network_params",
420
+ )
421
+ return loss_actor, metadata
422
+
423
+ def value_loss(self, tensordict) -> tuple[torch.Tensor, dict]:
424
+ weights = self._maybe_get_priority_weight(tensordict)
425
+ tensordict = tensordict.clone(False)
426
+
427
+ act = tensordict.get(self.tensor_keys.action)
428
+
429
+ # computing early for reprod
430
+ noise = (torch.randn_like(act) * self.policy_noise).clamp(
431
+ -self.noise_clip, self.noise_clip
432
+ )
433
+
434
+ with torch.no_grad():
435
+ next_td_actor = step_mdp(tensordict).select(
436
+ *self.actor_network.in_keys, strict=False
437
+ ) # next_observation ->
438
+ with self.target_actor_network_params.to_module(self.actor_network):
439
+ next_td_actor = self.actor_network(next_td_actor)
440
+ next_action = (next_td_actor.get(self.tensor_keys.action) + noise).clamp(
441
+ self.min_action, self.max_action
442
+ )
443
+ next_td_actor.set(
444
+ self.tensor_keys.action,
445
+ next_action,
446
+ )
447
+ next_val_td = next_td_actor.select(
448
+ *self.qvalue_network.in_keys, strict=False
449
+ ).expand(
450
+ self.num_qvalue_nets, *next_td_actor.batch_size
451
+ ) # for next value estimation
452
+ next_target_q1q2 = (
453
+ self._vmap_qvalue_network00(
454
+ next_val_td,
455
+ self.target_qvalue_network_params,
456
+ )
457
+ .get(self.tensor_keys.state_action_value)
458
+ .squeeze(-1)
459
+ )
460
+ # min over the next target qvalues
461
+ next_target_qvalue = next_target_q1q2.min(0)[0]
462
+
463
+ # set next target qvalues
464
+ tensordict.set(
465
+ ("next", self.tensor_keys.state_action_value),
466
+ next_target_qvalue.unsqueeze(-1),
467
+ )
468
+
469
+ qval_td = tensordict.select(*self.qvalue_network.in_keys, strict=False).expand(
470
+ self.num_qvalue_nets,
471
+ *tensordict.batch_size,
472
+ )
473
+ # preditcted current qvalues
474
+ current_qvalue = (
475
+ self._vmap_qvalue_network00(
476
+ qval_td,
477
+ self.qvalue_network_params,
478
+ )
479
+ .get(self.tensor_keys.state_action_value)
480
+ .squeeze(-1)
481
+ )
482
+
483
+ # compute target values for the qvalue loss (reward + gamma * next_target_qvalue * (1 - done))
484
+ target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
485
+
486
+ td_error = (current_qvalue - target_value).pow(2)
487
+ loss_qval = distance_loss(
488
+ current_qvalue,
489
+ target_value.expand_as(current_qvalue),
490
+ loss_function=self.loss_function,
491
+ ).sum(0)
492
+ metadata = {
493
+ "td_error": td_error,
494
+ "next_state_value": next_target_qvalue.detach(),
495
+ "pred_value": current_qvalue.detach(),
496
+ "target_value": target_value.detach(),
497
+ }
498
+ loss_qval = _reduce(loss_qval, reduction=self.reduction, weights=weights)
499
+ self._clear_weakrefs(
500
+ tensordict,
501
+ "actor_network_params",
502
+ "qvalue_network_params",
503
+ "target_actor_network_params",
504
+ "target_qvalue_network_params",
505
+ )
506
+ return loss_qval, metadata
507
+
508
+ @dispatch
509
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
510
+ tensordict_save = tensordict
511
+ loss_actor, metadata_actor = self.actor_loss(tensordict)
512
+ loss_qval, metadata_value = self.value_loss(tensordict_save)
513
+ tensordict_save.set(
514
+ self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0]
515
+ )
516
+ if not loss_qval.shape == loss_actor.shape:
517
+ raise RuntimeError(
518
+ f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}"
519
+ )
520
+ td_out = TensorDict(
521
+ loss_actor=loss_actor,
522
+ loss_qvalue=loss_qval,
523
+ **metadata_actor,
524
+ **metadata_value,
525
+ )
526
+ self._clear_weakrefs(
527
+ tensordict,
528
+ "actor_network_params",
529
+ "qvalue_network_params",
530
+ "target_actor_network_params",
531
+ "target_qvalue_network_params",
532
+ )
533
+ return td_out
534
+
535
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
536
+ if value_type is None:
537
+ value_type = self.default_value_estimator
538
+
539
+ # Handle ValueEstimatorBase instance or class
540
+ if isinstance(value_type, ValueEstimatorBase) or (
541
+ isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
542
+ ):
543
+ return LossModule.make_value_estimator(self, value_type, **hyperparams)
544
+
545
+ self.value_type = value_type
546
+ hp = dict(default_value_kwargs(value_type))
547
+ if hasattr(self, "gamma"):
548
+ hp["gamma"] = self.gamma
549
+ hp.update(hyperparams)
550
+ # we do not need a value network bc the next state value is already passed
551
+ if value_type == ValueEstimators.TD1:
552
+ self._value_estimator = TD1Estimator(value_network=None, **hp)
553
+ elif value_type == ValueEstimators.TD0:
554
+ self._value_estimator = TD0Estimator(value_network=None, **hp)
555
+ elif value_type == ValueEstimators.GAE:
556
+ raise NotImplementedError(
557
+ f"Value type {value_type} it not implemented for loss {type(self)}."
558
+ )
559
+ elif value_type == ValueEstimators.TDLambda:
560
+ self._value_estimator = TDLambdaEstimator(value_network=None, **hp)
561
+ else:
562
+ raise NotImplementedError(f"Unknown value type {value_type}")
563
+
564
+ tensor_keys = {
565
+ "value": self.tensor_keys.state_action_value,
566
+ "reward": self.tensor_keys.reward,
567
+ "done": self.tensor_keys.done,
568
+ "terminated": self.tensor_keys.terminated,
569
+ }
570
+ self._value_estimator.set_keys(**tensor_keys)