torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,272 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ DQN: Reproducing experimental results from Mnih et al. 2015 for the
8
+ Deep Q-Learning Algorithm on Atari Environments.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import functools
13
+ import warnings
14
+
15
+ import hydra
16
+ import torch.nn
17
+ import torch.optim
18
+ import tqdm
19
+ from tensordict.nn import CudaGraphModule, TensorDictSequential
20
+ from torchrl._utils import get_available_device, timeit
21
+ from torchrl.collectors import SyncDataCollector
22
+ from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
23
+ from torchrl.envs import ExplorationType, set_exploration_type
24
+ from torchrl.modules import EGreedyModule
25
+ from torchrl.objectives import DQNLoss, HardUpdate
26
+ from torchrl.record import VideoRecorder
27
+ from torchrl.record.loggers import generate_exp_name, get_logger
28
+ from utils_atari import eval_model, make_dqn_model, make_env
29
+
30
+ torch.set_float32_matmul_precision("high")
31
+
32
+
33
+ @hydra.main(config_path="", config_name="config_atari", version_base="1.1")
34
+ def main(cfg: DictConfig): # noqa: F821
35
+
36
+ device = torch.device(cfg.device) if cfg.device else get_available_device()
37
+
38
+ # Correct for frame_skip
39
+ frame_skip = 4
40
+ total_frames = cfg.collector.total_frames // frame_skip
41
+ frames_per_batch = cfg.collector.frames_per_batch // frame_skip
42
+ init_random_frames = cfg.collector.init_random_frames // frame_skip
43
+ test_interval = cfg.logger.test_interval // frame_skip
44
+
45
+ # Make the components
46
+ model = make_dqn_model(
47
+ cfg.env.env_name,
48
+ gym_backend=cfg.env.backend,
49
+ frame_skip=frame_skip,
50
+ device=device,
51
+ )
52
+ greedy_module = EGreedyModule(
53
+ annealing_num_steps=cfg.collector.annealing_frames,
54
+ eps_init=cfg.collector.eps_start,
55
+ eps_end=cfg.collector.eps_end,
56
+ spec=model.spec,
57
+ device=device,
58
+ )
59
+ model_explore = TensorDictSequential(
60
+ model,
61
+ greedy_module,
62
+ )
63
+
64
+ # Create the replay buffer
65
+ if cfg.buffer.scratch_dir in ("", None):
66
+ storage_cls = LazyMemmapStorage
67
+ else:
68
+ storage_cls = functools.partial(
69
+ LazyMemmapStorage, scratch_dir=cfg.buffer.scratch_dir
70
+ )
71
+
72
+ def transform(td):
73
+ return td.to(device)
74
+
75
+ replay_buffer = TensorDictReplayBuffer(
76
+ pin_memory=False,
77
+ storage=storage_cls(
78
+ max_size=cfg.buffer.buffer_size,
79
+ ),
80
+ batch_size=cfg.buffer.batch_size,
81
+ )
82
+ if transform is not None:
83
+ replay_buffer.append_transform(transform)
84
+
85
+ # Create the loss module
86
+ loss_module = DQNLoss(
87
+ value_network=model,
88
+ loss_function="l2",
89
+ delay_value=True,
90
+ )
91
+ loss_module.set_keys(done="end-of-life", terminated="end-of-life")
92
+ loss_module.make_value_estimator(gamma=cfg.loss.gamma, device=device)
93
+ target_net_updater = HardUpdate(
94
+ loss_module, value_network_update_interval=cfg.loss.hard_update_freq
95
+ )
96
+
97
+ # Create the optimizer
98
+ optimizer = torch.optim.Adam(loss_module.parameters(), lr=cfg.optim.lr)
99
+
100
+ # Create the logger
101
+ logger = None
102
+ if cfg.logger.backend:
103
+ exp_name = generate_exp_name("DQN", f"Atari_mnih15_{cfg.env.env_name}")
104
+ logger = get_logger(
105
+ cfg.logger.backend,
106
+ logger_name="dqn",
107
+ experiment_name=exp_name,
108
+ wandb_kwargs={
109
+ "config": dict(cfg),
110
+ "project": cfg.logger.project_name,
111
+ "group": cfg.logger.group_name,
112
+ },
113
+ )
114
+
115
+ # Create the test environment
116
+ test_env = make_env(
117
+ cfg.env.env_name,
118
+ frame_skip,
119
+ device,
120
+ gym_backend=cfg.env.backend,
121
+ is_test=True,
122
+ )
123
+ if cfg.logger.video:
124
+ test_env.insert_transform(
125
+ 0,
126
+ VideoRecorder(
127
+ logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"]
128
+ ),
129
+ )
130
+ test_env.eval()
131
+
132
+ def update(sampled_tensordict):
133
+ loss_td = loss_module(sampled_tensordict)
134
+ q_loss = loss_td["loss"]
135
+ optimizer.zero_grad()
136
+ q_loss.backward()
137
+ torch.nn.utils.clip_grad_norm_(
138
+ list(loss_module.parameters()), max_norm=max_grad
139
+ )
140
+ optimizer.step()
141
+ target_net_updater.step()
142
+ return q_loss.detach()
143
+
144
+ compile_mode = None
145
+ if cfg.compile.compile:
146
+ compile_mode = cfg.compile.compile_mode
147
+ if compile_mode in ("", None):
148
+ if cfg.compile.cudagraphs:
149
+ compile_mode = "default"
150
+ else:
151
+ compile_mode = "reduce-overhead"
152
+ update = torch.compile(update, mode=compile_mode)
153
+ if cfg.compile.cudagraphs:
154
+ warnings.warn(
155
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
156
+ category=UserWarning,
157
+ )
158
+ update = CudaGraphModule(update, warmup=50)
159
+
160
+ # Create the collector
161
+ collector = SyncDataCollector(
162
+ create_env_fn=make_env(
163
+ cfg.env.env_name, frame_skip, device, gym_backend=cfg.env.backend
164
+ ),
165
+ policy=model_explore,
166
+ frames_per_batch=frames_per_batch,
167
+ total_frames=total_frames,
168
+ device=device,
169
+ storing_device=device,
170
+ max_frames_per_traj=-1,
171
+ init_random_frames=init_random_frames,
172
+ compile_policy={"mode": compile_mode, "fullgraph": True}
173
+ if compile_mode is not None
174
+ else False,
175
+ cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
176
+ )
177
+
178
+ # Main loop
179
+ collected_frames = 0
180
+ num_updates = cfg.loss.num_updates
181
+ max_grad = cfg.optim.max_grad_norm
182
+ num_test_episodes = cfg.logger.num_test_episodes
183
+ q_losses = torch.zeros(num_updates, device=device)
184
+ pbar = tqdm.tqdm(total=total_frames)
185
+
186
+ c_iter = iter(collector)
187
+ total_iter = len(collector)
188
+ for i in range(total_iter):
189
+ timeit.printevery(1000, total_iter, erase=True)
190
+ with timeit("collecting"):
191
+ data = next(c_iter)
192
+ metrics_to_log = {}
193
+ pbar.update(data.numel())
194
+ data = data.reshape(-1)
195
+ current_frames = data.numel() * frame_skip
196
+ collected_frames += current_frames
197
+ greedy_module.step(current_frames)
198
+ with timeit("rb - extend"):
199
+ replay_buffer.extend(data)
200
+
201
+ # Get and log training rewards and episode lengths
202
+ episode_rewards = data["next", "episode_reward"][data["next", "done"]]
203
+ if len(episode_rewards) > 0:
204
+ episode_reward_mean = episode_rewards.mean().item()
205
+ episode_length = data["next", "step_count"][data["next", "done"]]
206
+ episode_length_mean = episode_length.sum().item() / len(episode_length)
207
+ metrics_to_log.update(
208
+ {
209
+ "train/episode_reward": episode_reward_mean,
210
+ "train/episode_length": episode_length_mean,
211
+ }
212
+ )
213
+
214
+ if collected_frames < init_random_frames:
215
+ if logger:
216
+ for key, value in metrics_to_log.items():
217
+ logger.log_scalar(key, value, step=collected_frames)
218
+ continue
219
+
220
+ # optimization steps
221
+ for j in range(num_updates):
222
+ with timeit("rb - sample"):
223
+ sampled_tensordict = replay_buffer.sample()
224
+ with timeit("update"):
225
+ q_loss = update(sampled_tensordict)
226
+ q_losses[j].copy_(q_loss)
227
+
228
+ # Get and log q-values, loss, epsilon, sampling time and training time
229
+ metrics_to_log.update(
230
+ {
231
+ "train/q_values": data["chosen_action_value"].sum() / frames_per_batch,
232
+ "train/q_loss": q_losses.mean(),
233
+ "train/epsilon": greedy_module.eps,
234
+ }
235
+ )
236
+
237
+ # Get and log evaluation rewards and eval time
238
+ with torch.no_grad(), set_exploration_type(
239
+ ExplorationType.DETERMINISTIC
240
+ ), timeit("eval"):
241
+ prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
242
+ cur_test_frame = (i * frames_per_batch) // test_interval
243
+ final = current_frames >= collector.total_frames
244
+ if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
245
+ model.eval()
246
+ test_rewards = eval_model(
247
+ model, test_env, num_episodes=num_test_episodes
248
+ )
249
+ metrics_to_log.update(
250
+ {
251
+ "eval/reward": test_rewards,
252
+ }
253
+ )
254
+ model.train()
255
+
256
+ # Log all the information
257
+ if logger:
258
+ metrics_to_log.update(timeit.todict(prefix="time"))
259
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
260
+ for key, value in metrics_to_log.items():
261
+ logger.log_scalar(key, value, step=collected_frames)
262
+
263
+ # update weights of the inference policy
264
+ collector.update_policy_weights_()
265
+
266
+ collector.shutdown()
267
+ if not test_env.is_closed:
268
+ test_env.close()
269
+
270
+
271
+ if __name__ == "__main__":
272
+ main()
@@ -0,0 +1,236 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import warnings
9
+
10
+ import hydra
11
+ import torch.nn
12
+ import torch.optim
13
+ import tqdm
14
+ from tensordict.nn import CudaGraphModule, TensorDictSequential
15
+ from torchrl._utils import get_available_device, timeit
16
+ from torchrl.collectors import SyncDataCollector
17
+ from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
18
+ from torchrl.envs import ExplorationType, set_exploration_type
19
+ from torchrl.modules import EGreedyModule
20
+ from torchrl.objectives import DQNLoss, HardUpdate
21
+ from torchrl.record import VideoRecorder
22
+ from torchrl.record.loggers import generate_exp_name, get_logger
23
+ from utils_cartpole import eval_model, make_dqn_model, make_env
24
+
25
+ torch.set_float32_matmul_precision("high")
26
+
27
+
28
+ @hydra.main(config_path="", config_name="config_cartpole", version_base="1.1")
29
+ def main(cfg: DictConfig): # noqa: F821
30
+
31
+ device = torch.device(cfg.device) if cfg.device else get_available_device()
32
+
33
+ # Make the components
34
+ model = make_dqn_model(cfg.env.env_name, device=device)
35
+
36
+ greedy_module = EGreedyModule(
37
+ annealing_num_steps=cfg.collector.annealing_frames,
38
+ eps_init=cfg.collector.eps_start,
39
+ eps_end=cfg.collector.eps_end,
40
+ spec=model.spec,
41
+ device=device,
42
+ )
43
+ model_explore = TensorDictSequential(
44
+ model,
45
+ greedy_module,
46
+ )
47
+
48
+ # Create the replay buffer
49
+ replay_buffer = TensorDictReplayBuffer(
50
+ pin_memory=False,
51
+ storage=LazyTensorStorage(max_size=cfg.buffer.buffer_size, device=device),
52
+ batch_size=cfg.buffer.batch_size,
53
+ )
54
+
55
+ # Create the loss module
56
+ loss_module = DQNLoss(
57
+ value_network=model,
58
+ loss_function="l2",
59
+ delay_value=True,
60
+ )
61
+ loss_module.make_value_estimator(gamma=cfg.loss.gamma, device=device)
62
+ loss_module = loss_module.to(device)
63
+ target_net_updater = HardUpdate(
64
+ loss_module, value_network_update_interval=cfg.loss.hard_update_freq
65
+ )
66
+
67
+ # Create the optimizer
68
+ optimizer = torch.optim.Adam(loss_module.parameters(), lr=cfg.optim.lr)
69
+
70
+ # Create the logger
71
+ logger = None
72
+ if cfg.logger.backend:
73
+ exp_name = generate_exp_name("DQN", f"CartPole_{cfg.env.env_name}")
74
+ logger = get_logger(
75
+ cfg.logger.backend,
76
+ logger_name="dqn",
77
+ experiment_name=exp_name,
78
+ wandb_kwargs={
79
+ "config": dict(cfg),
80
+ "project": cfg.logger.project_name,
81
+ "group": cfg.logger.group_name,
82
+ },
83
+ )
84
+
85
+ # Create the test environment
86
+ test_env = make_env(cfg.env.env_name, "cpu", from_pixels=cfg.logger.video)
87
+ if cfg.logger.video:
88
+ test_env.insert_transform(
89
+ 0,
90
+ VideoRecorder(
91
+ logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"]
92
+ ),
93
+ )
94
+
95
+ def update(sampled_tensordict):
96
+ loss_td = loss_module(sampled_tensordict)
97
+ q_loss = loss_td["loss"]
98
+ optimizer.zero_grad()
99
+ q_loss.backward()
100
+ optimizer.step()
101
+ target_net_updater.step()
102
+ return q_loss.detach()
103
+
104
+ compile_mode = None
105
+ if cfg.compile.compile:
106
+ compile_mode = cfg.compile.compile_mode
107
+ if compile_mode in ("", None):
108
+ if cfg.compile.cudagraphs:
109
+ compile_mode = "default"
110
+ else:
111
+ compile_mode = "reduce-overhead"
112
+ update = torch.compile(update, mode=compile_mode)
113
+ if cfg.compile.cudagraphs:
114
+ warnings.warn(
115
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
116
+ category=UserWarning,
117
+ )
118
+ update = CudaGraphModule(update, warmup=50)
119
+
120
+ # Create the collector
121
+ collector = SyncDataCollector(
122
+ create_env_fn=make_env(cfg.env.env_name, "cpu"),
123
+ policy=model_explore,
124
+ frames_per_batch=cfg.collector.frames_per_batch,
125
+ total_frames=cfg.collector.total_frames,
126
+ device="cpu",
127
+ storing_device="cpu",
128
+ max_frames_per_traj=-1,
129
+ init_random_frames=cfg.collector.init_random_frames,
130
+ compile_policy={"mode": compile_mode, "fullgraph": True}
131
+ if compile_mode is not None
132
+ else False,
133
+ cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
134
+ )
135
+
136
+ # Main loop
137
+ collected_frames = 0
138
+ num_updates = cfg.loss.num_updates
139
+ batch_size = cfg.buffer.batch_size
140
+ test_interval = cfg.logger.test_interval
141
+ num_test_episodes = cfg.logger.num_test_episodes
142
+ frames_per_batch = cfg.collector.frames_per_batch
143
+ pbar = tqdm.tqdm(total=cfg.collector.total_frames)
144
+ init_random_frames = cfg.collector.init_random_frames
145
+ q_losses = torch.zeros(num_updates, device=device)
146
+
147
+ c_iter = iter(collector)
148
+ total_iter = len(collector)
149
+ for i in range(total_iter):
150
+ timeit.printevery(1000, total_iter, erase=True)
151
+ with timeit("collecting"):
152
+ data = next(c_iter)
153
+
154
+ metrics_to_log = {}
155
+ pbar.update(data.numel())
156
+ data = data.reshape(-1)
157
+ current_frames = data.numel()
158
+
159
+ with timeit("rb - extend"):
160
+ replay_buffer.extend(data)
161
+ collected_frames += current_frames
162
+ greedy_module.step(current_frames)
163
+
164
+ # Get and log training rewards and episode lengths
165
+ episode_rewards = data["next", "episode_reward"][data["next", "done"]]
166
+ if len(episode_rewards) > 0:
167
+ episode_reward_mean = episode_rewards.mean().item()
168
+ episode_length = data["next", "step_count"][data["next", "done"]]
169
+ episode_length_mean = episode_length.sum().item() / len(episode_length)
170
+ metrics_to_log.update(
171
+ {
172
+ "train/episode_reward": episode_reward_mean,
173
+ "train/episode_length": episode_length_mean,
174
+ }
175
+ )
176
+
177
+ if collected_frames < init_random_frames:
178
+ if collected_frames < init_random_frames:
179
+ if logger:
180
+ for key, value in metrics_to_log.items():
181
+ logger.log_scalar(key, value, step=collected_frames)
182
+ continue
183
+
184
+ # optimization steps
185
+ for j in range(num_updates):
186
+ with timeit("rb - sample"):
187
+ sampled_tensordict = replay_buffer.sample(batch_size)
188
+ sampled_tensordict = sampled_tensordict.to(device)
189
+ with timeit("update"):
190
+ q_loss = update(sampled_tensordict)
191
+ q_losses[j].copy_(q_loss)
192
+
193
+ # Get and log q-values, loss, epsilon, sampling time and training time
194
+ metrics_to_log.update(
195
+ {
196
+ "train/q_values": (data["action_value"] * data["action"]).sum().item()
197
+ / frames_per_batch,
198
+ "train/q_loss": q_losses.mean().item(),
199
+ "train/epsilon": greedy_module.eps,
200
+ }
201
+ )
202
+
203
+ # Get and log evaluation rewards and eval time
204
+ with torch.no_grad(), set_exploration_type(
205
+ ExplorationType.DETERMINISTIC
206
+ ), timeit("eval"):
207
+ prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
208
+ cur_test_frame = (i * frames_per_batch) // test_interval
209
+ final = current_frames >= collector.total_frames
210
+ if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
211
+ model.eval()
212
+ test_rewards = eval_model(model, test_env, num_test_episodes)
213
+ model.train()
214
+ metrics_to_log.update(
215
+ {
216
+ "eval/reward": test_rewards,
217
+ }
218
+ )
219
+
220
+ # Log all the information
221
+ if logger:
222
+ metrics_to_log.update(timeit.todict(prefix="time"))
223
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
224
+ for key, value in metrics_to_log.items():
225
+ logger.log_scalar(key, value, step=collected_frames)
226
+
227
+ # update weights of the inference policy
228
+ collector.update_policy_weights_()
229
+
230
+ collector.shutdown()
231
+ if not test_env.is_closed:
232
+ test_env.close()
233
+
234
+
235
+ if __name__ == "__main__":
236
+ main()
@@ -0,0 +1,132 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch.nn
8
+ import torch.optim
9
+ from torchrl.data import Composite
10
+ from torchrl.envs import (
11
+ CatFrames,
12
+ DoubleToFloat,
13
+ EndOfLifeTransform,
14
+ GrayScale,
15
+ GymEnv,
16
+ NoopResetEnv,
17
+ Resize,
18
+ RewardSum,
19
+ set_gym_backend,
20
+ SignTransform,
21
+ StepCounter,
22
+ ToTensorImage,
23
+ TransformedEnv,
24
+ VecNorm,
25
+ )
26
+
27
+ from torchrl.modules import ConvNet, MLP, QValueActor
28
+ from torchrl.record import VideoRecorder
29
+
30
+
31
+ # ====================================================================
32
+ # Environment utils
33
+ # --------------------------------------------------------------------
34
+
35
+
36
+ def make_env(env_name, frame_skip, device, gym_backend, is_test=False):
37
+ with set_gym_backend(gym_backend):
38
+ env = GymEnv(
39
+ env_name,
40
+ frame_skip=frame_skip,
41
+ from_pixels=True,
42
+ pixels_only=False,
43
+ device=device,
44
+ categorical_action_encoding=True,
45
+ )
46
+ env = TransformedEnv(env)
47
+ env.append_transform(NoopResetEnv(noops=30, random=True))
48
+ if not is_test:
49
+ env.append_transform(EndOfLifeTransform())
50
+ env.append_transform(SignTransform(in_keys=["reward"]))
51
+ env.append_transform(ToTensorImage())
52
+ env.append_transform(GrayScale())
53
+ env.append_transform(Resize(84, 84))
54
+ env.append_transform(CatFrames(N=4, dim=-3))
55
+ env.append_transform(RewardSum())
56
+ env.append_transform(StepCounter(max_steps=4500))
57
+ env.append_transform(DoubleToFloat())
58
+ env.append_transform(VecNorm(in_keys=["pixels"]))
59
+ return env
60
+
61
+
62
+ # ====================================================================
63
+ # Model utils
64
+ # --------------------------------------------------------------------
65
+
66
+
67
+ def make_dqn_modules_pixels(proof_environment, device):
68
+
69
+ # Define input shape
70
+ input_shape = proof_environment.observation_spec["pixels"].shape
71
+ env_specs = proof_environment.specs
72
+ num_actions = env_specs["input_spec", "full_action_spec", "action"].space.n
73
+ action_spec = env_specs["input_spec", "full_action_spec", "action"]
74
+
75
+ # Define Q-Value Module
76
+ cnn = ConvNet(
77
+ activation_class=torch.nn.ReLU,
78
+ num_cells=[32, 64, 64],
79
+ kernel_sizes=[8, 4, 3],
80
+ strides=[4, 2, 1],
81
+ device=device,
82
+ )
83
+ cnn_output = cnn(torch.ones(input_shape, device=device))
84
+ mlp = MLP(
85
+ in_features=cnn_output.shape[-1],
86
+ activation_class=torch.nn.ReLU,
87
+ out_features=num_actions,
88
+ num_cells=[512],
89
+ device=device,
90
+ )
91
+ qvalue_module = QValueActor(
92
+ module=torch.nn.Sequential(cnn, mlp),
93
+ spec=Composite(action=action_spec).to(device),
94
+ in_keys=["pixels"],
95
+ )
96
+ return qvalue_module
97
+
98
+
99
+ def make_dqn_model(env_name, gym_backend, frame_skip, device):
100
+ proof_environment = make_env(
101
+ env_name, frame_skip, gym_backend=gym_backend, device=device
102
+ )
103
+ qvalue_module = make_dqn_modules_pixels(proof_environment, device=device)
104
+ del proof_environment
105
+ return qvalue_module
106
+
107
+
108
+ # ====================================================================
109
+ # Evaluation utils
110
+ # --------------------------------------------------------------------
111
+
112
+
113
+ def eval_model(actor, test_env, num_episodes=3):
114
+ test_rewards = torch.zeros(num_episodes, dtype=torch.float32)
115
+ for i in range(num_episodes):
116
+ td_test = test_env.rollout(
117
+ policy=actor,
118
+ auto_reset=True,
119
+ auto_cast_to_device=True,
120
+ break_when_any_done=True,
121
+ max_steps=10_000_000,
122
+ )
123
+ test_env.apply(dump_video)
124
+ reward = td_test["next", "episode_reward"][td_test["next", "done"]]
125
+ test_rewards[i] = reward.sum()
126
+ del td_test
127
+ return test_rewards.mean()
128
+
129
+
130
+ def dump_video(module):
131
+ if isinstance(module, VideoRecorder):
132
+ module.dump()