torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,231 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """DDPG Example.
6
+
7
+ This is a simple self-contained example of a DDPG training script.
8
+
9
+ It supports state environments like MuJoCo.
10
+
11
+ The helper functions are coded in the utils.py associated with this script.
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import warnings
16
+
17
+ import hydra
18
+ import numpy as np
19
+ import torch
20
+ import torch.cuda
21
+ import tqdm
22
+ from tensordict import TensorDict
23
+ from tensordict.nn import CudaGraphModule
24
+ from torchrl._utils import get_available_device, timeit
25
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
26
+ from torchrl.objectives import group_optimizers
27
+ from torchrl.record.loggers import generate_exp_name, get_logger
28
+ from utils import (
29
+ dump_video,
30
+ log_metrics,
31
+ make_collector,
32
+ make_ddpg_agent,
33
+ make_environment,
34
+ make_loss_module,
35
+ make_optimizer,
36
+ make_replay_buffer,
37
+ )
38
+
39
+
40
+ @hydra.main(version_base="1.1", config_path="", config_name="config")
41
+ def main(cfg: DictConfig): # noqa: F821
42
+ device = (
43
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
44
+ )
45
+ collector_device = (
46
+ torch.device(cfg.collector.device)
47
+ if cfg.collector.device
48
+ else get_available_device()
49
+ )
50
+
51
+ # Create logger
52
+ exp_name = generate_exp_name("DDPG", cfg.logger.exp_name)
53
+ logger = None
54
+ if cfg.logger.backend:
55
+ logger = get_logger(
56
+ logger_type=cfg.logger.backend,
57
+ logger_name="ddpg_logging",
58
+ experiment_name=exp_name,
59
+ wandb_kwargs={
60
+ "mode": cfg.logger.mode,
61
+ "config": dict(cfg),
62
+ "project": cfg.logger.project_name,
63
+ "group": cfg.logger.group_name,
64
+ },
65
+ )
66
+
67
+ # Set seeds
68
+ torch.manual_seed(cfg.env.seed)
69
+ np.random.seed(cfg.env.seed)
70
+
71
+ # Create environments
72
+ train_env, eval_env = make_environment(cfg, logger=logger)
73
+
74
+ # Create agent
75
+ model, exploration_policy = make_ddpg_agent(cfg, train_env, eval_env, device)
76
+
77
+ # Create DDPG loss
78
+ loss_module, target_net_updater = make_loss_module(cfg, model)
79
+
80
+ compile_mode = None
81
+ if cfg.compile.compile:
82
+ if cfg.compile.compile_mode not in (None, ""):
83
+ compile_mode = cfg.compile.compile_mode
84
+ elif cfg.compile.cudagraphs:
85
+ compile_mode = "default"
86
+ else:
87
+ compile_mode = "reduce-overhead"
88
+
89
+ # Create off-policy collector
90
+ collector = make_collector(
91
+ cfg,
92
+ train_env,
93
+ exploration_policy,
94
+ compile=cfg.compile.compile,
95
+ compile_mode=compile_mode,
96
+ cudagraph=cfg.compile.cudagraphs,
97
+ device=collector_device,
98
+ )
99
+
100
+ # Create replay buffer
101
+ replay_buffer = make_replay_buffer(
102
+ batch_size=cfg.optim.batch_size,
103
+ prb=cfg.replay_buffer.prb,
104
+ buffer_size=cfg.replay_buffer.size,
105
+ scratch_dir=cfg.replay_buffer.scratch_dir,
106
+ device="cpu",
107
+ )
108
+
109
+ # Create optimizers
110
+ optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)
111
+ optimizer = group_optimizers(optimizer_actor, optimizer_critic)
112
+
113
+ def update(sampled_tensordict):
114
+ optimizer.zero_grad(set_to_none=True)
115
+
116
+ td_loss: TensorDict = loss_module(sampled_tensordict)
117
+ td_loss.sum(reduce=True).backward()
118
+ optimizer.step()
119
+
120
+ # Update qnet_target params
121
+ target_net_updater.step()
122
+ return td_loss.detach()
123
+
124
+ if cfg.compile.compile:
125
+ update = torch.compile(update, mode=compile_mode)
126
+ if cfg.compile.cudagraphs:
127
+ warnings.warn(
128
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
129
+ category=UserWarning,
130
+ )
131
+ update = CudaGraphModule(update, warmup=50)
132
+
133
+ # Main loop
134
+ collected_frames = 0
135
+ pbar = tqdm.tqdm(total=cfg.collector.total_frames)
136
+
137
+ init_random_frames = cfg.collector.init_random_frames
138
+ num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
139
+ prb = cfg.replay_buffer.prb
140
+ frames_per_batch = cfg.collector.frames_per_batch
141
+ eval_iter = cfg.logger.eval_iter
142
+ eval_rollout_steps = cfg.env.max_episode_steps
143
+
144
+ c_iter = iter(collector)
145
+ total_iter = len(collector)
146
+ for _ in range(total_iter):
147
+ timeit.printevery(1000, total_iter, erase=True)
148
+ with timeit("collecting"):
149
+ tensordict = next(c_iter)
150
+ # Update exploration policy
151
+ exploration_policy[1].step(tensordict.numel())
152
+
153
+ # Update weights of the inference policy
154
+ collector.update_policy_weights_()
155
+
156
+ current_frames = tensordict.numel()
157
+ pbar.update(current_frames)
158
+
159
+ # Add to replay buffer
160
+ with timeit("rb - extend"):
161
+ tensordict = tensordict.reshape(-1)
162
+ replay_buffer.extend(tensordict)
163
+
164
+ collected_frames += current_frames
165
+
166
+ # Optimization steps
167
+ if collected_frames >= init_random_frames:
168
+ tds = []
169
+ for _ in range(num_updates):
170
+ # Sample from replay buffer
171
+ with timeit("rb - sample"):
172
+ sampled_tensordict = replay_buffer.sample().to(device)
173
+ with timeit("update"):
174
+ torch.compiler.cudagraph_mark_step_begin()
175
+ td_loss = update(sampled_tensordict)
176
+ tds.append(td_loss.clone())
177
+
178
+ # Update priority
179
+ if prb:
180
+ replay_buffer.update_priority(sampled_tensordict)
181
+ tds = torch.stack(tds)
182
+
183
+ episode_end = (
184
+ tensordict["next", "done"]
185
+ if tensordict["next", "done"].any()
186
+ else tensordict["next", "truncated"]
187
+ )
188
+ episode_rewards = tensordict["next", "episode_reward"][episode_end]
189
+
190
+ # Logging
191
+ metrics_to_log = {}
192
+ if len(episode_rewards) > 0:
193
+ episode_length = tensordict["next", "step_count"][episode_end]
194
+ metrics_to_log["train/reward"] = episode_rewards.mean().item()
195
+ metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
196
+ episode_length
197
+ )
198
+
199
+ if collected_frames >= init_random_frames:
200
+ tds = TensorDict(train=tds).flatten_keys("/").mean()
201
+ metrics_to_log.update(tds.to_dict())
202
+
203
+ # Evaluation
204
+ if abs(collected_frames % eval_iter) < frames_per_batch:
205
+ with set_exploration_type(
206
+ ExplorationType.DETERMINISTIC
207
+ ), torch.no_grad(), timeit("eval"):
208
+ eval_rollout = eval_env.rollout(
209
+ eval_rollout_steps,
210
+ exploration_policy,
211
+ auto_cast_to_device=True,
212
+ break_when_any_done=True,
213
+ )
214
+ eval_env.apply(dump_video)
215
+ eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
216
+ metrics_to_log["eval/reward"] = eval_reward
217
+
218
+ if logger is not None:
219
+ metrics_to_log.update(timeit.todict(prefix="time"))
220
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
221
+ log_metrics(logger, metrics_to_log, collected_frames)
222
+
223
+ collector.shutdown()
224
+ if not eval_env.is_closed:
225
+ eval_env.close()
226
+ if not train_env.is_closed:
227
+ train_env.close()
228
+
229
+
230
+ if __name__ == "__main__":
231
+ main()
@@ -0,0 +1,325 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+
9
+ import torch
10
+
11
+ from tensordict.nn import TensorDictModule, TensorDictSequential
12
+
13
+ from torch import nn, optim
14
+ from torchrl.collectors import SyncDataCollector
15
+ from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
16
+ from torchrl.data.replay_buffers.storages import LazyMemmapStorage
17
+ from torchrl.envs import (
18
+ CatTensors,
19
+ Compose,
20
+ DMControlEnv,
21
+ DoubleToFloat,
22
+ EnvCreator,
23
+ InitTracker,
24
+ ParallelEnv,
25
+ RewardSum,
26
+ StepCounter,
27
+ TransformedEnv,
28
+ )
29
+ from torchrl.envs.libs.gym import GymEnv, set_gym_backend
30
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
31
+ from torchrl.modules import (
32
+ AdditiveGaussianModule,
33
+ MLP,
34
+ OrnsteinUhlenbeckProcessModule,
35
+ TanhModule,
36
+ ValueOperator,
37
+ )
38
+
39
+ from torchrl.objectives import SoftUpdate
40
+ from torchrl.objectives.ddpg import DDPGLoss
41
+ from torchrl.record import VideoRecorder
42
+
43
+
44
+ # ====================================================================
45
+ # Environment utils
46
+ # -----------------
47
+
48
+
49
+ def env_maker(cfg, device="cpu", from_pixels=False):
50
+ lib = cfg.env.library
51
+ if lib in ("gym", "gymnasium"):
52
+ with set_gym_backend(lib):
53
+ return GymEnv(
54
+ cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False
55
+ )
56
+ elif lib == "dm_control":
57
+ env = DMControlEnv(
58
+ cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
59
+ )
60
+ return TransformedEnv(
61
+ env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
62
+ )
63
+ else:
64
+ raise NotImplementedError(f"Unknown lib {lib}.")
65
+
66
+
67
+ def apply_env_transforms(env, max_episode_steps=1000):
68
+ transformed_env = TransformedEnv(
69
+ env,
70
+ Compose(
71
+ InitTracker(),
72
+ StepCounter(max_episode_steps),
73
+ DoubleToFloat(),
74
+ RewardSum(),
75
+ ),
76
+ )
77
+ return transformed_env
78
+
79
+
80
+ def make_environment(cfg, logger):
81
+ """Make environments for training and evaluation."""
82
+ maker = functools.partial(env_maker, cfg, from_pixels=False)
83
+ parallel_env = ParallelEnv(
84
+ cfg.collector.env_per_collector,
85
+ EnvCreator(maker),
86
+ serial_for_single=True,
87
+ )
88
+ parallel_env.set_seed(cfg.env.seed)
89
+
90
+ train_env = apply_env_transforms(
91
+ parallel_env, max_episode_steps=cfg.env.max_episode_steps
92
+ )
93
+
94
+ maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video)
95
+ eval_env = TransformedEnv(
96
+ ParallelEnv(
97
+ cfg.logger.num_eval_envs,
98
+ EnvCreator(maker),
99
+ serial_for_single=True,
100
+ ),
101
+ train_env.transform.clone(),
102
+ )
103
+ eval_env.set_seed(0)
104
+ if cfg.logger.video:
105
+ eval_env = eval_env.append_transform(
106
+ VideoRecorder(logger, tag="rendered", in_keys=["pixels"])
107
+ )
108
+ return train_env, eval_env
109
+
110
+
111
+ # ====================================================================
112
+ # Collector and replay buffer
113
+ # ---------------------------
114
+
115
+
116
+ def make_collector(
117
+ cfg,
118
+ train_env,
119
+ actor_model_explore,
120
+ compile=False,
121
+ compile_mode=None,
122
+ cudagraph=False,
123
+ device: torch.device | None = None,
124
+ ):
125
+ """Make collector."""
126
+ collector = SyncDataCollector(
127
+ train_env,
128
+ actor_model_explore,
129
+ frames_per_batch=cfg.collector.frames_per_batch,
130
+ init_random_frames=cfg.collector.init_random_frames,
131
+ reset_at_each_iter=cfg.collector.reset_at_each_iter,
132
+ total_frames=cfg.collector.total_frames,
133
+ device=device,
134
+ compile_policy={"mode": compile_mode, "fullgraph": True} if compile else False,
135
+ cudagraph_policy=cudagraph,
136
+ )
137
+ collector.set_seed(cfg.env.seed)
138
+ return collector
139
+
140
+
141
+ def make_replay_buffer(
142
+ batch_size,
143
+ prb=False,
144
+ buffer_size=1000000,
145
+ scratch_dir=None,
146
+ device="cpu",
147
+ prefetch=3,
148
+ ):
149
+ if prb:
150
+ replay_buffer = TensorDictPrioritizedReplayBuffer(
151
+ alpha=0.7,
152
+ beta=0.5,
153
+ pin_memory=False,
154
+ prefetch=prefetch,
155
+ storage=LazyMemmapStorage(
156
+ buffer_size,
157
+ scratch_dir=scratch_dir,
158
+ device=device,
159
+ ),
160
+ batch_size=batch_size,
161
+ )
162
+ else:
163
+ replay_buffer = TensorDictReplayBuffer(
164
+ pin_memory=False,
165
+ prefetch=prefetch,
166
+ storage=LazyMemmapStorage(
167
+ buffer_size,
168
+ scratch_dir=scratch_dir,
169
+ device=device,
170
+ ),
171
+ batch_size=batch_size,
172
+ )
173
+ return replay_buffer
174
+
175
+
176
+ # ====================================================================
177
+ # Model
178
+ # -----
179
+
180
+
181
+ def make_ddpg_agent(cfg, train_env, eval_env, device):
182
+ """Make DDPG agent."""
183
+ # Define Actor Network
184
+ in_keys = ["observation"]
185
+ action_spec = train_env.action_spec_unbatched
186
+ actor_net_kwargs = {
187
+ "num_cells": cfg.network.hidden_sizes,
188
+ "out_features": action_spec.shape[-1],
189
+ "activation_class": get_activation(cfg),
190
+ }
191
+
192
+ actor_net = MLP(**actor_net_kwargs)
193
+
194
+ in_keys_actor = in_keys
195
+ actor_module = TensorDictModule(
196
+ actor_net,
197
+ in_keys=in_keys_actor,
198
+ out_keys=["param"],
199
+ )
200
+ actor = TensorDictSequential(
201
+ actor_module,
202
+ TanhModule(
203
+ in_keys=["param"],
204
+ out_keys=["action"],
205
+ ),
206
+ )
207
+
208
+ # Define Critic Network
209
+ qvalue_net_kwargs = {
210
+ "num_cells": cfg.network.hidden_sizes,
211
+ "out_features": 1,
212
+ "activation_class": get_activation(cfg),
213
+ }
214
+
215
+ qvalue_net = MLP(
216
+ **qvalue_net_kwargs,
217
+ )
218
+
219
+ qvalue = ValueOperator(
220
+ in_keys=["action"] + in_keys,
221
+ module=qvalue_net,
222
+ )
223
+
224
+ model = nn.ModuleList([actor, qvalue]).to(device)
225
+
226
+ # init nets
227
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
228
+ td = eval_env.reset()
229
+ td = td.to(device)
230
+ for net in model:
231
+ net(td)
232
+ del td
233
+ eval_env.close()
234
+
235
+ # Exploration wrappers:
236
+ if cfg.network.noise_type == "ou":
237
+ actor_model_explore = TensorDictSequential(
238
+ model[0],
239
+ OrnsteinUhlenbeckProcessModule(
240
+ spec=action_spec,
241
+ annealing_num_steps=1_000_000,
242
+ device=device,
243
+ safe=False,
244
+ ),
245
+ )
246
+ elif cfg.network.noise_type == "gaussian":
247
+ actor_model_explore = TensorDictSequential(
248
+ model[0],
249
+ AdditiveGaussianModule(
250
+ spec=action_spec,
251
+ sigma_end=1.0,
252
+ sigma_init=1.0,
253
+ mean=0.0,
254
+ std=0.1,
255
+ device=device,
256
+ safe=False,
257
+ ),
258
+ )
259
+ else:
260
+ raise NotImplementedError
261
+
262
+ return model, actor_model_explore
263
+
264
+
265
+ # ====================================================================
266
+ # DDPG Loss
267
+ # ---------
268
+
269
+
270
+ def make_loss_module(cfg, model):
271
+ """Make loss module and target network updater."""
272
+ # Create DDPG loss
273
+ loss_module = DDPGLoss(
274
+ actor_network=model[0],
275
+ value_network=model[1],
276
+ loss_function=cfg.optim.loss_function,
277
+ delay_actor=True,
278
+ delay_value=True,
279
+ )
280
+ loss_module.make_value_estimator(gamma=cfg.optim.gamma)
281
+
282
+ # Define Target Network Updater
283
+ target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak)
284
+ return loss_module, target_net_updater
285
+
286
+
287
+ def make_optimizer(cfg, loss_module):
288
+ critic_params = list(loss_module.value_network_params.flatten_keys().values())
289
+ actor_params = list(loss_module.actor_network_params.flatten_keys().values())
290
+
291
+ optimizer_actor = optim.Adam(
292
+ actor_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay
293
+ )
294
+ optimizer_critic = optim.Adam(
295
+ critic_params,
296
+ lr=cfg.optim.lr,
297
+ weight_decay=cfg.optim.weight_decay,
298
+ )
299
+ return optimizer_actor, optimizer_critic
300
+
301
+
302
+ # ====================================================================
303
+ # General utils
304
+ # ---------
305
+
306
+
307
+ def log_metrics(logger, metrics, step):
308
+ for metric_name, metric_value in metrics.items():
309
+ logger.log_scalar(metric_name, metric_value, step)
310
+
311
+
312
+ def get_activation(cfg):
313
+ if cfg.network.activation == "relu":
314
+ return nn.ReLU
315
+ elif cfg.network.activation == "tanh":
316
+ return nn.Tanh
317
+ elif cfg.network.activation == "leaky_relu":
318
+ return nn.LeakyReLU
319
+ else:
320
+ raise NotImplementedError
321
+
322
+
323
+ def dump_video(module):
324
+ if isinstance(module, VideoRecorder):
325
+ module.dump()
@@ -0,0 +1,163 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """Decision Transformer Example.
6
+ This is a self-contained example of an offline Decision Transformer training script.
7
+ The helper functions are coded in the utils.py associated with this script.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import warnings
13
+
14
+ import hydra
15
+ import numpy as np
16
+ import torch
17
+ import tqdm
18
+ from tensordict import TensorDict
19
+ from tensordict.nn import CudaGraphModule
20
+ from torchrl._utils import get_available_device, logger as torchrl_logger, timeit
21
+ from torchrl.envs.libs.gym import set_gym_backend
22
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
23
+ from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
24
+ from torchrl.record import VideoRecorder
25
+ from utils import (
26
+ dump_video,
27
+ log_metrics,
28
+ make_dt_loss,
29
+ make_dt_model,
30
+ make_dt_optimizer,
31
+ make_env,
32
+ make_logger,
33
+ make_offline_replay_buffer,
34
+ )
35
+
36
+
37
+ @hydra.main(config_path="", config_name="dt_config", version_base="1.1")
38
+ def main(cfg: DictConfig): # noqa: F821
39
+ set_gym_backend(cfg.env.backend).set()
40
+
41
+ model_device = (
42
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
43
+ )
44
+
45
+ # Set seeds
46
+ torch.manual_seed(cfg.env.seed)
47
+ np.random.seed(cfg.env.seed)
48
+
49
+ # Create logger
50
+ logger = make_logger(cfg)
51
+
52
+ # Create offline replay buffer
53
+ offline_buffer, obs_loc, obs_std = make_offline_replay_buffer(
54
+ cfg.replay_buffer, cfg.env.reward_scaling
55
+ )
56
+
57
+ # Create test environment
58
+ test_env = make_env(
59
+ cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video, device=model_device
60
+ )
61
+ if cfg.logger.video:
62
+ test_env = test_env.append_transform(
63
+ VideoRecorder(logger, tag="rendered", in_keys=["pixels"])
64
+ )
65
+
66
+ # Create policy model
67
+ actor = make_dt_model(cfg, device=model_device)
68
+
69
+ # Create loss
70
+ loss_module = make_dt_loss(cfg.loss, actor, device=model_device)
71
+
72
+ # Create optimizer
73
+ transformer_optim, scheduler = make_dt_optimizer(
74
+ cfg.optim, loss_module, model_device
75
+ )
76
+
77
+ # Create inference policy
78
+ inference_policy = DecisionTransformerInferenceWrapper(
79
+ policy=actor,
80
+ inference_context=cfg.env.inference_context,
81
+ device=model_device,
82
+ )
83
+ inference_policy.set_tensor_keys(
84
+ observation="observation_cat",
85
+ action="action_cat",
86
+ return_to_go="return_to_go_cat",
87
+ )
88
+
89
+ pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
90
+ clip_grad = cfg.optim.clip_grad
91
+
92
+ def update(data: TensorDict) -> TensorDict:
93
+ transformer_optim.zero_grad(set_to_none=True)
94
+ # Compute loss
95
+ loss_vals = loss_module(data)
96
+ transformer_loss = loss_vals["loss"]
97
+
98
+ transformer_loss.backward()
99
+ torch.nn.utils.clip_grad_norm_(actor.parameters(), clip_grad)
100
+ transformer_optim.step()
101
+
102
+ return loss_vals
103
+
104
+ if cfg.compile.compile:
105
+ compile_mode = cfg.compile.compile_mode
106
+ if compile_mode in ("", None):
107
+ if cfg.compile.cudagraphs:
108
+ compile_mode = "default"
109
+ else:
110
+ compile_mode = "reduce-overhead"
111
+ update = torch.compile(update, mode=compile_mode, dynamic=True)
112
+ if cfg.compile.cudagraphs:
113
+ warnings.warn(
114
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
115
+ category=UserWarning,
116
+ )
117
+ update = CudaGraphModule(update, warmup=50)
118
+
119
+ eval_steps = cfg.logger.eval_steps
120
+ pretrain_log_interval = cfg.logger.pretrain_log_interval
121
+ reward_scaling = cfg.env.reward_scaling
122
+
123
+ torchrl_logger.info(" ***Pretraining*** ")
124
+ # Pretraining
125
+ pbar = tqdm.tqdm(range(pretrain_gradient_steps))
126
+ for i in pbar:
127
+ timeit.printevery(1000, pretrain_gradient_steps, erase=True)
128
+ # Sample data
129
+ with timeit("rb - sample"):
130
+ data = offline_buffer.sample().to(model_device)
131
+ with timeit("update"):
132
+ loss_vals = update(data)
133
+ scheduler.step()
134
+ # Log metrics
135
+ metrics_to_log = {"train/loss": loss_vals["loss"]}
136
+
137
+ # Evaluation
138
+ with set_exploration_type(
139
+ ExplorationType.DETERMINISTIC
140
+ ), torch.no_grad(), timeit("eval"):
141
+ if i % pretrain_log_interval == 0:
142
+ eval_td = test_env.rollout(
143
+ max_steps=eval_steps,
144
+ policy=inference_policy,
145
+ auto_cast_to_device=True,
146
+ )
147
+ test_env.apply(dump_video)
148
+ metrics_to_log["eval/reward"] = (
149
+ eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
150
+ )
151
+
152
+ if logger is not None:
153
+ metrics_to_log.update(timeit.todict(prefix="time"))
154
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
155
+ log_metrics(logger, metrics_to_log, i)
156
+
157
+ pbar.close()
158
+ if not test_env.is_closed:
159
+ test_env.close()
160
+
161
+
162
+ if __name__ == "__main__":
163
+ main()