torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,2052 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import abc
9
+ import itertools
10
+ import pathlib
11
+ import time
12
+ import warnings
13
+ from collections import defaultdict, OrderedDict
14
+ from collections.abc import Callable, Sequence
15
+ from copy import deepcopy
16
+ from textwrap import indent
17
+ from typing import Any, Literal
18
+
19
+ import numpy as np
20
+ import torch.nn
21
+ from tensordict import NestedKey, pad, TensorDict, TensorDictBase
22
+ from tensordict._tensorcollection import TensorCollection
23
+ from tensordict.nn import TensorDictModule
24
+ from tensordict.utils import expand_right
25
+ from torch import nn, optim
26
+
27
+ from torchrl._utils import (
28
+ _CKPT_BACKEND,
29
+ KeyDependentDefaultDict,
30
+ logger as torchrl_logger,
31
+ rl_warnings,
32
+ timeit,
33
+ VERBOSE,
34
+ )
35
+ from torchrl.collectors import BaseCollector
36
+ from torchrl.collectors.utils import split_trajectories
37
+ from torchrl.data.replay_buffers import (
38
+ PrioritizedSampler,
39
+ TensorDictPrioritizedReplayBuffer,
40
+ TensorDictReplayBuffer,
41
+ )
42
+ from torchrl.data.utils import DEVICE_TYPING
43
+ from torchrl.envs.common import EnvBase
44
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
45
+ from torchrl.objectives.common import LossModule
46
+ from torchrl.objectives.utils import TargetNetUpdater
47
+ from torchrl.record.loggers import Logger
48
+
49
+ try:
50
+ from tqdm import tqdm
51
+
52
+ _has_tqdm = True
53
+ except ImportError:
54
+ _has_tqdm = False
55
+
56
+ try:
57
+ from torchsnapshot import Snapshot, StateDict
58
+
59
+ _has_ts = True
60
+ except ImportError:
61
+ _has_ts = False
62
+
63
+ REPLAY_BUFFER_CLASS = {
64
+ "prioritized": TensorDictPrioritizedReplayBuffer,
65
+ "circular": TensorDictReplayBuffer,
66
+ }
67
+
68
+ # Mapping of metric names to logger methods - controls how different metrics are logged
69
+ LOGGER_METHODS = {
70
+ "grad_norm": "log_scalar",
71
+ "loss": "log_scalar",
72
+ }
73
+
74
+ # Format strings for different data types in progress bar display
75
+ TYPE_DESCR = {float: "4.4f", int: ""}
76
+ REWARD_KEY = ("next", "reward")
77
+
78
+
79
+ class TrainerHookBase:
80
+ """An abstract hooking class for torchrl Trainer class."""
81
+
82
+ @abc.abstractmethod
83
+ def state_dict(self) -> dict[str, Any]:
84
+ raise NotImplementedError
85
+
86
+ @abc.abstractmethod
87
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
88
+ raise NotImplementedError
89
+
90
+ @abc.abstractmethod
91
+ def register(self, trainer: Trainer, name: str):
92
+ """Registers the hook in the trainer at a default location.
93
+
94
+ Args:
95
+ trainer (Trainer): the trainer where the hook must be registered.
96
+ name (str): the name of the hook.
97
+
98
+ .. note::
99
+ To register the hook at another location than the default, use
100
+ :meth:`~torchrl.trainers.Trainer.register_op`.
101
+
102
+ """
103
+ raise NotImplementedError
104
+
105
+
106
+ class Trainer:
107
+ """A generic Trainer class.
108
+
109
+ A trainer is responsible for collecting data and training the model.
110
+ To keep the class as versatile as possible, Trainer does not construct any
111
+ of its specific operations: they all must be hooked at specific points in
112
+ the training loop.
113
+
114
+ To build a Trainer, one needs an iterable data source (a :obj:`collector`), a
115
+ loss module and an optimizer.
116
+
117
+ Args:
118
+ collector (Sequence[TensorDictBase]): An iterable returning batches of
119
+ data in a TensorDict form of shape [batch x time steps].
120
+ total_frames (int): Total number of frames to be collected during
121
+ training.
122
+ loss_module (LossModule): A module that reads TensorDict batches
123
+ (possibly sampled from a replay buffer) and return a loss
124
+ TensorDict where every key points to a different loss component.
125
+ optimizer (optim.Optimizer): An optimizer that trains the parameters
126
+ of the model.
127
+ logger (Logger, optional): a Logger that will handle the logging.
128
+ optim_steps_per_batch (int, optional): number of optimization steps
129
+ per collection of data. An trainer works as follows: a main loop
130
+ collects batches of data (epoch loop), and a sub-loop (training
131
+ loop) performs model updates in between two collections of data.
132
+ If `None`, the trainer will use the number of workers as the number of optimization steps.
133
+ clip_grad_norm (bool, optional): If True, the gradients will be clipped
134
+ based on the total norm of the model parameters. If False,
135
+ all the partial derivatives will be clamped to
136
+ (-clip_norm, clip_norm). Default is ``True``.
137
+ clip_norm (Number, optional): value to be used for clipping gradients.
138
+ Default is None (no clip norm).
139
+ progress_bar (bool, optional): If True, a progress bar will be
140
+ displayed using tqdm. If tqdm is not installed, this option
141
+ won't have any effect. Default is ``True``
142
+ seed (int, optional): Seed to be used for the collector, pytorch and
143
+ numpy. Default is ``None``.
144
+ save_trainer_interval (int, optional): How often the trainer should be
145
+ saved to disk, in frame count. Default is 10000.
146
+ log_interval (int, optional): How often the values should be logged,
147
+ in frame count. Default is 10000.
148
+ save_trainer_file (path, optional): path where to save the trainer.
149
+ Default is None (no saving)
150
+ async_collection (bool, optional): Whether to collect data asynchronously.
151
+ This will only work if the replay buffer is registed within the data collector.
152
+ If using this, the UTD ratio (Update to Data) will be logged under the key "utd_ratio".
153
+ Default is False.
154
+ log_timings (bool, optional): If True, automatically register a LogTiming hook to log
155
+ timing information for all hooks to the logger (e.g., wandb, tensorboard).
156
+ Timing metrics will be logged with prefix "time/" (e.g., "time/hook/UpdateWeights").
157
+ Default is False.
158
+ """
159
+
160
+ @classmethod
161
+ def __new__(cls, *args, **kwargs):
162
+ # Training state trackers (used for logging and checkpointing)
163
+ cls._optim_count: int = 0 # Total number of optimization steps completed
164
+ cls._collected_frames: int = 0 # Total number of frames collected (deprecated)
165
+ cls._last_log: dict[
166
+ str, Any
167
+ ] = {} # Tracks when each metric was last logged (for log_interval control)
168
+ cls._last_save: int = (
169
+ 0 # Tracks when trainer was last saved (for save_interval control)
170
+ )
171
+ cls.collected_frames = 0 # Total number of frames collected (current)
172
+ cls._app_state = None # Application state for checkpointing
173
+ return super().__new__(cls)
174
+
175
+ def __init__(
176
+ self,
177
+ *,
178
+ collector: BaseCollector,
179
+ total_frames: int,
180
+ frame_skip: int,
181
+ optim_steps_per_batch: int,
182
+ loss_module: LossModule | Callable[[TensorDictBase], TensorDictBase],
183
+ optimizer: optim.Optimizer | None = None,
184
+ logger: Logger | None = None,
185
+ clip_grad_norm: bool = True,
186
+ clip_norm: float | None = None,
187
+ progress_bar: bool = True,
188
+ seed: int | None = None,
189
+ save_trainer_interval: int = 10000,
190
+ log_interval: int = 10000,
191
+ save_trainer_file: str | pathlib.Path | None = None,
192
+ num_epochs: int = 1,
193
+ async_collection: bool = False,
194
+ log_timings: bool = False,
195
+ ) -> None:
196
+ # objects
197
+ self.frame_skip = frame_skip
198
+ self.collector = collector
199
+ self.loss_module = loss_module
200
+ self.optimizer = optimizer
201
+ self.logger = logger
202
+ self.async_collection = async_collection
203
+
204
+ # Logging frequency control - how often to log each metric (in frames)
205
+ self._log_interval = log_interval
206
+
207
+ # seeding
208
+ self.seed = seed
209
+ if seed is not None:
210
+ self.set_seed()
211
+
212
+ # constants
213
+ self.optim_steps_per_batch = optim_steps_per_batch
214
+ self.total_frames = total_frames
215
+ self.num_epochs = num_epochs
216
+ self.clip_grad_norm = clip_grad_norm
217
+ self.clip_norm = clip_norm
218
+ if progress_bar and not _has_tqdm:
219
+ warnings.warn(
220
+ "tqdm library not found. "
221
+ "Consider installing tqdm to use the Trainer progress bar."
222
+ )
223
+ self.progress_bar = progress_bar and _has_tqdm
224
+ self.save_trainer_interval = save_trainer_interval
225
+ self.save_trainer_file = save_trainer_file
226
+
227
+ self._log_dict = defaultdict(list)
228
+
229
+ # Hook collections for different stages of the training loop
230
+ self._batch_process_ops = (
231
+ []
232
+ ) # Process collected batches (e.g., reward normalization)
233
+ self._post_steps_ops = [] # After optimization steps (e.g., weight updates)
234
+
235
+ # Logging hook collections - different points in training loop where logging can occur
236
+ self._post_steps_log_ops = (
237
+ []
238
+ ) # After optimization steps (e.g., validation rewards)
239
+ self._pre_steps_log_ops = (
240
+ []
241
+ ) # Before optimization steps (e.g., rewards, frame counts)
242
+ self._post_optim_log_ops = (
243
+ []
244
+ ) # After each optimization step (e.g., gradient norms)
245
+ self._pre_epoch_log_ops = (
246
+ []
247
+ ) # Before each epoch logging (e.g., epoch-specific metrics)
248
+ self._post_epoch_log_ops = (
249
+ []
250
+ ) # After each epoch logging (e.g., epoch completion metrics)
251
+
252
+ # Regular hook collections for non-logging operations
253
+ self._pre_epoch_ops = (
254
+ []
255
+ ) # Before each epoch (e.g., epoch setup, cache clearing)
256
+ self._post_epoch_ops = (
257
+ []
258
+ ) # After each epoch (e.g., epoch cleanup, weight syncing)
259
+
260
+ # Optimization-related hook collections
261
+ self._pre_optim_ops = [] # Before optimization steps (e.g., cache clearing)
262
+ self._post_loss_ops = (
263
+ []
264
+ ) # After loss computation, operates on batch (e.g., priority updates)
265
+ self._process_loss_ops = (
266
+ []
267
+ ) # Transform loss values before optimizer (e.g., scaling, clipping)
268
+ self._optimizer_ops = [] # During optimization (e.g., gradient clipping)
269
+ self._process_optim_batch_ops = (
270
+ []
271
+ ) # Process batches for optimization (e.g., subsampling)
272
+ self._post_optim_ops = [] # After optimization (e.g., weight syncing)
273
+
274
+ self._modules = {}
275
+
276
+ if self.optimizer is not None:
277
+ optimizer_hook = OptimizerHook(self.optimizer)
278
+ optimizer_hook.register(self)
279
+
280
+ if log_timings:
281
+ log_timing_hook = LogTiming(prefix="time", percall=True, erase=False)
282
+ log_timing_hook.register(self)
283
+
284
+ def register_module(self, module_name: str, module: Any) -> None:
285
+ if module_name in self._modules:
286
+ raise RuntimeError(
287
+ f"{module_name} is already registered, choose a different name."
288
+ )
289
+ self._modules[module_name] = module
290
+
291
+ def _wrap_hook_with_timing(
292
+ self, op: Callable, hook_name: str | None = None
293
+ ) -> Callable:
294
+ """Wrap a hook with timing measurement.
295
+
296
+ Args:
297
+ op: The hook/operation to wrap
298
+ hook_name: Optional name for the hook. If not provided, will be inferred from op.
299
+
300
+ Returns:
301
+ A wrapped version of the hook that measures execution time.
302
+ """
303
+ if hook_name is None:
304
+ hook_name = getattr(
305
+ op,
306
+ "__name__",
307
+ op.__class__.__name__ if hasattr(op, "__class__") else "unknown_hook",
308
+ )
309
+
310
+ def timed_hook(*args, **kwargs):
311
+ with timeit(f"hook/{hook_name}"):
312
+ return op(*args, **kwargs)
313
+
314
+ # Preserve original attributes for debugging
315
+ timed_hook.__wrapped__ = op
316
+ timed_hook.__name__ = hook_name
317
+ return timed_hook
318
+
319
+ def _get_state(self):
320
+ if _CKPT_BACKEND == "torchsnapshot":
321
+ state = StateDict(
322
+ collected_frames=self.collected_frames,
323
+ _last_log=self._last_log,
324
+ _last_save=self._last_save,
325
+ _optim_count=self._optim_count,
326
+ )
327
+ else:
328
+ state = OrderedDict(
329
+ collected_frames=self.collected_frames,
330
+ _last_log=self._last_log,
331
+ _last_save=self._last_save,
332
+ _optim_count=self._optim_count,
333
+ )
334
+ return state
335
+
336
+ @property
337
+ def app_state(self):
338
+ self._app_state = {
339
+ "state": StateDict(**self._get_state()),
340
+ "collector": self.collector,
341
+ "loss_module": self.loss_module,
342
+ **{k: item for k, item in self._modules.items()},
343
+ }
344
+ return self._app_state
345
+
346
+ def state_dict(self) -> dict:
347
+ state = self._get_state()
348
+ state_dict = OrderedDict(
349
+ collector=self.collector.state_dict(),
350
+ loss_module=self.loss_module.state_dict(),
351
+ state=state,
352
+ **{k: item.state_dict() for k, item in self._modules.items()},
353
+ )
354
+ return state_dict
355
+
356
+ def load_state_dict(self, state_dict: dict) -> None:
357
+ model_state_dict = state_dict["loss_module"]
358
+ collector_state_dict = state_dict["collector"]
359
+
360
+ self.loss_module.load_state_dict(model_state_dict)
361
+ self.collector.load_state_dict(collector_state_dict)
362
+ for key, item in self._modules.items():
363
+ item.load_state_dict(state_dict[key])
364
+
365
+ self.collected_frames = state_dict["state"]["collected_frames"]
366
+ self._last_log = state_dict["state"]["_last_log"]
367
+ self._last_save = state_dict["state"]["_last_save"]
368
+ self._optim_count = state_dict["state"]["_optim_count"]
369
+
370
+ def _save_trainer(self) -> None:
371
+ if _CKPT_BACKEND == "torchsnapshot":
372
+ if not _has_ts:
373
+ raise ImportError(
374
+ "torchsnapshot not found. Consider installing torchsnapshot or "
375
+ "using the torch checkpointing backend (`CKPT_BACKEND=torch`)"
376
+ )
377
+ Snapshot.take(app_state=self.app_state, path=self.save_trainer_file)
378
+ elif _CKPT_BACKEND == "torch":
379
+ torch.save(self.state_dict(), self.save_trainer_file)
380
+ else:
381
+ raise NotImplementedError(
382
+ f"CKPT_BACKEND should be one of {_CKPT_BACKEND.backends}, got {_CKPT_BACKEND}."
383
+ )
384
+
385
+ def save_trainer(self, force_save: bool = False) -> None:
386
+ _save = force_save
387
+ if self.save_trainer_file is not None:
388
+ if (self.collected_frames - self._last_save) > self.save_trainer_interval:
389
+ self._last_save = self.collected_frames
390
+ _save = True
391
+ if _save and self.save_trainer_file:
392
+ self._save_trainer()
393
+
394
+ def load_from_file(self, file: str | pathlib.Path, **kwargs) -> Trainer:
395
+ """Loads a file and its state-dict in the trainer.
396
+
397
+ Keyword arguments are passed to the :func:`~torch.load` function.
398
+
399
+ """
400
+ if _CKPT_BACKEND == "torchsnapshot":
401
+ snapshot = Snapshot(path=file)
402
+ snapshot.restore(app_state=self.app_state)
403
+ elif _CKPT_BACKEND == "torch":
404
+ loaded_dict: OrderedDict = torch.load(file, **kwargs)
405
+ self.load_state_dict(loaded_dict)
406
+ return self
407
+
408
+ def set_seed(self):
409
+ seed = self.collector.set_seed(self.seed, static_seed=False)
410
+ torch.manual_seed(seed)
411
+ np.random.seed(seed)
412
+
413
+ @property
414
+ def collector(self) -> BaseCollector:
415
+ return self._collector
416
+
417
+ @collector.setter
418
+ def collector(self, collector: BaseCollector) -> None:
419
+ self._collector = collector
420
+
421
+ def register_op(
422
+ self,
423
+ dest: Literal[
424
+ "batch_process",
425
+ "pre_optim_steps",
426
+ "process_optim_batch",
427
+ "post_loss",
428
+ "process_loss",
429
+ "optimizer",
430
+ "post_steps",
431
+ "post_optim",
432
+ "pre_steps_log",
433
+ "post_steps_log",
434
+ "post_optim_log",
435
+ "pre_epoch_log",
436
+ "post_epoch_log",
437
+ "pre_epoch",
438
+ "post_epoch",
439
+ ],
440
+ op: Callable,
441
+ **kwargs,
442
+ ) -> None:
443
+ # Wrap hook with timing for performance monitoring
444
+ # Get hook name from registered modules if available
445
+ hook_name = None
446
+ for name, module in self._modules.items():
447
+ if module is op or (callable(module) and module.__call__ is op):
448
+ hook_name = name
449
+ break
450
+
451
+ timed_op = self._wrap_hook_with_timing(op, hook_name)
452
+
453
+ if dest == "batch_process":
454
+ _check_input_output_typehint(
455
+ op, input=TensorDictBase, output=TensorDictBase
456
+ )
457
+ self._batch_process_ops.append((timed_op, kwargs))
458
+
459
+ elif dest == "pre_optim_steps":
460
+ _check_input_output_typehint(op, input=None, output=None)
461
+ self._pre_optim_ops.append((timed_op, kwargs))
462
+
463
+ elif dest == "process_optim_batch":
464
+ _check_input_output_typehint(
465
+ op, input=TensorDictBase, output=TensorDictBase
466
+ )
467
+ self._process_optim_batch_ops.append((timed_op, kwargs))
468
+
469
+ elif dest == "post_loss":
470
+ _check_input_output_typehint(
471
+ op, input=TensorDictBase, output=TensorDictBase
472
+ )
473
+ self._post_loss_ops.append((timed_op, kwargs))
474
+
475
+ elif dest == "process_loss":
476
+ _check_input_output_typehint(
477
+ op, input=TensorDictBase, output=TensorDictBase
478
+ )
479
+ self._process_loss_ops.append((timed_op, kwargs))
480
+
481
+ elif dest == "optimizer":
482
+ _check_input_output_typehint(
483
+ op, input=[TensorDictBase, bool, float, int], output=TensorDictBase
484
+ )
485
+ self._optimizer_ops.append((timed_op, kwargs))
486
+
487
+ elif dest == "post_steps":
488
+ _check_input_output_typehint(op, input=None, output=None)
489
+ self._post_steps_ops.append((timed_op, kwargs))
490
+
491
+ elif dest == "post_optim":
492
+ _check_input_output_typehint(op, input=None, output=None)
493
+ self._post_optim_ops.append((timed_op, kwargs))
494
+
495
+ elif dest == "pre_steps_log":
496
+ _check_input_output_typehint(
497
+ op, input=TensorDictBase, output=tuple[str, float]
498
+ )
499
+ self._pre_steps_log_ops.append((timed_op, kwargs))
500
+
501
+ elif dest == "post_steps_log":
502
+ _check_input_output_typehint(
503
+ op, input=TensorDictBase, output=tuple[str, float]
504
+ )
505
+ self._post_steps_log_ops.append((timed_op, kwargs))
506
+
507
+ elif dest == "post_optim_log":
508
+ _check_input_output_typehint(
509
+ op, input=TensorDictBase, output=tuple[str, float]
510
+ )
511
+ self._post_optim_log_ops.append((timed_op, kwargs))
512
+
513
+ elif dest == "pre_epoch_log":
514
+ _check_input_output_typehint(
515
+ op, input=TensorDictBase, output=tuple[str, float]
516
+ )
517
+ self._pre_epoch_log_ops.append((timed_op, kwargs))
518
+
519
+ elif dest == "post_epoch_log":
520
+ _check_input_output_typehint(
521
+ op, input=TensorDictBase, output=tuple[str, float]
522
+ )
523
+ self._post_epoch_log_ops.append((timed_op, kwargs))
524
+
525
+ elif dest == "pre_epoch":
526
+ _check_input_output_typehint(op, input=None, output=None)
527
+ self._pre_epoch_ops.append((timed_op, kwargs))
528
+
529
+ elif dest == "post_epoch":
530
+ _check_input_output_typehint(op, input=None, output=None)
531
+ self._post_epoch_ops.append((timed_op, kwargs))
532
+
533
+ else:
534
+ raise RuntimeError(
535
+ f"The hook collection {dest} is not recognised. Choose from:"
536
+ f"(batch_process, pre_optim_steps, process_optim_batch, post_loss, "
537
+ f"process_loss, optimizer, post_steps, post_optim, pre_steps_log, "
538
+ f"post_steps_log, post_optim_log, pre_epoch_log, post_epoch_log, "
539
+ f"pre_epoch, post_epoch)"
540
+ )
541
+
542
+ register_hook = register_op
543
+
544
+ # Process batch
545
+ def _process_batch_hook(self, batch: TensorDictBase) -> TensorDictBase:
546
+ for op, kwargs in self._batch_process_ops:
547
+ out = op(batch, **kwargs)
548
+ if isinstance(out, TensorDictBase):
549
+ batch = out
550
+ return batch
551
+
552
+ def _post_steps_hook(self) -> None:
553
+ for op, kwargs in self._post_steps_ops:
554
+ op(**kwargs)
555
+
556
+ def _post_optim_log(self, batch: TensorDictBase) -> None:
557
+ """Execute logging hooks that run AFTER EACH optimization step.
558
+
559
+ These hooks log metrics that are computed after each individual optimization step,
560
+ such as gradient norms, individual loss components, or step-specific metrics.
561
+ Called after each optimization step within the optimization loop.
562
+ """
563
+ for op, kwargs in self._post_optim_log_ops:
564
+ result = op(batch, **kwargs)
565
+ if result is not None:
566
+ self._log(**result)
567
+
568
+ def _pre_optim_hook(self):
569
+ for op, kwargs in self._pre_optim_ops:
570
+ op(**kwargs)
571
+
572
+ def _process_optim_batch_hook(self, batch):
573
+ for op, kwargs in self._process_optim_batch_ops:
574
+ out = op(batch, **kwargs)
575
+ if isinstance(out, TensorDictBase):
576
+ batch = out
577
+ return batch
578
+
579
+ def _post_loss_hook(self, batch):
580
+ for op, kwargs in self._post_loss_ops:
581
+ out = op(batch, **kwargs)
582
+ if isinstance(out, TensorDictBase):
583
+ batch = out
584
+ return batch
585
+
586
+ def _process_loss_hook(
587
+ self, sub_batch: TensorDictBase, losses_td: TensorDictBase
588
+ ) -> TensorDictBase:
589
+ """Apply registered loss transformation hooks before optimization.
590
+
591
+ Unlike ``post_loss`` hooks which operate on the batch (e.g., for priority updates),
592
+ ``process_loss`` hooks transform the loss TensorDict itself. These hooks receive
593
+ both the sub_batch and the losses, and should return the modified losses.
594
+
595
+ Use cases include loss scaling, clipping, or applying importance weights.
596
+
597
+ Args:
598
+ sub_batch: The batch of data used to compute the losses.
599
+ losses_td: The TensorDict containing loss components from the loss module.
600
+
601
+ Returns:
602
+ The (possibly modified) losses TensorDict.
603
+ """
604
+ for op, kwargs in self._process_loss_ops:
605
+ out = op(sub_batch, losses_td, **kwargs)
606
+ if isinstance(out, TensorDictBase):
607
+ losses_td = out
608
+ return losses_td
609
+
610
+ def _optimizer_hook(self, batch):
611
+ for i, (op, kwargs) in enumerate(self._optimizer_ops):
612
+ out = op(batch, self.clip_grad_norm, self.clip_norm, i, **kwargs)
613
+ if isinstance(out, TensorDictBase):
614
+ batch = out
615
+ return batch.detach()
616
+
617
+ def _post_optim_hook(self):
618
+ for op, kwargs in self._post_optim_ops:
619
+ op(**kwargs)
620
+
621
+ def _pre_epoch_log_hook(self, batch: TensorDictBase) -> None:
622
+ """Execute logging hooks that run BEFORE each epoch of optimization.
623
+
624
+ These hooks log metrics that should be computed before starting a new epoch
625
+ of optimization steps. Called once per epoch within the optimization loop.
626
+ """
627
+ for op, kwargs in self._pre_epoch_log_ops:
628
+ result = op(batch, **kwargs)
629
+ if result is not None:
630
+ self._log(**result)
631
+
632
+ def _pre_epoch_hook(self, batch: TensorDictBase, **kwargs) -> None:
633
+ """Execute regular hooks that run BEFORE each epoch of optimization.
634
+
635
+ These hooks perform non-logging operations before starting a new epoch
636
+ of optimization steps. Called once per epoch within the optimization loop.
637
+ """
638
+ for op, kwargs in self._pre_epoch_ops:
639
+ batch = op(batch, **kwargs)
640
+ return batch
641
+
642
+ def _post_epoch_log_hook(self, batch: TensorDictBase) -> None:
643
+ """Execute logging hooks that run AFTER each epoch of optimization.
644
+
645
+ These hooks log metrics that should be computed after completing an epoch
646
+ of optimization steps. Called once per epoch within the optimization loop.
647
+ """
648
+ for op, kwargs in self._post_epoch_log_ops:
649
+ result = op(batch, **kwargs)
650
+ if result is not None:
651
+ self._log(**result)
652
+
653
+ def _post_epoch_hook(self) -> None:
654
+ """Execute regular hooks that run AFTER each epoch of optimization.
655
+
656
+ These hooks perform non-logging operations after completing an epoch
657
+ of optimization steps. Called once per epoch within the optimization loop.
658
+ """
659
+ for op, kwargs in self._post_epoch_ops:
660
+ op(**kwargs)
661
+
662
+ def _pre_steps_log_hook(self, batch: TensorDictBase) -> None:
663
+ """Execute logging hooks that run BEFORE optimization steps.
664
+
665
+ These hooks typically log metrics from the collected batch data,
666
+ such as rewards, frame counts, or other batch-level statistics.
667
+ Called once per batch collection, before any optimization occurs.
668
+ """
669
+ for op, kwargs in self._pre_steps_log_ops:
670
+ result = op(batch, **kwargs)
671
+ if result is not None:
672
+ self._log(**result)
673
+
674
+ def _post_steps_log_hook(self, batch: TensorDictBase) -> None:
675
+ """Execute logging hooks that run AFTER optimization steps.
676
+
677
+ These hooks typically log metrics that depend on the optimization results,
678
+ such as validation rewards, evaluation metrics, or post-training statistics.
679
+ Called once per batch collection, after all optimization steps are complete.
680
+ """
681
+ for op, kwargs in self._post_steps_log_ops:
682
+ result = op(batch, **kwargs)
683
+ if result is not None:
684
+ self._log(**result)
685
+
686
+ def train(self):
687
+ if self.progress_bar:
688
+ self._pbar = tqdm(total=self.total_frames)
689
+ self._pbar_str = {}
690
+
691
+ if self.async_collection:
692
+ self.collector.start()
693
+ while self.collector.getattr_rb("write_count") == 0:
694
+ time.sleep(0.1)
695
+
696
+ # Create async iterator that monitors write_count progress
697
+ iterator = self._async_iterator()
698
+ else:
699
+ iterator = self.collector
700
+
701
+ for batch in iterator:
702
+ if not self.async_collection:
703
+ batch = self._process_batch_hook(batch)
704
+ current_frames = (
705
+ batch.get(("collector", "mask"), torch.tensor(batch.numel()))
706
+ .sum()
707
+ .item()
708
+ * self.frame_skip
709
+ )
710
+ self.collected_frames += current_frames
711
+ else:
712
+ # In async mode, batch is None and we track frames via write_count
713
+ batch = None
714
+ cf = self.collected_frames
715
+ self.collected_frames = self.collector.getattr_rb("write_count")
716
+ current_frames = self.collected_frames - cf
717
+
718
+ # LOGGING POINT 1: Pre-optimization logging (e.g., rewards, frame counts)
719
+ self._pre_steps_log_hook(batch)
720
+
721
+ if self.collected_frames >= self.collector.init_random_frames:
722
+ self.optim_steps(batch)
723
+ self._post_steps_hook()
724
+
725
+ # LOGGING POINT 2: Post-optimization logging (e.g., validation rewards, evaluation metrics)
726
+ self._post_steps_log_hook(batch)
727
+
728
+ if self.progress_bar:
729
+ self._pbar.update(current_frames)
730
+ self._pbar_description()
731
+
732
+ if self.collected_frames >= self.total_frames:
733
+ self.save_trainer(force_save=True)
734
+ break
735
+ self.save_trainer()
736
+
737
+ self.collector.shutdown()
738
+
739
+ def _async_iterator(self):
740
+ """Create an iterator for async collection that monitors replay buffer write_count.
741
+
742
+ This iterator yields None batches and terminates when total_frames is reached
743
+ based on the replay buffer's write_count rather than using a fixed range.
744
+ This ensures the training loop properly consumes the entire collector output.
745
+ """
746
+ while True:
747
+ current_write_count = self.collector.getattr_rb("write_count")
748
+ # Check if we've reached the target frames
749
+ if current_write_count >= self.total_frames:
750
+ break
751
+ else:
752
+ yield None
753
+
754
+ def __del__(self):
755
+ try:
756
+ self.collector.shutdown()
757
+ except Exception:
758
+ pass
759
+
760
+ def shutdown(self):
761
+ if VERBOSE:
762
+ torchrl_logger.info("shutting down collector")
763
+ self.collector.shutdown()
764
+
765
+ def optim_steps(self, batch: TensorDictBase) -> None:
766
+ average_losses = None
767
+
768
+ self._pre_optim_hook()
769
+ optim_steps_per_batch = self.optim_steps_per_batch
770
+ j = -1
771
+
772
+ for _ in range(self.num_epochs):
773
+ # LOGGING POINT 3: Pre-epoch logging (e.g., epoch-specific metrics)
774
+ self._pre_epoch_log_hook(batch)
775
+ # Regular pre-epoch operations (e.g., epoch setup)
776
+ batch_processed = self._pre_epoch_hook(batch)
777
+
778
+ if optim_steps_per_batch is None:
779
+ prog = itertools.count()
780
+ else:
781
+ prog = range(optim_steps_per_batch)
782
+
783
+ for j in prog:
784
+ self._optim_count += 1
785
+ try:
786
+ sub_batch = self._process_optim_batch_hook(batch_processed)
787
+ except StopIteration:
788
+ break
789
+ if sub_batch is None:
790
+ break
791
+ losses_td = self.loss_module(sub_batch)
792
+ self._post_loss_hook(sub_batch)
793
+
794
+ losses_td = self._process_loss_hook(sub_batch, losses_td)
795
+
796
+ losses_detached = self._optimizer_hook(losses_td)
797
+ self._post_optim_hook()
798
+
799
+ # LOGGING POINT 4: Post-optimization step logging (e.g., gradient norms, step-specific metrics)
800
+ self._post_optim_log(sub_batch)
801
+
802
+ if average_losses is None:
803
+ average_losses: TensorDictBase = losses_detached
804
+ else:
805
+ for key, item in losses_detached.items():
806
+ val = average_losses.get(key)
807
+ average_losses.set(key, val * j / (j + 1) + item / (j + 1))
808
+ del sub_batch, losses_td, losses_detached
809
+
810
+ # LOGGING POINT 5: Post-epoch logging (e.g., epoch completion metrics)
811
+ self._post_epoch_log_hook(batch)
812
+ # Regular post-epoch operations (e.g., epoch cleanup)
813
+ self._post_epoch_hook()
814
+
815
+ if j >= 0:
816
+ # Log optimization statistics and average losses after completing all optimization steps
817
+ # This is the main logging point for training metrics like loss values and optimization step count
818
+ self._log(
819
+ optim_steps=self._optim_count,
820
+ **average_losses,
821
+ )
822
+
823
+ def _log(self, log_pbar=False, **kwargs) -> None:
824
+ """Main logging method that handles both logger output and progress bar updates.
825
+
826
+ This method is called from various hooks throughout the training loop to log metrics.
827
+ It maintains a history of logged values and controls logging frequency based on log_interval.
828
+
829
+ Args:
830
+ log_pbar: If True, the value will also be displayed in the progress bar
831
+ **kwargs: Key-value pairs to log, where key is the metric name and value is the metric value
832
+ """
833
+ collected_frames = self.collected_frames
834
+ for key, item in kwargs.items():
835
+ # Store all values in history regardless of logging frequency
836
+ self._log_dict[key].append(item)
837
+
838
+ # Check if enough frames have passed since last logging for this key
839
+ if (collected_frames - self._last_log.get(key, 0)) > self._log_interval:
840
+ self._last_log[key] = collected_frames
841
+ _log = True
842
+ else:
843
+ _log = False
844
+
845
+ # Determine logging method (defaults to "log_scalar")
846
+ method = LOGGER_METHODS.get(key, "log_scalar")
847
+
848
+ # Log to external logger (e.g., tensorboard, wandb) if conditions are met
849
+ if _log and self.logger is not None:
850
+ getattr(self.logger, method)(key, item, step=collected_frames)
851
+
852
+ # Update progress bar if requested and method is scalar
853
+ if method == "log_scalar" and self.progress_bar and log_pbar:
854
+ if isinstance(item, torch.Tensor):
855
+ item = item.item()
856
+ self._pbar_str[key] = item
857
+
858
+ def _pbar_description(self) -> None:
859
+ """Update the progress bar description with current metric values.
860
+
861
+ This method formats and displays the current values of metrics that have
862
+ been marked for progress bar display (log_pbar=True) in the logging hooks.
863
+ """
864
+ if self.progress_bar:
865
+ self._pbar.set_description(
866
+ ", ".join(
867
+ [
868
+ f"{key}: {self._pbar_str[key]:{TYPE_DESCR.get(type(self._pbar_str[key]), '4.4f')}}"
869
+ for key in sorted(self._pbar_str.keys())
870
+ ]
871
+ )
872
+ )
873
+
874
+ def __repr__(self) -> str:
875
+ loss_str = indent(f"loss={self.loss_module}", 4 * " ")
876
+ collector_str = indent(f"collector={self.collector}", 4 * " ")
877
+ optimizer_str = indent(f"optimizer={self.optimizer}", 4 * " ")
878
+ logger = indent(f"logger={self.logger}", 4 * " ")
879
+
880
+ string = "\n".join(
881
+ [
882
+ loss_str,
883
+ collector_str,
884
+ optimizer_str,
885
+ logger,
886
+ ]
887
+ )
888
+ string = f"Trainer(\n{string})"
889
+ return string
890
+
891
+
892
+ def _get_list_state_dict(hook_list):
893
+ out = []
894
+ for item, kwargs in hook_list:
895
+ if hasattr(item, "state_dict"):
896
+ out.append((item.state_dict(), kwargs))
897
+ else:
898
+ out.append((None, kwargs))
899
+ return out
900
+
901
+
902
+ def _load_list_state_dict(list_state_dict, hook_list):
903
+ for i, ((state_dict_item, kwargs), (item, _)) in enumerate(
904
+ zip(list_state_dict, hook_list)
905
+ ):
906
+ if state_dict_item is not None:
907
+ item.load_state_dict(state_dict_item)
908
+ hook_list[i] = (item, kwargs)
909
+
910
+
911
+ class SelectKeys(TrainerHookBase):
912
+ """Selects keys in a TensorDict batch.
913
+
914
+ Args:
915
+ keys (iterable of strings): keys to be selected in the tensordict.
916
+
917
+ Examples:
918
+ >>> trainer = make_trainer()
919
+ >>> key1 = "first key"
920
+ >>> key2 = "second key"
921
+ >>> td = TensorDict(
922
+ ... {
923
+ ... key1: torch.randn(3),
924
+ ... key2: torch.randn(3),
925
+ ... },
926
+ ... [],
927
+ ... )
928
+ >>> trainer.register_op("batch_process", SelectKeys([key1]))
929
+ >>> td_out = trainer._process_batch_hook(td)
930
+ >>> assert key1 in td_out.keys()
931
+ >>> assert key2 not in td_out.keys()
932
+
933
+ """
934
+
935
+ def __init__(self, keys: Sequence[str]):
936
+ if isinstance(keys, str):
937
+ raise RuntimeError(
938
+ "Expected keys to be an iterable of str, got str instead"
939
+ )
940
+ self.keys = keys
941
+
942
+ def __call__(self, batch: TensorDictBase) -> TensorDictBase:
943
+ return batch.select(*self.keys)
944
+
945
+ def state_dict(self) -> dict[str, Any]:
946
+ return {}
947
+
948
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
949
+ pass
950
+
951
+ def register(self, trainer, name="select_keys") -> None:
952
+ trainer.register_op("batch_process", self)
953
+ trainer.register_module(name, self)
954
+
955
+
956
+ class ReplayBufferTrainer(TrainerHookBase):
957
+ """Replay buffer hook provider.
958
+
959
+ Args:
960
+ replay_buffer (TensorDictReplayBuffer): replay buffer to be used.
961
+ batch_size (int, optional): batch size when sampling data from the
962
+ latest collection or from the replay buffer. If none is provided,
963
+ the replay buffer batch-size will be used (preferred option for
964
+ unchanged batch-sizes).
965
+ memmap (bool, optional): if ``True``, a memmap tensordict is created.
966
+ Default is ``False``.
967
+ device (device, optional): device where the samples must be placed.
968
+ Default to ``None``.
969
+ flatten_tensordicts (bool, optional): if ``True``, the tensordicts will be
970
+ flattened (or equivalently masked with the valid mask obtained from
971
+ the collector) before being passed to the replay buffer. Otherwise,
972
+ no transform will be achieved other than padding (see :obj:`max_dims` arg below).
973
+ Defaults to ``False``.
974
+ max_dims (sequence of int, optional): if :obj:`flatten_tensordicts` is set to False,
975
+ this will be a list of the length of the batch_size of the provided
976
+ tensordicts that represent the maximum size of each. If provided,
977
+ this list of sizes will be used to pad the tensordict and make their shape
978
+ match before they are passed to the replay buffer. If there is no
979
+ maximum value, a -1 value should be provided.
980
+ iterate (bool, optional): if ``True``, the replay buffer will be iterated over
981
+ in a loop. Defaults to ``False`` (call to :meth:`~torchrl.data.ReplayBuffer.sample` will be used).
982
+
983
+ Examples:
984
+ >>> rb_trainer = ReplayBufferTrainer(replay_buffer=replay_buffer, batch_size=N)
985
+ >>> trainer.register_op("batch_process", rb_trainer.extend)
986
+ >>> trainer.register_op("process_optim_batch", rb_trainer.sample)
987
+ >>> trainer.register_op("post_loss", rb_trainer.update_priority)
988
+
989
+ """
990
+
991
+ def __init__(
992
+ self,
993
+ replay_buffer: TensorDictReplayBuffer,
994
+ batch_size: int | None = None,
995
+ memmap: bool = False,
996
+ device: DEVICE_TYPING | None = None,
997
+ flatten_tensordicts: bool = False,
998
+ max_dims: Sequence[int] | None = None,
999
+ iterate: bool = False,
1000
+ ) -> None:
1001
+ self.replay_buffer = replay_buffer
1002
+ if hasattr(replay_buffer, "update_tensordict_priority"):
1003
+ self._update_priority = self.replay_buffer.update_tensordict_priority
1004
+ else:
1005
+ if isinstance(replay_buffer.sampler, PrioritizedSampler):
1006
+ raise ValueError(
1007
+ "Prioritized sampler not supported for replay buffer trainer if not within a TensorDictReplayBuffer"
1008
+ )
1009
+ self._update_priority = None
1010
+ self.batch_size = batch_size
1011
+ self.memmap = memmap
1012
+ self.device = device
1013
+ self.flatten_tensordicts = flatten_tensordicts
1014
+ self.max_dims = max_dims
1015
+ self.iterate = iterate
1016
+ if iterate:
1017
+ self.replay_buffer_iter = iter(self.replay_buffer)
1018
+
1019
+ def extend(self, batch: TensorDictBase) -> TensorDictBase:
1020
+ if self.flatten_tensordicts:
1021
+ if ("collector", "mask") in batch.keys(True):
1022
+ batch = batch[batch.get(("collector", "mask"))]
1023
+ else:
1024
+ if "truncated" in batch["next"]:
1025
+ batch["next", "truncated"][..., -1] = True
1026
+ batch = batch.reshape(-1)
1027
+ else:
1028
+ if self.max_dims is not None:
1029
+ pads = []
1030
+ for d in range(batch.ndimension()):
1031
+ pad_value = (
1032
+ 0
1033
+ if self.max_dims[d] == -1
1034
+ else self.max_dims[d] - batch.batch_size[d]
1035
+ )
1036
+ pads += [0, pad_value]
1037
+ batch = pad(batch, pads)
1038
+ batch = batch.cpu()
1039
+ self.replay_buffer.extend(batch)
1040
+ return batch
1041
+
1042
+ def sample(self, batch: TensorDictBase) -> TensorDictBase:
1043
+ if self.iterate:
1044
+ try:
1045
+ sample = next(self.replay_buffer_iter)
1046
+ except StopIteration:
1047
+ # reset the replay buffer
1048
+ self.replay_buffer_iter = iter(self.replay_buffer)
1049
+ raise
1050
+ else:
1051
+ sample = self.replay_buffer.sample(batch_size=self.batch_size)
1052
+ return sample.to(self.device) if self.device is not None else sample
1053
+
1054
+ def update_priority(self, batch: TensorDictBase) -> None:
1055
+ if self._update_priority is not None:
1056
+ self._update_priority(batch)
1057
+
1058
+ def state_dict(self) -> dict[str, Any]:
1059
+ return {
1060
+ "replay_buffer": self.replay_buffer.state_dict(),
1061
+ }
1062
+
1063
+ def load_state_dict(self, state_dict) -> None:
1064
+ self.replay_buffer.load_state_dict(state_dict["replay_buffer"])
1065
+
1066
+ def register(self, trainer: Trainer, name: str = "replay_buffer"):
1067
+ trainer.register_op("batch_process", self.extend)
1068
+ trainer.register_op("process_optim_batch", self.sample)
1069
+ trainer.register_op("post_loss", self.update_priority)
1070
+ trainer.register_module(name, self)
1071
+
1072
+
1073
+ class OptimizerHook(TrainerHookBase):
1074
+ """Add an optimizer for one or more loss components.
1075
+
1076
+ Args:
1077
+ optimizer (optim.Optimizer): An optimizer to apply to the loss_components.
1078
+ loss_components (Sequence[str], optional): The keys in the loss TensorDict
1079
+ for which the optimizer should be appled to the respective values.
1080
+ If omitted, the optimizer is applied to all components with the
1081
+ names starting with `loss_`.
1082
+
1083
+ Examples:
1084
+ >>> optimizer_hook = OptimizerHook(optimizer, ["loss_actor"])
1085
+ >>> trainer.register_op("optimizer", optimizer_hook)
1086
+
1087
+ """
1088
+
1089
+ def __init__(
1090
+ self,
1091
+ optimizer: optim.Optimizer,
1092
+ loss_components: Sequence[str] | None = None,
1093
+ ):
1094
+ if loss_components is not None and not loss_components:
1095
+ raise ValueError(
1096
+ "loss_components list cannot be empty. "
1097
+ "Set to None to act on all components of the loss."
1098
+ )
1099
+
1100
+ self.optimizer = optimizer
1101
+ self.loss_components = loss_components
1102
+ if self.loss_components is not None:
1103
+ self.loss_components = set(self.loss_components)
1104
+
1105
+ def _grad_clip(self, clip_grad_norm: bool, clip_norm: float) -> float:
1106
+ params = []
1107
+ for param_group in self.optimizer.param_groups:
1108
+ params += param_group["params"]
1109
+
1110
+ if clip_grad_norm and clip_norm is not None:
1111
+ gn = nn.utils.clip_grad_norm_(params, clip_norm)
1112
+ else:
1113
+ gn = sum([p.grad.pow(2).sum() for p in params if p.grad is not None]).sqrt()
1114
+ if clip_norm is not None:
1115
+ nn.utils.clip_grad_value_(params, clip_norm)
1116
+
1117
+ return float(gn)
1118
+
1119
+ def __call__(
1120
+ self,
1121
+ losses_td: TensorDictBase,
1122
+ clip_grad_norm: bool,
1123
+ clip_norm: float,
1124
+ index: int,
1125
+ ) -> TensorDictBase:
1126
+ loss_components = (
1127
+ [item for key, item in losses_td.items() if key in self.loss_components]
1128
+ if self.loss_components is not None
1129
+ else [item for key, item in losses_td.items() if key.startswith("loss")]
1130
+ )
1131
+ loss = sum(loss_components)
1132
+ loss.backward()
1133
+
1134
+ grad_norm = self._grad_clip(clip_grad_norm, clip_norm)
1135
+ losses_td[f"grad_norm_{index}"] = torch.tensor(grad_norm)
1136
+
1137
+ self.optimizer.step()
1138
+ self.optimizer.zero_grad()
1139
+
1140
+ return losses_td
1141
+
1142
+ def state_dict(self) -> dict[str, Any]:
1143
+ return {}
1144
+
1145
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
1146
+ pass
1147
+
1148
+ def register(self, trainer, name="optimizer") -> None:
1149
+ trainer.register_op("optimizer", self)
1150
+ trainer.register_module(name, self)
1151
+
1152
+
1153
+ class ClearCudaCache(TrainerHookBase):
1154
+ """Clears cuda cache at a given interval.
1155
+
1156
+ Examples:
1157
+ >>> clear_cuda = ClearCudaCache(100)
1158
+ >>> trainer.register_op("pre_optim_steps", clear_cuda)
1159
+
1160
+ """
1161
+
1162
+ def __init__(self, interval: int):
1163
+ self.interval = interval
1164
+ self.count = 0
1165
+
1166
+ def __call__(self, *args, **kwargs):
1167
+ self.count += 1
1168
+ if self.count % self.interval == 0:
1169
+ torch.cuda.empty_cache()
1170
+
1171
+
1172
+ class LogTiming(TrainerHookBase):
1173
+ """Hook to log timing information collected by timeit context managers.
1174
+
1175
+ This hook extracts timing data from the global timeit registry and logs it
1176
+ to the trainer's logger (e.g., wandb, tensorboard). It's useful for profiling
1177
+ different parts of the training loop.
1178
+
1179
+ Args:
1180
+ prefix (str, optional): Prefix to add to timing metric names.
1181
+ Default is "time".
1182
+ percall (bool, optional): If True, log average time per call.
1183
+ If False, log total time. Default is True.
1184
+ erase (bool, optional): If True, reset timing data after each log.
1185
+ Default is False.
1186
+
1187
+ Examples:
1188
+ >>> # Log timing data after each optimization step
1189
+ >>> log_timing = LogTiming(prefix="time", percall=True)
1190
+ >>> trainer.register_op("post_optim_log", log_timing)
1191
+
1192
+ >>> # Log timing data after each batch collection
1193
+ >>> log_timing = LogTiming(prefix="time", erase=True)
1194
+ >>> trainer.register_op("post_steps_log", log_timing)
1195
+
1196
+ Note:
1197
+ This hook works with timing data collected using the `timeit` context manager.
1198
+ For example, hooks registered with `register_op` are automatically wrapped
1199
+ with timing measurement.
1200
+ """
1201
+
1202
+ def __init__(
1203
+ self,
1204
+ prefix: str = "time",
1205
+ percall: bool = True,
1206
+ erase: bool = False,
1207
+ ):
1208
+ self.prefix = prefix
1209
+ self.percall = percall
1210
+ self.erase = erase
1211
+
1212
+ def __call__(self, batch: TensorDictBase | None = None) -> dict:
1213
+ """Extract timing data and return as a dict for logging.
1214
+
1215
+ Args:
1216
+ batch: The batch (unused, but required by hook signature)
1217
+
1218
+ Returns:
1219
+ Dictionary of timing metrics with the format {metric_name: value}
1220
+ """
1221
+ timing_dict = timeit.todict(percall=self.percall, prefix=self.prefix)
1222
+
1223
+ if self.erase:
1224
+ timeit.erase()
1225
+
1226
+ return timing_dict
1227
+
1228
+ def state_dict(self) -> dict[str, Any]:
1229
+ """Return state dict for checkpointing."""
1230
+ return {
1231
+ "prefix": self.prefix,
1232
+ "percall": self.percall,
1233
+ "erase": self.erase,
1234
+ }
1235
+
1236
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
1237
+ """Load state dict from checkpoint."""
1238
+ self.prefix = state_dict.get("prefix", "time")
1239
+ self.percall = state_dict.get("percall", True)
1240
+ self.erase = state_dict.get("erase", False)
1241
+
1242
+ def register(self, trainer: Trainer, name: str | None = None):
1243
+ if name is None:
1244
+ name = "log_timing"
1245
+ trainer.register_module(name, self)
1246
+ trainer.register_op("post_steps_log", self)
1247
+
1248
+
1249
+ class LogScalar(TrainerHookBase):
1250
+ """Generic scalar logger hook for any tensor values in the batch.
1251
+
1252
+ This hook can log any scalar values from the collected batch data, including
1253
+ rewards, action norms, done states, and any other metrics. It automatically
1254
+ handles masking and computes both mean and standard deviation.
1255
+
1256
+ Args:
1257
+ key (NestedKey): the key where to find the value in the input batch.
1258
+ Can be a string for simple keys or a tuple for nested keys.
1259
+ Default is `torchrl.trainers.trainers.REWARD_KEY` (= `("next", "reward")`).
1260
+ logname (str, optional): name of the metric to be logged. If None, will use
1261
+ the key as the log name. Default is None.
1262
+ log_pbar (bool, optional): if ``True``, the value will be logged on
1263
+ the progression bar. Default is ``False``.
1264
+ include_std (bool, optional): if ``True``, also log the standard deviation
1265
+ of the values. Default is ``True``.
1266
+ reduction (str, optional): reduction method to apply. Can be "mean", "sum",
1267
+ "min", "max". Default is "mean".
1268
+
1269
+ Examples:
1270
+ >>> # Log training rewards
1271
+ >>> log_reward = LogScalar(("next", "reward"), "r_training", log_pbar=True)
1272
+ >>> trainer.register_op("pre_steps_log", log_reward)
1273
+
1274
+ >>> # Log action norms
1275
+ >>> log_action_norm = LogScalar("action", "action_norm", include_std=True)
1276
+ >>> trainer.register_op("pre_steps_log", log_action_norm)
1277
+
1278
+ >>> # Log done states (as percentage)
1279
+ >>> log_done = LogScalar(("next", "done"), "done_percentage", reduction="mean")
1280
+ >>> trainer.register_op("pre_steps_log", log_done)
1281
+
1282
+ """
1283
+
1284
+ def __init__(
1285
+ self,
1286
+ key: NestedKey = REWARD_KEY,
1287
+ logname: str | None = None,
1288
+ log_pbar: bool = False,
1289
+ include_std: bool = True,
1290
+ reduction: str = "mean",
1291
+ ):
1292
+ self.key = key
1293
+ self.logname = logname if logname is not None else str(key)
1294
+ self.log_pbar = log_pbar
1295
+ self.include_std = include_std
1296
+ self.reduction = reduction
1297
+
1298
+ # Validate reduction method
1299
+ if reduction not in ["mean", "sum", "min", "max"]:
1300
+ raise ValueError(
1301
+ f"reduction must be one of ['mean', 'sum', 'min', 'max'], got {reduction}"
1302
+ )
1303
+
1304
+ def _apply_reduction(self, tensor: torch.Tensor) -> torch.Tensor:
1305
+ """Apply the specified reduction to the tensor."""
1306
+ if self.reduction == "mean":
1307
+ return tensor.float().mean()
1308
+ elif self.reduction == "sum":
1309
+ return tensor.sum()
1310
+ elif self.reduction == "min":
1311
+ return tensor.min()
1312
+ elif self.reduction == "max":
1313
+ return tensor.max()
1314
+ else:
1315
+ raise ValueError(f"Unknown reduction: {self.reduction}")
1316
+
1317
+ def __call__(self, batch: TensorDictBase) -> dict:
1318
+ # Get the tensor from the batch
1319
+ tensor = batch.get(self.key)
1320
+
1321
+ # Apply mask if available
1322
+ if ("collector", "mask") in batch.keys(True):
1323
+ mask = batch.get(("collector", "mask"))
1324
+ tensor = tensor[mask]
1325
+
1326
+ # Compute the main statistic
1327
+ main_value = self._apply_reduction(tensor).item()
1328
+
1329
+ # Prepare the result dictionary
1330
+ result = {
1331
+ self.logname: main_value,
1332
+ "log_pbar": self.log_pbar,
1333
+ }
1334
+
1335
+ # Add standard deviation if requested
1336
+ if self.include_std and tensor.numel() > 1:
1337
+ std_value = tensor.std().item()
1338
+ result[f"{self.logname}_std"] = std_value
1339
+
1340
+ return result
1341
+
1342
+ def register(self, trainer: Trainer, name: str | None = None):
1343
+ if name is None:
1344
+ name = f"log_{self.logname}"
1345
+ trainer.register_op("pre_steps_log", self)
1346
+ trainer.register_module(name, self)
1347
+
1348
+
1349
+ class RewardNormalizer(TrainerHookBase):
1350
+ """Reward normalizer hook.
1351
+
1352
+ Args:
1353
+ decay (:obj:`float`, optional): exponential moving average decay parameter.
1354
+ Default is 0.999
1355
+ scale (:obj:`float`, optional): the scale used to multiply the reward once
1356
+ normalized. Defaults to 1.0.
1357
+ eps (:obj:`float`, optional): the epsilon jitter used to prevent numerical
1358
+ underflow. Defaults to ``torch.finfo(DEFAULT_DTYPE).eps``
1359
+ where ``DEFAULT_DTYPE=torch.get_default_dtype()``.
1360
+ reward_key (str or tuple, optional): the key where to find the reward
1361
+ in the input batch. Defaults to ``("next", "reward")``
1362
+
1363
+ Examples:
1364
+ >>> reward_normalizer = RewardNormalizer()
1365
+ >>> trainer.register_op("batch_process", reward_normalizer.update_reward_stats)
1366
+ >>> trainer.register_op("process_optim_batch", reward_normalizer.normalize_reward)
1367
+
1368
+ """
1369
+
1370
+ def __init__(
1371
+ self,
1372
+ decay: float = 0.999,
1373
+ scale: float = 1.0,
1374
+ eps: float | None = None,
1375
+ log_pbar: bool = False,
1376
+ reward_key=None,
1377
+ ):
1378
+ self._normalize_has_been_called = False
1379
+ self._update_has_been_called = False
1380
+ self._reward_stats = OrderedDict()
1381
+ self._reward_stats["decay"] = decay
1382
+ self.scale = scale
1383
+ if eps is None:
1384
+ eps = torch.finfo(torch.get_default_dtype()).eps
1385
+ self.eps = eps
1386
+ if reward_key is None:
1387
+ reward_key = REWARD_KEY
1388
+ self.reward_key = reward_key
1389
+
1390
+ @torch.no_grad()
1391
+ def update_reward_stats(self, batch: TensorDictBase) -> None:
1392
+ reward = batch.get(self.reward_key)
1393
+ if ("collector", "mask") in batch.keys(True):
1394
+ reward = reward[batch.get(("collector", "mask"))]
1395
+ if self._update_has_been_called and not self._normalize_has_been_called:
1396
+ # We'd like to check that rewards are normalized. Problem is that the trainer can collect data without calling steps...
1397
+ # raise RuntimeError(
1398
+ # "There have been two consecutive calls to update_reward_stats without a call to normalize_reward. "
1399
+ # "Check that normalize_reward has been registered in the trainer."
1400
+ # )
1401
+ pass
1402
+ decay = self._reward_stats.get("decay", 0.999)
1403
+ sum = self._reward_stats["sum"] = (
1404
+ decay * self._reward_stats.get("sum", 0.0) + reward.sum()
1405
+ )
1406
+ ssq = self._reward_stats["ssq"] = (
1407
+ decay * self._reward_stats.get("ssq", 0.0) + reward.pow(2).sum()
1408
+ )
1409
+ count = self._reward_stats["count"] = (
1410
+ decay * self._reward_stats.get("count", 0.0) + reward.numel()
1411
+ )
1412
+
1413
+ self._reward_stats["mean"] = sum / count
1414
+ if count > 1:
1415
+ var = self._reward_stats["var"] = (ssq - sum.pow(2) / count) / (count - 1)
1416
+ else:
1417
+ var = self._reward_stats["var"] = torch.zeros_like(sum)
1418
+
1419
+ self._reward_stats["std"] = var.clamp_min(self.eps).sqrt()
1420
+ self._update_has_been_called = True
1421
+
1422
+ def normalize_reward(self, tensordict: TensorDictBase) -> TensorDictBase:
1423
+ tensordict = tensordict.to_tensordict() # make sure it is not a SubTensorDict
1424
+ reward = tensordict.get(self.reward_key)
1425
+
1426
+ if reward.device is not None:
1427
+ reward = reward - self._reward_stats["mean"].to(reward.device)
1428
+ reward = reward / self._reward_stats["std"].to(reward.device)
1429
+ else:
1430
+ reward = reward - self._reward_stats["mean"]
1431
+ reward = reward / self._reward_stats["std"]
1432
+
1433
+ tensordict.set(self.reward_key, reward * self.scale)
1434
+ self._normalize_has_been_called = True
1435
+ return tensordict
1436
+
1437
+ def state_dict(self) -> dict[str, Any]:
1438
+ return {
1439
+ "_reward_stats": deepcopy(self._reward_stats),
1440
+ "scale": self.scale,
1441
+ "_normalize_has_been_called": self._normalize_has_been_called,
1442
+ "_update_has_been_called": self._update_has_been_called,
1443
+ }
1444
+
1445
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
1446
+ for key, value in state_dict.items():
1447
+ setattr(self, key, value)
1448
+
1449
+ def register(self, trainer: Trainer, name: str = "reward_normalizer"):
1450
+ trainer.register_op("batch_process", self.update_reward_stats)
1451
+ trainer.register_op("process_optim_batch", self.normalize_reward)
1452
+ trainer.register_module(name, self)
1453
+
1454
+
1455
+ def mask_batch(batch: TensorDictBase) -> TensorDictBase:
1456
+ """Batch masking hook.
1457
+
1458
+ If a tensordict contained padded trajectories but only single events are
1459
+ needed, this hook can be used to select the valid events from the original
1460
+ tensordict.
1461
+
1462
+ Args:
1463
+ batch:
1464
+
1465
+ Examples:
1466
+ >>> trainer = mocking_trainer()
1467
+ >>> trainer.register_op("batch_process", mask_batch)
1468
+
1469
+ """
1470
+ if ("collector", "mask") in batch.keys(True):
1471
+ mask = batch.get(("collector", "mask"))
1472
+ return batch[mask]
1473
+ return batch
1474
+
1475
+
1476
+ class BatchSubSampler(TrainerHookBase):
1477
+ """Data subsampler for online RL sota-implementations.
1478
+
1479
+ This class subsamples a part of a whole batch of data just collected from the
1480
+ environment.
1481
+
1482
+ Args:
1483
+ batch_size (int): sub-batch size to collect. The provided batch size
1484
+ must be equal to the total number of items in the output tensordict,
1485
+ which will have size [batch_size // sub_traj_len, sub_traj_len].
1486
+ sub_traj_len (int, optional): length of the trajectories that
1487
+ sub-samples must have in online settings. Default is -1 (i.e.
1488
+ takes the full length of the trajectory)
1489
+ min_sub_traj_len (int, optional): minimum value of :obj:`sub_traj_len`, in
1490
+ case some elements of the batch contain few steps.
1491
+ Default is -1 (i.e. no minimum value)
1492
+
1493
+ Examples:
1494
+ >>> td = TensorDict(
1495
+ ... {
1496
+ ... key1: torch.stack([torch.arange(0, 10), torch.arange(10, 20)], 0),
1497
+ ... key2: torch.stack([torch.arange(0, 10), torch.arange(10, 20)], 0),
1498
+ ... },
1499
+ ... [2, 10],
1500
+ ... )
1501
+ >>> trainer.register_op(
1502
+ ... "process_optim_batch",
1503
+ ... BatchSubSampler(batch_size=batch_size, sub_traj_len=sub_traj_len),
1504
+ ... )
1505
+ >>> td_out = trainer._process_optim_batch_hook(td)
1506
+ >>> assert td_out.shape == torch.Size([batch_size // sub_traj_len, sub_traj_len])
1507
+
1508
+ """
1509
+
1510
+ def __init__(
1511
+ self, batch_size: int, sub_traj_len: int = 0, min_sub_traj_len: int = 0
1512
+ ) -> None:
1513
+ self.batch_size = batch_size
1514
+ self.sub_traj_len = sub_traj_len
1515
+ self.min_sub_traj_len = min_sub_traj_len
1516
+
1517
+ def __call__(self, batch: TensorDictBase) -> TensorDictBase:
1518
+ """Sub-sampled part of a batch randomly.
1519
+
1520
+ If the batch has one dimension, a random subsample of length
1521
+ self.bach_size will be returned. If the batch has two or more
1522
+ dimensions, it is assumed that the first dimension represents the
1523
+ batch, and the second the time. If so, the resulting subsample will
1524
+ contain consecutive samples across time.
1525
+
1526
+ """
1527
+ if batch.ndimension() == 1:
1528
+ return batch[torch.randperm(batch.shape[0])[: self.batch_size]]
1529
+
1530
+ sub_traj_len = self.sub_traj_len if self.sub_traj_len > 0 else batch.shape[1]
1531
+ if ("collector", "mask") in batch.keys(True):
1532
+ # if a valid mask is present, it's important to sample only
1533
+ # valid steps
1534
+ traj_len = batch.get(("collector", "mask")).sum(-1)
1535
+ sub_traj_len = max(
1536
+ self.min_sub_traj_len,
1537
+ min(sub_traj_len, traj_len.min().int().item()),
1538
+ )
1539
+ else:
1540
+ traj_len = (
1541
+ torch.ones(batch.shape[0], device=batch.device, dtype=torch.bool)
1542
+ * batch.shape[1]
1543
+ )
1544
+ len_mask = traj_len >= sub_traj_len
1545
+ valid_trajectories = torch.arange(batch.shape[0], device=batch.device)[len_mask]
1546
+
1547
+ batch_size = self.batch_size // sub_traj_len
1548
+ if batch_size == 0:
1549
+ raise RuntimeError(
1550
+ "Resulting batch size is zero. The batch size given to "
1551
+ "BatchSubSampler must be equal to the total number of elements "
1552
+ "that will result in a batch provided to the loss function."
1553
+ )
1554
+ traj_idx = valid_trajectories[
1555
+ torch.randint(
1556
+ valid_trajectories.numel(), (batch_size,), device=batch.device
1557
+ )
1558
+ ]
1559
+
1560
+ if sub_traj_len < batch.shape[1]:
1561
+ _traj_len = traj_len[traj_idx]
1562
+ seq_idx = (
1563
+ torch.rand_like(_traj_len, dtype=torch.float)
1564
+ * (_traj_len - sub_traj_len)
1565
+ ).int()
1566
+ seq_idx = seq_idx.unsqueeze(-1).expand(-1, sub_traj_len)
1567
+ elif sub_traj_len == batch.shape[1]:
1568
+ seq_idx = torch.zeros(
1569
+ batch_size, sub_traj_len, device=batch.device, dtype=torch.long
1570
+ )
1571
+ else:
1572
+ raise ValueError(
1573
+ f"sub_traj_len={sub_traj_len} is not allowed. Accepted values "
1574
+ f"are in the range [1, {batch.shape[1]}]."
1575
+ )
1576
+
1577
+ seq_idx = seq_idx + torch.arange(sub_traj_len, device=seq_idx.device)
1578
+ td = batch[traj_idx].clone()
1579
+ td = td.apply(
1580
+ lambda t: t.gather(
1581
+ dim=1,
1582
+ index=expand_right(seq_idx, (batch_size, sub_traj_len, *t.shape[2:])),
1583
+ ),
1584
+ batch_size=(batch_size, sub_traj_len),
1585
+ )
1586
+ if ("collector", "mask") in batch.keys(True) and not td.get(
1587
+ ("collector", "mask")
1588
+ ).all():
1589
+ raise RuntimeError("Sampled invalid steps")
1590
+ return td
1591
+
1592
+ def state_dict(self) -> dict[str, Any]:
1593
+ return {}
1594
+
1595
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
1596
+ pass
1597
+
1598
+ def register(self, trainer: Trainer, name: str = "batch_subsampler"):
1599
+ trainer.register_op(
1600
+ "process_optim_batch",
1601
+ self,
1602
+ )
1603
+ trainer.register_module(name, self)
1604
+
1605
+
1606
+ class LogValidationReward(TrainerHookBase):
1607
+ """Recorder hook for :class:`~torchrl.trainers.Trainer`.
1608
+
1609
+ Args:
1610
+ record_interval (int): total number of optimization steps
1611
+ between two calls to the recorder for testing.
1612
+ record_frames (int): number of frames to be recorded during
1613
+ testing.
1614
+ frame_skip (int): frame_skip used in the environment. It is
1615
+ important to let the trainer know the number of frames skipped at
1616
+ each iteration, otherwise the frame count can be underestimated.
1617
+ For logging, this parameter is important to normalize the reward.
1618
+ Finally, to compare different runs with different frame_skip,
1619
+ one must normalize the frame count and rewards. Defaults to ``1``.
1620
+ policy_exploration (ProbabilisticTDModule): a policy
1621
+ instance used for
1622
+
1623
+ (1) updating the exploration noise schedule;
1624
+
1625
+ (2) testing the policy on the recorder.
1626
+
1627
+ Given that this instance is supposed to both explore and render
1628
+ the performance of the policy, it should be possible to turn off
1629
+ the explorative behavior by calling the
1630
+ `set_exploration_type(ExplorationType.DETERMINISTIC)` context manager.
1631
+ environment (EnvBase): An environment instance to be used
1632
+ for testing.
1633
+ exploration_type (ExplorationType, optional): exploration mode to use for the
1634
+ policy. By default, no exploration is used and the value used is
1635
+ ``ExplorationType.DETERMINISTIC``. Set to ``ExplorationType.RANDOM`` to enable exploration
1636
+ log_keys (sequence of str or tuples or str, optional): keys to read in the tensordict
1637
+ for logging. Defaults to ``[("next", "reward")]``.
1638
+ out_keys (Dict[str, str], optional): a dictionary mapping the ``log_keys``
1639
+ to their name in the logs. Defaults to ``{("next", "reward"): "r_evaluation"}``.
1640
+ suffix (str, optional): suffix of the video to be recorded.
1641
+ log_pbar (bool, optional): if ``True``, the reward value will be logged on
1642
+ the progression bar. Default is `False`.
1643
+
1644
+ """
1645
+
1646
+ ENV_DEPREC = (
1647
+ "the environment should be passed under the 'environment' key"
1648
+ " and not the 'recorder' key."
1649
+ )
1650
+
1651
+ def __init__(
1652
+ self,
1653
+ *,
1654
+ record_interval: int,
1655
+ record_frames: int,
1656
+ frame_skip: int = 1,
1657
+ policy_exploration: TensorDictModule,
1658
+ environment: EnvBase = None,
1659
+ exploration_type: ExplorationType = ExplorationType.RANDOM,
1660
+ log_keys: list[str | tuple[str]] | None = None,
1661
+ out_keys: dict[str | tuple[str], str] | None = None,
1662
+ suffix: str | None = None,
1663
+ log_pbar: bool = False,
1664
+ recorder: EnvBase = None,
1665
+ ) -> None:
1666
+ if environment is None and recorder is not None:
1667
+ warnings.warn(self.ENV_DEPREC)
1668
+ environment = recorder
1669
+ elif environment is not None and recorder is not None:
1670
+ raise ValueError("environment and recorder conflict.")
1671
+ self.policy_exploration = policy_exploration
1672
+ self.environment = environment
1673
+ self.record_frames = record_frames
1674
+ self.frame_skip = frame_skip
1675
+ self._count = 0
1676
+ self.record_interval = record_interval
1677
+ self.exploration_type = exploration_type
1678
+ if log_keys is None:
1679
+ log_keys = [("next", "reward")]
1680
+ if out_keys is None:
1681
+ out_keys = KeyDependentDefaultDict(lambda x: x)
1682
+ out_keys[("next", "reward")] = "r_evaluation"
1683
+ self.log_keys = log_keys
1684
+ self.out_keys = out_keys
1685
+ self.suffix = suffix
1686
+ self.log_pbar = log_pbar
1687
+
1688
+ @torch.inference_mode()
1689
+ def __call__(self, batch: TensorDictBase) -> dict:
1690
+ out = None
1691
+ if self._count % self.record_interval == 0:
1692
+ with set_exploration_type(self.exploration_type):
1693
+ if isinstance(self.policy_exploration, torch.nn.Module):
1694
+ self.policy_exploration.eval()
1695
+ self.environment.eval()
1696
+ td_record = self.environment.rollout(
1697
+ policy=self.policy_exploration,
1698
+ max_steps=self.record_frames,
1699
+ auto_reset=True,
1700
+ auto_cast_to_device=True,
1701
+ break_when_any_done=False,
1702
+ ).clone()
1703
+ td_record = split_trajectories(td_record)
1704
+ if isinstance(self.policy_exploration, torch.nn.Module):
1705
+ self.policy_exploration.train()
1706
+ self.environment.train()
1707
+ self.environment.transform.dump(suffix=self.suffix)
1708
+
1709
+ out = {}
1710
+ for key in self.log_keys:
1711
+ value = td_record.get(key).float()
1712
+ if key == ("next", "reward"):
1713
+ mask = td_record["mask"]
1714
+ mean_value = value[mask].mean() / self.frame_skip
1715
+ total_value = value.sum(dim=td_record.ndim - 1).mean()
1716
+ out[self.out_keys[key]] = mean_value
1717
+ out["total_" + self.out_keys[key]] = total_value
1718
+ continue
1719
+ out[self.out_keys[key]] = value
1720
+ out["log_pbar"] = self.log_pbar
1721
+ self._count += 1
1722
+ self.environment.close()
1723
+ return out
1724
+
1725
+ def state_dict(self) -> dict:
1726
+ return {
1727
+ "_count": self._count,
1728
+ "recorder_state_dict": self.environment.state_dict(),
1729
+ }
1730
+
1731
+ def load_state_dict(self, state_dict: dict) -> None:
1732
+ self._count = state_dict["_count"]
1733
+ self.environment.load_state_dict(state_dict["recorder_state_dict"])
1734
+
1735
+ def register(self, trainer: Trainer, name: str = "recorder"):
1736
+ trainer.register_module(name, self)
1737
+ trainer.register_op(
1738
+ "post_steps_log",
1739
+ self,
1740
+ )
1741
+
1742
+
1743
+ def _resolve_module(trainer: Trainer, path: str):
1744
+ """Resolve a module from a trainer using a string path.
1745
+
1746
+ Args:
1747
+ trainer (Trainer): The trainer instance to resolve from.
1748
+ path (str): A dot-separated path to the module (e.g., "loss_module.actor_network").
1749
+
1750
+ Returns:
1751
+ The resolved module.
1752
+
1753
+ Raises:
1754
+ AttributeError: If the path cannot be resolved.
1755
+
1756
+ Examples:
1757
+ >>> module = _resolve_module(trainer, "loss_module.actor_network")
1758
+ >>> module = _resolve_module(trainer, "collector.policy")
1759
+ """
1760
+ obj = trainer
1761
+ for attr in path.split("."):
1762
+ obj = getattr(obj, attr)
1763
+ return obj
1764
+
1765
+
1766
+ class UpdateWeights(TrainerHookBase):
1767
+ """A collector weights update hook class.
1768
+
1769
+ This hook must be used whenever the collector policy weights sit on a
1770
+ different device than the policy weights being trained by the Trainer.
1771
+ In that case, those weights must be synced across devices at regular
1772
+ intervals. If the devices match, this will result in a no-op.
1773
+
1774
+ Args:
1775
+ collector (BaseCollector): A data collector where the policy weights
1776
+ must be synced.
1777
+ update_weights_interval (int): Interval (in terms of number of batches
1778
+ collected) where the sync must take place.
1779
+ policy_weights_getter (Callable, optional): A callable that returns the policy
1780
+ weights to sync. Used for backward compatibility. If both this and
1781
+ weight_update_map are provided, weight_update_map takes precedence.
1782
+ weight_update_map (dict[str, str], optional): A mapping from destination paths
1783
+ (keys in collector's weight_sync_schemes) to source paths on the trainer.
1784
+ Example: ``{"policy": "loss_module.actor_network", "replay_buffer.transforms[0]": "loss_module.critic_network"}``.
1785
+ trainer (Trainer, optional): The trainer instance, required when using
1786
+ weight_update_map to resolve source paths.
1787
+
1788
+ Examples:
1789
+ >>> # Legacy usage with policy_weights_getter
1790
+ >>> update_weights = UpdateWeights(
1791
+ ... trainer.collector, T,
1792
+ ... policy_weights_getter=lambda: TensorDict.from_module(policy)
1793
+ ... )
1794
+ >>> trainer.register_op("post_steps", update_weights)
1795
+
1796
+ >>> # New usage with weight_update_map
1797
+ >>> update_weights = UpdateWeights(
1798
+ ... trainer.collector, T,
1799
+ ... weight_update_map={
1800
+ ... "policy": "loss_module.actor_network",
1801
+ ... "replay_buffer.transforms[0]": "loss_module.critic_network"
1802
+ ... },
1803
+ ... trainer=trainer
1804
+ ... )
1805
+ >>> trainer.register_op("post_steps", update_weights)
1806
+
1807
+ """
1808
+
1809
+ def __init__(
1810
+ self,
1811
+ collector: BaseCollector,
1812
+ update_weights_interval: int,
1813
+ policy_weights_getter: Callable[[Any], Any] | None = None,
1814
+ weight_update_map: dict[str, str] | None = None,
1815
+ trainer: Trainer | None = None,
1816
+ ):
1817
+ self.collector = collector
1818
+ self.update_weights_interval = update_weights_interval
1819
+ self.counter = 0
1820
+ self.policy_weights_getter = policy_weights_getter
1821
+ self.weight_update_map = weight_update_map
1822
+ self.trainer = trainer
1823
+
1824
+ # Validate inputs
1825
+ if weight_update_map is not None and trainer is None:
1826
+ raise ValueError("trainer must be provided when using weight_update_map")
1827
+
1828
+ def __call__(self):
1829
+ self.counter += 1
1830
+ if self.counter % self.update_weights_interval == 0:
1831
+ # New approach: use weight_update_map if provided
1832
+ if self.weight_update_map is not None:
1833
+ self._update_with_map()
1834
+ # Legacy approach: use policy_weights_getter
1835
+ else:
1836
+ weights = (
1837
+ self.policy_weights_getter()
1838
+ if self.policy_weights_getter is not None
1839
+ else None
1840
+ )
1841
+ if weights is not None:
1842
+ self.collector.update_policy_weights_(weights)
1843
+ else:
1844
+ self.collector.update_policy_weights_()
1845
+
1846
+ def _update_with_map(self):
1847
+ """Update weights using the weight_update_map."""
1848
+ from torchrl.weight_update.weight_sync_schemes import WeightStrategy
1849
+
1850
+ weights_dict = {}
1851
+
1852
+ for destination, source_path in self.weight_update_map.items():
1853
+ # Resolve the source module from the trainer
1854
+ source_module = _resolve_module(self.trainer, source_path)
1855
+
1856
+ # Get the scheme for this destination to know the extraction strategy
1857
+ if (
1858
+ hasattr(self.collector, "_weight_sync_schemes")
1859
+ and self.collector._weight_sync_schemes
1860
+ and destination in self.collector._weight_sync_schemes
1861
+ ):
1862
+ scheme = self.collector._weight_sync_schemes[destination]
1863
+ strategy = WeightStrategy(extract_as=scheme.strategy_str)
1864
+ weights = strategy.extract_weights(source_module)
1865
+ else:
1866
+ # Fallback: use TensorDict extraction if no scheme found
1867
+ weights = TensorDict.from_module(source_module)
1868
+
1869
+ weights_dict[destination] = weights
1870
+
1871
+ # Send all weights atomically
1872
+ self.collector.update_policy_weights_(weights_dict=weights_dict)
1873
+
1874
+ def register(self, trainer: Trainer, name: str = "update_weights"):
1875
+ trainer.register_module(name, self)
1876
+ trainer.register_op(
1877
+ "post_steps",
1878
+ self,
1879
+ )
1880
+
1881
+ def state_dict(self) -> dict:
1882
+ return {}
1883
+
1884
+ def load_state_dict(self, state_dict) -> None:
1885
+ return
1886
+
1887
+
1888
+ class CountFramesLog(TrainerHookBase):
1889
+ """A frame counter hook.
1890
+
1891
+ Args:
1892
+ frame_skip (int): frame skip of the environment. This argument is
1893
+ important to keep track of the total number of frames, not the
1894
+ apparent one.
1895
+ log_pbar (bool, optional): if ``True``, the reward value will be logged on
1896
+ the progression bar. Default is `False`.
1897
+
1898
+ Examples:
1899
+ >>> count_frames = CountFramesLog(frame_skip=frame_skip)
1900
+ >>> trainer.register_op("pre_steps_log", count_frames)
1901
+
1902
+
1903
+ """
1904
+
1905
+ @classmethod
1906
+ def __new__(cls, *args, **kwargs):
1907
+ cls.frame_count = 0
1908
+ return super().__new__(cls)
1909
+
1910
+ def __init__(self, frame_skip: int, log_pbar: bool = False):
1911
+ self.frame_skip = frame_skip
1912
+ self.log_pbar = log_pbar
1913
+
1914
+ def __call__(self, batch: TensorDictBase) -> dict:
1915
+ if ("collector", "mask") in batch.keys(True):
1916
+ current_frames = (
1917
+ batch.get(("collector", "mask")).sum().item() * self.frame_skip
1918
+ )
1919
+ else:
1920
+ current_frames = batch.numel() * self.frame_skip
1921
+ self.frame_count += current_frames
1922
+ return {"n_frames": self.frame_count, "log_pbar": self.log_pbar}
1923
+
1924
+ def register(self, trainer: Trainer, name: str = "count_frames_log"):
1925
+ trainer.register_module(name, self)
1926
+ trainer.register_op(
1927
+ "pre_steps_log",
1928
+ self,
1929
+ )
1930
+
1931
+ def state_dict(self) -> dict:
1932
+ return {"frame_count": self.frame_count}
1933
+
1934
+ def load_state_dict(self, state_dict) -> None:
1935
+ self.frame_count = state_dict["frame_count"]
1936
+
1937
+
1938
+ def _check_input_output_typehint(
1939
+ func: Callable, input: type | list[type], output: type
1940
+ ):
1941
+ # Placeholder for a function that checks the types input / output against expectations
1942
+ return
1943
+
1944
+
1945
+ def flatten_dict(d):
1946
+ """Flattens a dictionary with sub-dictionaries accessed through point-separated (:obj:`"var1.var2"`) fields."""
1947
+ out = {}
1948
+ for key, item in d.items():
1949
+ if isinstance(item, dict):
1950
+ item = flatten_dict(item)
1951
+ for _key, _item in item.items():
1952
+ out[".".join([key, _key])] = _item
1953
+ else:
1954
+ out[key] = item
1955
+ return out
1956
+
1957
+
1958
+ class TargetNetUpdaterHook(TrainerHookBase):
1959
+ """A hook for target parameters update.
1960
+
1961
+ Examples:
1962
+ >>> # define a loss module
1963
+ >>> loss_module = SACLoss(actor_network, qvalue_network)
1964
+ >>> # define a target network updater
1965
+ >>> target_net_updater = SoftUpdate(loss_module)
1966
+ >>> # define a target network updater hook
1967
+ >>> target_net_updater_hook = TargetNetUpdaterHook(target_net_updater)
1968
+ >>> # register the target network updater hook
1969
+ >>> trainer.register_op("post_optim", target_net_updater_hook)
1970
+ """
1971
+
1972
+ def __init__(self, target_params_updater: TargetNetUpdater):
1973
+ if not isinstance(target_params_updater, TargetNetUpdater):
1974
+ raise ValueError(
1975
+ f"Expected a target network updater, got {type(target_params_updater)=}"
1976
+ )
1977
+ self.target_params_updater = target_params_updater
1978
+
1979
+ def __call__(self, tensordict: TensorCollection | None = None):
1980
+ self.target_params_updater.step()
1981
+ return tensordict
1982
+
1983
+ def register(self, trainer: Trainer, name: str):
1984
+ trainer.register_op("post_steps", self)
1985
+
1986
+
1987
+ class UTDRHook(TrainerHookBase):
1988
+ """Hook for logging Update-to-Data (UTD) ratio during async collection.
1989
+
1990
+ The UTD ratio measures how many optimization steps are performed per
1991
+ collected data sample, providing insight into training efficiency during
1992
+ asynchronous data collection. This metric is particularly useful for
1993
+ off-policy algorithms where data collection and training happen concurrently.
1994
+
1995
+ The UTD ratio is calculated as: (batch_size * update_count) / write_count
1996
+ where:
1997
+ - batch_size: Size of batches sampled from replay buffer
1998
+ - update_count: Total number of optimization steps performed
1999
+ - write_count: Total number of samples written to replay buffer
2000
+
2001
+ Args:
2002
+ trainer (Trainer): The trainer instance to monitor for UTD calculation.
2003
+ Must have async_collection=True for meaningful results.
2004
+
2005
+ Note:
2006
+ This hook is only meaningful when async_collection is enabled, as it
2007
+ relies on the replay buffer's write_count to track data collection progress.
2008
+ """
2009
+
2010
+ def __init__(self, trainer: Trainer):
2011
+ self.trainer = trainer
2012
+
2013
+ def __call__(self, batch: TensorDictBase | None = None) -> dict:
2014
+ if (
2015
+ hasattr(self.trainer, "replay_buffer")
2016
+ and self.trainer.replay_buffer is not None
2017
+ ):
2018
+ write_count = self.trainer.replay_buffer.write_count
2019
+ batch_size = self.trainer.replay_buffer.batch_size
2020
+ else:
2021
+ write_count = self.trainer.collector.getattr_rb("write_count")
2022
+ batch_size = self.trainer.collector.getattr_rb("batch_size")
2023
+ if not write_count:
2024
+ return {}
2025
+ if batch_size is None and rl_warnings():
2026
+ warnings.warn("Batch size is not set. Using 1.")
2027
+ batch_size = 1
2028
+ update_count = self.trainer._optim_count
2029
+ utd_ratio = batch_size * update_count / write_count
2030
+ return {
2031
+ "utd_ratio": utd_ratio,
2032
+ "write_count": write_count,
2033
+ "update_count": update_count,
2034
+ "log_pbar": False,
2035
+ }
2036
+
2037
+ def register(self, trainer: Trainer, name: str = "utdr_hook"):
2038
+ """Register the UTD ratio hook with the trainer.
2039
+
2040
+ Args:
2041
+ trainer (Trainer): The trainer to register with.
2042
+ name (str): Name to use when registering the hook module.
2043
+ """
2044
+ trainer.register_op("pre_steps_log", self)
2045
+ trainer.register_module(name, self)
2046
+
2047
+ def state_dict(self) -> dict[str, Any]:
2048
+ """Return state dictionary for checkpointing."""
2049
+ return {}
2050
+
2051
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
2052
+ """Load state from dictionary."""