torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,231 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import os
8
+ import pickle
9
+
10
+ import time
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+
15
+ import ray
16
+
17
+ import vmas
18
+ from matplotlib import pyplot as plt
19
+ from ray import tune
20
+
21
+ from ray.rllib.agents.ppo import PPOTrainer
22
+ from ray.rllib.algorithms.callbacks import DefaultCallbacks
23
+ from ray.tune import register_env
24
+ from torchrl._utils import logger as torchrl_logger
25
+ from torchrl.collectors import SyncDataCollector
26
+ from torchrl.envs.libs.vmas import VmasEnv
27
+ from vmas import Wrapper
28
+
29
+
30
+ def store_pickled_evaluation(name: str, evaluation: dict):
31
+ save_folder = f"{os.path.dirname(os.path.realpath(__file__))}"
32
+ file = f"{save_folder}/{name}.pkl"
33
+
34
+ pickle.dump(evaluation, open(file, "wb"))
35
+
36
+
37
+ def load_pickled_evaluation(
38
+ name: str,
39
+ ):
40
+ save_folder = f"{os.path.dirname(os.path.realpath(__file__))}"
41
+ file = Path(f"{save_folder}/{name}.pkl")
42
+
43
+ if file.is_file():
44
+ return pickle.load(open(file, "rb"))
45
+ return None
46
+
47
+
48
+ def run_vmas_torchrl(
49
+ scenario_name: str, n_envs: int, n_steps: int, device: str, seed: int = 0
50
+ ):
51
+ env = VmasEnv(
52
+ scenario_name,
53
+ device=device,
54
+ num_envs=n_envs,
55
+ continuous_actions=False,
56
+ seed=seed,
57
+ )
58
+
59
+ collector = SyncDataCollector(
60
+ env,
61
+ policy=None,
62
+ device=device,
63
+ frames_per_batch=n_envs * n_steps,
64
+ total_frames=n_envs * n_steps,
65
+ )
66
+
67
+ init_time = time.time()
68
+
69
+ for _data in collector:
70
+ pass
71
+
72
+ total_time = time.time() - init_time
73
+ collector.shutdown()
74
+ return total_time
75
+
76
+
77
+ def run_vmas_rllib(
78
+ scenario_name: str, n_envs: int, n_steps: int, device: str, seed: int = 0
79
+ ):
80
+ class TimerCallback(DefaultCallbacks):
81
+ result_time = None
82
+
83
+ def on_train_result(
84
+ self,
85
+ *,
86
+ algorithm,
87
+ result: dict,
88
+ **kwargs,
89
+ ) -> None:
90
+ TimerCallback.result_time = (
91
+ result["timers"]["training_iteration_time_ms"]
92
+ - result["timers"]["learn_time_ms"]
93
+ )
94
+
95
+ def env_creator(config: dict):
96
+ env = vmas.make_env(
97
+ scenario=config["scenario_name"],
98
+ num_envs=config["num_envs"],
99
+ device=config["device"],
100
+ continuous_actions=False,
101
+ wrapper=Wrapper.RLLIB,
102
+ )
103
+ return env
104
+
105
+ if not ray.is_initialized():
106
+ ray.init()
107
+ register_env(scenario_name, lambda config: env_creator(config))
108
+
109
+ num_gpus = 0.5 if device == "cuda" else 0
110
+ num_gpus_per_worker = 0.5 if device == "cuda" else 0
111
+ tune.run(
112
+ PPOTrainer,
113
+ stop={"training_iteration": 1},
114
+ config={
115
+ "seed": seed,
116
+ "framework": "torch",
117
+ "env": scenario_name,
118
+ "train_batch_size": n_envs * n_steps,
119
+ "rollout_fragment_length": n_steps,
120
+ "sgd_minibatch_size": n_envs * n_steps,
121
+ "num_gpus": num_gpus,
122
+ "num_workers": 0,
123
+ "num_gpus_per_worker": num_gpus_per_worker,
124
+ "num_envs_per_worker": n_envs,
125
+ "batch_mode": "truncate_episodes",
126
+ "env_config": {
127
+ "device": device,
128
+ "num_envs": n_envs,
129
+ "scenario_name": scenario_name,
130
+ "max_steps": n_steps,
131
+ },
132
+ "callbacks": TimerCallback,
133
+ },
134
+ )
135
+ assert TimerCallback.result_time is not None
136
+ TimerCallback.result_time /= 1_000 # convert to seconds
137
+ return TimerCallback.result_time
138
+
139
+
140
+ def run_comparison_torchrl_rllib(
141
+ scenario_name: str,
142
+ device: str,
143
+ n_steps: int = 100,
144
+ max_n_envs: int = 3000,
145
+ step_n_envs: int = 3,
146
+ ):
147
+ """
148
+
149
+ Args:
150
+ scenario_name (str): name of scenario to benchmark
151
+ device (str): device to ron comparison on ("cpu" or "cuda")
152
+ n_steps (int): number of environment steps
153
+ max_n_envs (int): the maximum number of parallel environments to test
154
+ step_n_envs (int): the step size in number of environments from 1 to max_n_envs
155
+
156
+ """
157
+ list_n_envs = np.linspace(1, max_n_envs, step_n_envs)
158
+
159
+ figure_name = f"VMAS_{scenario_name}_{n_steps}_{device}_steps_rllib_vs_torchrl"
160
+ figure_name_pkl = figure_name + f"_range_{1}_{max_n_envs}_num_{step_n_envs}"
161
+
162
+ evaluation = load_pickled_evaluation(figure_name_pkl)
163
+ if not evaluation:
164
+ evaluation = {}
165
+ for framework in ["TorchRL", "RLlib"]:
166
+ if framework not in evaluation.keys():
167
+ torchrl_logger.info(f"\nFramework {framework}")
168
+ vmas_times = []
169
+ for n_envs in list_n_envs:
170
+ n_envs = int(n_envs)
171
+ torchrl_logger.info(f"Running {n_envs} environments")
172
+ if framework == "TorchRL":
173
+ vmas_times.append(
174
+ (n_envs * n_steps)
175
+ / run_vmas_torchrl(
176
+ scenario_name=scenario_name,
177
+ n_envs=n_envs,
178
+ n_steps=n_steps,
179
+ device=device,
180
+ )
181
+ )
182
+ else:
183
+ vmas_times.append(
184
+ (n_envs * n_steps)
185
+ / run_vmas_rllib(
186
+ scenario_name=scenario_name,
187
+ n_envs=n_envs,
188
+ n_steps=n_steps,
189
+ device=device,
190
+ )
191
+ )
192
+ torchrl_logger.info(f"fps {vmas_times[-1]}s")
193
+ evaluation[framework] = vmas_times
194
+
195
+ store_pickled_evaluation(name=figure_name_pkl, evaluation=evaluation)
196
+
197
+ fig, ax = plt.subplots()
198
+ for key, item in evaluation.items():
199
+ ax.plot(
200
+ list_n_envs,
201
+ item,
202
+ label=key,
203
+ )
204
+
205
+ plt.xlabel("Number of batched environments", fontsize=14)
206
+ plt.ylabel("Frames per second", fontsize=14)
207
+ ax.legend(loc="upper left")
208
+
209
+ ax.set_title(
210
+ f"Execution time of '{scenario_name}' for {n_steps} steps on {device}.",
211
+ fontsize=8,
212
+ )
213
+
214
+ save_folder = os.path.dirname(os.path.realpath(__file__))
215
+ plt.savefig(f"{save_folder}/{figure_name}.pdf")
216
+
217
+
218
+ if __name__ == "__main__":
219
+ # pip install matplotlib
220
+ # pip install "ray[rllib]"==2.1.0
221
+ # pip install torchrl
222
+ # pip install vmas
223
+ # pip install numpy==1.23.5
224
+
225
+ run_comparison_torchrl_rllib(
226
+ scenario_name="simple_spread",
227
+ device="cuda",
228
+ n_steps=100,
229
+ max_n_envs=30000,
230
+ step_n_envs=10,
231
+ )
@@ -0,0 +1,7 @@
1
+ pytest-benchmark
2
+ tenacity
3
+ safetensors
4
+ tqdm
5
+ pandas
6
+ numpy
7
+ matplotlib
@@ -0,0 +1,193 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ Sample latency benchmarking (using RPC)
8
+ ======================================
9
+ A rough benchmark of sample latency using different storage types over the network using `torch.rpc`.
10
+ Run this script with --rank=0 and --rank=1 flags set in separate processes - these ranks correspond to the trainer worker and buffer worker respectively, and both need to be initialised.
11
+ e.g. to benchmark LazyMemmapStorage, run the following commands using either two separate shells or multiprocessing.
12
+ - python3 benchmark_sample_latency_over_rpc.py --rank=0 --storage=LazyMemmapStorage
13
+ - python3 benchmark_sample_latency_over_rpc.py --rank=1 --storage=LazyMemmapStorage
14
+ This code is based on examples/distributed/distributed_replay_buffer.py.
15
+ """
16
+ import argparse
17
+ import os
18
+ import pickle
19
+ import sys
20
+ import time
21
+ import timeit
22
+ from datetime import datetime
23
+
24
+ import torch
25
+ import torch.distributed.rpc as rpc
26
+ from tensordict import TensorDict
27
+ from torchrl._utils import logger as torchrl_logger
28
+ from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer
29
+ from torchrl.data.replay_buffers.samplers import RandomSampler
30
+ from torchrl.data.replay_buffers.storages import (
31
+ LazyMemmapStorage,
32
+ LazyTensorStorage,
33
+ ListStorage,
34
+ )
35
+ from torchrl.data.replay_buffers.writers import RoundRobinWriter
36
+
37
+ RETRY_LIMIT = 2
38
+ RETRY_DELAY_SECS = 3
39
+ REPLAY_BUFFER_NODE = "ReplayBuffer"
40
+ TRAINER_NODE = "Trainer"
41
+ TENSOR_SIZE = 3 * 86 * 86
42
+ BUFFER_SIZE = 1001
43
+ BATCH_SIZE = 256
44
+ REPEATS = 1000
45
+
46
+ storage_options = {
47
+ "LazyMemmapStorage": LazyMemmapStorage,
48
+ "LazyTensorStorage": LazyTensorStorage,
49
+ "ListStorage": ListStorage,
50
+ }
51
+
52
+ storage_arg_options = {
53
+ "LazyMemmapStorage": {"scratch_dir": "/tmp/", "device": torch.device("cpu")},
54
+ "LazyTensorStorage": {},
55
+ "ListStorage": {},
56
+ }
57
+ parser = argparse.ArgumentParser(
58
+ description="RPC Replay Buffer Example",
59
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
60
+ )
61
+
62
+ parser.add_argument(
63
+ "--rank",
64
+ type=int,
65
+ default=-1,
66
+ help="Node Rank [0 = Replay Buffer, 1 = Dummy Trainer, 2+ = Dummy Data Collector]",
67
+ )
68
+
69
+ parser.add_argument(
70
+ "--storage",
71
+ type=str,
72
+ default="LazyMemmapStorage",
73
+ help="Storage type [LazyMemmapStorage, LazyTensorStorage, ListStorage]",
74
+ )
75
+
76
+
77
+ class DummyTrainerNode:
78
+ def __init__(self) -> None:
79
+ self.id = rpc.get_worker_info().id
80
+ self.replay_buffer = self._create_replay_buffer()
81
+ self._ret = None
82
+
83
+ def train(self, batch_size: int) -> None:
84
+ start_time = timeit.default_timer()
85
+ ret = rpc.rpc_sync(
86
+ self.replay_buffer.owner(),
87
+ ReplayBufferNode.sample,
88
+ args=(self.replay_buffer, batch_size),
89
+ )
90
+ if storage_type == "ListStorage":
91
+ self._ret = ret[0]
92
+ else:
93
+ if self._ret is None:
94
+ self._ret = ret
95
+ else:
96
+ self._ret.update_(ret)
97
+ # make sure the content is read
98
+ self._ret["observation"] + 1
99
+ self._ret["next_observation"] + 1
100
+ return timeit.default_timer() - start_time
101
+
102
+ def _create_replay_buffer(self) -> rpc.RRef:
103
+ while True:
104
+ try:
105
+ replay_buffer_info = rpc.get_worker_info(REPLAY_BUFFER_NODE)
106
+ buffer_rref = rpc.remote(
107
+ replay_buffer_info, ReplayBufferNode, args=(1000000,)
108
+ )
109
+ torchrl_logger.info(f"Connected to replay buffer {replay_buffer_info}")
110
+ return buffer_rref
111
+ except Exception:
112
+ torchrl_logger.info("Failed to connect to replay buffer")
113
+ time.sleep(RETRY_DELAY_SECS)
114
+
115
+
116
+ class ReplayBufferNode(RemoteTensorDictReplayBuffer):
117
+ def __init__(self, capacity: int):
118
+ super().__init__(
119
+ storage=storage_options[storage_type](
120
+ max_size=capacity, **storage_arg_options[storage_type]
121
+ ),
122
+ sampler=RandomSampler(),
123
+ writer=RoundRobinWriter(),
124
+ collate_fn=lambda x: x,
125
+ )
126
+ tds = TensorDict(
127
+ {
128
+ "observation": torch.randn(
129
+ BUFFER_SIZE,
130
+ TENSOR_SIZE,
131
+ ),
132
+ "next_observation": torch.randn(
133
+ BUFFER_SIZE,
134
+ TENSOR_SIZE,
135
+ ),
136
+ },
137
+ batch_size=[BUFFER_SIZE],
138
+ )
139
+ self.extend(tds)
140
+
141
+
142
+ if __name__ == "__main__":
143
+ args = parser.parse_args()
144
+ rank = args.rank
145
+ storage_type = args.storage
146
+
147
+ torchrl_logger.debug(f"RANK: {rank}; Storage: {storage_type}")
148
+
149
+ os.environ["MASTER_ADDR"] = "localhost"
150
+ os.environ["MASTER_PORT"] = "29500"
151
+ os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
152
+ options = rpc.TensorPipeRpcBackendOptions(
153
+ num_worker_threads=16, init_method="tcp://localhost:10002", rpc_timeout=120
154
+ )
155
+ if rank == 0:
156
+ # rank 0 is the trainer
157
+ rpc.init_rpc(
158
+ TRAINER_NODE,
159
+ rank=rank,
160
+ backend=rpc.BackendType.TENSORPIPE,
161
+ rpc_backend_options=options,
162
+ )
163
+ trainer = DummyTrainerNode()
164
+ results = []
165
+ for i in range(REPEATS):
166
+ result = trainer.train(batch_size=BATCH_SIZE)
167
+ if i == 0:
168
+ continue
169
+ results.append(result)
170
+ torchrl_logger.info(f"{i}, {results[-1]}")
171
+
172
+ with open(
173
+ f'./benchmark_{datetime.now().strftime("%d-%m-%Y%H:%M:%S")};batch_size={BATCH_SIZE};tensor_size={TENSOR_SIZE};repeat={REPEATS};storage={storage_type}.pkl',
174
+ "wb+",
175
+ ) as f:
176
+ pickle.dump(results, f)
177
+
178
+ tensor_results = torch.tensor(results)
179
+ torchrl_logger.info(f"Mean: {torch.mean(tensor_results)}")
180
+ breakpoint()
181
+ elif rank == 1:
182
+ # rank 1 is the replay buffer
183
+ # replay buffer waits passively for construction instructions from trainer node
184
+ rpc.init_rpc(
185
+ REPLAY_BUFFER_NODE,
186
+ rank=rank,
187
+ backend=rpc.BackendType.TENSORPIPE,
188
+ rpc_backend_options=options,
189
+ )
190
+ breakpoint()
191
+ else:
192
+ sys.exit(1)
193
+ rpc.shutdown()
@@ -0,0 +1,240 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import argparse
6
+ import time
7
+
8
+ import pytest
9
+ import torch.cuda
10
+ import tqdm
11
+
12
+ from torchrl.collectors import (
13
+ MultiaSyncDataCollector,
14
+ MultiSyncDataCollector,
15
+ SyncDataCollector,
16
+ )
17
+ from torchrl.data import LazyTensorStorage, ReplayBuffer
18
+ from torchrl.data.utils import CloudpickleWrapper
19
+ from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, StepCounter, TransformedEnv
20
+ from torchrl.envs.libs.dm_control import DMControlEnv
21
+ from torchrl.modules import RandomPolicy
22
+
23
+
24
+ def single_collector_setup():
25
+ device = "cuda:0" if torch.cuda.device_count() else "cpu"
26
+ env = TransformedEnv(DMControlEnv("cheetah", "run", device=device), StepCounter(50))
27
+ c = SyncDataCollector(
28
+ env,
29
+ RandomPolicy(env.action_spec),
30
+ total_frames=-1,
31
+ frames_per_batch=100,
32
+ device=device,
33
+ )
34
+ c = iter(c)
35
+ for i, _ in enumerate(c):
36
+ if i == 10:
37
+ break
38
+ return ((c,), {})
39
+
40
+
41
+ def sync_collector_setup():
42
+ device = "cuda:0" if torch.cuda.device_count() else "cpu"
43
+ env = EnvCreator(
44
+ lambda: TransformedEnv(
45
+ DMControlEnv("cheetah", "run", device=device), StepCounter(50)
46
+ )
47
+ )
48
+ c = MultiSyncDataCollector(
49
+ [env, env],
50
+ RandomPolicy(env().action_spec),
51
+ total_frames=-1,
52
+ frames_per_batch=100,
53
+ device=device,
54
+ )
55
+ c = iter(c)
56
+ for i, _ in enumerate(c):
57
+ if i == 10:
58
+ break
59
+ return ((c,), {})
60
+
61
+
62
+ def async_collector_setup():
63
+ device = "cuda:0" if torch.cuda.device_count() else "cpu"
64
+ env = EnvCreator(
65
+ lambda: TransformedEnv(
66
+ DMControlEnv("cheetah", "run", device=device), StepCounter(50)
67
+ )
68
+ )
69
+ c = MultiaSyncDataCollector(
70
+ [env, env],
71
+ RandomPolicy(env().action_spec),
72
+ total_frames=-1,
73
+ frames_per_batch=100,
74
+ device=device,
75
+ )
76
+ c = iter(c)
77
+ for i, _ in enumerate(c):
78
+ if i == 10:
79
+ break
80
+ return ((c,), {})
81
+
82
+
83
+ def single_collector_setup_pixels():
84
+ device = "cuda:0" if torch.cuda.device_count() else "cpu"
85
+ # env = TransformedEnv(
86
+ # DMControlEnv("cheetah", "run", device=device, from_pixels=True), StepCounter(50)
87
+ # )
88
+ env = TransformedEnv(GymEnv("ALE/Pong-v5"), StepCounter(50))
89
+ c = SyncDataCollector(
90
+ env,
91
+ RandomPolicy(env.action_spec),
92
+ total_frames=-1,
93
+ frames_per_batch=100,
94
+ device=device,
95
+ )
96
+ c = iter(c)
97
+ for i, _ in enumerate(c):
98
+ if i == 10:
99
+ break
100
+ return ((c,), {})
101
+
102
+
103
+ def sync_collector_setup_pixels():
104
+ device = "cuda:0" if torch.cuda.device_count() else "cpu"
105
+ env = EnvCreator(
106
+ lambda: TransformedEnv(
107
+ # DMControlEnv("cheetah", "run", device=device, from_pixels=True),
108
+ GymEnv("ALE/Pong-v5"),
109
+ StepCounter(50),
110
+ )
111
+ )
112
+ c = MultiSyncDataCollector(
113
+ [env, env],
114
+ RandomPolicy(env().action_spec),
115
+ total_frames=-1,
116
+ frames_per_batch=100,
117
+ device=device,
118
+ )
119
+ c = iter(c)
120
+ for i, _ in enumerate(c):
121
+ if i == 10:
122
+ break
123
+ return ((c,), {})
124
+
125
+
126
+ def async_collector_setup_pixels():
127
+ device = "cuda:0" if torch.cuda.device_count() else "cpu"
128
+ env = EnvCreator(
129
+ lambda: TransformedEnv(
130
+ # DMControlEnv("cheetah", "run", device=device, from_pixels=True),
131
+ GymEnv("ALE/Pong-v5"),
132
+ StepCounter(50),
133
+ )
134
+ )
135
+ c = MultiaSyncDataCollector(
136
+ [env, env],
137
+ RandomPolicy(env().action_spec),
138
+ total_frames=-1,
139
+ frames_per_batch=100,
140
+ device=device,
141
+ )
142
+ c = iter(c)
143
+ for i, _ in enumerate(c):
144
+ if i == 10:
145
+ break
146
+ return ((c,), {})
147
+
148
+
149
+ def execute_collector(c):
150
+ # will run for 9 iterations (1 during setup)
151
+ next(c)
152
+
153
+
154
+ def test_single(benchmark):
155
+ (c,), _ = single_collector_setup()
156
+ benchmark(execute_collector, c)
157
+
158
+
159
+ def test_sync(benchmark):
160
+ (c,), _ = sync_collector_setup()
161
+ benchmark(execute_collector, c)
162
+
163
+
164
+ def test_async(benchmark):
165
+ (c,), _ = async_collector_setup()
166
+ benchmark(execute_collector, c)
167
+
168
+
169
+ @pytest.mark.skipif(not torch.cuda.device_count(), reason="no rendering without cuda")
170
+ def test_single_pixels(benchmark):
171
+ (c,), _ = single_collector_setup_pixels()
172
+ benchmark(execute_collector, c)
173
+
174
+
175
+ @pytest.mark.skipif(not torch.cuda.device_count(), reason="no rendering without cuda")
176
+ def test_sync_pixels(benchmark):
177
+ (c,), _ = sync_collector_setup_pixels()
178
+ benchmark(execute_collector, c)
179
+
180
+
181
+ @pytest.mark.skipif(not torch.cuda.device_count(), reason="no rendering without cuda")
182
+ def test_async_pixels(benchmark):
183
+ (c,), _ = async_collector_setup_pixels()
184
+ benchmark(execute_collector, c)
185
+
186
+
187
+ class TestRBGCollector:
188
+ @pytest.mark.parametrize(
189
+ "n_col,n_wokrers_per_col",
190
+ [
191
+ [2, 2],
192
+ [4, 2],
193
+ [8, 2],
194
+ [16, 2],
195
+ [2, 1],
196
+ [4, 1],
197
+ [8, 1],
198
+ [16, 1],
199
+ ],
200
+ )
201
+ def test_multiasync_rb(self, n_col, n_wokrers_per_col):
202
+ make_env = EnvCreator(lambda: GymEnv("ALE/Pong-v5"))
203
+ if n_wokrers_per_col > 1:
204
+ make_env = ParallelEnv(n_wokrers_per_col, make_env)
205
+ env = make_env
206
+ policy = RandomPolicy(env.action_spec)
207
+ else:
208
+ env = make_env()
209
+ policy = RandomPolicy(env.action_spec)
210
+
211
+ storage = LazyTensorStorage(10_000)
212
+ rb = ReplayBuffer(storage=storage)
213
+ rb.extend(env.rollout(2, policy).reshape(-1))
214
+ rb.append_transform(CloudpickleWrapper(lambda x: x.reshape(-1)), invert=True)
215
+
216
+ fpb = n_wokrers_per_col * 100
217
+ total_frames = n_wokrers_per_col * 100_000
218
+ c = MultiaSyncDataCollector(
219
+ [make_env] * n_col,
220
+ policy,
221
+ frames_per_batch=fpb,
222
+ total_frames=total_frames,
223
+ replay_buffer=rb,
224
+ )
225
+ frames = 0
226
+ pbar = tqdm.tqdm(total=total_frames - (n_col * fpb))
227
+ for i, _ in enumerate(c):
228
+ if i == n_col:
229
+ t0 = time.time()
230
+ if i >= n_col:
231
+ frames += fpb
232
+ if i > n_col:
233
+ fps = frames / (time.time() - t0)
234
+ pbar.update(fpb)
235
+ pbar.set_description(f"fps: {fps: 4.4f}")
236
+
237
+
238
+ if __name__ == "__main__":
239
+ args, unknown = argparse.ArgumentParser().parse_known_args()
240
+ pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)