torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,373 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import pathlib
9
+ import warnings
10
+
11
+ from collections.abc import Callable
12
+
13
+ from functools import partial
14
+
15
+ from tensordict import TensorDict, TensorDictBase
16
+ from torch import optim
17
+
18
+ from torchrl.collectors import BaseCollector
19
+
20
+ from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
21
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
22
+ from torchrl.objectives.common import LossModule
23
+ from torchrl.objectives.value.advantages import GAE
24
+ from torchrl.record.loggers import Logger
25
+ from torchrl.trainers.trainers import (
26
+ LogScalar,
27
+ ReplayBufferTrainer,
28
+ Trainer,
29
+ UpdateWeights,
30
+ )
31
+
32
+ try:
33
+ pass
34
+
35
+ _has_tqdm = True
36
+ except ImportError:
37
+ _has_tqdm = False
38
+
39
+ try:
40
+ pass
41
+
42
+ _has_ts = True
43
+ except ImportError:
44
+ _has_ts = False
45
+
46
+
47
+ class PPOTrainer(Trainer):
48
+ """PPO (Proximal Policy Optimization) trainer implementation.
49
+
50
+ .. warning::
51
+ This is an experimental/prototype feature. The API may change in future versions.
52
+ Please report any issues or feedback to help improve this implementation.
53
+
54
+ This trainer implements the PPO algorithm for training reinforcement learning agents.
55
+ It extends the base Trainer class with PPO-specific functionality including
56
+ policy optimization, value function learning, and entropy regularization.
57
+
58
+ PPO typically uses multiple epochs of optimization on the same batch of data.
59
+ This trainer defaults to 4 epochs, which is a common choice for PPO implementations.
60
+
61
+ The trainer includes comprehensive logging capabilities for monitoring training progress:
62
+ - Training rewards (mean, std, max, total)
63
+ - Action statistics (norms)
64
+ - Episode completion rates
65
+ - Observation statistics (optional)
66
+
67
+ Logging can be configured via constructor parameters to enable/disable specific metrics.
68
+
69
+ Args:
70
+ collector (BaseCollector): The data collector for gathering training data.
71
+ total_frames (int): Total number of frames to train for.
72
+ frame_skip (int): Frame skip value for the environment.
73
+ optim_steps_per_batch (int): Number of optimization steps per batch.
74
+ loss_module (LossModule): The loss module for computing policy and value losses.
75
+ optimizer (optim.Optimizer, optional): The optimizer for training.
76
+ logger (Logger, optional): Logger for tracking training metrics.
77
+ clip_grad_norm (bool, optional): Whether to clip gradient norms. Default: True.
78
+ clip_norm (float, optional): Maximum gradient norm value.
79
+ progress_bar (bool, optional): Whether to show a progress bar. Default: True.
80
+ seed (int, optional): Random seed for reproducibility.
81
+ save_trainer_interval (int, optional): Interval for saving trainer state. Default: 10000.
82
+ log_interval (int, optional): Interval for logging metrics. Default: 10000.
83
+ save_trainer_file (str | pathlib.Path, optional): File path for saving trainer state.
84
+ num_epochs (int, optional): Number of epochs per batch. Default: 4.
85
+ replay_buffer (ReplayBuffer, optional): Replay buffer for storing data.
86
+ batch_size (int, optional): Batch size for optimization.
87
+ gamma (float, optional): Discount factor for GAE. Default: 0.9.
88
+ lmbda (float, optional): Lambda parameter for GAE. Default: 0.99.
89
+ enable_logging (bool, optional): Whether to enable logging. Default: True.
90
+ log_rewards (bool, optional): Whether to log rewards. Default: True.
91
+ log_actions (bool, optional): Whether to log actions. Default: True.
92
+ log_observations (bool, optional): Whether to log observations. Default: False.
93
+ async_collection (bool, optional): Whether to use async collection. Default: False.
94
+ add_gae (bool, optional): Whether to add GAE computation. Default: True.
95
+ gae (Callable, optional): Custom GAE module. If None and add_gae is True, a default GAE will be created.
96
+ weight_update_map (dict[str, str], optional): Mapping from collector destination paths (keys in
97
+ collector's weight_sync_schemes) to trainer source paths. Required if collector has
98
+ weight_sync_schemes configured. Example: {"policy": "loss_module.actor_network",
99
+ "replay_buffer.transforms[0]": "loss_module.critic_network"}
100
+ log_timings (bool, optional): If True, automatically register a LogTiming hook to log
101
+ timing information for all hooks to the logger (e.g., wandb, tensorboard).
102
+ Timing metrics will be logged with prefix "time/" (e.g., "time/hook/UpdateWeights").
103
+ Default is False.
104
+
105
+ Examples:
106
+ >>> # Basic usage with manual configuration
107
+ >>> from torchrl.trainers.algorithms.ppo import PPOTrainer
108
+ >>> from torchrl.trainers.algorithms.configs import PPOTrainerConfig
109
+ >>> from hydra import instantiate
110
+ >>> config = PPOTrainerConfig(...) # Configure with required parameters
111
+ >>> trainer = instantiate(config)
112
+ >>> trainer.train()
113
+
114
+ .. note::
115
+ This trainer requires a configurable environment setup. See the
116
+ :class:`~torchrl.trainers.algorithms.configs` module for configuration options.
117
+
118
+ .. warning::
119
+ This is an experimental feature. The API may change in future versions.
120
+ We welcome feedback and contributions to help improve this implementation!
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ *,
126
+ collector: BaseCollector,
127
+ total_frames: int,
128
+ frame_skip: int,
129
+ optim_steps_per_batch: int,
130
+ loss_module: LossModule | Callable[[TensorDictBase], TensorDictBase],
131
+ optimizer: optim.Optimizer | None = None,
132
+ logger: Logger | None = None,
133
+ clip_grad_norm: bool = True,
134
+ clip_norm: float | None = None,
135
+ progress_bar: bool = True,
136
+ seed: int | None = None,
137
+ save_trainer_interval: int = 10000,
138
+ log_interval: int = 10000,
139
+ save_trainer_file: str | pathlib.Path | None = None,
140
+ num_epochs: int = 4,
141
+ replay_buffer: ReplayBuffer | None = None,
142
+ batch_size: int | None = None,
143
+ gamma: float = 0.9,
144
+ lmbda: float = 0.99,
145
+ enable_logging: bool = True,
146
+ log_rewards: bool = True,
147
+ log_actions: bool = True,
148
+ log_observations: bool = False,
149
+ async_collection: bool = False,
150
+ add_gae: bool = True,
151
+ gae: Callable[[TensorDictBase], TensorDictBase] | None = None,
152
+ weight_update_map: dict[str, str] | None = None,
153
+ log_timings: bool = False,
154
+ ) -> None:
155
+ warnings.warn(
156
+ "PPOTrainer is an experimental/prototype feature. The API may change in future versions. "
157
+ "Please report any issues or feedback to help improve this implementation.",
158
+ UserWarning,
159
+ stacklevel=2,
160
+ )
161
+ super().__init__(
162
+ collector=collector,
163
+ total_frames=total_frames,
164
+ frame_skip=frame_skip,
165
+ optim_steps_per_batch=optim_steps_per_batch,
166
+ loss_module=loss_module,
167
+ optimizer=optimizer,
168
+ logger=logger,
169
+ clip_grad_norm=clip_grad_norm,
170
+ clip_norm=clip_norm,
171
+ progress_bar=progress_bar,
172
+ seed=seed,
173
+ save_trainer_interval=save_trainer_interval,
174
+ log_interval=log_interval,
175
+ save_trainer_file=save_trainer_file,
176
+ num_epochs=num_epochs,
177
+ async_collection=async_collection,
178
+ log_timings=log_timings,
179
+ )
180
+ self.replay_buffer = replay_buffer
181
+ self.async_collection = async_collection
182
+
183
+ if add_gae and gae is None:
184
+ gae = GAE(
185
+ gamma=gamma,
186
+ lmbda=lmbda,
187
+ value_network=self.loss_module.critic_network,
188
+ average_gae=True,
189
+ )
190
+ self.register_op("pre_epoch", gae)
191
+ elif not add_gae and gae is not None:
192
+ raise ValueError("gae must not be provided if add_gae is False")
193
+
194
+ if (
195
+ not self.async_collection
196
+ and replay_buffer is not None
197
+ and not isinstance(replay_buffer.sampler, SamplerWithoutReplacement)
198
+ ):
199
+ warnings.warn(
200
+ "Sampler is not a SamplerWithoutReplacement, which is required for PPO."
201
+ )
202
+
203
+ if replay_buffer is not None:
204
+ rb_trainer = ReplayBufferTrainer(
205
+ replay_buffer,
206
+ batch_size=None,
207
+ flatten_tensordicts=True,
208
+ memmap=False,
209
+ device=getattr(replay_buffer.storage, "device", "cpu"),
210
+ iterate=True,
211
+ )
212
+
213
+ if not self.async_collection:
214
+ # rb has been extended by the collector
215
+ self.register_op("pre_epoch", rb_trainer.extend)
216
+ self.register_op("process_optim_batch", rb_trainer.sample)
217
+ self.register_op("post_loss", rb_trainer.update_priority)
218
+
219
+ # Set up weight updates
220
+ # Validate weight_update_map if collector has weight_sync_schemes
221
+ if (
222
+ hasattr(self.collector, "_weight_sync_schemes")
223
+ and self.collector._weight_sync_schemes
224
+ ):
225
+ if weight_update_map is None:
226
+ raise ValueError(
227
+ "Collector has weight_sync_schemes configured, but weight_update_map was not provided. "
228
+ f"Please provide a mapping for all destinations: {list(self.collector._weight_sync_schemes.keys())}"
229
+ )
230
+
231
+ # Validate that all scheme destinations are covered in the map
232
+ scheme_destinations = set(self.collector._weight_sync_schemes.keys())
233
+ map_destinations = set(weight_update_map.keys())
234
+
235
+ if scheme_destinations != map_destinations:
236
+ missing = scheme_destinations - map_destinations
237
+ extra = map_destinations - scheme_destinations
238
+ error_msg = "weight_update_map does not match collector's weight_sync_schemes.\n"
239
+ if missing:
240
+ error_msg += f" Missing destinations: {missing}\n"
241
+ if extra:
242
+ error_msg += f" Extra destinations: {extra}\n"
243
+ raise ValueError(error_msg)
244
+
245
+ # Use the weight_update_map approach
246
+ update_weights = UpdateWeights(
247
+ self.collector,
248
+ 1,
249
+ weight_update_map=weight_update_map,
250
+ trainer=self,
251
+ )
252
+ else:
253
+ # Fall back to legacy approach for backward compatibility
254
+ if weight_update_map is not None:
255
+ warnings.warn(
256
+ "weight_update_map was provided but collector has no weight_sync_schemes. "
257
+ "Ignoring weight_update_map and using legacy policy_weights_getter.",
258
+ UserWarning,
259
+ stacklevel=2,
260
+ )
261
+
262
+ policy_weights_getter = partial(
263
+ TensorDict.from_module, self.loss_module.actor_network
264
+ )
265
+ update_weights = UpdateWeights(
266
+ self.collector, 1, policy_weights_getter=policy_weights_getter
267
+ )
268
+
269
+ self.register_op("post_steps", update_weights)
270
+
271
+ # Store logging configuration
272
+ self.enable_logging = enable_logging
273
+ self.log_rewards = log_rewards
274
+ self.log_actions = log_actions
275
+ self.log_observations = log_observations
276
+
277
+ # Set up comprehensive logging for PPO training
278
+ if self.enable_logging:
279
+ self._setup_ppo_logging()
280
+
281
+ def _setup_ppo_logging(self):
282
+ """Set up logging hooks for PPO-specific metrics.
283
+
284
+ This method configures logging for common PPO metrics including:
285
+ - Training rewards (mean and std)
286
+ - Action statistics (norms, entropy)
287
+ - Episode completion rates
288
+ - Value function statistics
289
+ - Advantage statistics
290
+ """
291
+ # Always log done states as percentage (episode completion rate)
292
+ log_done_percentage = LogScalar(
293
+ key=("next", "done"),
294
+ logname="done_percentage",
295
+ log_pbar=True,
296
+ include_std=False, # No std for binary values
297
+ reduction="mean",
298
+ )
299
+ if not self.async_collection:
300
+ self.register_op("pre_steps_log", log_done_percentage)
301
+ else:
302
+ self.register_op("post_optim_log", log_done_percentage)
303
+
304
+ # Log rewards if enabled
305
+ if self.log_rewards:
306
+ # 1. Log training rewards (most important metric for PPO)
307
+ log_rewards = LogScalar(
308
+ key=("next", "reward"),
309
+ logname="r_training",
310
+ log_pbar=True, # Show in progress bar
311
+ include_std=True,
312
+ reduction="mean",
313
+ )
314
+ if not self.async_collection:
315
+ self.register_op("pre_steps_log", log_rewards)
316
+ else:
317
+ self.register_op("post_optim_log", log_rewards)
318
+
319
+ # 2. Log maximum reward in batch (for monitoring best performance)
320
+ log_max_reward = LogScalar(
321
+ key=("next", "reward"),
322
+ logname="r_max",
323
+ log_pbar=False,
324
+ include_std=False,
325
+ reduction="max",
326
+ )
327
+ if not self.async_collection:
328
+ self.register_op("pre_steps_log", log_max_reward)
329
+ else:
330
+ self.register_op("post_optim_log", log_max_reward)
331
+
332
+ # 3. Log total reward in batch (for monitoring cumulative performance)
333
+ log_total_reward = LogScalar(
334
+ key=("next", "reward"),
335
+ logname="r_total",
336
+ log_pbar=False,
337
+ include_std=False,
338
+ reduction="sum",
339
+ )
340
+ if not self.async_collection:
341
+ self.register_op("pre_steps_log", log_total_reward)
342
+ else:
343
+ self.register_op("post_optim_log", log_total_reward)
344
+
345
+ # Log actions if enabled
346
+ if self.log_actions:
347
+ # 4. Log action norms (useful for monitoring policy behavior)
348
+ log_action_norm = LogScalar(
349
+ key="action",
350
+ logname="action_norm",
351
+ log_pbar=False,
352
+ include_std=True,
353
+ reduction="mean",
354
+ )
355
+ if not self.async_collection:
356
+ self.register_op("pre_steps_log", log_action_norm)
357
+ else:
358
+ self.register_op("post_optim_log", log_action_norm)
359
+
360
+ # Log observations if enabled
361
+ if self.log_observations:
362
+ # 5. Log observation statistics (for monitoring state distributions)
363
+ log_obs_norm = LogScalar(
364
+ key="observation",
365
+ logname="obs_norm",
366
+ log_pbar=False,
367
+ include_std=True,
368
+ reduction="mean",
369
+ )
370
+ if not self.async_collection:
371
+ self.register_op("pre_steps_log", log_obs_norm)
372
+ else:
373
+ self.register_op("post_optim_log", log_obs_norm)
@@ -0,0 +1,308 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import pathlib
9
+ import warnings
10
+
11
+ from collections.abc import Callable
12
+
13
+ from functools import partial
14
+
15
+ from tensordict import TensorDict, TensorDictBase
16
+ from torch import optim
17
+
18
+ from torchrl.collectors import BaseCollector
19
+
20
+ from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
21
+ from torchrl.objectives.common import LossModule
22
+ from torchrl.objectives.utils import TargetNetUpdater
23
+ from torchrl.record.loggers import Logger
24
+ from torchrl.trainers.trainers import (
25
+ LogScalar,
26
+ ReplayBufferTrainer,
27
+ TargetNetUpdaterHook,
28
+ Trainer,
29
+ UpdateWeights,
30
+ UTDRHook,
31
+ )
32
+
33
+
34
+ class SACTrainer(Trainer):
35
+ """A trainer class for Soft Actor-Critic (SAC) algorithm.
36
+
37
+ This trainer implements the SAC algorithm, an off-policy actor-critic method that
38
+ optimizes a stochastic policy in an off-policy way, forming a bridge between
39
+ stochastic policy optimization and DDPG-style approaches. SAC incorporates the
40
+ entropy measure of the policy into the reward to encourage exploration.
41
+
42
+ The trainer handles:
43
+ - Replay buffer management for off-policy learning
44
+ - Target network updates with configurable update frequency
45
+ - Policy weight updates to the data collector
46
+ - Comprehensive logging of training metrics
47
+ - Gradient clipping and optimization steps
48
+
49
+ Args:
50
+ collector (BaseCollector): The data collector used to gather environment interactions.
51
+ total_frames (int): Total number of frames to collect during training.
52
+ frame_skip (int): Number of frames to skip between policy updates.
53
+ optim_steps_per_batch (int): Number of optimization steps per collected batch.
54
+ loss_module (LossModule | Callable): The SAC loss module or a callable that computes losses.
55
+ optimizer (optim.Optimizer, optional): The optimizer for training. If None, must be configured elsewhere.
56
+ logger (Logger, optional): Logger for recording training metrics. Defaults to None.
57
+ clip_grad_norm (bool, optional): Whether to clip gradient norms. Defaults to True.
58
+ clip_norm (float, optional): Maximum gradient norm for clipping. Defaults to None.
59
+ progress_bar (bool, optional): Whether to show a progress bar during training. Defaults to True.
60
+ seed (int, optional): Random seed for reproducibility. Defaults to None.
61
+ save_trainer_interval (int, optional): Interval for saving trainer state. Defaults to 10000.
62
+ log_interval (int, optional): Interval for logging metrics. Defaults to 10000.
63
+ save_trainer_file (str | pathlib.Path, optional): File path for saving trainer state. Defaults to None.
64
+ replay_buffer (ReplayBuffer, optional): Replay buffer for storing and sampling experiences. Defaults to None.
65
+ batch_size (int, optional): Batch size for sampling from replay buffer. Defaults to None.
66
+ enable_logging (bool, optional): Whether to enable metric logging. Defaults to True.
67
+ log_rewards (bool, optional): Whether to log reward statistics. Defaults to True.
68
+ log_actions (bool, optional): Whether to log action statistics. Defaults to True.
69
+ log_observations (bool, optional): Whether to log observation statistics. Defaults to False.
70
+ target_net_updater (TargetNetUpdater, optional): Target network updater for soft updates. Defaults to None.
71
+
72
+ Example:
73
+ >>> from torchrl.collectors import Collector
74
+ >>> from torchrl.objectives import SACLoss
75
+ >>> from torchrl.data import ReplayBuffer, LazyTensorStorage
76
+ >>> from torch import optim
77
+ >>>
78
+ >>> # Set up collector, loss, and replay buffer
79
+ >>> collector = Collector(env, policy, frames_per_batch=1000)
80
+ >>> loss_module = SACLoss(actor_network, qvalue_network)
81
+ >>> optimizer = optim.Adam(loss_module.parameters(), lr=3e-4)
82
+ >>> replay_buffer = ReplayBuffer(storage=LazyTensorStorage(100000))
83
+ >>>
84
+ >>> # Create and run trainer
85
+ >>> trainer = SACTrainer(
86
+ ... collector=collector,
87
+ ... total_frames=1000000,
88
+ ... frame_skip=1,
89
+ ... optim_steps_per_batch=100,
90
+ ... loss_module=loss_module,
91
+ ... optimizer=optimizer,
92
+ ... replay_buffer=replay_buffer,
93
+ ... )
94
+ >>> trainer.train()
95
+
96
+ Note:
97
+ This is an experimental/prototype feature. The API may change in future versions.
98
+ SAC is particularly effective for continuous control tasks and environments where
99
+ exploration is crucial due to its entropy regularization.
100
+
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ *,
106
+ collector: BaseCollector,
107
+ total_frames: int,
108
+ frame_skip: int,
109
+ optim_steps_per_batch: int,
110
+ loss_module: LossModule | Callable[[TensorDictBase], TensorDictBase],
111
+ optimizer: optim.Optimizer | None = None,
112
+ logger: Logger | None = None,
113
+ clip_grad_norm: bool = True,
114
+ clip_norm: float | None = None,
115
+ progress_bar: bool = True,
116
+ seed: int | None = None,
117
+ save_trainer_interval: int = 10000,
118
+ log_interval: int = 10000,
119
+ save_trainer_file: str | pathlib.Path | None = None,
120
+ replay_buffer: ReplayBuffer | None = None,
121
+ batch_size: int | None = None,
122
+ enable_logging: bool = True,
123
+ log_rewards: bool = True,
124
+ log_actions: bool = True,
125
+ log_observations: bool = False,
126
+ target_net_updater: TargetNetUpdater | None = None,
127
+ async_collection: bool = False,
128
+ log_timings: bool = False,
129
+ ) -> None:
130
+ warnings.warn(
131
+ "SACTrainer is an experimental/prototype feature. The API may change in future versions. "
132
+ "Please report any issues or feedback to help improve this implementation.",
133
+ UserWarning,
134
+ stacklevel=2,
135
+ )
136
+ # try to get the action spec
137
+ self._pass_action_spec_from_collector_to_loss(collector, loss_module)
138
+
139
+ super().__init__(
140
+ collector=collector,
141
+ total_frames=total_frames,
142
+ frame_skip=frame_skip,
143
+ optim_steps_per_batch=optim_steps_per_batch,
144
+ loss_module=loss_module,
145
+ optimizer=optimizer,
146
+ logger=logger,
147
+ clip_grad_norm=clip_grad_norm,
148
+ clip_norm=clip_norm,
149
+ progress_bar=progress_bar,
150
+ seed=seed,
151
+ save_trainer_interval=save_trainer_interval,
152
+ log_interval=log_interval,
153
+ save_trainer_file=save_trainer_file,
154
+ async_collection=async_collection,
155
+ log_timings=log_timings,
156
+ )
157
+ self.replay_buffer = replay_buffer
158
+ self.async_collection = async_collection
159
+
160
+ # Note: SAC can use any sampler type, unlike PPO which requires SamplerWithoutReplacement
161
+
162
+ if replay_buffer is not None:
163
+ rb_trainer = ReplayBufferTrainer(
164
+ replay_buffer,
165
+ batch_size=None,
166
+ flatten_tensordicts=True,
167
+ memmap=False,
168
+ device=getattr(replay_buffer.storage, "device", "cpu"),
169
+ iterate=True,
170
+ )
171
+ if not self.async_collection:
172
+ self.register_op("pre_epoch", rb_trainer.extend)
173
+ self.register_op("process_optim_batch", rb_trainer.sample)
174
+ self.register_op("post_loss", rb_trainer.update_priority)
175
+ self.register_op("post_optim", TargetNetUpdaterHook(target_net_updater))
176
+
177
+ policy_weights_getter = partial(
178
+ TensorDict.from_module, self.loss_module.actor_network
179
+ )
180
+ update_weights = UpdateWeights(
181
+ self.collector, 1, policy_weights_getter=policy_weights_getter
182
+ )
183
+ self.register_op("post_steps", update_weights)
184
+
185
+ # Store logging configuration
186
+ self.enable_logging = enable_logging
187
+ self.log_rewards = log_rewards
188
+ self.log_actions = log_actions
189
+ self.log_observations = log_observations
190
+
191
+ # Set up comprehensive logging for SAC training
192
+ if self.enable_logging:
193
+ self._setup_sac_logging()
194
+
195
+ def _pass_action_spec_from_collector_to_loss(
196
+ self, collector: BaseCollector, loss: LossModule
197
+ ):
198
+ """Pass the action specification from the collector's environment to the loss module.
199
+
200
+ This method extracts the action specification from the collector's environment
201
+ and assigns it to the loss module if the loss module doesn't already have one.
202
+ This is necessary for SAC loss computation which requires knowledge of the
203
+ action space bounds for proper entropy calculation and action clipping.
204
+
205
+ Args:
206
+ collector (BaseCollector): The data collector containing the environment.
207
+ loss (LossModule): The loss module that needs the action specification.
208
+ """
209
+ if hasattr(loss, "_action_spec") and loss._action_spec is None:
210
+ action_spec = collector.getattr_env("full_action_spec_unbatched").cpu()
211
+ loss._action_spec = action_spec
212
+
213
+ def _setup_sac_logging(self):
214
+ """Set up logging hooks for SAC-specific metrics.
215
+
216
+ This method configures logging for common SAC metrics including:
217
+ - Training rewards (mean, max, total, and std)
218
+ - Action statistics (action norms)
219
+ - Episode completion rates (done percentage)
220
+ - Observation statistics (when enabled)
221
+ - Q-value and policy loss metrics (handled by loss module)
222
+ """
223
+ # Always log done states as percentage (episode completion rate)
224
+ log_done_percentage = LogScalar(
225
+ key=("next", "done"),
226
+ logname="done_percentage",
227
+ log_pbar=True,
228
+ include_std=False, # No std for binary values
229
+ reduction="mean",
230
+ )
231
+ if not self.async_collection:
232
+ self.register_op("pre_steps_log", log_done_percentage)
233
+ else:
234
+ self.register_op("post_optim_log", log_done_percentage)
235
+
236
+ # Log rewards if enabled
237
+ if self.log_rewards:
238
+ # 1. Log training rewards (most important metric for SAC)
239
+ log_rewards = LogScalar(
240
+ key=("next", "reward"),
241
+ logname="r_training",
242
+ log_pbar=True, # Show in progress bar
243
+ include_std=True,
244
+ reduction="mean",
245
+ )
246
+ if not self.async_collection:
247
+ self.register_op("pre_steps_log", log_rewards)
248
+ else:
249
+ # In the async case, use the batch passed to the optimizer
250
+ self.register_op("post_optim_log", log_rewards)
251
+
252
+ # 2. Log maximum reward in batch (for monitoring best performance)
253
+ log_max_reward = LogScalar(
254
+ key=("next", "reward"),
255
+ logname="r_max",
256
+ log_pbar=False,
257
+ include_std=False,
258
+ reduction="max",
259
+ )
260
+ if not self.async_collection:
261
+ self.register_op("pre_steps_log", log_max_reward)
262
+ else:
263
+ self.register_op("post_optim_log", log_max_reward)
264
+
265
+ # 3. Log total reward in batch (for monitoring cumulative performance)
266
+ log_total_reward = LogScalar(
267
+ key=("next", "reward_sum"),
268
+ logname="r_total",
269
+ log_pbar=False,
270
+ include_std=False,
271
+ reduction="max",
272
+ )
273
+ if not self.async_collection:
274
+ self.register_op("pre_steps_log", log_total_reward)
275
+ else:
276
+ self.register_op("post_optim_log", log_total_reward)
277
+
278
+ # Log actions if enabled
279
+ if self.log_actions:
280
+ # 4. Log action norms (useful for monitoring policy behavior)
281
+ log_action_norm = LogScalar(
282
+ key="action",
283
+ logname="action_norm",
284
+ log_pbar=False,
285
+ include_std=True,
286
+ reduction="mean",
287
+ )
288
+ if not self.async_collection:
289
+ self.register_op("pre_steps_log", log_action_norm)
290
+ else:
291
+ self.register_op("post_optim_log", log_action_norm)
292
+
293
+ # Log observations if enabled
294
+ if self.log_observations:
295
+ # 5. Log observation statistics (for monitoring state distributions)
296
+ log_obs_norm = LogScalar(
297
+ key="observation",
298
+ logname="obs_norm",
299
+ log_pbar=False,
300
+ include_std=True,
301
+ reduction="mean",
302
+ )
303
+ if not self.async_collection:
304
+ self.register_op("pre_steps_log", log_obs_norm)
305
+ else:
306
+ self.register_op("post_optim_log", log_obs_norm)
307
+
308
+ self.register_op("pre_steps_log", UTDRHook(self))