torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,586 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import contextlib
8
+ import functools
9
+ import importlib.util
10
+ import time
11
+
12
+ import hydra
13
+ import torch
14
+ import torch.cuda
15
+ import tqdm
16
+
17
+ from dreamer_utils import (
18
+ _default_device,
19
+ DreamerProfiler,
20
+ dump_video,
21
+ log_metrics,
22
+ make_collector,
23
+ make_dreamer,
24
+ make_environments,
25
+ make_replay_buffer,
26
+ make_storage_transform,
27
+ )
28
+ from omegaconf import DictConfig
29
+
30
+ # mixed precision training
31
+ from torch.amp import GradScaler
32
+ from torch.autograd.profiler import record_function
33
+ from torch.nn.utils import clip_grad_norm_
34
+ from torchrl._utils import compile_with_warmup, logger as torchrl_logger, timeit
35
+ from torchrl.envs.llm.transforms import PolicyVersion
36
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
37
+ from torchrl.objectives.dreamer import (
38
+ DreamerActorLoss,
39
+ DreamerModelLoss,
40
+ DreamerValueLoss,
41
+ )
42
+ from torchrl.record.loggers import generate_exp_name, get_logger
43
+
44
+
45
+ @hydra.main(version_base="1.1", config_path="", config_name="config")
46
+ def main(cfg: DictConfig): # noqa: F821
47
+ # cfg = correct_for_frame_skip(cfg)
48
+
49
+ device = _default_device(cfg.networks.device)
50
+ assert device.type == "cuda", "Dreamer only supports CUDA devices"
51
+
52
+ # Early check for video dependencies before starting training
53
+ if cfg.logger.video:
54
+ missing_deps = []
55
+ if importlib.util.find_spec("moviepy") is None:
56
+ missing_deps.append("moviepy (pip install moviepy)")
57
+ if importlib.util.find_spec("torchvision") is None:
58
+ missing_deps.append("torchvision (pip install torchvision)")
59
+ if missing_deps:
60
+ raise ImportError(
61
+ f"Video logging requires: {', '.join(missing_deps)}\n"
62
+ "Alternatively, disable video logging with: logger.video=False"
63
+ )
64
+
65
+ # Create logger
66
+ exp_name = generate_exp_name("Dreamer", cfg.logger.exp_name)
67
+ logger = None
68
+ if cfg.logger.backend:
69
+ logger = get_logger(
70
+ logger_type=cfg.logger.backend,
71
+ logger_name="dreamer_logging",
72
+ experiment_name=exp_name,
73
+ wandb_kwargs={
74
+ "mode": cfg.logger.mode,
75
+ "project": cfg.logger.project,
76
+ },
77
+ )
78
+ # Log hyperparameters using wandb.config.update() with OmegaConf resolution
79
+ # This properly resolves interpolations like ${env.name} and uses the official wandb API
80
+ if hasattr(logger, "log_hparams"):
81
+ logger.log_hparams(cfg)
82
+
83
+ # make_environments returns (train_env_factory, test_env) for async collection
84
+ train_env_factory, test_env = make_environments(
85
+ cfg=cfg,
86
+ parallel_envs=cfg.env.n_parallel_envs,
87
+ logger=logger,
88
+ )
89
+
90
+ # Make dreamer components
91
+ action_key = "action"
92
+ value_key = "state_value"
93
+ (
94
+ world_model,
95
+ model_based_env,
96
+ model_based_env_eval,
97
+ actor_model,
98
+ value_model,
99
+ policy,
100
+ ) = make_dreamer(
101
+ cfg=cfg,
102
+ device=device,
103
+ action_key=action_key,
104
+ value_key=value_key,
105
+ use_decoder_in_env=cfg.logger.video,
106
+ logger=logger,
107
+ )
108
+ # Losses
109
+ world_model_loss = DreamerModelLoss(world_model)
110
+ # Adapt loss keys to gym backend
111
+ if cfg.env.backend == "gym":
112
+ world_model_loss.set_keys(pixels="observation", reco_pixels="reco_observation")
113
+
114
+ actor_loss = DreamerActorLoss(
115
+ actor_model,
116
+ value_model,
117
+ model_based_env,
118
+ imagination_horizon=cfg.optimization.imagination_horizon,
119
+ discount_loss=True,
120
+ )
121
+
122
+ actor_loss.make_value_estimator(
123
+ gamma=cfg.optimization.gamma, lmbda=cfg.optimization.lmbda
124
+ )
125
+ value_loss = DreamerValueLoss(
126
+ value_model, discount_loss=True, gamma=cfg.optimization.gamma
127
+ )
128
+
129
+ # Make replay buffer with minimal sample-time transforms
130
+ # Note: Buffer must be created BEFORE collector for true async collection
131
+ batch_size = cfg.replay_buffer.batch_size
132
+ batch_length = cfg.replay_buffer.batch_length
133
+ buffer_size = cfg.replay_buffer.buffer_size
134
+ scratch_dir = cfg.replay_buffer.scratch_dir
135
+ prefetch = cfg.replay_buffer.prefetch
136
+ profiling_enabled = cfg.profiling.enabled
137
+ replay_buffer = make_replay_buffer(
138
+ batch_size=batch_size,
139
+ batch_seq_len=batch_length,
140
+ buffer_size=buffer_size,
141
+ buffer_scratch_dir=scratch_dir,
142
+ device=device,
143
+ prefetch=prefetch if not profiling_enabled else None,
144
+ pixel_obs=cfg.env.from_pixels,
145
+ grayscale=cfg.env.grayscale,
146
+ image_size=cfg.env.image_size,
147
+ )
148
+
149
+ # Create storage transform for extend-time processing (applied once per frame)
150
+ storage_transform = make_storage_transform(
151
+ pixel_obs=cfg.env.from_pixels,
152
+ grayscale=cfg.env.grayscale,
153
+ image_size=cfg.env.image_size,
154
+ )
155
+
156
+ # Create policy version tracker for async collection
157
+ # This tracks policy versions so we can correlate collected data with policy updates
158
+ policy_version = PolicyVersion(version_type="int")
159
+
160
+ # Make async multi-collector with replay buffer for true async collection
161
+ # Device allocation: cuda:0 for training, cuda:1+ for collectors (if multi-GPU)
162
+ collector = make_collector(
163
+ cfg,
164
+ train_env_factory,
165
+ policy,
166
+ training_device=device,
167
+ replay_buffer=replay_buffer,
168
+ storage_transform=storage_transform,
169
+ track_policy_version=policy_version,
170
+ )
171
+
172
+ # Enable collector worker profiling if configured
173
+ if profiling_enabled and cfg.profiling.collector.enabled:
174
+ torchrl_logger.info(
175
+ f"Enabling collector profiling: workers={cfg.profiling.collector.workers}, "
176
+ f"num_rollouts={cfg.profiling.collector.num_rollouts}, "
177
+ f"warmup_rollouts={cfg.profiling.collector.warmup_rollouts}, "
178
+ f"init_random_frames_override={cfg.profiling.collector.init_random_frames_override}"
179
+ )
180
+ collector.enable_profile(
181
+ workers=list(cfg.profiling.collector.workers),
182
+ num_rollouts=cfg.profiling.collector.num_rollouts,
183
+ warmup_rollouts=cfg.profiling.collector.warmup_rollouts,
184
+ save_path=cfg.profiling.collector.trace_file,
185
+ activities=["cpu", "cuda"] if cfg.profiling.profile_cuda else ["cpu"],
186
+ record_shapes=cfg.profiling.record_shapes,
187
+ profile_memory=cfg.profiling.profile_memory,
188
+ with_stack=cfg.profiling.with_stack,
189
+ with_flops=cfg.profiling.with_flops,
190
+ )
191
+
192
+ # Training config
193
+ total_optim_steps = cfg.optimization.total_optim_steps
194
+ log_every = cfg.optimization.log_every
195
+ grad_clip = cfg.optimization.grad_clip
196
+ eval_every = cfg.logger.eval_every
197
+ eval_rollout_steps = cfg.logger.eval_rollout_steps
198
+
199
+ # Override total_optim_steps if profiling is enabled
200
+ if profiling_enabled:
201
+ total_optim_steps = cfg.profiling.total_optim_steps
202
+
203
+ # Training loop - progress bar tracks optimization steps
204
+ pbar = tqdm.tqdm(total=total_optim_steps, desc="Optim steps")
205
+
206
+ # Make optimizer (fused=True for faster GPU execution)
207
+ use_fused = device.type == "cuda"
208
+ world_model_opt = torch.optim.Adam(
209
+ world_model.parameters(), lr=cfg.optimization.world_model_lr, fused=use_fused
210
+ )
211
+ actor_opt = torch.optim.Adam(
212
+ actor_model.parameters(), lr=cfg.optimization.actor_lr, fused=use_fused
213
+ )
214
+ value_opt = torch.optim.Adam(
215
+ value_model.parameters(), lr=cfg.optimization.value_lr, fused=use_fused
216
+ )
217
+
218
+ # Grad scaler for mixed precision training https://pytorch.org/docs/stable/amp.html
219
+ # autocast can be: false, true (=bfloat16), float16, bfloat16
220
+ autocast_cfg = cfg.optimization.autocast
221
+ if autocast_cfg in (False, "false", "False"):
222
+ autocast_dtype = None
223
+ elif autocast_cfg in (True, "true", "True", "bfloat16"):
224
+ autocast_dtype = torch.bfloat16
225
+ elif autocast_cfg == "float16":
226
+ autocast_dtype = torch.float16
227
+ else:
228
+ raise ValueError(
229
+ f"Invalid autocast value: {autocast_cfg}. Use false, true, float16, or bfloat16."
230
+ )
231
+
232
+ if autocast_dtype is not None:
233
+ scaler1 = GradScaler()
234
+ scaler2 = GradScaler()
235
+ scaler3 = GradScaler()
236
+
237
+ # Enable TensorFloat32 for better performance on Ampere+ GPUs
238
+ if device.type == "cuda":
239
+ torch.set_float32_matmul_precision("high")
240
+
241
+ compile_cfg = cfg.optimization.compile
242
+ compile_enabled = compile_cfg.enabled
243
+ compile_losses = set(compile_cfg.losses)
244
+ if compile_enabled:
245
+ torch._dynamo.config.capture_scalar_outputs = True
246
+
247
+ compile_warmup = 3
248
+ torchrl_logger.info(f"Compiling loss modules with warmup={compile_warmup}")
249
+ backend = compile_cfg.backend
250
+ mode = compile_cfg.mode
251
+
252
+ # Note: We do NOT compile rssm_prior/rssm_posterior here because they are
253
+ # shared with the policy used in the collector. Compiling them would cause
254
+ # issues with the MultiCollector workers.
255
+ #
256
+ # Instead, we compile the loss modules themselves which wraps the forward pass.
257
+ # fullgraph=False allows graph breaks which can help with inductor issues.
258
+ # warmup=compile_warmup runs eagerly for first `compile_warmup` calls before compiling.
259
+ if "world_model" in compile_losses:
260
+ world_model_loss = compile_with_warmup(
261
+ world_model_loss,
262
+ backend=backend,
263
+ mode=mode,
264
+ fullgraph=False,
265
+ warmup=compile_warmup,
266
+ )
267
+ if "actor" in compile_losses:
268
+ actor_loss = compile_with_warmup(
269
+ actor_loss, backend=backend, mode=mode, warmup=compile_warmup
270
+ )
271
+ if "value" in compile_losses:
272
+ value_loss = compile_with_warmup(
273
+ value_loss, backend=backend, mode=mode, warmup=compile_warmup
274
+ )
275
+ else:
276
+ compile_warmup = 0
277
+
278
+ # Throughput tracking
279
+ t_log_start = time.time()
280
+
281
+ # Profiling setup (encapsulated in helper class)
282
+ profiler = DreamerProfiler(cfg, device, pbar, compile_warmup=compile_warmup)
283
+
284
+ # Start async collection - collector fills the buffer in background
285
+ torchrl_logger.info("Starting async collection...")
286
+ torchrl_logger.debug(f"Collector type: {type(collector).__name__}")
287
+ torchrl_logger.debug(f"Number of collector workers: {cfg.collector.num_collectors}")
288
+ collector.start()
289
+ torchrl_logger.debug("collector.start() completed")
290
+
291
+ # Wait for enough samples to start training
292
+ # The collector handles init_random_frames internally, but we also wait here
293
+ # to ensure the buffer has enough data before we start sampling.
294
+ # Use init_random_frames_override when collector profiling is enabled
295
+ if profiling_enabled and cfg.profiling.collector.enabled:
296
+ min_frames_to_start = cfg.profiling.collector.init_random_frames_override
297
+ torchrl_logger.info(
298
+ f"Collector profiling: overriding init_random_frames to {min_frames_to_start}"
299
+ )
300
+ else:
301
+ min_frames_to_start = cfg.collector.init_random_frames
302
+
303
+ # Always need at least batch_size frames to sample a batch
304
+ # (bug fix: init_random_frames_override=0 would hang on empty buffer)
305
+ min_frames_to_start = max(min_frames_to_start, batch_size)
306
+ torchrl_logger.info(
307
+ f"Waiting for {min_frames_to_start} initial frames before training..."
308
+ )
309
+ while replay_buffer.write_count < min_frames_to_start:
310
+ time.sleep(0.1)
311
+
312
+ torchrl_logger.info(
313
+ f"Collected {replay_buffer.write_count} frames (random frames phase complete: {min_frames_to_start} frames). "
314
+ f"Starting training..."
315
+ )
316
+ torchrl_logger.info(
317
+ "NOTE: From now on, collectors will use the policy instead of random actions. "
318
+ "Policy outputs keys like 'encoded_latents', 'loc', 'scale' that weren't present during random collection."
319
+ )
320
+
321
+ # Track frames for FPS calculation over logging interval
322
+ frames_at_log_start = replay_buffer.write_count
323
+
324
+ # Main training loop - iterate over optimization steps
325
+ for optim_step in range(total_optim_steps):
326
+ # Update progress bar every step
327
+ pbar.update(1)
328
+
329
+ # Debug logging every 100 steps
330
+ if optim_step % 100 == 0:
331
+ cuda_mem_allocated = torch.cuda.memory_allocated(device) / (1024**3)
332
+ cuda_mem_reserved = torch.cuda.memory_reserved(device) / (1024**3)
333
+ torchrl_logger.debug(
334
+ f"optim_step={optim_step}: "
335
+ f"buffer_count={replay_buffer.write_count}, "
336
+ f"cuda_allocated={cuda_mem_allocated:.2f}GB, "
337
+ f"cuda_reserved={cuda_mem_reserved:.2f}GB"
338
+ )
339
+
340
+ # sample from replay buffer
341
+ with timeit("train/sample"), record_function("## train/sample ##"):
342
+ sampled_tensordict = replay_buffer.sample().reshape(-1, batch_length)
343
+ if profiling_enabled:
344
+ torch.cuda.synchronize()
345
+
346
+ # update world model
347
+ with timeit("train/world_model-forward"), record_function(
348
+ "## world_model/forward ##"
349
+ ):
350
+ # Mark step begin for CUDAGraph to prevent tensor overwrite issues
351
+ torch.compiler.cudagraph_mark_step_begin()
352
+ with torch.autocast(
353
+ device_type=device.type,
354
+ dtype=autocast_dtype,
355
+ ) if autocast_dtype else contextlib.nullcontext():
356
+ assert (
357
+ sampled_tensordict.device.type == "cuda"
358
+ ), "sampled_tensordict should be on CUDA"
359
+ model_loss_td, sampled_tensordict = world_model_loss(sampled_tensordict)
360
+ loss_world_model = (
361
+ model_loss_td["loss_model_kl"]
362
+ + model_loss_td["loss_model_reco"]
363
+ + model_loss_td["loss_model_reward"]
364
+ )
365
+
366
+ with timeit("train/world_model-backward"), record_function(
367
+ "## world_model/backward ##"
368
+ ):
369
+ world_model_opt.zero_grad()
370
+ if autocast_dtype:
371
+ scaler1.scale(loss_world_model).backward()
372
+ scaler1.unscale_(world_model_opt)
373
+ else:
374
+ loss_world_model.backward()
375
+ torchrl_logger.debug("world_model_loss backward OK")
376
+ world_model_grad = clip_grad_norm_(world_model.parameters(), grad_clip)
377
+ if autocast_dtype:
378
+ scaler1.step(world_model_opt)
379
+ scaler1.update()
380
+ else:
381
+ world_model_opt.step()
382
+
383
+ # update actor network
384
+ with timeit("train/actor-forward"), record_function("## actor/forward ##"):
385
+ # Mark step begin for CUDAGraph to prevent tensor overwrite issues
386
+ torch.compiler.cudagraph_mark_step_begin()
387
+ with torch.autocast(
388
+ device_type=device.type, dtype=autocast_dtype
389
+ ) if autocast_dtype else contextlib.nullcontext():
390
+ actor_loss_td, sampled_tensordict = actor_loss(
391
+ sampled_tensordict.reshape(-1)
392
+ )
393
+
394
+ with timeit("train/actor-backward"), record_function("## actor/backward ##"):
395
+ actor_opt.zero_grad()
396
+ if autocast_dtype:
397
+ scaler2.scale(actor_loss_td["loss_actor"]).backward()
398
+ scaler2.unscale_(actor_opt)
399
+ else:
400
+ actor_loss_td["loss_actor"].backward()
401
+ torchrl_logger.debug("actor_loss backward OK")
402
+ actor_model_grad = clip_grad_norm_(actor_model.parameters(), grad_clip)
403
+ if autocast_dtype:
404
+ scaler2.step(actor_opt)
405
+ scaler2.update()
406
+ else:
407
+ actor_opt.step()
408
+
409
+ # update value network
410
+ with timeit("train/value-forward"), record_function("## value/forward ##"):
411
+ # Mark step begin for CUDAGraph to prevent tensor overwrite issues
412
+ torch.compiler.cudagraph_mark_step_begin()
413
+ with torch.autocast(
414
+ device_type=device.type, dtype=autocast_dtype
415
+ ) if autocast_dtype else contextlib.nullcontext():
416
+ value_loss_td, sampled_tensordict = value_loss(sampled_tensordict)
417
+
418
+ with timeit("train/value-backward"), record_function("## value/backward ##"):
419
+ value_opt.zero_grad()
420
+ if autocast_dtype:
421
+ scaler3.scale(value_loss_td["loss_value"]).backward()
422
+ scaler3.unscale_(value_opt)
423
+ else:
424
+ value_loss_td["loss_value"].backward()
425
+ torchrl_logger.debug("value_loss backward OK")
426
+ critic_model_grad = clip_grad_norm_(value_model.parameters(), grad_clip)
427
+ if autocast_dtype:
428
+ scaler3.step(value_opt)
429
+ scaler3.update()
430
+ else:
431
+ value_opt.step()
432
+
433
+ # Step profiler (returns True if profiling complete)
434
+ if profiler.step():
435
+ break
436
+
437
+ # Check if profiling is complete and we should exit
438
+ if profiler.should_exit():
439
+ torchrl_logger.info("Profiling complete. Exiting training loop.")
440
+ break
441
+
442
+ # Log metrics periodically (every log_every optim steps)
443
+ if (optim_step + 1) % log_every == 0:
444
+ # Track collected frames from buffer write count
445
+ collected_frames = replay_buffer.write_count
446
+ frames_collected_this_interval = collected_frames - frames_at_log_start
447
+
448
+ # Compute throughput metrics
449
+ t_log_end = time.time()
450
+ log_interval_time = t_log_end - t_log_start
451
+
452
+ # SPS: Samples (batch elements) processed per second
453
+ total_samples = log_every * batch_size
454
+ sps = total_samples / log_interval_time if log_interval_time > 0 else 0
455
+
456
+ # UPS: Updates (gradient steps) per second
457
+ # 3 updates per optim step (world_model, actor, value)
458
+ total_updates = log_every * 3
459
+ ups = total_updates / log_interval_time if log_interval_time > 0 else 0
460
+
461
+ # FPS: Frames collected per second (measured from buffer over logging interval)
462
+ fps = (
463
+ frames_collected_this_interval / log_interval_time
464
+ if log_interval_time > 0
465
+ else 0
466
+ )
467
+
468
+ # OPS: Optim steps per second
469
+ ops = log_every / log_interval_time if log_interval_time > 0 else 0
470
+
471
+ # OPF: Optim steps per frame (ratio of training to collection)
472
+ opf = (optim_step + 1) / collected_frames if collected_frames > 0 else 0
473
+
474
+ # Update progress bar with throughput metrics
475
+ pbar.set_postfix(
476
+ fps=f"{fps:.1f}",
477
+ ops=f"{ops:.1f}",
478
+ opf=f"{opf:.2f}",
479
+ frames=collected_frames,
480
+ )
481
+
482
+ # Get reward stats from sampled data (since we don't iterate over collector directly)
483
+ sampled_reward = sampled_tensordict.get(("next", "reward"))
484
+ reward_mean = sampled_reward.mean().item()
485
+ reward_std = sampled_reward.std().item()
486
+
487
+ metrics = {
488
+ "loss_model_kl": model_loss_td["loss_model_kl"].item(),
489
+ "loss_model_reco": model_loss_td["loss_model_reco"].item(),
490
+ "loss_model_reward": model_loss_td["loss_model_reward"].item(),
491
+ "loss_actor": actor_loss_td["loss_actor"].item(),
492
+ "loss_value": value_loss_td["loss_value"].item(),
493
+ "world_model_grad": world_model_grad,
494
+ "actor_model_grad": actor_model_grad,
495
+ "critic_model_grad": critic_model_grad,
496
+ # Reward stats from sampled batch
497
+ "train/reward_mean": reward_mean,
498
+ "train/reward_std": reward_std,
499
+ # Throughput metrics
500
+ "throughput/fps": fps, # Frames per second (collection)
501
+ "throughput/ops": ops, # Optim steps per second
502
+ "throughput/opf": opf, # Optim steps per frame
503
+ "throughput/sps": sps, # Samples per second (training)
504
+ "throughput/ups": ups, # Updates per second (gradient steps)
505
+ "throughput/log_interval_time": log_interval_time,
506
+ # Collection tracking (not a target, just for monitoring)
507
+ "collected_frames": collected_frames,
508
+ # Policy version tracking
509
+ "policy_version": policy_version.version,
510
+ # Detailed timing from timeit (some metrics may be empty when compiling)
511
+ **timeit.todict(prefix="time"),
512
+ }
513
+
514
+ if logger is not None:
515
+ log_metrics(logger, metrics, collected_frames)
516
+
517
+ # Reset timer and frame counter for next logging interval
518
+ t_log_start = time.time()
519
+ frames_at_log_start = collected_frames
520
+
521
+ # Update policy weights in collector (for async collection)
522
+ with timeit("train/weight_update") as weight_update_timer:
523
+ torchrl_logger.debug(
524
+ f"optim_step={optim_step}: Starting weight update..."
525
+ )
526
+ policy[1].step(frames_collected_this_interval)
527
+ collector.update_policy_weights_()
528
+ # Increment policy version after weight update
529
+ collector.increment_version()
530
+ torchrl_logger.debug(
531
+ f"optim_step={optim_step}: Weight update completed in "
532
+ f"{weight_update_timer.elapsed():.3f}s, "
533
+ f"policy_version={policy_version.version}"
534
+ )
535
+
536
+ # Evaluation (every eval_every optimization steps)
537
+ if (optim_step + 1) % eval_every == 0:
538
+ # Real env
539
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
540
+ eval_rollout = test_env.rollout(
541
+ eval_rollout_steps,
542
+ policy,
543
+ auto_cast_to_device=True,
544
+ break_when_any_done=True,
545
+ )
546
+ test_env.apply(
547
+ functools.partial(dump_video, step=replay_buffer.write_count)
548
+ )
549
+ eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
550
+ eval_metrics = {"eval/reward": eval_reward}
551
+ if logger is not None:
552
+ log_metrics(logger, eval_metrics, replay_buffer.write_count)
553
+ # Simulated env
554
+ if model_based_env_eval is not None:
555
+ with set_exploration_type(
556
+ ExplorationType.DETERMINISTIC
557
+ ), torch.no_grad():
558
+ eval_rollout = model_based_env_eval.rollout(
559
+ eval_rollout_steps,
560
+ policy,
561
+ auto_cast_to_device=True,
562
+ break_when_any_done=True,
563
+ auto_reset=False,
564
+ tensordict=eval_rollout[..., 0]
565
+ .exclude("next", "action")
566
+ .to(device),
567
+ )
568
+ model_based_env_eval.apply(
569
+ functools.partial(dump_video, step=replay_buffer.write_count)
570
+ )
571
+ eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
572
+ eval_metrics = {"eval/simulated_reward": eval_reward}
573
+ if logger is not None:
574
+ log_metrics(logger, eval_metrics, replay_buffer.write_count)
575
+
576
+ if not test_env.is_closed:
577
+ test_env.close()
578
+ # Shutdown async collector (use async_shutdown since we used start())
579
+ collector.async_shutdown()
580
+
581
+ del test_env
582
+ del collector
583
+
584
+
585
+ if __name__ == "__main__":
586
+ main()