torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,305 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ This script reproduces the Proximal Policy Optimization (PPO) Algorithm
8
+ results from Schulman et al. 2017 for the Atari Environments.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import warnings
13
+
14
+ import hydra
15
+ from torchrl._utils import compile_with_warmup, get_available_device
16
+
17
+
18
+ @hydra.main(config_path="", config_name="config_atari", version_base="1.1")
19
+ def main(cfg: DictConfig): # noqa: F821
20
+
21
+ import torch.optim
22
+ import tqdm
23
+
24
+ from tensordict import TensorDict
25
+ from tensordict.nn import CudaGraphModule
26
+
27
+ from torchrl._utils import timeit
28
+ from torchrl.collectors import SyncDataCollector
29
+ from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
30
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
31
+ from torchrl.envs import ExplorationType, set_exploration_type
32
+ from torchrl.objectives import ClipPPOLoss
33
+ from torchrl.objectives.value.advantages import GAE
34
+ from torchrl.record import VideoRecorder
35
+ from torchrl.record.loggers import generate_exp_name, get_logger
36
+ from utils_atari import eval_model, make_parallel_env, make_ppo_models
37
+
38
+ torch.set_float32_matmul_precision("high")
39
+
40
+ device = (
41
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
42
+ )
43
+
44
+ # Correct for frame_skip
45
+ frame_skip = 4
46
+ total_frames = cfg.collector.total_frames // frame_skip
47
+ frames_per_batch = cfg.collector.frames_per_batch // frame_skip
48
+ mini_batch_size = cfg.loss.mini_batch_size // frame_skip
49
+ test_interval = cfg.logger.test_interval // frame_skip
50
+
51
+ compile_mode = None
52
+ if cfg.compile.compile:
53
+ compile_mode = cfg.compile.compile_mode
54
+ if compile_mode in ("", None):
55
+ if cfg.compile.cudagraphs:
56
+ compile_mode = "default"
57
+ else:
58
+ compile_mode = "reduce-overhead"
59
+
60
+ # Create models (check utils_atari.py)
61
+ actor, critic = make_ppo_models(
62
+ cfg.env.env_name, device=device, gym_backend=cfg.env.backend
63
+ )
64
+
65
+ # Create collector
66
+ collector = SyncDataCollector(
67
+ create_env_fn=make_parallel_env(
68
+ cfg.env.env_name,
69
+ num_envs=cfg.env.num_envs,
70
+ device=device,
71
+ gym_backend=cfg.env.backend,
72
+ ),
73
+ policy=actor,
74
+ frames_per_batch=frames_per_batch,
75
+ total_frames=total_frames,
76
+ device=device,
77
+ max_frames_per_traj=-1,
78
+ compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False,
79
+ cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
80
+ )
81
+
82
+ # Create data buffer
83
+ sampler = SamplerWithoutReplacement()
84
+ data_buffer = TensorDictReplayBuffer(
85
+ storage=LazyTensorStorage(
86
+ frames_per_batch, compilable=cfg.compile.compile, device=device
87
+ ),
88
+ sampler=sampler,
89
+ batch_size=mini_batch_size,
90
+ compilable=cfg.compile.compile,
91
+ )
92
+
93
+ # Create loss and adv modules
94
+ adv_module = GAE(
95
+ gamma=cfg.loss.gamma,
96
+ lmbda=cfg.loss.gae_lambda,
97
+ value_network=critic,
98
+ average_gae=False,
99
+ device=device,
100
+ vectorized=not cfg.compile.compile,
101
+ )
102
+ loss_module = ClipPPOLoss(
103
+ actor_network=actor,
104
+ critic_network=critic,
105
+ clip_epsilon=cfg.loss.clip_epsilon,
106
+ loss_critic_type=cfg.loss.loss_critic_type,
107
+ entropy_coeff=cfg.loss.entropy_coeff,
108
+ critic_coeff=cfg.loss.critic_coeff,
109
+ normalize_advantage=True,
110
+ )
111
+
112
+ # use end-of-life as done key
113
+ adv_module.set_keys(done="end-of-life", terminated="end-of-life")
114
+ loss_module.set_keys(done="end-of-life", terminated="end-of-life")
115
+
116
+ # Create optimizer
117
+ optim = torch.optim.Adam(
118
+ loss_module.parameters(),
119
+ lr=cfg.optim.lr,
120
+ weight_decay=cfg.optim.weight_decay,
121
+ eps=cfg.optim.eps,
122
+ )
123
+
124
+ # Create logger
125
+ logger = None
126
+ if cfg.logger.backend:
127
+ exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
128
+ logger = get_logger(
129
+ cfg.logger.backend,
130
+ logger_name="ppo",
131
+ experiment_name=exp_name,
132
+ wandb_kwargs={
133
+ "config": dict(cfg),
134
+ "project": cfg.logger.project_name,
135
+ "group": cfg.logger.group_name,
136
+ },
137
+ )
138
+ logger_video = cfg.logger.video
139
+ else:
140
+ logger_video = False
141
+
142
+ # Create test environment
143
+ test_env = make_parallel_env(
144
+ cfg.env.env_name, 1, device, is_test=True, gym_backend=cfg.env.backend
145
+ )
146
+ if logger_video:
147
+ test_env = test_env.append_transform(
148
+ VideoRecorder(logger, tag="rendering/test", in_keys=["pixels_int"])
149
+ )
150
+ test_env.eval()
151
+
152
+ # Main loop
153
+ collected_frames = 0
154
+ num_network_updates = torch.zeros((), dtype=torch.int64, device=device)
155
+ pbar = tqdm.tqdm(total=total_frames)
156
+ num_mini_batches = frames_per_batch // mini_batch_size
157
+ total_network_updates = (
158
+ (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches
159
+ )
160
+
161
+ def update(batch, num_network_updates):
162
+ optim.zero_grad(set_to_none=True)
163
+
164
+ # Linearly decrease the learning rate and clip epsilon
165
+ alpha = torch.ones((), device=device)
166
+ if cfg_optim_anneal_lr:
167
+ alpha = 1 - (num_network_updates / total_network_updates)
168
+ for group in optim.param_groups:
169
+ group["lr"] = cfg_optim_lr * alpha
170
+ if cfg_loss_anneal_clip_eps:
171
+ loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
172
+ num_network_updates = num_network_updates + 1
173
+ # Get a data batch
174
+ batch = batch.to(device, non_blocking=True)
175
+
176
+ # Forward pass PPO loss
177
+ loss = loss_module(batch)
178
+ loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
179
+ # Backward pass
180
+ loss_sum.backward()
181
+ torch.nn.utils.clip_grad_norm_(
182
+ loss_module.parameters(), max_norm=cfg_optim_max_grad_norm
183
+ )
184
+
185
+ # Update the networks
186
+ optim.step()
187
+ return loss.detach().set("alpha", alpha), num_network_updates
188
+
189
+ if cfg.compile.compile:
190
+ update = compile_with_warmup(update, mode=compile_mode, warmup=1)
191
+ adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1)
192
+
193
+ if cfg.compile.cudagraphs:
194
+ warnings.warn(
195
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
196
+ category=UserWarning,
197
+ )
198
+ update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
199
+ adv_module = CudaGraphModule(adv_module)
200
+
201
+ # extract cfg variables
202
+ cfg_loss_ppo_epochs = cfg.loss.ppo_epochs
203
+ cfg_optim_anneal_lr = cfg.optim.anneal_lr
204
+ cfg_optim_lr = cfg.optim.lr
205
+ cfg_loss_anneal_clip_eps = cfg.loss.anneal_clip_epsilon
206
+ cfg_loss_clip_epsilon = cfg.loss.clip_epsilon
207
+ cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
208
+ cfg_optim_max_grad_norm = cfg.optim.max_grad_norm
209
+ cfg.loss.clip_epsilon = cfg_loss_clip_epsilon
210
+ losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches])
211
+
212
+ collector_iter = iter(collector)
213
+ total_iter = len(collector)
214
+ for i in range(total_iter):
215
+ timeit.printevery(1000, total_iter, erase=True)
216
+
217
+ with timeit("collecting"):
218
+ data = next(collector_iter)
219
+
220
+ metrics_to_log = {}
221
+ frames_in_batch = data.numel()
222
+ collected_frames += frames_in_batch * frame_skip
223
+ pbar.update(frames_in_batch)
224
+
225
+ # Get training rewards and episode lengths
226
+ episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
227
+ if len(episode_rewards) > 0:
228
+ episode_length = data["next", "step_count"][data["next", "terminated"]]
229
+ metrics_to_log.update(
230
+ {
231
+ "train/reward": episode_rewards.mean().item(),
232
+ "train/episode_length": episode_length.sum().item()
233
+ / len(episode_length),
234
+ }
235
+ )
236
+
237
+ with timeit("training"):
238
+ for j in range(cfg_loss_ppo_epochs):
239
+
240
+ # Compute GAE
241
+ with torch.no_grad(), timeit("adv"):
242
+ torch.compiler.cudagraph_mark_step_begin()
243
+ data = adv_module(data)
244
+ if compile_mode:
245
+ data = data.clone()
246
+ with timeit("rb - extend"):
247
+ # Update the data buffer
248
+ data_reshape = data.reshape(-1)
249
+ data_buffer.extend(data_reshape)
250
+
251
+ for k, batch in enumerate(data_buffer):
252
+ with timeit("update"):
253
+ torch.compiler.cudagraph_mark_step_begin()
254
+ loss, num_network_updates = update(
255
+ batch, num_network_updates=num_network_updates
256
+ )
257
+ loss = loss.clone()
258
+ num_network_updates = num_network_updates.clone()
259
+ losses[j, k] = loss.select(
260
+ "loss_critic", "loss_entropy", "loss_objective"
261
+ )
262
+
263
+ # Get training losses and times
264
+ losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[])
265
+ for key, value in losses_mean.items():
266
+ metrics_to_log.update({f"train/{key}": value.item()})
267
+ metrics_to_log.update(
268
+ {
269
+ "train/lr": loss["alpha"] * cfg_optim_lr,
270
+ "train/clip_epsilon": loss["alpha"] * cfg_loss_clip_epsilon,
271
+ }
272
+ )
273
+
274
+ # Get test rewards
275
+ with torch.no_grad(), set_exploration_type(
276
+ ExplorationType.DETERMINISTIC
277
+ ), timeit("eval"):
278
+ if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
279
+ i * frames_in_batch * frame_skip
280
+ ) // test_interval:
281
+ actor.eval()
282
+ test_rewards = eval_model(
283
+ actor, test_env, num_episodes=cfg_logger_num_test_episodes
284
+ )
285
+ metrics_to_log.update(
286
+ {
287
+ "eval/reward": test_rewards.mean(),
288
+ }
289
+ )
290
+ actor.train()
291
+ if logger:
292
+ metrics_to_log.update(timeit.todict(prefix="time"))
293
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
294
+ for key, value in metrics_to_log.items():
295
+ logger.log_scalar(key, value, collected_frames)
296
+
297
+ collector.update_policy_weights_()
298
+
299
+ collector.shutdown()
300
+ if not test_env.is_closed:
301
+ test_env.close()
302
+
303
+
304
+ if __name__ == "__main__":
305
+ main()
@@ -0,0 +1,293 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ This script reproduces the Proximal Policy Optimization (PPO) Algorithm
8
+ results from Schulman et al. 2017 for the on MuJoCo Environments.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import warnings
13
+
14
+ import hydra
15
+ from torchrl._utils import compile_with_warmup, get_available_device
16
+
17
+
18
+ @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1")
19
+ def main(cfg: DictConfig): # noqa: F821
20
+
21
+ import torch.optim
22
+ import tqdm
23
+
24
+ from tensordict import TensorDict
25
+ from tensordict.nn import CudaGraphModule
26
+
27
+ from torchrl._utils import timeit
28
+ from torchrl.collectors import SyncDataCollector
29
+ from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
30
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
31
+ from torchrl.envs import ExplorationType, set_exploration_type
32
+ from torchrl.objectives import ClipPPOLoss, group_optimizers
33
+ from torchrl.objectives.value.advantages import GAE
34
+ from torchrl.record import VideoRecorder
35
+ from torchrl.record.loggers import generate_exp_name, get_logger
36
+ from utils_mujoco import eval_model, make_env, make_ppo_models
37
+
38
+ torch.set_float32_matmul_precision("high")
39
+
40
+ device = (
41
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
42
+ )
43
+
44
+ num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size
45
+ total_network_updates = (
46
+ (cfg.collector.total_frames // cfg.collector.frames_per_batch)
47
+ * cfg.loss.ppo_epochs
48
+ * num_mini_batches
49
+ )
50
+
51
+ compile_mode = None
52
+ if cfg.compile.compile:
53
+ compile_mode = cfg.compile.compile_mode
54
+ if compile_mode in ("", None):
55
+ if cfg.compile.cudagraphs:
56
+ compile_mode = "default"
57
+ else:
58
+ compile_mode = "reduce-overhead"
59
+
60
+ # Create models (check utils_mujoco.py)
61
+ actor, critic = make_ppo_models(cfg.env.env_name, device=device)
62
+
63
+ # Create collector
64
+ collector = SyncDataCollector(
65
+ create_env_fn=make_env(cfg.env.env_name, device),
66
+ policy=actor,
67
+ frames_per_batch=cfg.collector.frames_per_batch,
68
+ total_frames=cfg.collector.total_frames,
69
+ device=device,
70
+ max_frames_per_traj=-1,
71
+ compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False,
72
+ cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
73
+ )
74
+
75
+ # Create data buffer
76
+ sampler = SamplerWithoutReplacement()
77
+ data_buffer = TensorDictReplayBuffer(
78
+ storage=LazyTensorStorage(
79
+ cfg.collector.frames_per_batch,
80
+ compilable=cfg.compile.compile,
81
+ device=device,
82
+ ),
83
+ sampler=sampler,
84
+ batch_size=cfg.loss.mini_batch_size,
85
+ compilable=cfg.compile.compile,
86
+ )
87
+
88
+ # Create loss and adv modules
89
+ adv_module = GAE(
90
+ gamma=cfg.loss.gamma,
91
+ lmbda=cfg.loss.gae_lambda,
92
+ value_network=critic,
93
+ average_gae=False,
94
+ device=device,
95
+ vectorized=not cfg.compile.compile,
96
+ )
97
+
98
+ loss_module = ClipPPOLoss(
99
+ actor_network=actor,
100
+ critic_network=critic,
101
+ clip_epsilon=cfg.loss.clip_epsilon,
102
+ loss_critic_type=cfg.loss.loss_critic_type,
103
+ entropy_coeff=cfg.loss.entropy_coeff,
104
+ critic_coeff=cfg.loss.critic_coeff,
105
+ normalize_advantage=True,
106
+ )
107
+
108
+ # Create optimizers
109
+ actor_optim = torch.optim.Adam(
110
+ actor.parameters(), lr=torch.tensor(cfg.optim.lr, device=device), eps=1e-5
111
+ )
112
+ critic_optim = torch.optim.Adam(
113
+ critic.parameters(), lr=torch.tensor(cfg.optim.lr, device=device), eps=1e-5
114
+ )
115
+ optim = group_optimizers(actor_optim, critic_optim)
116
+ del actor_optim, critic_optim
117
+
118
+ # Create logger
119
+ logger = None
120
+ if cfg.logger.backend:
121
+ exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
122
+ logger = get_logger(
123
+ cfg.logger.backend,
124
+ logger_name="ppo",
125
+ experiment_name=exp_name,
126
+ wandb_kwargs={
127
+ "config": dict(cfg),
128
+ "project": cfg.logger.project_name,
129
+ "group": cfg.logger.group_name,
130
+ },
131
+ )
132
+ logger_video = cfg.logger.video
133
+ else:
134
+ logger_video = False
135
+
136
+ # Create test environment
137
+ test_env = make_env(cfg.env.env_name, device, from_pixels=logger_video)
138
+ if logger_video:
139
+ test_env = test_env.append_transform(
140
+ VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
141
+ )
142
+ test_env.eval()
143
+
144
+ def update(batch, num_network_updates):
145
+ optim.zero_grad(set_to_none=True)
146
+ # Linearly decrease the learning rate and clip epsilon
147
+ alpha = torch.ones((), device=device)
148
+ if cfg_optim_anneal_lr:
149
+ alpha = 1 - (num_network_updates / total_network_updates)
150
+ for group in optim.param_groups:
151
+ group["lr"] = cfg_optim_lr * alpha
152
+ if cfg_loss_anneal_clip_eps:
153
+ loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
154
+ num_network_updates = num_network_updates + 1
155
+
156
+ # Forward pass PPO loss
157
+ loss = loss_module(batch)
158
+ critic_loss = loss["loss_critic"]
159
+ actor_loss = loss["loss_objective"] + loss["loss_entropy"]
160
+ total_loss = critic_loss + actor_loss
161
+
162
+ # Backward pass
163
+ total_loss.backward()
164
+
165
+ # Update the networks
166
+ optim.step()
167
+ return loss.detach().set("alpha", alpha), num_network_updates
168
+
169
+ if cfg.compile.compile:
170
+ update = compile_with_warmup(update, mode=compile_mode, warmup=1)
171
+ adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1)
172
+
173
+ if cfg.compile.cudagraphs:
174
+ warnings.warn(
175
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
176
+ category=UserWarning,
177
+ )
178
+ update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
179
+ adv_module = CudaGraphModule(adv_module)
180
+
181
+ # Main loop
182
+ collected_frames = 0
183
+ num_network_updates = torch.zeros((), dtype=torch.int64, device=device)
184
+ pbar = tqdm.tqdm(total=cfg.collector.total_frames)
185
+
186
+ # extract cfg variables
187
+ cfg_loss_ppo_epochs = cfg.loss.ppo_epochs
188
+ cfg_optim_anneal_lr = cfg.optim.anneal_lr
189
+ cfg_optim_lr = torch.tensor(cfg.optim.lr, device=device)
190
+ cfg_loss_anneal_clip_eps = cfg.loss.anneal_clip_epsilon
191
+ cfg_loss_clip_epsilon = cfg.loss.clip_epsilon
192
+ cfg_logger_test_interval = cfg.logger.test_interval
193
+ cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
194
+ losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches])
195
+
196
+ collector_iter = iter(collector)
197
+ total_iter = len(collector)
198
+ for i in range(total_iter):
199
+ timeit.printevery(1000, total_iter, erase=True)
200
+
201
+ with timeit("collecting"):
202
+ data = next(collector_iter)
203
+
204
+ metrics_to_log = {}
205
+ frames_in_batch = data.numel()
206
+ collected_frames += frames_in_batch
207
+ pbar.update(frames_in_batch)
208
+
209
+ # Get training rewards and episode lengths
210
+ episode_rewards = data["next", "episode_reward"][data["next", "done"]]
211
+ if len(episode_rewards) > 0:
212
+ episode_length = data["next", "step_count"][data["next", "done"]]
213
+ metrics_to_log.update(
214
+ {
215
+ "train/reward": episode_rewards.mean().item(),
216
+ "train/episode_length": episode_length.sum().item()
217
+ / len(episode_length),
218
+ }
219
+ )
220
+
221
+ with timeit("training"):
222
+ for j in range(cfg_loss_ppo_epochs):
223
+
224
+ # Compute GAE
225
+ with torch.no_grad(), timeit("adv"):
226
+ torch.compiler.cudagraph_mark_step_begin()
227
+ data = adv_module(data)
228
+ if compile_mode:
229
+ data = data.clone()
230
+
231
+ with timeit("rb - extend"):
232
+ # Update the data buffer
233
+ data_reshape = data.reshape(-1)
234
+ data_buffer.extend(data_reshape)
235
+
236
+ for k, batch in enumerate(data_buffer):
237
+ with timeit("update"):
238
+ torch.compiler.cudagraph_mark_step_begin()
239
+ loss, num_network_updates = update(
240
+ batch, num_network_updates=num_network_updates
241
+ )
242
+ loss = loss.clone()
243
+ num_network_updates = num_network_updates.clone()
244
+ losses[j, k] = loss.select(
245
+ "loss_critic", "loss_entropy", "loss_objective"
246
+ )
247
+
248
+ # Get training losses and times
249
+ losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[])
250
+ for key, value in losses_mean.items():
251
+ metrics_to_log.update({f"train/{key}": value.item()})
252
+ metrics_to_log.update(
253
+ {
254
+ "train/lr": loss["alpha"] * cfg_optim_lr,
255
+ "train/clip_epsilon": loss["alpha"] * cfg_loss_clip_epsilon
256
+ if cfg_loss_anneal_clip_eps
257
+ else cfg_loss_clip_epsilon,
258
+ }
259
+ )
260
+
261
+ # Get test rewards
262
+ with torch.no_grad(), set_exploration_type(
263
+ ExplorationType.DETERMINISTIC
264
+ ), timeit("eval"):
265
+ if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < (
266
+ i * frames_in_batch
267
+ ) // cfg_logger_test_interval:
268
+ actor.eval()
269
+ test_rewards = eval_model(
270
+ actor, test_env, num_episodes=cfg_logger_num_test_episodes
271
+ )
272
+ metrics_to_log.update(
273
+ {
274
+ "eval/reward": test_rewards.mean(),
275
+ }
276
+ )
277
+ actor.train()
278
+
279
+ if logger:
280
+ metrics_to_log.update(timeit.todict(prefix="time"))
281
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
282
+ for key, value in metrics_to_log.items():
283
+ logger.log_scalar(key, value, collected_frames)
284
+
285
+ collector.update_policy_weights_()
286
+
287
+ collector.shutdown()
288
+ if not test_env.is_closed:
289
+ test_env.close()
290
+
291
+
292
+ if __name__ == "__main__":
293
+ main()