torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,683 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import warnings
8
+ from dataclasses import dataclass
9
+
10
+ import torch
11
+ from tensordict import TensorDict, TensorDictBase, TensorDictParams
12
+ from tensordict.nn import dispatch, TensorDictModule
13
+ from tensordict.utils import NestedKey
14
+ from torch import nn
15
+
16
+ from torchrl.data.tensor_specs import TensorSpec
17
+ from torchrl.data.utils import _find_action_space
18
+ from torchrl.envs.utils import step_mdp
19
+ from torchrl.modules.tensordict_module.actors import (
20
+ DistributionalQValueActor,
21
+ QValueActor,
22
+ )
23
+ from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible
24
+ from torchrl.objectives.common import LossModule
25
+ from torchrl.objectives.utils import (
26
+ _GAMMA_LMBDA_DEPREC_ERROR,
27
+ _reduce,
28
+ default_value_kwargs,
29
+ distance_loss,
30
+ ValueEstimators,
31
+ )
32
+ from torchrl.objectives.value import TDLambdaEstimator, ValueEstimatorBase
33
+ from torchrl.objectives.value.advantages import TD0Estimator, TD1Estimator
34
+
35
+
36
+ class DQNLoss(LossModule):
37
+ """The DQN Loss class.
38
+
39
+ Args:
40
+ value_network (QValueActor or nn.Module): a Q value operator.
41
+
42
+ Keyword Args:
43
+ loss_function (str, optional): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
44
+ Defaults to "l2".
45
+ delay_value (bool, optional): whether to duplicate the value network
46
+ into a new target value network to
47
+ create a DQN with a target network. Default is ``True``.
48
+ double_dqn (bool, optional): whether to use Double DQN, as described in
49
+ https://arxiv.org/abs/1509.06461. Defaults to ``False``.
50
+ action_space (str or TensorSpec, optional): Action space. Must be one of
51
+ ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``,
52
+ or an instance of the corresponding specs (:class:`torchrl.data.OneHot`,
53
+ :class:`torchrl.data.MultiOneHot`,
54
+ :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`).
55
+ If not provided, an attempt to retrieve it from the value network
56
+ will be made.
57
+ priority_key (NestedKey, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
58
+ The key at which priority is assumed to be stored within TensorDicts added
59
+ to this ReplayBuffer. This is to be used when the sampler is of type
60
+ :class:`~torchrl.data.PrioritizedSampler`. Defaults to ``"td_error"``.
61
+ reduction (str, optional): Specifies the reduction to apply to the output:
62
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
63
+ ``"mean"``: the sum of the output will be divided by the number of
64
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
65
+
66
+ Examples:
67
+ >>> from torchrl.modules import MLP
68
+ >>> from torchrl.data import OneHot
69
+ >>> n_obs, n_act = 4, 3
70
+ >>> value_net = MLP(in_features=n_obs, out_features=n_act)
71
+ >>> spec = OneHot(n_act)
72
+ >>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec)
73
+ >>> loss = DQNLoss(actor, action_space=spec)
74
+ >>> batch = [10,]
75
+ >>> data = TensorDict({
76
+ ... "observation": torch.randn(*batch, n_obs),
77
+ ... "action": spec.rand(batch),
78
+ ... ("next", "observation"): torch.randn(*batch, n_obs),
79
+ ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
80
+ ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
81
+ ... ("next", "reward"): torch.randn(*batch, 1)
82
+ ... }, batch)
83
+ >>> loss(data)
84
+ TensorDict(
85
+ fields={
86
+ loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
87
+ batch_size=torch.Size([]),
88
+ device=None,
89
+ is_shared=False)
90
+
91
+ This class is compatible with non-tensordict based modules too and can be
92
+ used without recurring to any tensordict-related primitive. In this case,
93
+ the expected keyword arguments are:
94
+ ``["observation", "next_observation", "action", "next_reward", "next_done", "next_terminated"]``,
95
+ and a single loss value is returned.
96
+
97
+ Examples:
98
+ >>> from torchrl.objectives import DQNLoss
99
+ >>> from torchrl.data import OneHot
100
+ >>> from torch import nn
101
+ >>> import torch
102
+ >>> n_obs = 3
103
+ >>> n_action = 4
104
+ >>> action_spec = OneHot(n_action)
105
+ >>> value_network = nn.Linear(n_obs, n_action) # a simple value model
106
+ >>> dqn_loss = DQNLoss(value_network, action_space=action_spec)
107
+ >>> # define data
108
+ >>> observation = torch.randn(n_obs)
109
+ >>> next_observation = torch.randn(n_obs)
110
+ >>> action = action_spec.rand()
111
+ >>> next_reward = torch.randn(1)
112
+ >>> next_done = torch.zeros(1, dtype=torch.bool)
113
+ >>> next_terminated = torch.zeros(1, dtype=torch.bool)
114
+ >>> loss_val = dqn_loss(
115
+ ... observation=observation,
116
+ ... next_observation=next_observation,
117
+ ... next_reward=next_reward,
118
+ ... next_done=next_done,
119
+ ... next_terminated=next_terminated,
120
+ ... action=action)
121
+
122
+ """
123
+
124
+ @dataclass
125
+ class _AcceptedKeys:
126
+ """Maintains default values for all configurable tensordict keys.
127
+
128
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
129
+ default values.
130
+
131
+ Attributes:
132
+ advantage (NestedKey): The input tensordict key where the advantage is expected.
133
+ Will be used for the underlying value estimator. Defaults to ``"advantage"``.
134
+ value_target (NestedKey): The input tensordict key where the target state value is expected.
135
+ Will be used for the underlying value estimator Defaults to ``"value_target"``.
136
+ value (NestedKey): The input tensordict key where the chosen action value is expected.
137
+ Will be used for the underlying value estimator. Defaults to ``"chosen_action_value"``.
138
+ action_value (NestedKey): The input tensordict key where the action value is expected.
139
+ Defaults to ``"action_value"``.
140
+ action (NestedKey): The input tensordict key where the action is expected.
141
+ Defaults to ``"action"``.
142
+ priority (NestedKey): The input tensordict key where the target priority is written to.
143
+ Defaults to ``"td_error"``.
144
+ reward (NestedKey): The input tensordict key where the reward is expected.
145
+ Will be used for the underlying value estimator. Defaults to ``"reward"``.
146
+ done (NestedKey): The key in the input TensorDict that indicates
147
+ whether a trajectory is done. Will be used for the underlying value estimator.
148
+ Defaults to ``"done"``.
149
+ terminated (NestedKey): The key in the input TensorDict that indicates
150
+ whether a trajectory is terminated. Will be used for the underlying value estimator.
151
+ Defaults to ``"terminated"``.
152
+ """
153
+
154
+ advantage: NestedKey = "advantage"
155
+ value_target: NestedKey = "value_target"
156
+ value: NestedKey = "chosen_action_value"
157
+ action_value: NestedKey = "action_value"
158
+ action: NestedKey = "action"
159
+ priority: NestedKey = "td_error"
160
+ reward: NestedKey = "reward"
161
+ done: NestedKey = "done"
162
+ terminated: NestedKey = "terminated"
163
+ priority_weight: NestedKey = "priority_weight"
164
+
165
+ tensor_keys: _AcceptedKeys
166
+ default_keys = _AcceptedKeys
167
+ default_value_estimator = ValueEstimators.TD0
168
+ out_keys = ["loss"]
169
+
170
+ value_network: TensorDictModule
171
+ value_network_params: TensorDictParams
172
+ target_value_network_params: TensorDictParams
173
+
174
+ def __init__(
175
+ self,
176
+ value_network: QValueActor | nn.Module,
177
+ *,
178
+ loss_function: str | None = "l2",
179
+ delay_value: bool = True,
180
+ double_dqn: bool = False,
181
+ gamma: float | None = None,
182
+ action_space: str | TensorSpec = None,
183
+ priority_key: str | None = None,
184
+ reduction: str | None = None,
185
+ use_prioritized_weights: str | bool = "auto",
186
+ ) -> None:
187
+ if reduction is None:
188
+ reduction = "mean"
189
+ super().__init__()
190
+ self.use_prioritized_weights = use_prioritized_weights
191
+ self._in_keys = None
192
+ if double_dqn and not delay_value:
193
+ raise ValueError("double_dqn=True requires delay_value=True.")
194
+ self.double_dqn = double_dqn
195
+ self._set_deprecated_ctor_keys(priority=priority_key)
196
+ self.delay_value = delay_value
197
+ value_network = ensure_tensordict_compatible(
198
+ module=value_network,
199
+ wrapper_type=QValueActor,
200
+ action_space=action_space,
201
+ )
202
+
203
+ self.convert_to_functional(
204
+ value_network,
205
+ "value_network",
206
+ create_target_params=self.delay_value,
207
+ )
208
+
209
+ self.value_network_in_keys = value_network.in_keys
210
+
211
+ self.loss_function = loss_function
212
+ if action_space is None:
213
+ # infer from value net
214
+ try:
215
+ action_space = value_network.spec
216
+ except AttributeError:
217
+ # let's try with action_space then
218
+ try:
219
+ action_space = value_network.action_space
220
+ except AttributeError:
221
+ raise ValueError(
222
+ "The action space could not be retrieved from the value_network. "
223
+ "Make sure it is available to the DQN loss module."
224
+ )
225
+ if action_space is None:
226
+ warnings.warn(
227
+ "action_space was not specified. DQNLoss will default to 'one-hot'. "
228
+ "This behavior will be deprecated soon and a space will have to be passed. "
229
+ "Check the DQNLoss documentation to see how to pass the action space. "
230
+ )
231
+ action_space = "one-hot"
232
+ self.action_space = _find_action_space(action_space)
233
+ self.reduction = reduction
234
+ if gamma is not None:
235
+ raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
236
+
237
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
238
+ if self._value_estimator is not None:
239
+ self._value_estimator.set_keys(
240
+ advantage=self.tensor_keys.advantage,
241
+ value_target=self.tensor_keys.value_target,
242
+ value=self.tensor_keys.value,
243
+ reward=self.tensor_keys.reward,
244
+ done=self.tensor_keys.done,
245
+ terminated=self.tensor_keys.terminated,
246
+ )
247
+ self._set_in_keys()
248
+
249
+ def _set_in_keys(self):
250
+ keys = [
251
+ self.tensor_keys.action,
252
+ ("next", self.tensor_keys.reward),
253
+ ("next", self.tensor_keys.done),
254
+ ("next", self.tensor_keys.terminated),
255
+ *self.value_network.in_keys,
256
+ *[("next", key) for key in self.value_network.in_keys],
257
+ ]
258
+ self._in_keys = list(set(keys))
259
+
260
+ @property
261
+ def in_keys(self):
262
+ if self._in_keys is None:
263
+ self._set_in_keys()
264
+ return self._in_keys
265
+
266
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
267
+ if value_type is None:
268
+ value_type = self.default_value_estimator
269
+
270
+ # Handle ValueEstimatorBase instance or class
271
+ if isinstance(value_type, ValueEstimatorBase) or (
272
+ isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
273
+ ):
274
+ return LossModule.make_value_estimator(self, value_type, **hyperparams)
275
+
276
+ self.value_type = value_type
277
+ hp = dict(default_value_kwargs(value_type))
278
+ if hasattr(self, "gamma"):
279
+ hp["gamma"] = self.gamma
280
+ hp.update(hyperparams)
281
+ if value_type is ValueEstimators.TD1:
282
+ self._value_estimator = TD1Estimator(**hp, value_network=self.value_network)
283
+ elif value_type is ValueEstimators.TD0:
284
+ self._value_estimator = TD0Estimator(**hp, value_network=self.value_network)
285
+ elif value_type is ValueEstimators.GAE:
286
+ raise NotImplementedError(
287
+ f"Value type {value_type} it not implemented for loss {type(self)}."
288
+ )
289
+ elif value_type is ValueEstimators.TDLambda:
290
+ self._value_estimator = TDLambdaEstimator(
291
+ **hp, value_network=self.value_network
292
+ )
293
+ else:
294
+ raise NotImplementedError(f"Unknown value type {value_type}")
295
+
296
+ tensor_keys = {
297
+ "advantage": self.tensor_keys.advantage,
298
+ "value_target": self.tensor_keys.value_target,
299
+ "value": self.tensor_keys.value,
300
+ "reward": self.tensor_keys.reward,
301
+ "done": self.tensor_keys.done,
302
+ "terminated": self.tensor_keys.terminated,
303
+ }
304
+ self._value_estimator.set_keys(**tensor_keys)
305
+
306
+ @dispatch
307
+ def forward(self, tensordict: TensorDictBase) -> TensorDict:
308
+ """Computes the DQN loss given a tensordict sampled from the replay buffer.
309
+
310
+ This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
311
+ a priority to items in the tensordict.
312
+
313
+ Args:
314
+ tensordict (TensorDictBase): a tensordict with keys ["action"] and the in_keys of
315
+ the value network (observations, "done", "terminated", "reward" in a "next" tensordict).
316
+
317
+ Returns:
318
+ a tensor containing the DQN loss.
319
+
320
+ """
321
+ td_copy = tensordict.clone(False)
322
+ with self.value_network_params.to_module(self.value_network):
323
+ self.value_network(td_copy)
324
+
325
+ action = tensordict.get(self.tensor_keys.action)
326
+ pred_val = td_copy.get(self.tensor_keys.action_value)
327
+
328
+ if self.action_space == "categorical":
329
+ if action.ndim != pred_val.ndim:
330
+ # unsqueeze the action if it lacks on trailing singleton dim
331
+ action = action.unsqueeze(-1)
332
+ pred_val_index = torch.gather(pred_val, -1, index=action).squeeze(-1)
333
+ else:
334
+ action = action.to(torch.float)
335
+ pred_val_index = (pred_val * action).sum(-1)
336
+
337
+ if self.double_dqn:
338
+ step_td = step_mdp(td_copy, keep_other=False)
339
+ step_td_copy = step_td.clone(False)
340
+ # Use online network to compute the action
341
+ with self.value_network_params.data.to_module(self.value_network):
342
+ self.value_network(step_td)
343
+ next_action = step_td.get(self.tensor_keys.action)
344
+
345
+ # Use target network to compute the values
346
+ with self.target_value_network_params.to_module(self.value_network):
347
+ self.value_network(step_td_copy)
348
+ next_pred_val = step_td_copy.get(self.tensor_keys.action_value)
349
+
350
+ if self.action_space == "categorical":
351
+ if next_action.ndim != next_pred_val.ndim:
352
+ # unsqueeze the action if it lacks on trailing singleton dim
353
+ next_action = next_action.unsqueeze(-1)
354
+ next_value = torch.gather(next_pred_val, -1, index=next_action)
355
+ else:
356
+ next_value = (next_pred_val * next_action).sum(-1, keepdim=True)
357
+ else:
358
+ next_value = None
359
+ target_value = self.value_estimator.value_estimate(
360
+ td_copy,
361
+ target_params=self.target_value_network_params,
362
+ next_value=next_value,
363
+ ).squeeze(-1)
364
+
365
+ with torch.no_grad():
366
+ priority_tensor = (pred_val_index - target_value).pow(2)
367
+ priority_tensor = priority_tensor.unsqueeze(-1)
368
+ if tensordict.device is not None:
369
+ priority_tensor = priority_tensor.to(tensordict.device)
370
+
371
+ tensordict.set(
372
+ self.tensor_keys.priority,
373
+ priority_tensor,
374
+ inplace=True,
375
+ )
376
+ loss = distance_loss(pred_val_index, target_value, self.loss_function)
377
+ # Extract weights for prioritized replay buffer
378
+ weights = None
379
+ if (
380
+ self.use_prioritized_weights in (True, "auto")
381
+ and self.tensor_keys.priority_weight in tensordict.keys()
382
+ ):
383
+ weights = tensordict.get(self.tensor_keys.priority_weight)
384
+ loss = _reduce(loss, reduction=self.reduction, weights=weights)
385
+ td_out = TensorDict(loss=loss)
386
+
387
+ self._clear_weakrefs(
388
+ tensordict,
389
+ td_out,
390
+ "value_network_params",
391
+ "target_value_network_params",
392
+ )
393
+
394
+ return td_out
395
+
396
+
397
+ class DistributionalDQNLoss(LossModule):
398
+ """A distributional DQN loss class.
399
+
400
+ Distributional DQN uses a value network that outputs a distribution of
401
+ values over a discrete support of discounted returns (unlike regular DQN
402
+ where the value network outputs a single point prediction of the
403
+ disctounted return).
404
+
405
+ For more details regarding Distributional DQN, refer to "A Distributional
406
+ Perspective on Reinforcement Learning",
407
+ https://arxiv.org/pdf/1707.06887.pdf
408
+
409
+ Args:
410
+ value_network (DistributionalQValueActor or nn.Module): the distributional Q
411
+ value operator.
412
+ gamma (scalar): a discount factor for return computation.
413
+
414
+ .. note::
415
+ Unlike :class:`DQNLoss`, this class does not currently support
416
+ custom value functions. The next value estimation is always
417
+ bootstrapped.
418
+
419
+ delay_value (bool): whether to duplicate the value network into a new
420
+ target value network to create double DQN
421
+ priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
422
+ The key at which priority is assumed to be stored within TensorDicts added
423
+ to this ReplayBuffer. This is to be used when the sampler is of type
424
+ :class:`~torchrl.data.PrioritizedSampler`. Defaults to ``"td_error"``.
425
+ reduction (str, optional): Specifies the reduction to apply to the output:
426
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
427
+ ``"mean"``: the sum of the output will be divided by the number of
428
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
429
+ """
430
+
431
+ @dataclass
432
+ class _AcceptedKeys:
433
+ """Maintains default values for all configurable tensordict keys.
434
+
435
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
436
+ default values
437
+
438
+ Attributes:
439
+ state_action_value (NestedKey): The input tensordict key where the state action value is expected.
440
+ Defaults to ``"state_action_value"``.
441
+ action (NestedKey): The input tensordict key where the action is expected.
442
+ Defaults to ``"action"``.
443
+ priority (NestedKey): The input tensordict key where the target priority is written to.
444
+ Defaults to ``"td_error"``.
445
+ reward (NestedKey): The input tensordict key where the reward is expected.
446
+ Defaults to ``"reward"``.
447
+ done (NestedKey): The input tensordict key where the flag if a trajectory is done is expected.
448
+ Defaults to ``"done"``.
449
+ terminated (NestedKey): The input tensordict key where the flag if a trajectory is done is expected.
450
+ Defaults to ``"terminated"``.
451
+ steps_to_next_obs (NestedKey): The input tensordict key where the steps_to_next_obs is expected.
452
+ Defaults to ``"steps_to_next_obs"``.
453
+ """
454
+
455
+ action_value: NestedKey = "action_value"
456
+ action: NestedKey = "action"
457
+ priority: NestedKey = "td_error"
458
+ reward: NestedKey = "reward"
459
+ done: NestedKey = "done"
460
+ terminated: NestedKey = "terminated"
461
+ steps_to_next_obs: NestedKey = "steps_to_next_obs"
462
+ priority_weight: NestedKey = "priority_weight"
463
+
464
+ tensor_keys: _AcceptedKeys
465
+ default_keys = _AcceptedKeys
466
+ default_value_estimator = ValueEstimators.TD0
467
+
468
+ value_network: TensorDictModule
469
+ value_network_params: TensorDictParams
470
+ target_value_network_params: TensorDictParams
471
+
472
+ def __init__(
473
+ self,
474
+ value_network: DistributionalQValueActor | nn.Module,
475
+ *,
476
+ gamma: float,
477
+ delay_value: bool = True,
478
+ priority_key: str | None = None,
479
+ reduction: str | None = None,
480
+ use_prioritized_weights: str | bool = "auto",
481
+ ):
482
+ if reduction is None:
483
+ reduction = "mean"
484
+ super().__init__()
485
+ self.use_prioritized_weights = use_prioritized_weights
486
+ self._set_deprecated_ctor_keys(priority=priority_key)
487
+ self.register_buffer("gamma", torch.tensor(gamma))
488
+ self.delay_value = delay_value
489
+
490
+ value_network = ensure_tensordict_compatible(
491
+ module=value_network, wrapper_type=DistributionalQValueActor
492
+ )
493
+
494
+ self.convert_to_functional(
495
+ value_network,
496
+ "value_network",
497
+ create_target_params=self.delay_value,
498
+ )
499
+ self.action_space = self.value_network.action_space
500
+ self.reduction = reduction
501
+
502
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
503
+ pass
504
+
505
+ @staticmethod
506
+ def _log_ps_a_default(action, action_log_softmax, batch_size, atoms):
507
+ action_expand = action.unsqueeze(-2).expand_as(action_log_softmax)
508
+ log_ps_a = action_log_softmax.masked_select(action_expand.to(torch.bool))
509
+ log_ps_a = log_ps_a.view(batch_size, atoms) # log p(s_t, a_t; θonline)
510
+ return log_ps_a
511
+
512
+ @staticmethod
513
+ def _log_ps_a_categorical(action, action_log_softmax):
514
+ # Reshaping action of shape `[*batch_sizes, 1]` to `[*batch_sizes, atoms, 1]` for gather.
515
+ if action.shape[-1] != 1:
516
+ action = action.unsqueeze(-1)
517
+ action = action.unsqueeze(-2)
518
+ new_shape = [-1] * len(action.shape)
519
+ new_shape[-2] = action_log_softmax.shape[-2] # calculating atoms
520
+ action = action.expand(new_shape)
521
+ return torch.gather(action_log_softmax, -1, index=action).squeeze(-1)
522
+
523
+ def forward(self, input_tensordict: TensorDictBase) -> TensorDict:
524
+ # from https://github.com/Kaixhin/Rainbow/blob/9ff5567ad1234ae0ed30d8471e8f13ae07119395/agent.py
525
+ tensordict = TensorDict(
526
+ source=input_tensordict, batch_size=input_tensordict.batch_size
527
+ )
528
+
529
+ if tensordict.batch_dims != 1:
530
+ raise RuntimeError(
531
+ f"{self.__class__.__name___} expects a 1-dimensional "
532
+ "tensordict as input"
533
+ )
534
+ batch_size = tensordict.batch_size[0]
535
+ support = self.value_network_params["support"]
536
+ atoms = support.numel()
537
+ Vmin = support.min().item()
538
+ Vmax = support.max().item()
539
+ delta_z = (Vmax - Vmin) / (atoms - 1)
540
+
541
+ action = tensordict.get(self.tensor_keys.action)
542
+ reward = tensordict.get(("next", self.tensor_keys.reward))
543
+ done = tensordict.get(("next", self.tensor_keys.done))
544
+ terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done)
545
+
546
+ steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, 1)
547
+ discount = self.gamma**steps_to_next_obs
548
+
549
+ # Calculate current state probabilities (online network noise already
550
+ # sampled)
551
+ td_clone = tensordict.clone()
552
+ with self.value_network_params.to_module(self.value_network):
553
+ self.value_network(
554
+ td_clone,
555
+ ) # Log probabilities log p(s_t, ·; θonline)
556
+ action_log_softmax = td_clone.get(self.tensor_keys.action_value)
557
+
558
+ if self.action_space == "categorical":
559
+ log_ps_a = self._log_ps_a_categorical(action, action_log_softmax)
560
+ else:
561
+ log_ps_a = self._log_ps_a_default(
562
+ action, action_log_softmax, batch_size, atoms
563
+ )
564
+
565
+ with torch.no_grad(), self.value_network_params.to_module(self.value_network):
566
+ # Calculate nth next state probabilities
567
+ next_td = step_mdp(tensordict)
568
+ self.value_network(next_td) # Probabilities p(s_t+n, ·; θonline)
569
+
570
+ next_td_action = next_td.get(self.tensor_keys.action)
571
+ if self.action_space == "categorical":
572
+ argmax_indices_ns = next_td_action.squeeze(-1)
573
+ else:
574
+ argmax_indices_ns = next_td_action.argmax(-1) # one-hot encoding
575
+ with self.target_value_network_params.to_module(self.value_network):
576
+ self.value_network(next_td) # Probabilities p(s_t+n, ·; θtarget)
577
+ pns = next_td.get(self.tensor_keys.action_value).exp()
578
+ # Double-Q probabilities
579
+ # p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)
580
+ pns_a = pns[range(batch_size), :, argmax_indices_ns]
581
+
582
+ # Compute Tz (Bellman operator T applied to z)
583
+ # Tz = R^n + (γ^n)z (accounting for terminal states)
584
+ if isinstance(discount, torch.Tensor):
585
+ discount = discount.to("cpu")
586
+ # done = done.to("cpu")
587
+ terminated = terminated.to("cpu")
588
+ reward = reward.to("cpu")
589
+ support = support.to("cpu")
590
+ pns_a = pns_a.to("cpu")
591
+
592
+ Tz = reward + (1 - terminated.to(reward.dtype)) * discount * support
593
+ if Tz.shape != torch.Size([batch_size, atoms]):
594
+ raise RuntimeError(
595
+ "Tz shape must be torch.Size([batch_size, atoms]), "
596
+ f"got Tz.shape={Tz.shape} and batch_size={batch_size}, "
597
+ f"atoms={atoms}"
598
+ )
599
+ # Clamp between supported values
600
+ Tz = Tz.clamp_(min=Vmin, max=Vmax)
601
+ if not torch.isfinite(Tz).all():
602
+ raise RuntimeError("Tz has some non-finite elements")
603
+ # Compute L2 projection of Tz onto fixed support z
604
+ b = (Tz - Vmin) / delta_z # b = (Tz - Vmin) / Δz
605
+ low, up = b.floor().to(torch.int64), b.ceil().to(torch.int64)
606
+ # Fix disappearing probability mass when l = b = u (b is int)
607
+ low[(up > 0) & (low == up)] -= 1
608
+ up[(low < (atoms - 1)) & (low == up)] += 1
609
+
610
+ # Distribute probability of Tz
611
+ m = torch.zeros(batch_size, atoms)
612
+ offset = torch.linspace(
613
+ 0,
614
+ ((batch_size - 1) * atoms),
615
+ batch_size,
616
+ dtype=torch.int64,
617
+ # device=device,
618
+ )
619
+ offset = offset.unsqueeze(1).expand(batch_size, atoms)
620
+ index = (low + offset).view(-1)
621
+ tensor = (pns_a * (up.float() - b)).view(-1)
622
+ # m_l = m_l + p(s_t+n, a*)(u - b)
623
+ m.view(-1).index_add_(0, index, tensor)
624
+ index = (up + offset).view(-1)
625
+ tensor = (pns_a * (b - low.float())).view(-1)
626
+ # m_u = m_u + p(s_t+n, a*)(b - l)
627
+ m.view(-1).index_add_(0, index, tensor)
628
+
629
+ # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
630
+ loss = -torch.sum(m.to(input_tensordict.device) * log_ps_a, 1)
631
+ input_tensordict.set(
632
+ self.tensor_keys.priority,
633
+ loss.detach().unsqueeze(1).to(input_tensordict.device),
634
+ inplace=True,
635
+ )
636
+ # Extract weights for prioritized replay buffer
637
+ weights = None
638
+ if (
639
+ self.use_prioritized_weights in (True, "auto")
640
+ and self.tensor_keys.priority_weight in tensordict.keys()
641
+ ):
642
+ weights = tensordict.get(self.tensor_keys.priority_weight)
643
+ loss = _reduce(loss, reduction=self.reduction, weights=weights)
644
+ td_out = TensorDict(loss=loss)
645
+ self._clear_weakrefs(
646
+ tensordict,
647
+ td_out,
648
+ "value_network_params",
649
+ "target_value_network_params",
650
+ )
651
+ return td_out
652
+
653
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
654
+ if value_type is None:
655
+ value_type = self.default_value_estimator
656
+
657
+ # Handle ValueEstimatorBase instance or class
658
+ if isinstance(value_type, ValueEstimatorBase) or (
659
+ isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
660
+ ):
661
+ return LossModule.make_value_estimator(self, value_type, **hyperparams)
662
+
663
+ self.value_type = value_type
664
+ if value_type is ValueEstimators.TD1:
665
+ raise NotImplementedError(
666
+ f"value type {value_type} is not implemented for {self.__class__.__name__}."
667
+ )
668
+ elif value_type is ValueEstimators.TD0:
669
+ # see forward call
670
+ pass
671
+ elif value_type is ValueEstimators.GAE:
672
+ raise NotImplementedError(
673
+ f"value type {value_type} is not implemented for {self.__class__.__name__}."
674
+ )
675
+ elif value_type is ValueEstimators.TDLambda:
676
+ raise NotImplementedError(
677
+ f"value type {value_type} is not implemented for {self.__class__.__name__}."
678
+ )
679
+ else:
680
+ raise NotImplementedError(f"Unknown value type {value_type}")
681
+
682
+ def _default_value_estimator(self):
683
+ self.make_value_estimator(ValueEstimators.TD0)