torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,11 @@
1
+ vllm==0.11.0
2
+ peft
3
+ bitsandbytes
4
+ datasets
5
+ wandb
6
+ hydra-core
7
+ ray
8
+ tqdm
9
+ tensordict
10
+ accelerate
11
+ xformers
@@ -0,0 +1,16 @@
1
+ vllm==0.11.0
2
+ torch
3
+ transformers
4
+ peft
5
+ bitsandbytes
6
+ datasets
7
+ wandb
8
+ hydra-core
9
+ ray
10
+ tqdm
11
+ tensordict
12
+ accelerate
13
+ xformers
14
+ nltk
15
+ langdetect
16
+ immutabledict
@@ -0,0 +1,33 @@
1
+ ## Reproducing Importance Weighted Actor-Learner Architecture (IMPALA) Algorithm Results
2
+
3
+ This repository contains scripts that enable training agents using the IMPALA Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Espeholt et al. 2018.
4
+
5
+ ## Examples Structure
6
+
7
+ Please note that we provide 2 examples, one for single node training and one for distributed training. Both examples rely on the same utils file, but besides that are independent. Each example contains the following files:
8
+
9
+ 1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. impala_single_node_ray.py).
10
+
11
+ 2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils.py).
12
+
13
+ 3. **Configuration File:** This file includes default hyperparameters specified in the original paper. For the multi-node case, the file also includes the configuration file of the Ray cluster. Users can modify these hyperparameters to customize their experiments (e.g. config_single_node.yaml).
14
+
15
+
16
+ ## Running the Examples
17
+
18
+ You can execute the single node IMPALA algorithm on Atari environments by running the following command:
19
+
20
+ ```bash
21
+ python impala_single_node.py
22
+ ```
23
+
24
+ You can execute the multi-node IMPALA algorithm on Atari environments by running the following command:
25
+
26
+ ```bash
27
+ python impala_single_node_ray.py
28
+ ```
29
+ or
30
+
31
+ ```bash
32
+ python impala_single_node_submitit.py
33
+ ```
@@ -0,0 +1,292 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ This script reproduces the IMPALA Algorithm
8
+ results from Espeholt et al. 2018 for the on Atari Environments.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import hydra
13
+ from torchrl._utils import logger as torchrl_logger
14
+
15
+
16
+ @hydra.main(config_path="", config_name="config_multi_node_ray", version_base="1.1")
17
+ def main(cfg: DictConfig): # noqa: F821
18
+
19
+ import time
20
+
21
+ import torch.optim
22
+ import tqdm
23
+
24
+ from tensordict import TensorDict
25
+ from torchrl.collectors import SyncDataCollector
26
+ from torchrl.collectors.distributed import RayCollector
27
+ from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
28
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
29
+ from torchrl.envs import ExplorationType, set_exploration_type
30
+ from torchrl.objectives import A2CLoss
31
+ from torchrl.objectives.value import VTrace
32
+ from torchrl.record.loggers import generate_exp_name, get_logger
33
+ from utils import eval_model, make_env, make_ppo_models
34
+
35
+ device = cfg.local_device
36
+ if not device:
37
+ device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
38
+ else:
39
+ device = torch.device(device)
40
+
41
+ # Correct for frame_skip
42
+ frame_skip = 4
43
+ total_frames = cfg.collector.total_frames // frame_skip
44
+ frames_per_batch = cfg.collector.frames_per_batch // frame_skip
45
+ test_interval = cfg.logger.test_interval // frame_skip
46
+
47
+ # Extract other config parameters
48
+ batch_size = cfg.loss.batch_size # Number of rollouts per batch
49
+ num_workers = (
50
+ cfg.collector.num_workers
51
+ ) # Number of parallel workers collecting rollouts
52
+ lr = cfg.optim.lr
53
+ anneal_lr = cfg.optim.anneal_lr
54
+ sgd_updates = cfg.loss.sgd_updates
55
+ max_grad_norm = cfg.optim.max_grad_norm
56
+ num_test_episodes = cfg.logger.num_test_episodes
57
+ total_network_updates = (
58
+ total_frames // (frames_per_batch * batch_size)
59
+ ) * cfg.loss.sgd_updates
60
+
61
+ # Create models (check utils.py)
62
+ actor, critic = make_ppo_models(cfg.env.env_name, cfg.env.backend)
63
+ actor, critic = actor.to(device), critic.to(device)
64
+
65
+ # Create collector
66
+ ray_init_config = {
67
+ "address": cfg.ray_init_config.address,
68
+ "num_cpus": cfg.ray_init_config.num_cpus,
69
+ "num_gpus": cfg.ray_init_config.num_gpus,
70
+ "resources": cfg.ray_init_config.resources,
71
+ "object_store_memory": cfg.ray_init_config.object_store_memory,
72
+ "local_mode": cfg.ray_init_config.local_mode,
73
+ "ignore_reinit_error": cfg.ray_init_config.ignore_reinit_error,
74
+ "include_dashboard": cfg.ray_init_config.include_dashboard,
75
+ "dashboard_host": cfg.ray_init_config.dashboard_host,
76
+ "dashboard_port": cfg.ray_init_config.dashboard_port,
77
+ "job_config": cfg.ray_init_config.job_config,
78
+ "configure_logging": cfg.ray_init_config.configure_logging,
79
+ "logging_level": cfg.ray_init_config.logging_level,
80
+ "logging_format": cfg.ray_init_config.logging_format,
81
+ "log_to_driver": cfg.ray_init_config.log_to_driver,
82
+ "namespace": cfg.ray_init_config.namespace,
83
+ "runtime_env": cfg.ray_init_config.runtime_env,
84
+ "storage": cfg.ray_init_config.storage,
85
+ }
86
+ remote_config = {
87
+ "num_cpus": cfg.remote_worker_resources.num_cpus,
88
+ "num_gpus": cfg.remote_worker_resources.num_gpus
89
+ if torch.cuda.device_count()
90
+ else 0,
91
+ "memory": cfg.remote_worker_resources.memory,
92
+ }
93
+ collector = RayCollector(
94
+ create_env_fn=[make_env(cfg.env.env_name, device, gym_backend=cfg.env.backend)]
95
+ * num_workers,
96
+ policy=actor,
97
+ collector_class=SyncDataCollector,
98
+ frames_per_batch=frames_per_batch,
99
+ total_frames=total_frames,
100
+ max_frames_per_traj=-1,
101
+ ray_init_config=ray_init_config,
102
+ remote_configs=remote_config,
103
+ sync=False,
104
+ update_after_each_batch=True,
105
+ )
106
+
107
+ # Create data buffer
108
+ sampler = SamplerWithoutReplacement()
109
+ data_buffer = TensorDictReplayBuffer(
110
+ storage=LazyMemmapStorage(frames_per_batch * batch_size),
111
+ sampler=sampler,
112
+ batch_size=frames_per_batch * batch_size,
113
+ )
114
+
115
+ # Create loss and adv modules
116
+ adv_module = VTrace(
117
+ gamma=cfg.loss.gamma,
118
+ value_network=critic,
119
+ actor_network=actor,
120
+ average_adv=False,
121
+ )
122
+ loss_module = A2CLoss(
123
+ actor_network=actor,
124
+ critic_network=critic,
125
+ loss_critic_type=cfg.loss.loss_critic_type,
126
+ entropy_coeff=cfg.loss.entropy_coeff,
127
+ critic_coeff=cfg.loss.critic_coeff,
128
+ )
129
+ loss_module.set_keys(done="eol", terminated="eol")
130
+
131
+ # Create optimizer
132
+ optim = torch.optim.RMSprop(
133
+ loss_module.parameters(),
134
+ lr=cfg.optim.lr,
135
+ weight_decay=cfg.optim.weight_decay,
136
+ eps=cfg.optim.eps,
137
+ alpha=cfg.optim.alpha,
138
+ )
139
+
140
+ # Create logger
141
+ logger = None
142
+ if cfg.logger.backend:
143
+ exp_name = generate_exp_name(
144
+ "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}"
145
+ )
146
+ logger = get_logger(
147
+ cfg.logger.backend,
148
+ logger_name="impala",
149
+ experiment_name=exp_name,
150
+ wandb_kwargs={
151
+ "config": dict(cfg),
152
+ "project": cfg.logger.project_name,
153
+ "group": cfg.logger.group_name,
154
+ },
155
+ )
156
+
157
+ # Create test environment
158
+ test_env = make_env(
159
+ cfg.env.env_name, device, gym_backend=cfg.env.backend, is_test=True
160
+ )
161
+ test_env.eval()
162
+
163
+ # Main loop
164
+ collected_frames = 0
165
+ num_network_updates = 0
166
+ pbar = tqdm.tqdm(total=total_frames)
167
+ accumulator = []
168
+ start_time = sampling_start = time.time()
169
+ for i, data in enumerate(collector):
170
+
171
+ metrics_to_log = {}
172
+ sampling_time = time.time() - sampling_start
173
+ frames_in_batch = data.numel()
174
+ collected_frames += frames_in_batch * frame_skip
175
+ pbar.update(data.numel())
176
+
177
+ # Get training rewards and episode lengths
178
+ episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
179
+ if len(episode_rewards) > 0:
180
+ episode_length = data["next", "step_count"][data["next", "terminated"]]
181
+ metrics_to_log.update(
182
+ {
183
+ "train/reward": episode_rewards.mean().item(),
184
+ "train/episode_length": episode_length.sum().item()
185
+ / len(episode_length),
186
+ }
187
+ )
188
+
189
+ if len(accumulator) < batch_size:
190
+ accumulator.append(data)
191
+ if logger:
192
+ for key, value in metrics_to_log.items():
193
+ logger.log_scalar(key, value, collected_frames)
194
+ continue
195
+
196
+ losses = TensorDict(batch_size=[sgd_updates])
197
+ training_start = time.time()
198
+ for j in range(sgd_updates):
199
+
200
+ # Create a single batch of trajectories
201
+ stacked_data = torch.stack(accumulator, dim=0).contiguous()
202
+ stacked_data = stacked_data.to(device, non_blocking=True)
203
+
204
+ # Compute advantage
205
+ with torch.no_grad():
206
+ stacked_data = adv_module(stacked_data)
207
+
208
+ # Add to replay buffer
209
+ for stacked_d in stacked_data:
210
+ stacked_data_reshape = stacked_d.reshape(-1)
211
+ data_buffer.extend(stacked_data_reshape)
212
+
213
+ for batch in data_buffer:
214
+
215
+ # Linearly decrease the learning rate and clip epsilon
216
+ alpha = 1.0
217
+ if anneal_lr:
218
+ alpha = 1 - (num_network_updates / total_network_updates)
219
+ for group in optim.param_groups:
220
+ group["lr"] = lr * alpha
221
+ num_network_updates += 1
222
+
223
+ # Get a data batch
224
+ batch = batch.to(device, non_blocking=True)
225
+
226
+ # Forward pass loss
227
+ loss = loss_module(batch)
228
+ losses[j] = loss.select(
229
+ "loss_critic", "loss_entropy", "loss_objective"
230
+ ).detach()
231
+ loss_sum = (
232
+ loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
233
+ )
234
+
235
+ # Backward pass
236
+ loss_sum.backward()
237
+ torch.nn.utils.clip_grad_norm_(
238
+ list(loss_module.parameters()), max_norm=max_grad_norm
239
+ )
240
+
241
+ # Update the networks
242
+ optim.step()
243
+ optim.zero_grad()
244
+
245
+ # Get training losses and times
246
+ training_time = time.time() - training_start
247
+ losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
248
+ for key, value in losses.items():
249
+ metrics_to_log.update({f"train/{key}": value.item()})
250
+ metrics_to_log.update(
251
+ {
252
+ "train/lr": alpha * lr,
253
+ "train/sampling_time": sampling_time,
254
+ "train/training_time": training_time,
255
+ }
256
+ )
257
+
258
+ # Get test rewards
259
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
260
+ if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
261
+ i * frames_in_batch * frame_skip
262
+ ) // test_interval:
263
+ actor.eval()
264
+ eval_start = time.time()
265
+ test_reward = eval_model(
266
+ actor, test_env, num_episodes=num_test_episodes
267
+ )
268
+ eval_time = time.time() - eval_start
269
+ metrics_to_log.update(
270
+ {
271
+ "eval/reward": test_reward,
272
+ "eval/time": eval_time,
273
+ }
274
+ )
275
+ actor.train()
276
+
277
+ if logger:
278
+ for key, value in metrics_to_log.items():
279
+ logger.log_scalar(key, value, collected_frames)
280
+
281
+ collector.update_policy_weights_()
282
+ sampling_start = time.time()
283
+ accumulator = []
284
+
285
+ collector.shutdown()
286
+ end_time = time.time()
287
+ execution_time = end_time - start_time
288
+ torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
289
+
290
+
291
+ if __name__ == "__main__":
292
+ main()
@@ -0,0 +1,284 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ This script reproduces the IMPALA Algorithm
8
+ results from Espeholt et al. 2018 for the on Atari Environments.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import hydra
13
+ from torchrl._utils import logger as torchrl_logger
14
+
15
+
16
+ @hydra.main(
17
+ config_path="", config_name="config_multi_node_submitit", version_base="1.1"
18
+ )
19
+ def main(cfg: DictConfig): # noqa: F821
20
+
21
+ import time
22
+
23
+ import torch.optim
24
+ import tqdm
25
+
26
+ from tensordict import TensorDict
27
+ from torchrl.collectors import SyncDataCollector
28
+ from torchrl.collectors.distributed import DistributedDataCollector
29
+ from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
30
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
31
+ from torchrl.envs import ExplorationType, set_exploration_type
32
+ from torchrl.objectives import A2CLoss
33
+ from torchrl.objectives.value import VTrace
34
+ from torchrl.record.loggers import generate_exp_name, get_logger
35
+ from utils import eval_model, make_env, make_ppo_models
36
+
37
+ device = cfg.local_device
38
+ if not device:
39
+ device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
40
+ else:
41
+ device = torch.device(device)
42
+
43
+ # Correct for frame_skip
44
+ frame_skip = 4
45
+ total_frames = cfg.collector.total_frames // frame_skip
46
+ frames_per_batch = cfg.collector.frames_per_batch // frame_skip
47
+ test_interval = cfg.logger.test_interval // frame_skip
48
+
49
+ # Extract other config parameters
50
+ batch_size = cfg.loss.batch_size # Number of rollouts per batch
51
+ num_workers = (
52
+ cfg.collector.num_workers
53
+ ) # Number of parallel workers collecting rollouts
54
+ lr = cfg.optim.lr
55
+ anneal_lr = cfg.optim.anneal_lr
56
+ sgd_updates = cfg.loss.sgd_updates
57
+ max_grad_norm = cfg.optim.max_grad_norm
58
+ num_test_episodes = cfg.logger.num_test_episodes
59
+ total_network_updates = (
60
+ total_frames // (frames_per_batch * batch_size)
61
+ ) * cfg.loss.sgd_updates
62
+
63
+ # Create models (check utils.py)
64
+ actor, critic = make_ppo_models(cfg.env.env_name, cfg.env.backend)
65
+ actor, critic = actor.to(device), critic.to(device)
66
+
67
+ slurm_kwargs = {
68
+ "timeout_min": cfg.slurm_config.timeout_min,
69
+ "slurm_partition": cfg.slurm_config.slurm_partition,
70
+ "slurm_cpus_per_task": cfg.slurm_config.slurm_cpus_per_task,
71
+ "slurm_gpus_per_node": cfg.slurm_config.slurm_gpus_per_node,
72
+ }
73
+ # Create collector
74
+ device_str = "device" if num_workers <= 1 else "devices"
75
+ if cfg.collector.backend == "nccl":
76
+ collector_kwargs = {device_str: "cuda:0", f"storing_{device_str}": "cuda:0"}
77
+ elif cfg.collector.backend == "gloo":
78
+ collector_kwargs = {device_str: "cpu", f"storing_{device_str}": "cpu"}
79
+ else:
80
+ raise NotImplementedError(
81
+ f"device assignment not implemented for backend {cfg.collector.backend}"
82
+ )
83
+ collector = DistributedDataCollector(
84
+ create_env_fn=[make_env(cfg.env.env_name, device, gym_backend=cfg.env.backend)]
85
+ * num_workers,
86
+ policy=actor,
87
+ num_workers_per_collector=1,
88
+ frames_per_batch=frames_per_batch,
89
+ total_frames=total_frames,
90
+ collector_class=SyncDataCollector,
91
+ collector_kwargs=collector_kwargs,
92
+ slurm_kwargs=slurm_kwargs,
93
+ storing_device="cuda:0" if cfg.collector.backend == "nccl" else "cpu",
94
+ launcher="submitit",
95
+ # update_after_each_batch=True,
96
+ backend=cfg.collector.backend,
97
+ )
98
+
99
+ # Create data buffer
100
+ sampler = SamplerWithoutReplacement()
101
+ data_buffer = TensorDictReplayBuffer(
102
+ storage=LazyMemmapStorage(frames_per_batch * batch_size),
103
+ sampler=sampler,
104
+ batch_size=frames_per_batch * batch_size,
105
+ )
106
+
107
+ # Create loss and adv modules
108
+ adv_module = VTrace(
109
+ gamma=cfg.loss.gamma,
110
+ value_network=critic,
111
+ actor_network=actor,
112
+ average_adv=False,
113
+ )
114
+ loss_module = A2CLoss(
115
+ actor_network=actor,
116
+ critic_network=critic,
117
+ loss_critic_type=cfg.loss.loss_critic_type,
118
+ entropy_coeff=cfg.loss.entropy_coeff,
119
+ critic_coeff=cfg.loss.critic_coeff,
120
+ )
121
+ loss_module.set_keys(done="eol", terminated="eol")
122
+
123
+ # Create optimizer
124
+ optim = torch.optim.RMSprop(
125
+ loss_module.parameters(),
126
+ lr=cfg.optim.lr,
127
+ weight_decay=cfg.optim.weight_decay,
128
+ eps=cfg.optim.eps,
129
+ alpha=cfg.optim.alpha,
130
+ )
131
+
132
+ # Create logger
133
+ logger = None
134
+ if cfg.logger.backend:
135
+ exp_name = generate_exp_name(
136
+ "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}"
137
+ )
138
+ logger = get_logger(
139
+ cfg.logger.backend,
140
+ logger_name="impala",
141
+ experiment_name=exp_name,
142
+ wandb_kwargs={
143
+ "config": dict(cfg),
144
+ "project": cfg.logger.project_name,
145
+ "group": cfg.logger.group_name,
146
+ },
147
+ )
148
+
149
+ # Create test environment
150
+ test_env = make_env(
151
+ cfg.env.env_name, device, gym_backend=cfg.env.backend, is_test=True
152
+ )
153
+ test_env.eval()
154
+
155
+ # Main loop
156
+ collected_frames = 0
157
+ num_network_updates = 0
158
+ pbar = tqdm.tqdm(total=total_frames)
159
+ accumulator = []
160
+ start_time = sampling_start = time.time()
161
+ for i, data in enumerate(collector):
162
+
163
+ metrics_to_log = {}
164
+ sampling_time = time.time() - sampling_start
165
+ frames_in_batch = data.numel()
166
+ collected_frames += frames_in_batch * frame_skip
167
+ pbar.update(data.numel())
168
+
169
+ # Get training rewards and episode lengths
170
+ episode_rewards = data["next", "episode_reward"][data["next", "done"]]
171
+ if len(episode_rewards) > 0:
172
+ episode_length = data["next", "step_count"][data["next", "done"]]
173
+ metrics_to_log.update(
174
+ {
175
+ "train/reward": episode_rewards.mean().item(),
176
+ "train/episode_length": episode_length.sum().item()
177
+ / len(episode_length),
178
+ }
179
+ )
180
+
181
+ if len(accumulator) < batch_size:
182
+ accumulator.append(data)
183
+ if logger:
184
+ for key, value in metrics_to_log.items():
185
+ logger.log_scalar(key, value, collected_frames)
186
+ continue
187
+
188
+ losses = TensorDict(batch_size=[sgd_updates])
189
+ training_start = time.time()
190
+ for j in range(sgd_updates):
191
+
192
+ # Create a single batch of trajectories
193
+ stacked_data = torch.stack(accumulator, dim=0).contiguous()
194
+ stacked_data = stacked_data.to(device, non_blocking=True)
195
+
196
+ # Compute advantage
197
+ with torch.no_grad():
198
+ stacked_data = adv_module(stacked_data)
199
+
200
+ # Add to replay buffer
201
+ for stacked_d in stacked_data:
202
+ stacked_data_reshape = stacked_d.reshape(-1)
203
+ data_buffer.extend(stacked_data_reshape)
204
+
205
+ for batch in data_buffer:
206
+
207
+ # Linearly decrease the learning rate and clip epsilon
208
+ alpha = 1.0
209
+ if anneal_lr:
210
+ alpha = 1 - (num_network_updates / total_network_updates)
211
+ for group in optim.param_groups:
212
+ group["lr"] = lr * alpha
213
+ num_network_updates += 1
214
+
215
+ # Get a data batch
216
+ batch = batch.to(device)
217
+
218
+ # Forward pass loss
219
+ loss = loss_module(batch)
220
+ losses[j] = loss.select(
221
+ "loss_critic", "loss_entropy", "loss_objective"
222
+ ).detach()
223
+ loss_sum = (
224
+ loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
225
+ )
226
+
227
+ # Backward pass
228
+ loss_sum.backward()
229
+ torch.nn.utils.clip_grad_norm_(
230
+ list(loss_module.parameters()), max_norm=max_grad_norm
231
+ )
232
+
233
+ # Update the networks
234
+ optim.step()
235
+ optim.zero_grad()
236
+
237
+ # Get training losses and times
238
+ training_time = time.time() - training_start
239
+ losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
240
+ for key, value in losses.items():
241
+ metrics_to_log.update({f"train/{key}": value.item()})
242
+ metrics_to_log.update(
243
+ {
244
+ "train/lr": alpha * lr,
245
+ "train/sampling_time": sampling_time,
246
+ "train/training_time": training_time,
247
+ }
248
+ )
249
+
250
+ # Get test rewards
251
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
252
+ if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
253
+ i * frames_in_batch * frame_skip
254
+ ) // test_interval:
255
+ actor.eval()
256
+ eval_start = time.time()
257
+ test_reward = eval_model(
258
+ actor, test_env, num_episodes=num_test_episodes
259
+ )
260
+ eval_time = time.time() - eval_start
261
+ metrics_to_log.update(
262
+ {
263
+ "eval/reward": test_reward,
264
+ "eval/time": eval_time,
265
+ }
266
+ )
267
+ actor.train()
268
+
269
+ if logger:
270
+ for key, value in metrics_to_log.items():
271
+ logger.log_scalar(key, value, collected_frames)
272
+
273
+ collector.update_policy_weights_()
274
+ sampling_start = time.time()
275
+ accumulator = []
276
+
277
+ collector.shutdown()
278
+ end_time = time.time()
279
+ execution_time = end_time - start_time
280
+ torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
281
+
282
+
283
+ if __name__ == "__main__":
284
+ main()