torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,261 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ This script reproduces the IMPALA Algorithm
8
+ results from Espeholt et al. 2018 for the on Atari Environments.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import hydra
13
+ from torchrl._utils import logger as torchrl_logger
14
+
15
+
16
+ @hydra.main(config_path="", config_name="config_single_node", version_base="1.1")
17
+ def main(cfg: DictConfig): # noqa: F821
18
+
19
+ import time
20
+
21
+ import torch.optim
22
+ import tqdm
23
+
24
+ from tensordict import TensorDict
25
+ from torchrl.collectors import MultiaSyncDataCollector
26
+ from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
27
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
28
+ from torchrl.envs import ExplorationType, set_exploration_type
29
+ from torchrl.objectives import A2CLoss
30
+ from torchrl.objectives.value import VTrace
31
+ from torchrl.record.loggers import generate_exp_name, get_logger
32
+ from utils import eval_model, make_env, make_ppo_models
33
+
34
+ device = cfg.device
35
+ if not device:
36
+ device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
37
+ else:
38
+ device = torch.device(device)
39
+
40
+ # Correct for frame_skip
41
+ frame_skip = 4
42
+ total_frames = cfg.collector.total_frames // frame_skip
43
+ frames_per_batch = cfg.collector.frames_per_batch // frame_skip
44
+ test_interval = cfg.logger.test_interval // frame_skip
45
+
46
+ # Extract other config parameters
47
+ batch_size = cfg.loss.batch_size # Number of rollouts per batch
48
+ num_workers = (
49
+ cfg.collector.num_workers
50
+ ) # Number of parallel workers collecting rollouts
51
+ lr = cfg.optim.lr
52
+ anneal_lr = cfg.optim.anneal_lr
53
+ sgd_updates = cfg.loss.sgd_updates
54
+ max_grad_norm = cfg.optim.max_grad_norm
55
+ num_test_episodes = cfg.logger.num_test_episodes
56
+ total_network_updates = (
57
+ total_frames // (frames_per_batch * batch_size)
58
+ ) * cfg.loss.sgd_updates
59
+
60
+ # Create models (check utils.py)
61
+ actor, critic = make_ppo_models(cfg.env.env_name, cfg.env.backend)
62
+
63
+ # Create collector
64
+ collector = MultiaSyncDataCollector(
65
+ create_env_fn=[make_env(cfg.env.env_name, device, gym_backend=cfg.env.backend)]
66
+ * num_workers,
67
+ policy=actor,
68
+ frames_per_batch=frames_per_batch,
69
+ total_frames=total_frames,
70
+ device=device,
71
+ storing_device=device,
72
+ max_frames_per_traj=-1,
73
+ update_at_each_batch=True,
74
+ )
75
+
76
+ # Create data buffer
77
+ sampler = SamplerWithoutReplacement()
78
+ data_buffer = TensorDictReplayBuffer(
79
+ storage=LazyMemmapStorage(frames_per_batch * batch_size),
80
+ sampler=sampler,
81
+ batch_size=frames_per_batch * batch_size,
82
+ )
83
+
84
+ # Create loss and adv modules
85
+ adv_module = VTrace(
86
+ gamma=cfg.loss.gamma,
87
+ value_network=critic,
88
+ actor_network=actor,
89
+ average_adv=False,
90
+ )
91
+ loss_module = A2CLoss(
92
+ actor_network=actor,
93
+ critic_network=critic,
94
+ loss_critic_type=cfg.loss.loss_critic_type,
95
+ entropy_coeff=cfg.loss.entropy_coeff,
96
+ critic_coeff=cfg.loss.critic_coeff,
97
+ )
98
+ loss_module.set_keys(done="eol", terminated="eol")
99
+
100
+ # Create optimizer
101
+ optim = torch.optim.RMSprop(
102
+ loss_module.parameters(),
103
+ lr=cfg.optim.lr,
104
+ weight_decay=cfg.optim.weight_decay,
105
+ eps=cfg.optim.eps,
106
+ alpha=cfg.optim.alpha,
107
+ )
108
+
109
+ # Create logger
110
+ logger = None
111
+ if cfg.logger.backend:
112
+ exp_name = generate_exp_name(
113
+ "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}"
114
+ )
115
+ logger = get_logger(
116
+ cfg.logger.backend,
117
+ logger_name="impala",
118
+ experiment_name=exp_name,
119
+ wandb_kwargs={
120
+ "config": dict(cfg),
121
+ "project": cfg.logger.project_name,
122
+ "group": cfg.logger.group_name,
123
+ },
124
+ )
125
+
126
+ # Create test environment
127
+ test_env = make_env(
128
+ cfg.env.env_name, device, gym_backend=cfg.env.backend, is_test=True
129
+ )
130
+ test_env.eval()
131
+
132
+ # Main loop
133
+ collected_frames = 0
134
+ num_network_updates = 0
135
+ pbar = tqdm.tqdm(total=total_frames)
136
+ accumulator = []
137
+ start_time = sampling_start = time.time()
138
+ for i, data in enumerate(collector):
139
+
140
+ metrics_to_log = {}
141
+ sampling_time = time.time() - sampling_start
142
+ frames_in_batch = data.numel()
143
+ collected_frames += frames_in_batch * frame_skip
144
+ pbar.update(data.numel())
145
+
146
+ # Get training rewards and episode lengths
147
+ episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
148
+ if len(episode_rewards) > 0:
149
+ episode_length = data["next", "step_count"][data["next", "terminated"]]
150
+ metrics_to_log.update(
151
+ {
152
+ "train/reward": episode_rewards.mean().item(),
153
+ "train/episode_length": episode_length.sum().item()
154
+ / len(episode_length),
155
+ }
156
+ )
157
+
158
+ if len(accumulator) < batch_size:
159
+ accumulator.append(data)
160
+ if logger:
161
+ for key, value in metrics_to_log.items():
162
+ logger.log_scalar(key, value, collected_frames)
163
+ continue
164
+
165
+ losses = TensorDict(batch_size=[sgd_updates])
166
+ training_start = time.time()
167
+ for j in range(sgd_updates):
168
+
169
+ # Create a single batch of trajectories
170
+ stacked_data = torch.stack(accumulator, dim=0).contiguous()
171
+ stacked_data = stacked_data.to(device, non_blocking=True)
172
+
173
+ # Compute advantage
174
+ with torch.no_grad():
175
+ stacked_data = adv_module(stacked_data)
176
+
177
+ # Add to replay buffer
178
+ for stacked_d in stacked_data:
179
+ stacked_data_reshape = stacked_d.reshape(-1)
180
+ data_buffer.extend(stacked_data_reshape)
181
+
182
+ for batch in data_buffer:
183
+
184
+ # Linearly decrease the learning rate and clip epsilon
185
+ alpha = 1.0
186
+ if anneal_lr:
187
+ alpha = 1 - (num_network_updates / total_network_updates)
188
+ for group in optim.param_groups:
189
+ group["lr"] = lr * alpha
190
+ num_network_updates += 1
191
+
192
+ # Get a data batch
193
+ batch = batch.to(device, non_blocking=True)
194
+
195
+ # Forward pass loss
196
+ loss = loss_module(batch)
197
+ losses[j] = loss.select(
198
+ "loss_critic", "loss_entropy", "loss_objective"
199
+ ).detach()
200
+ loss_sum = (
201
+ loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
202
+ )
203
+
204
+ # Backward pass
205
+ loss_sum.backward()
206
+ torch.nn.utils.clip_grad_norm_(
207
+ list(loss_module.parameters()), max_norm=max_grad_norm
208
+ )
209
+
210
+ # Update the networks
211
+ optim.step()
212
+ optim.zero_grad()
213
+
214
+ # Get training losses and times
215
+ training_time = time.time() - training_start
216
+ losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
217
+ for key, value in losses.items():
218
+ metrics_to_log.update({f"train/{key}": value.item()})
219
+ metrics_to_log.update(
220
+ {
221
+ "train/lr": alpha * lr,
222
+ "train/sampling_time": sampling_time,
223
+ "train/training_time": training_time,
224
+ }
225
+ )
226
+
227
+ # Get test rewards
228
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
229
+ if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
230
+ i * frames_in_batch * frame_skip
231
+ ) // test_interval:
232
+ actor.eval()
233
+ eval_start = time.time()
234
+ test_reward = eval_model(
235
+ actor, test_env, num_episodes=num_test_episodes
236
+ )
237
+ eval_time = time.time() - eval_start
238
+ metrics_to_log.update(
239
+ {
240
+ "eval/reward": test_reward,
241
+ "eval/time": eval_time,
242
+ }
243
+ )
244
+ actor.train()
245
+
246
+ if logger:
247
+ for key, value in metrics_to_log.items():
248
+ logger.log_scalar(key, value, collected_frames)
249
+
250
+ collector.update_policy_weights_()
251
+ sampling_start = time.time()
252
+ accumulator = []
253
+
254
+ collector.shutdown()
255
+ end_time = time.time()
256
+ execution_time = end_time - start_time
257
+ torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
258
+
259
+
260
+ if __name__ == "__main__":
261
+ main()
@@ -0,0 +1,184 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch.nn
8
+ import torch.optim
9
+ from tensordict.nn import TensorDictModule
10
+ from torchrl.envs import (
11
+ CatFrames,
12
+ DoubleToFloat,
13
+ EndOfLifeTransform,
14
+ ExplorationType,
15
+ GrayScale,
16
+ GymEnv,
17
+ NoopResetEnv,
18
+ Resize,
19
+ RewardSum,
20
+ set_gym_backend,
21
+ SignTransform,
22
+ StepCounter,
23
+ ToTensorImage,
24
+ TransformedEnv,
25
+ VecNorm,
26
+ )
27
+ from torchrl.modules import (
28
+ ActorValueOperator,
29
+ ConvNet,
30
+ MLP,
31
+ OneHotCategorical,
32
+ ProbabilisticActor,
33
+ ValueOperator,
34
+ )
35
+
36
+
37
+ # ====================================================================
38
+ # Environment utils
39
+ # --------------------------------------------------------------------
40
+
41
+
42
+ def make_env(env_name, device, gym_backend, is_test=False):
43
+ with set_gym_backend(gym_backend):
44
+ env = GymEnv(
45
+ env_name, frame_skip=4, from_pixels=True, pixels_only=False, device=device
46
+ )
47
+ env = TransformedEnv(env)
48
+ env.append_transform(NoopResetEnv(noops=30, random=True))
49
+ if not is_test:
50
+ env.append_transform(EndOfLifeTransform())
51
+ env.append_transform(SignTransform(in_keys=["reward"]))
52
+ env.append_transform(ToTensorImage(from_int=False))
53
+ env.append_transform(GrayScale())
54
+ env.append_transform(Resize(84, 84))
55
+ env.append_transform(CatFrames(N=4, dim=-3))
56
+ env.append_transform(RewardSum())
57
+ env.append_transform(StepCounter(max_steps=4500))
58
+ env.append_transform(DoubleToFloat())
59
+ env.append_transform(VecNorm(in_keys=["pixels"]))
60
+ return env
61
+
62
+
63
+ # ====================================================================
64
+ # Model utils
65
+ # --------------------------------------------------------------------
66
+
67
+
68
+ def make_ppo_modules_pixels(proof_environment):
69
+
70
+ # Define input shape
71
+ input_shape = proof_environment.observation_spec["pixels"].shape
72
+
73
+ # Define distribution class and kwargs
74
+ num_outputs = proof_environment.action_spec_unbatched.space.n
75
+ distribution_class = OneHotCategorical
76
+ distribution_kwargs = {}
77
+
78
+ # Define input keys
79
+ in_keys = ["pixels"]
80
+
81
+ # Define a shared Module and TensorDictModule (CNN + MLP)
82
+ common_cnn = ConvNet(
83
+ activation_class=torch.nn.ReLU,
84
+ num_cells=[32, 64, 64],
85
+ kernel_sizes=[8, 4, 3],
86
+ strides=[4, 2, 1],
87
+ )
88
+ common_cnn_output = common_cnn(torch.ones(input_shape))
89
+ common_mlp = MLP(
90
+ in_features=common_cnn_output.shape[-1],
91
+ activation_class=torch.nn.ReLU,
92
+ activate_last_layer=True,
93
+ out_features=512,
94
+ num_cells=[],
95
+ )
96
+ common_mlp_output = common_mlp(common_cnn_output)
97
+
98
+ # Define shared net as TensorDictModule
99
+ common_module = TensorDictModule(
100
+ module=torch.nn.Sequential(common_cnn, common_mlp),
101
+ in_keys=in_keys,
102
+ out_keys=["common_features"],
103
+ )
104
+
105
+ # Define one head for the policy
106
+ policy_net = MLP(
107
+ in_features=common_mlp_output.shape[-1],
108
+ out_features=num_outputs,
109
+ activation_class=torch.nn.ReLU,
110
+ num_cells=[],
111
+ )
112
+ policy_module = TensorDictModule(
113
+ module=policy_net,
114
+ in_keys=["common_features"],
115
+ out_keys=["logits"],
116
+ )
117
+
118
+ # Add probabilistic sampling of the actions
119
+ policy_module = ProbabilisticActor(
120
+ policy_module,
121
+ in_keys=["logits"],
122
+ spec=proof_environment.full_action_spec_unbatched,
123
+ distribution_class=distribution_class,
124
+ distribution_kwargs=distribution_kwargs,
125
+ return_log_prob=True,
126
+ default_interaction_type=ExplorationType.RANDOM,
127
+ )
128
+
129
+ # Define another head for the value
130
+ value_net = MLP(
131
+ activation_class=torch.nn.ReLU,
132
+ in_features=common_mlp_output.shape[-1],
133
+ out_features=1,
134
+ num_cells=[],
135
+ )
136
+ value_module = ValueOperator(
137
+ value_net,
138
+ in_keys=["common_features"],
139
+ )
140
+
141
+ return common_module, policy_module, value_module
142
+
143
+
144
+ def make_ppo_models(env_name, gym_backend):
145
+
146
+ proof_environment = make_env(env_name, device="cpu", gym_backend=gym_backend)
147
+ common_module, policy_module, value_module = make_ppo_modules_pixels(
148
+ proof_environment
149
+ )
150
+
151
+ # Wrap modules in a single ActorCritic operator
152
+ actor_critic = ActorValueOperator(
153
+ common_operator=common_module,
154
+ policy_operator=policy_module,
155
+ value_operator=value_module,
156
+ )
157
+
158
+ actor = actor_critic.get_policy_operator()
159
+ critic = actor_critic.get_value_operator()
160
+
161
+ del proof_environment
162
+
163
+ return actor, critic
164
+
165
+
166
+ # ====================================================================
167
+ # Evaluation utils
168
+ # --------------------------------------------------------------------
169
+
170
+
171
+ def eval_model(actor, test_env, num_episodes=3):
172
+ test_rewards = torch.zeros(num_episodes, dtype=torch.float32)
173
+ for i in range(num_episodes):
174
+ td_test = test_env.rollout(
175
+ policy=actor,
176
+ auto_reset=True,
177
+ auto_cast_to_device=True,
178
+ break_when_any_done=True,
179
+ max_steps=10_000_000,
180
+ )
181
+ reward = td_test["next", "episode_reward"][td_test["next", "done"]]
182
+ test_rewards[i] = reward.sum()
183
+ del td_test
184
+ return test_rewards.mean()
@@ -0,0 +1,230 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """IQL Example.
6
+
7
+ This is a self-contained example of an online discrete IQL training script.
8
+
9
+ It works across Gym and MuJoCo over a variety of tasks.
10
+
11
+ The helper functions are coded in the utils.py associated with this script.
12
+
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import warnings
17
+
18
+ import hydra
19
+ import numpy as np
20
+ import torch
21
+ import tqdm
22
+ from tensordict import TensorDict
23
+ from tensordict.nn import CudaGraphModule
24
+ from torchrl._utils import get_available_device, timeit
25
+ from torchrl.envs import set_gym_backend
26
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
27
+ from torchrl.objectives import group_optimizers
28
+ from torchrl.record.loggers import generate_exp_name, get_logger
29
+ from utils import (
30
+ dump_video,
31
+ log_metrics,
32
+ make_collector,
33
+ make_discrete_iql_model,
34
+ make_discrete_loss,
35
+ make_environment,
36
+ make_iql_optimizer,
37
+ make_replay_buffer,
38
+ )
39
+
40
+ torch.set_float32_matmul_precision("high")
41
+
42
+
43
+ @hydra.main(config_path="", config_name="discrete_iql")
44
+ def main(cfg: DictConfig): # noqa: F821
45
+ set_gym_backend(cfg.env.backend).set()
46
+
47
+ # Create logger
48
+ exp_name = generate_exp_name("Discrete-IQL-online", cfg.logger.exp_name)
49
+ logger = None
50
+ if cfg.logger.backend:
51
+ logger = get_logger(
52
+ logger_type=cfg.logger.backend,
53
+ logger_name="iql_logging",
54
+ experiment_name=exp_name,
55
+ wandb_kwargs={
56
+ "mode": cfg.logger.mode,
57
+ "config": dict(cfg),
58
+ "project": cfg.logger.project_name,
59
+ "group": cfg.logger.group_name,
60
+ },
61
+ )
62
+
63
+ # Set seeds
64
+ torch.manual_seed(cfg.env.seed)
65
+ np.random.seed(cfg.env.seed)
66
+ device = (
67
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
68
+ )
69
+
70
+ # Create environments
71
+ train_env, eval_env = make_environment(
72
+ cfg,
73
+ cfg.env.train_num_envs,
74
+ cfg.env.eval_num_envs,
75
+ logger=logger,
76
+ )
77
+
78
+ # Create replay buffer
79
+ replay_buffer = make_replay_buffer(
80
+ batch_size=cfg.optim.batch_size,
81
+ prb=cfg.replay_buffer.prb,
82
+ buffer_size=cfg.replay_buffer.size,
83
+ device="cpu",
84
+ )
85
+
86
+ # Create model
87
+ model = make_discrete_iql_model(cfg, train_env, eval_env, device)
88
+
89
+ compile_mode = None
90
+ if cfg.compile.compile:
91
+ compile_mode = cfg.compile.compile_mode
92
+ if compile_mode in ("", None):
93
+ if cfg.compile.cudagraphs:
94
+ compile_mode = "default"
95
+ else:
96
+ compile_mode = "reduce-overhead"
97
+
98
+ # Create collector
99
+ collector = make_collector(
100
+ cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode
101
+ )
102
+
103
+ # Create loss
104
+ loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device)
105
+
106
+ # Create optimizer
107
+ optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
108
+ cfg.optim, loss_module
109
+ )
110
+ optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value)
111
+ del optimizer_actor, optimizer_critic, optimizer_value
112
+
113
+ def update(sampled_tensordict):
114
+ optimizer.zero_grad(set_to_none=True)
115
+ # compute losses
116
+ actor_loss, _ = loss_module.actor_loss(sampled_tensordict)
117
+ value_loss, _ = loss_module.value_loss(sampled_tensordict)
118
+ q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict)
119
+ (actor_loss + value_loss + q_loss).backward()
120
+ optimizer.step()
121
+
122
+ # update qnet_target params
123
+ target_net_updater.step()
124
+ metadata.update(
125
+ {"actor_loss": actor_loss, "value_loss": value_loss, "q_loss": q_loss}
126
+ )
127
+ return TensorDict(metadata).detach()
128
+
129
+ if cfg.compile.compile:
130
+ update = torch.compile(update, mode=compile_mode)
131
+ if cfg.compile.cudagraphs:
132
+ warnings.warn(
133
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
134
+ category=UserWarning,
135
+ )
136
+ update = CudaGraphModule(update, warmup=50)
137
+
138
+ # Main loop
139
+ collected_frames = 0
140
+ pbar = tqdm.tqdm(total=cfg.collector.total_frames)
141
+
142
+ init_random_frames = cfg.collector.init_random_frames
143
+ num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
144
+ prb = cfg.replay_buffer.prb
145
+ eval_iter = cfg.logger.eval_iter
146
+ frames_per_batch = cfg.collector.frames_per_batch
147
+ eval_rollout_steps = cfg.collector.max_frames_per_traj
148
+
149
+ collector_iter = iter(collector)
150
+ total_iter = len(collector)
151
+ for _ in range(total_iter):
152
+ timeit.printevery(1000, total_iter, erase=True)
153
+
154
+ with timeit("collection"):
155
+ tensordict = next(collector_iter)
156
+ current_frames = tensordict.numel()
157
+ pbar.update(current_frames)
158
+
159
+ # update weights of the inference policy
160
+ collector.update_policy_weights_()
161
+
162
+ with timeit("buffer - extend"):
163
+ tensordict = tensordict.reshape(-1)
164
+
165
+ # add to replay buffer
166
+ replay_buffer.extend(tensordict)
167
+ collected_frames += current_frames
168
+
169
+ # optimization steps
170
+ with timeit("training"):
171
+ if collected_frames >= init_random_frames:
172
+ for _ in range(num_updates):
173
+ # sample from replay buffer
174
+ with timeit("buffer - sample"):
175
+ sampled_tensordict = replay_buffer.sample().to(device)
176
+
177
+ with timeit("training - update"):
178
+ torch.compiler.cudagraph_mark_step_begin()
179
+ metadata = update(sampled_tensordict)
180
+ # update priority
181
+ if prb:
182
+ sampled_tensordict.set(
183
+ loss_module.tensor_keys.priority,
184
+ metadata.pop("td_error").detach().max(0).values,
185
+ )
186
+ replay_buffer.update_priority(sampled_tensordict)
187
+
188
+ episode_rewards = tensordict["next", "episode_reward"][
189
+ tensordict["next", "done"]
190
+ ]
191
+
192
+ metrics_to_log = {}
193
+ # Evaluation
194
+ if abs(collected_frames % eval_iter) < frames_per_batch:
195
+ with set_exploration_type(
196
+ ExplorationType.DETERMINISTIC
197
+ ), torch.no_grad(), timeit("eval"):
198
+ eval_rollout = eval_env.rollout(
199
+ eval_rollout_steps,
200
+ model[0],
201
+ auto_cast_to_device=True,
202
+ break_when_any_done=True,
203
+ )
204
+ eval_env.apply(dump_video)
205
+ eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
206
+ metrics_to_log["eval/reward"] = eval_reward
207
+
208
+ # Logging
209
+ if len(episode_rewards) > 0:
210
+ episode_length = tensordict["next", "step_count"][
211
+ tensordict["next", "done"]
212
+ ]
213
+ metrics_to_log["train/reward"] = episode_rewards.mean().item()
214
+ metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
215
+ episode_length
216
+ )
217
+ if collected_frames >= init_random_frames:
218
+ metrics_to_log["train/q_loss"] = metadata["q_loss"]
219
+ metrics_to_log["train/actor_loss"] = metadata["actor_loss"]
220
+ metrics_to_log["train/value_loss"] = metadata["value_loss"]
221
+ if logger is not None:
222
+ metrics_to_log.update(timeit.todict(prefix="time"))
223
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
224
+ log_metrics(logger, metrics_to_log, collected_frames)
225
+
226
+ collector.shutdown()
227
+
228
+
229
+ if __name__ == "__main__":
230
+ main()