torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,227 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """Discrete (DQN) CQL Example.
6
+
7
+ This is a simple self-contained example of a discrete CQL training script.
8
+
9
+ It supports state environments like gym and gymnasium.
10
+
11
+ The helper functions are coded in the utils.py associated with this script.
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import warnings
16
+
17
+ import hydra
18
+ import numpy as np
19
+ import torch
20
+ import torch.cuda
21
+ import tqdm
22
+ from tensordict.nn import CudaGraphModule
23
+ from torchrl._utils import get_available_device, timeit
24
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
25
+ from torchrl.record.loggers import generate_exp_name, get_logger
26
+ from utils import (
27
+ log_metrics,
28
+ make_collector,
29
+ make_discrete_cql_optimizer,
30
+ make_discrete_loss,
31
+ make_discretecql_model,
32
+ make_environment,
33
+ make_replay_buffer,
34
+ )
35
+
36
+ torch.set_float32_matmul_precision("high")
37
+
38
+
39
+ @hydra.main(version_base="1.1", config_path="", config_name="discrete_online_config")
40
+ def main(cfg: DictConfig): # noqa: F821
41
+ device = (
42
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
43
+ )
44
+
45
+ # Create logger
46
+ exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name)
47
+ logger = None
48
+ if cfg.logger.backend:
49
+ logger = get_logger(
50
+ logger_type=cfg.logger.backend,
51
+ logger_name="discretecql_logging",
52
+ experiment_name=exp_name,
53
+ wandb_kwargs={
54
+ "mode": cfg.logger.mode,
55
+ "config": dict(cfg),
56
+ "project": cfg.logger.project_name,
57
+ },
58
+ )
59
+
60
+ # Set seeds
61
+ torch.manual_seed(cfg.env.seed)
62
+ np.random.seed(cfg.env.seed)
63
+
64
+ # Create environments
65
+ train_env, eval_env = make_environment(cfg)
66
+
67
+ # Create agent
68
+ model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device)
69
+
70
+ # Create loss
71
+ loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device)
72
+
73
+ compile_mode = None
74
+ if cfg.compile.compile:
75
+ if cfg.compile.compile_mode not in (None, ""):
76
+ compile_mode = cfg.compile.compile_mode
77
+ elif cfg.compile.cudagraphs:
78
+ compile_mode = "default"
79
+ else:
80
+ compile_mode = "reduce-overhead"
81
+
82
+ # Create off-policy collector
83
+ collector = make_collector(
84
+ cfg,
85
+ train_env,
86
+ explore_policy,
87
+ compile=cfg.compile.compile,
88
+ compile_mode=compile_mode,
89
+ cudagraph=cfg.compile.cudagraphs,
90
+ )
91
+
92
+ # Create replay buffer
93
+ replay_buffer = make_replay_buffer(
94
+ batch_size=cfg.optim.batch_size,
95
+ prb=cfg.replay_buffer.prb,
96
+ buffer_size=cfg.replay_buffer.size,
97
+ scratch_dir=cfg.replay_buffer.scratch_dir,
98
+ device="cpu",
99
+ )
100
+
101
+ # Create optimizers
102
+ optimizer = make_discrete_cql_optimizer(cfg, loss_module)
103
+
104
+ def update(sampled_tensordict):
105
+ # Compute loss
106
+ optimizer.zero_grad(set_to_none=True)
107
+ loss_dict = loss_module(sampled_tensordict)
108
+
109
+ q_loss = loss_dict["loss_qvalue"]
110
+ cql_loss = loss_dict["loss_cql"]
111
+ loss = q_loss + cql_loss
112
+
113
+ # Update model
114
+ loss.backward()
115
+ optimizer.step()
116
+
117
+ # Update target params
118
+ target_net_updater.step()
119
+ return loss_dict.detach()
120
+
121
+ if compile_mode:
122
+ update = torch.compile(update, mode=compile_mode)
123
+ if cfg.compile.cudagraphs:
124
+ warnings.warn(
125
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
126
+ category=UserWarning,
127
+ )
128
+ update = CudaGraphModule(update, warmup=50)
129
+
130
+ # Main loop
131
+ collected_frames = 0
132
+ pbar = tqdm.tqdm(total=cfg.collector.total_frames)
133
+
134
+ init_random_frames = cfg.collector.init_random_frames
135
+ num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
136
+ prb = cfg.replay_buffer.prb
137
+ eval_rollout_steps = cfg.env.max_episode_steps
138
+ eval_iter = cfg.logger.eval_iter
139
+ frames_per_batch = cfg.collector.frames_per_batch
140
+
141
+ c_iter = iter(collector)
142
+ total_iter = len(collector)
143
+ for _ in range(total_iter):
144
+ timeit.printevery(1000, total_iter, erase=True)
145
+ with timeit("collecting"):
146
+ torch.compiler.cudagraph_mark_step_begin()
147
+ tensordict = next(c_iter)
148
+
149
+ # Update exploration policy
150
+ explore_policy[1].step(tensordict.numel())
151
+
152
+ # Update weights of the inference policy
153
+ collector.update_policy_weights_()
154
+
155
+ current_frames = tensordict.numel()
156
+ pbar.update(current_frames)
157
+
158
+ tensordict = tensordict.reshape(-1)
159
+ with timeit("rb - extend"):
160
+ # Add to replay buffer
161
+ replay_buffer.extend(tensordict)
162
+ collected_frames += current_frames
163
+
164
+ # Optimization steps
165
+ if collected_frames >= init_random_frames:
166
+ tds = []
167
+ for _ in range(num_updates):
168
+ # Sample from replay buffer
169
+ with timeit("rb - sample"):
170
+ sampled_tensordict = replay_buffer.sample()
171
+ sampled_tensordict = sampled_tensordict.to(device)
172
+ with timeit("update"):
173
+ torch.compiler.cudagraph_mark_step_begin()
174
+ loss_dict = update(sampled_tensordict).clone()
175
+ tds.append(loss_dict)
176
+
177
+ # Update priority
178
+ if prb:
179
+ replay_buffer.update_priority(sampled_tensordict)
180
+
181
+ episode_end = (
182
+ tensordict["next", "done"]
183
+ if tensordict["next", "done"].any()
184
+ else tensordict["next", "truncated"]
185
+ )
186
+ episode_rewards = tensordict["next", "episode_reward"][episode_end]
187
+
188
+ metrics_to_log = {}
189
+ # Evaluation
190
+ with timeit("eval"):
191
+ if collected_frames % eval_iter < frames_per_batch:
192
+ with set_exploration_type(
193
+ ExplorationType.DETERMINISTIC
194
+ ), torch.no_grad():
195
+ eval_rollout = eval_env.rollout(
196
+ eval_rollout_steps,
197
+ model,
198
+ auto_cast_to_device=True,
199
+ break_when_any_done=True,
200
+ )
201
+ eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
202
+ metrics_to_log["eval/reward"] = eval_reward
203
+
204
+ # Logging
205
+ if len(episode_rewards) > 0:
206
+ episode_length = tensordict["next", "step_count"][episode_end]
207
+ metrics_to_log["train/reward"] = episode_rewards.mean().item()
208
+ metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
209
+ episode_length
210
+ )
211
+ metrics_to_log["train/epsilon"] = explore_policy[1].eps
212
+
213
+ if collected_frames >= init_random_frames:
214
+ tds = torch.stack(tds, dim=0).mean()
215
+ metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
216
+ metrics_to_log["train/cql_loss"] = tds["loss_cql"]
217
+
218
+ if logger is not None:
219
+ metrics_to_log.update(timeit.todict(prefix="time"))
220
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
221
+ log_metrics(logger, metrics_to_log, collected_frames)
222
+
223
+ collector.shutdown()
224
+
225
+
226
+ if __name__ == "__main__":
227
+ main()
@@ -0,0 +1,471 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+
9
+ import torch.nn
10
+ import torch.optim
11
+ from tensordict.nn import TensorDictModule, TensorDictSequential
12
+ from tensordict.nn.distributions import NormalParamExtractor
13
+
14
+ from torchrl.collectors import SyncDataCollector
15
+ from torchrl.data import (
16
+ Composite,
17
+ LazyMemmapStorage,
18
+ TensorDictPrioritizedReplayBuffer,
19
+ TensorDictReplayBuffer,
20
+ )
21
+ from torchrl.data.datasets.minari_data import MinariExperienceReplay
22
+ from torchrl.data.replay_buffers import SamplerWithoutReplacement
23
+ from torchrl.envs import (
24
+ CatTensors,
25
+ Compose,
26
+ DMControlEnv,
27
+ DoubleToFloat,
28
+ EnvCreator,
29
+ ParallelEnv,
30
+ RewardSum,
31
+ TransformedEnv,
32
+ )
33
+ from torchrl.envs.libs.gym import GymEnv, set_gym_backend
34
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
35
+ from torchrl.modules import (
36
+ EGreedyModule,
37
+ MLP,
38
+ ProbabilisticActor,
39
+ QValueActor,
40
+ TanhNormal,
41
+ ValueOperator,
42
+ )
43
+ from torchrl.objectives import CQLLoss, DiscreteCQLLoss, SoftUpdate
44
+ from torchrl.record import VideoRecorder
45
+
46
+ from torchrl.trainers.helpers.models import ACTIVATIONS
47
+
48
+ # ====================================================================
49
+ # Environment utils
50
+ # -----------------
51
+
52
+
53
+ def env_maker(cfg, device="cpu", from_pixels=False):
54
+ lib = cfg.env.backend
55
+ if lib in ("gym", "gymnasium"):
56
+ with set_gym_backend(lib):
57
+ return GymEnv(
58
+ cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False
59
+ )
60
+ elif lib == "dm_control":
61
+ env = DMControlEnv(
62
+ cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
63
+ )
64
+ return TransformedEnv(
65
+ env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
66
+ )
67
+ else:
68
+ raise NotImplementedError(f"Unknown lib {lib}.")
69
+
70
+
71
+ def apply_env_transforms(
72
+ env,
73
+ ):
74
+ transformed_env = TransformedEnv(
75
+ env,
76
+ Compose(
77
+ DoubleToFloat(),
78
+ RewardSum(),
79
+ ),
80
+ )
81
+ return transformed_env
82
+
83
+
84
+ def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None):
85
+ """Make environments for training and evaluation."""
86
+ maker = functools.partial(env_maker, cfg)
87
+ parallel_env = ParallelEnv(
88
+ train_num_envs,
89
+ EnvCreator(maker),
90
+ serial_for_single=True,
91
+ )
92
+ parallel_env.set_seed(cfg.env.seed)
93
+
94
+ train_env = apply_env_transforms(parallel_env)
95
+
96
+ maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video)
97
+ eval_env = TransformedEnv(
98
+ ParallelEnv(
99
+ eval_num_envs,
100
+ EnvCreator(maker),
101
+ serial_for_single=True,
102
+ ),
103
+ train_env.transform.clone(),
104
+ )
105
+ eval_env.set_seed(0)
106
+ if cfg.logger.video:
107
+ eval_env = eval_env.insert_transform(
108
+ 0, VideoRecorder(logger=logger, tag="rendered", in_keys=["pixels"])
109
+ )
110
+ return train_env, eval_env
111
+
112
+
113
+ # ====================================================================
114
+ # Collector and replay buffer
115
+ # ---------------------------
116
+
117
+
118
+ def make_collector(
119
+ cfg,
120
+ train_env,
121
+ actor_model_explore,
122
+ compile=False,
123
+ compile_mode=None,
124
+ cudagraph=False,
125
+ ):
126
+ """Make collector."""
127
+ device = cfg.collector.device
128
+ if device in ("", None):
129
+ if torch.cuda.is_available():
130
+ device = torch.device("cuda:0")
131
+ else:
132
+ device = torch.device("cpu")
133
+ collector = SyncDataCollector(
134
+ train_env,
135
+ actor_model_explore,
136
+ init_random_frames=cfg.collector.init_random_frames,
137
+ frames_per_batch=cfg.collector.frames_per_batch,
138
+ max_frames_per_traj=cfg.collector.max_frames_per_traj,
139
+ total_frames=cfg.collector.total_frames,
140
+ device=device,
141
+ compile_policy={"mode": compile_mode} if compile else False,
142
+ cudagraph_policy=cudagraph,
143
+ )
144
+ collector.set_seed(cfg.env.seed)
145
+ return collector
146
+
147
+
148
+ def make_replay_buffer(
149
+ batch_size,
150
+ prb=False,
151
+ buffer_size=1000000,
152
+ scratch_dir=None,
153
+ device="cpu",
154
+ prefetch=3,
155
+ ):
156
+ if prb:
157
+ replay_buffer = TensorDictPrioritizedReplayBuffer(
158
+ alpha=0.7,
159
+ beta=0.5,
160
+ pin_memory=False,
161
+ prefetch=prefetch,
162
+ storage=LazyMemmapStorage(
163
+ buffer_size,
164
+ scratch_dir=scratch_dir,
165
+ device=device,
166
+ ),
167
+ batch_size=batch_size,
168
+ )
169
+ else:
170
+ replay_buffer = TensorDictReplayBuffer(
171
+ pin_memory=False,
172
+ prefetch=prefetch,
173
+ storage=LazyMemmapStorage(
174
+ buffer_size,
175
+ scratch_dir=scratch_dir,
176
+ device=device,
177
+ ),
178
+ batch_size=batch_size,
179
+ )
180
+ return replay_buffer
181
+
182
+
183
+ def make_offline_replay_buffer(rb_cfg):
184
+ data = MinariExperienceReplay(
185
+ dataset_id=rb_cfg.dataset,
186
+ split_trajs=False,
187
+ batch_size=rb_cfg.batch_size,
188
+ sampler=SamplerWithoutReplacement(drop_last=True),
189
+ prefetch=4,
190
+ download=True,
191
+ )
192
+
193
+ data.append_transform(DoubleToFloat())
194
+
195
+ return data
196
+
197
+
198
+ def make_offline_discrete_replay_buffer(rb_cfg):
199
+ import gymnasium as gym
200
+ import minari
201
+ from minari import DataCollector
202
+
203
+ # Create custom minari dataset from environment
204
+
205
+ env = gym.make(rb_cfg.env)
206
+ env = DataCollector(env)
207
+
208
+ for _ in range(rb_cfg.episodes):
209
+ env.reset(seed=123)
210
+ while True:
211
+ action = env.action_space.sample()
212
+ obs, rew, terminated, truncated, info = env.step(action)
213
+ if terminated or truncated:
214
+ break
215
+
216
+ env.create_dataset(
217
+ dataset_id=rb_cfg.dataset,
218
+ algorithm_name="Random-Policy",
219
+ code_permalink="https://github.com/Farama-Foundation/Minari",
220
+ author="Farama",
221
+ author_email="contact@farama.org",
222
+ )
223
+
224
+ data = MinariExperienceReplay(
225
+ dataset_id=rb_cfg.dataset,
226
+ split_trajs=False,
227
+ batch_size=rb_cfg.batch_size,
228
+ load_from_local_minari=True,
229
+ sampler=SamplerWithoutReplacement(drop_last=True),
230
+ prefetch=4,
231
+ )
232
+
233
+ data.append_transform(DoubleToFloat())
234
+
235
+ # Clean up
236
+ minari.delete_dataset(rb_cfg.dataset)
237
+
238
+ return data
239
+
240
+
241
+ # ====================================================================
242
+ # Model
243
+ # -----
244
+ #
245
+ # We give one version of the model for learning from pixels, and one for state.
246
+ # TorchRL comes in handy at this point, as the high-level interactions with
247
+ # these models is unchanged, regardless of the modality.
248
+ #
249
+
250
+
251
+ def make_cql_model(cfg, train_env, eval_env, device="cpu"):
252
+ model_cfg = cfg.model
253
+
254
+ action_spec = train_env.action_spec_unbatched
255
+
256
+ actor_net, q_net = make_cql_modules_state(model_cfg, eval_env)
257
+ in_keys = ["observation"]
258
+ out_keys = ["loc", "scale"]
259
+
260
+ actor_module = TensorDictModule(actor_net, in_keys=in_keys, out_keys=out_keys)
261
+
262
+ # We use a ProbabilisticActor to make sure that we map the
263
+ # network output to the right space using a TanhDelta
264
+ # distribution.
265
+ actor = ProbabilisticActor(
266
+ module=actor_module,
267
+ in_keys=["loc", "scale"],
268
+ spec=action_spec,
269
+ distribution_class=TanhNormal,
270
+ # Wrapping the kwargs in a TensorDictParams such that these items are
271
+ # send to device when necessary - not compatible with compile yet
272
+ # distribution_kwargs=TensorDictParams(
273
+ # TensorDict(
274
+ # {
275
+ # "low": torch.as_tensor(action_spec.space.low, device=device),
276
+ # "high": torch.as_tensor(action_spec.space.high, device=device),
277
+ # "tanh_loc": NonTensorData(False),
278
+ # }
279
+ # ),
280
+ # no_convert=True,
281
+ # ),
282
+ distribution_kwargs={
283
+ "low": action_spec.space.low.to(device),
284
+ "high": action_spec.space.high.to(device),
285
+ "tanh_loc": False,
286
+ },
287
+ default_interaction_type=ExplorationType.RANDOM,
288
+ )
289
+
290
+ in_keys = ["observation", "action"]
291
+
292
+ out_keys = ["state_action_value"]
293
+ qvalue = ValueOperator(
294
+ in_keys=in_keys,
295
+ out_keys=out_keys,
296
+ module=q_net,
297
+ )
298
+
299
+ model = torch.nn.ModuleList([actor, qvalue]).to(device)
300
+ # init nets
301
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
302
+ td = eval_env.reset()
303
+ td = td.to(device)
304
+ for net in model:
305
+ net(td)
306
+ del td
307
+ eval_env.close()
308
+
309
+ return model
310
+
311
+
312
+ def make_discretecql_model(cfg, train_env, eval_env, device="cpu"):
313
+ model_cfg = cfg.model
314
+
315
+ action_spec = train_env.action_spec
316
+
317
+ actor_net_kwargs = {
318
+ "num_cells": model_cfg.hidden_sizes,
319
+ "out_features": action_spec.shape[-1],
320
+ "activation_class": ACTIVATIONS[model_cfg.activation],
321
+ }
322
+ actor_net = MLP(**actor_net_kwargs)
323
+ qvalue_module = QValueActor(
324
+ module=actor_net,
325
+ spec=Composite(action=action_spec),
326
+ in_keys=["observation"],
327
+ )
328
+ qvalue_module = qvalue_module.to(device)
329
+ # init nets
330
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
331
+ td = eval_env.reset()
332
+ td = td.to(device)
333
+ qvalue_module(td)
334
+
335
+ del td
336
+ greedy_module = EGreedyModule(
337
+ annealing_num_steps=cfg.collector.annealing_frames,
338
+ eps_init=cfg.collector.eps_start,
339
+ eps_end=cfg.collector.eps_end,
340
+ spec=action_spec,
341
+ )
342
+ model_explore = TensorDictSequential(
343
+ qvalue_module,
344
+ greedy_module,
345
+ ).to(device)
346
+ return qvalue_module, model_explore
347
+
348
+
349
+ def make_cql_modules_state(model_cfg, proof_environment):
350
+ action_spec = proof_environment.action_spec_unbatched
351
+
352
+ actor_net_kwargs = {
353
+ "num_cells": model_cfg.hidden_sizes,
354
+ "out_features": 2 * action_spec.shape[-1],
355
+ "activation_class": ACTIVATIONS[model_cfg.activation],
356
+ }
357
+ actor_net = MLP(**actor_net_kwargs)
358
+ actor_extractor = NormalParamExtractor(
359
+ scale_mapping=f"biased_softplus_{model_cfg.default_policy_scale}",
360
+ scale_lb=model_cfg.scale_lb,
361
+ )
362
+ actor_net = torch.nn.Sequential(actor_net, actor_extractor)
363
+
364
+ qvalue_net_kwargs = {
365
+ "num_cells": model_cfg.hidden_sizes,
366
+ "out_features": 1,
367
+ "activation_class": ACTIVATIONS[model_cfg.activation],
368
+ }
369
+
370
+ q_net = MLP(**qvalue_net_kwargs)
371
+
372
+ return actor_net, q_net
373
+
374
+
375
+ # ====================================================================
376
+ # CQL Loss
377
+ # ---------
378
+
379
+
380
+ def make_continuous_loss(loss_cfg, model, device: torch.device | None = None):
381
+ loss_module = CQLLoss(
382
+ model[0],
383
+ model[1],
384
+ loss_function=loss_cfg.loss_function,
385
+ temperature=loss_cfg.temperature,
386
+ min_q_weight=loss_cfg.min_q_weight,
387
+ max_q_backup=loss_cfg.max_q_backup,
388
+ deterministic_backup=loss_cfg.deterministic_backup,
389
+ num_random=loss_cfg.num_random,
390
+ with_lagrange=loss_cfg.with_lagrange,
391
+ lagrange_thresh=loss_cfg.lagrange_thresh,
392
+ )
393
+ loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
394
+ target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)
395
+
396
+ return loss_module, target_net_updater
397
+
398
+
399
+ def make_discrete_loss(loss_cfg, model, device: torch.device | None = None):
400
+
401
+ if "action_space" in loss_cfg: # especify action space
402
+ loss_module = DiscreteCQLLoss(
403
+ model,
404
+ loss_function=loss_cfg.loss_function,
405
+ action_space=loss_cfg.action_space,
406
+ delay_value=True,
407
+ )
408
+ else:
409
+ loss_module = DiscreteCQLLoss(
410
+ model,
411
+ loss_function=loss_cfg.loss_function,
412
+ delay_value=True,
413
+ )
414
+
415
+ loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
416
+ target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)
417
+
418
+ return loss_module, target_net_updater
419
+
420
+
421
+ def make_discrete_cql_optimizer(cfg, loss_module):
422
+ optim = torch.optim.Adam(
423
+ loss_module.parameters(),
424
+ lr=cfg.optim.lr,
425
+ weight_decay=cfg.optim.weight_decay,
426
+ )
427
+ return optim
428
+
429
+
430
+ def make_continuous_cql_optimizer(cfg, loss_module):
431
+ critic_params = loss_module.qvalue_network_params.flatten_keys().values()
432
+ actor_params = loss_module.actor_network_params.flatten_keys().values()
433
+ actor_optim = torch.optim.Adam(
434
+ actor_params,
435
+ lr=cfg.optim.actor_lr,
436
+ weight_decay=cfg.optim.weight_decay,
437
+ )
438
+ critic_optim = torch.optim.Adam(
439
+ critic_params,
440
+ lr=cfg.optim.critic_lr,
441
+ weight_decay=cfg.optim.weight_decay,
442
+ )
443
+ alpha_optim = torch.optim.Adam(
444
+ [loss_module.log_alpha],
445
+ lr=cfg.optim.actor_lr,
446
+ weight_decay=cfg.optim.weight_decay,
447
+ )
448
+ if loss_module.with_lagrange:
449
+ alpha_prime_optim = torch.optim.Adam(
450
+ [loss_module.log_alpha_prime],
451
+ lr=cfg.optim.critic_lr,
452
+ )
453
+ else:
454
+ alpha_prime_optim = None
455
+ return actor_optim, critic_optim, alpha_optim, alpha_prime_optim
456
+
457
+
458
+ # ====================================================================
459
+ # General utils
460
+ # ---------
461
+
462
+
463
+ def log_metrics(logger, metrics, step):
464
+ if logger is not None:
465
+ for metric_name, metric_value in metrics.items():
466
+ logger.log_scalar(metric_name, metric_value, step)
467
+
468
+
469
+ def dump_video(module):
470
+ if isinstance(module, VideoRecorder):
471
+ module.dump()