torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1105 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import abc
8
+ import multiprocessing
9
+
10
+ from collections.abc import Callable, Mapping, Sequence
11
+ from concurrent.futures import as_completed, ThreadPoolExecutor
12
+
13
+ # import queue
14
+ from multiprocessing import Queue
15
+ from queue import Empty
16
+ from typing import Literal
17
+
18
+ import torch
19
+ from tensordict import (
20
+ lazy_stack,
21
+ LazyStackedTensorDict,
22
+ maybe_dense_stack,
23
+ TensorDict,
24
+ TensorDictBase,
25
+ )
26
+
27
+ from tensordict.tensorclass import NonTensorData, NonTensorStack
28
+ from tensordict.utils import _zip_strict, expand_as_right
29
+
30
+ from torchrl.data.tensor_specs import NonTensor
31
+ from torchrl.envs.common import _EnvPostInit, EnvBase
32
+
33
+
34
+ class _AsyncEnvMeta(_EnvPostInit):
35
+ """A metaclass for asynchronous environment pools that determines the backend implementation to use based on the provided arguments.
36
+
37
+ This class is responsible for instantiating the appropriate subclass of `AsyncEnvPool` based on the specified
38
+ backend, such as threading or multiprocessing.
39
+ """
40
+
41
+ def __call__(cls, *args, **kwargs):
42
+ backend = kwargs.get("backend", "threading")
43
+ if cls is AsyncEnvPool:
44
+ if backend == "threading":
45
+ instance: ThreadingAsyncEnvPool = ThreadingAsyncEnvPool(*args, **kwargs)
46
+ elif backend == "multiprocessing":
47
+ instance: ProcessorAsyncEnvPool = ProcessorAsyncEnvPool(*args, **kwargs)
48
+ elif backend == "asyncio":
49
+ raise NotImplementedError
50
+ # instance: AsyncioAsyncEnvPool = AsyncioAsyncEnvPool(*args, **kwargs)
51
+ else:
52
+ raise NotImplementedError
53
+ return instance
54
+ else:
55
+ return super().__call__(*args, **kwargs)
56
+
57
+
58
+ class AsyncEnvPool(EnvBase, metaclass=_AsyncEnvMeta):
59
+ """A base class for asynchronous environment pools, providing a common interface for managing multiple environments concurrently.
60
+
61
+ This class supports different backends for parallel execution, such as threading
62
+ and multiprocessing, and provides methods for asynchronous stepping and resetting
63
+ of environments.
64
+
65
+ .. note:: This class and its subclasses should work when nested in with :class:`~torchrl.envs.TransformedEnv` and
66
+ batched environments, but users won't currently be able to use the async features of the base environment when
67
+ it's nested in these classes. One should prefer nested transformed envs within an `AsyncEnvPool` instead.
68
+ If this is not possible, please raise an issue.
69
+
70
+ Args:
71
+ env_makers (Callable[[], EnvBase] | EnvBase | list[EnvBase] | list[Callable[[], EnvBase]]):
72
+ A callable or list of callables that create environment instances, or
73
+ environment instances themselves.
74
+ backend (Literal["threading", "multiprocessing", "asyncio"], optional):
75
+ The backend to use for parallel execution. Defaults to `"threading"`.
76
+ stack (Literal["dense", "maybe_dense", "lazy"], optional):
77
+ The method to use for stacking environment outputs. Defaults to `"dense"`.
78
+ create_env_kwargs (dict, optional):
79
+ Keyword arguments to pass to the environment maker. Defaults to `{}`.
80
+
81
+ Attributes:
82
+ min_get (int): Minimum number of environments to process in a batch.
83
+ env_makers (list): List of environment makers or environments.
84
+ num_envs (int): Number of environments in the pool.
85
+ backend (str): Backend used for parallel execution.
86
+ stack (str): Method used for stacking environment outputs.
87
+
88
+ Examples:
89
+ >>> from functools import partial
90
+ >>> from torchrl.envs import AsyncEnvPool, GymEnv
91
+ >>> import torch
92
+ >>> # Choose backend
93
+ >>> backend = "threading"
94
+ >>> env = AsyncEnvPool([partial(GymEnv, "Pendulum-v1"), partial(GymEnv, "CartPole-v1")], stack="lazy", backend=backend)
95
+ >>> assert env.batch_size == (2,)
96
+ >>> # Execute a sync reset
97
+ >>> reset = env.reset()
98
+ >>> print(reset)
99
+ LazyStackedTensorDict(
100
+ fields={
101
+ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
102
+ env_index: NonTensorStack(
103
+ [0, 1],
104
+ batch_size=torch.Size([2]),
105
+ device=None),
106
+ observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
107
+ terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
108
+ truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
109
+ exclusive_fields={
110
+ },
111
+ batch_size=torch.Size([2]),
112
+ device=None,
113
+ is_shared=False,
114
+ stack_dim=0)
115
+ >>> # Execute a sync step
116
+ >>> s = env.rand_step(reset)
117
+ >>> print(s)
118
+ LazyStackedTensorDict(
119
+ fields={
120
+ action: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
121
+ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
122
+ env_index: NonTensorStack(
123
+ [0, 1],
124
+ batch_size=torch.Size([2]),
125
+ device=None),
126
+ next: LazyStackedTensorDict(
127
+ fields={
128
+ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
129
+ observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
130
+ reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
131
+ terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
132
+ truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
133
+ exclusive_fields={
134
+ },
135
+ batch_size=torch.Size([2]),
136
+ device=None,
137
+ is_shared=False,
138
+ stack_dim=0),
139
+ observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
140
+ terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
141
+ truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
142
+ exclusive_fields={
143
+ },
144
+ batch_size=torch.Size([2]),
145
+ device=None,
146
+ is_shared=False,
147
+ stack_dim=0)
148
+ >>> s = env.step_mdp(s)
149
+ >>> # Execute an asynchronous step in env 0
150
+ >>> s0 = s[0]
151
+ >>> s0["action"] = torch.randn(1).clamp(-1, 1)
152
+ >>> # We must tell the env which data this is from
153
+ >>> s0["env_index"] = 0
154
+ >>> env.async_step_send(s0)
155
+ >>> # Receive data
156
+ >>> s0_result = env.async_step_recv()
157
+ >>> print('result', s0_result)
158
+ result LazyStackedTensorDict(
159
+ fields={
160
+ action: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
161
+ done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
162
+ env_index: NonTensorStack(
163
+ [0],
164
+ batch_size=torch.Size([1]),
165
+ device=None),
166
+ next: LazyStackedTensorDict(
167
+ fields={
168
+ done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
169
+ observation: Tensor(shape=torch.Size([1, 3]), device=cpu, dtype=torch.float32, is_shared=False),
170
+ reward: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
171
+ terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
172
+ truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
173
+ exclusive_fields={
174
+ },
175
+ batch_size=torch.Size([1]),
176
+ device=None,
177
+ is_shared=False,
178
+ stack_dim=0),
179
+ observation: Tensor(shape=torch.Size([1, 3]), device=cpu, dtype=torch.float32, is_shared=False),
180
+ terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
181
+ truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
182
+ exclusive_fields={
183
+ },
184
+ batch_size=torch.Size([1]),
185
+ device=None,
186
+ is_shared=False,
187
+ stack_dim=0)
188
+ >>> # Close env
189
+ >>> env.close()
190
+
191
+
192
+ """
193
+
194
+ _env_idx_key = "env_index"
195
+
196
+ def __init__(
197
+ self,
198
+ env_makers: Callable[[], EnvBase]
199
+ | EnvBase
200
+ | list[EnvBase]
201
+ | list[Callable[[], EnvBase]],
202
+ *,
203
+ backend: Literal["threading", "multiprocessing", "asyncio"] = "threading",
204
+ stack: Literal["dense", "maybe_dense", "lazy"] = "dense",
205
+ create_env_kwargs: dict | list[dict] | None = None,
206
+ ) -> None:
207
+ if not isinstance(env_makers, Sequence):
208
+ env_makers = [env_makers]
209
+ self.min_get = 1
210
+ self.env_makers = env_makers
211
+ self.num_envs = len(env_makers)
212
+ self.backend = backend
213
+ if create_env_kwargs is None:
214
+ create_env_kwargs = {}
215
+ if isinstance(create_env_kwargs, Mapping):
216
+ create_env_kwargs = [create_env_kwargs] * self.num_envs
217
+ if len(create_env_kwargs) != self.num_envs:
218
+ raise ValueError(
219
+ f"create_env_kwargs must be a dict or a list of dicts with length {self.num_envs}"
220
+ )
221
+ self.create_env_kwargs = create_env_kwargs
222
+
223
+ self.stack = stack
224
+ if stack == "dense":
225
+ self._stack_func = torch.stack
226
+ elif stack == "maybe_dense":
227
+ self._stack_func = maybe_dense_stack
228
+ elif stack == "lazy":
229
+ self._stack_func = lazy_stack
230
+ else:
231
+ raise NotImplementedError
232
+
233
+ output_spec, input_spec = self._setup()
234
+ input_spec["full_state_spec"].set(
235
+ self._env_idx_key, NonTensor(example_data=0, shape=input_spec.shape)
236
+ )
237
+ self.__dict__["_output_spec"] = output_spec
238
+ self.__dict__["_input_spec"] = input_spec
239
+ # Use spec shape as batch_size since it correctly includes both pool dimension
240
+ # and child env batch dimensions (e.g., (4, 1) for 4 envs with batch_size=(1,))
241
+ super().__init__(batch_size=input_spec.shape)
242
+ self._busy = set()
243
+
244
+ @property
245
+ def env_batch_sizes(self) -> list[torch.Size]:
246
+ """Returns the batch-sizes of every env."""
247
+ raise NotImplementedError
248
+
249
+ @abc.abstractmethod
250
+ def _get_child_specs(self) -> list:
251
+ """Returns the list of child env specs for stacking.
252
+
253
+ For ThreadingAsyncEnvPool, returns [env.full_*_spec for env in self.envs].
254
+ For ProcessorAsyncEnvPool, returns cached specs from setup.
255
+ """
256
+ raise NotImplementedError
257
+
258
+ # Override spec properties to properly stack child env specs.
259
+ # This bypasses the problematic StackedComposite.get() behavior that loses
260
+ # nested keys like full_action_spec when cloning stacked specs.
261
+
262
+ @property
263
+ def full_action_spec(self):
264
+ child_specs = self._get_child_specs()
265
+ return torch.stack(
266
+ [s["input_spec"]["full_action_spec"] for s in child_specs], dim=0
267
+ )
268
+
269
+ @property
270
+ def full_observation_spec(self):
271
+ child_specs = self._get_child_specs()
272
+ return torch.stack(
273
+ [s["output_spec"]["full_observation_spec"] for s in child_specs], dim=0
274
+ )
275
+
276
+ @property
277
+ def full_reward_spec(self):
278
+ child_specs = self._get_child_specs()
279
+ return torch.stack(
280
+ [s["output_spec"]["full_reward_spec"] for s in child_specs], dim=0
281
+ )
282
+
283
+ @property
284
+ def full_done_spec(self):
285
+ child_specs = self._get_child_specs()
286
+ return torch.stack(
287
+ [s["output_spec"]["full_done_spec"] for s in child_specs], dim=0
288
+ )
289
+
290
+ @property
291
+ def full_state_spec(self):
292
+ child_specs = self._get_child_specs()
293
+ specs = torch.stack(
294
+ [s["input_spec"]["full_state_spec"] for s in child_specs], dim=0
295
+ )
296
+ # Add env_index key for async tracking
297
+ specs.set(self._env_idx_key, NonTensor(example_data=0, shape=specs.shape))
298
+ return specs
299
+
300
+ # TODO: _make_single_env_spec (used by *_unbatched properties) takes spec[0],
301
+ # which assumes all child envs have identical specs. Should add validation
302
+ # that child specs match, and error if they differ.
303
+
304
+ def _reset(
305
+ self,
306
+ tensordict: TensorDictBase | None = None,
307
+ **kwargs,
308
+ ) -> TensorDictBase:
309
+ if self._current_step > 0:
310
+ raise RuntimeError("Some envs are still processing a step.")
311
+ if tensordict is None:
312
+ if self._stack_func in ("lazy_stack", "maybe_dense"):
313
+ tensordict = LazyStackedTensorDict(
314
+ *[
315
+ TensorDict(batch_size=self.env_batch_sizes[i])
316
+ for i in range(self.num_envs)
317
+ ]
318
+ )
319
+ else:
320
+ tensordict = TensorDict(
321
+ batch_size=(self.num_envs,) + self.env_batch_sizes[0]
322
+ )
323
+ env_idx_nt = NonTensorStack(*range(tensordict.shape[0]))
324
+ while env_idx_nt.batch_dims < tensordict.batch_dims:
325
+ env_idx_nt = expand_as_right(env_idx_nt, tensordict)
326
+ tensordict[self._env_idx_key] = env_idx_nt
327
+ self._async_private_reset_send(tensordict)
328
+ tensordict = self._async_private_reset_recv(min_get=self.num_envs)
329
+ return tensordict
330
+
331
+ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
332
+ if self._current_step > 0:
333
+ raise RuntimeError("Some envs are still processing a step.")
334
+ tensordict.set(self._env_idx_key, torch.arange(tensordict.shape[0]))
335
+ self._async_private_step_send(tensordict)
336
+ tensordict = self._async_private_step_recv(min_get=self.num_envs)
337
+ # Using pop instead of del to account for tensorclasses
338
+ tensordict.pop(self._env_idx_key)
339
+ return tensordict
340
+
341
+ def step_and_maybe_reset(
342
+ self, tensordict: TensorDictBase
343
+ ) -> tuple[TensorDictBase, TensorDictBase]:
344
+ if self._current_step_reset > 0:
345
+ raise RuntimeError("Some envs are still processing a step.")
346
+ tensordict.set(self._env_idx_key, torch.arange(tensordict.shape[0]))
347
+ self.async_step_and_maybe_reset_send(tensordict)
348
+ tensordict, tensordict_ = self.async_step_and_maybe_reset_recv(
349
+ min_get=self.num_envs
350
+ )
351
+ return tensordict, tensordict_
352
+
353
+ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
354
+ if self._current_step > 0:
355
+ raise RuntimeError("Some envs are still processing a step.")
356
+ tensordict.set(self._env_idx_key, torch.arange(tensordict.shape[0]))
357
+ self.async_step_send(tensordict)
358
+ tensordict = self.async_step_recv(min_get=self.num_envs)
359
+ return tensordict
360
+
361
+ def reset(
362
+ self,
363
+ tensordict: TensorDictBase | None = None,
364
+ **kwargs,
365
+ ) -> TensorDictBase:
366
+ if self._current_step > 0:
367
+ raise RuntimeError("Some envs are still processing a step.")
368
+ if tensordict is None:
369
+ if self._stack_func in ("lazy_stack", "maybe_dense"):
370
+ tensordict = LazyStackedTensorDict(
371
+ *[
372
+ TensorDict(batch_size=self.env_batch_sizes[i])
373
+ for i in range(self.num_envs)
374
+ ]
375
+ )
376
+ else:
377
+ tensordict = TensorDict(
378
+ batch_size=(self.num_envs,) + self.env_batch_sizes[0]
379
+ )
380
+ indices = NonTensorStack(*range(tensordict.shape[0]))
381
+ if indices.shape != tensordict.shape:
382
+ indices = expand_as_right(indices, tensordict)
383
+ tensordict[self._env_idx_key] = indices
384
+ self.async_reset_send(tensordict)
385
+ tensordict = self.async_reset_recv(min_get=self.num_envs)
386
+ return tensordict
387
+
388
+ def _sort_results(self, results, *other_results):
389
+ # Extract env indices from results. When child envs have a batch dimension
390
+ # (e.g., batch_size=(1,)), r[self._env_idx_key] may be a 1D sequence
391
+ # instead of a scalar, so we need to handle both cases.
392
+ idx = []
393
+ for r in results:
394
+ env_idx = r[self._env_idx_key]
395
+ # Handle sequence types (NonTensorStack, etc.) by taking first element
396
+ while hasattr(env_idx, "__len__") and not isinstance(env_idx, (str, bytes)):
397
+ if len(env_idx) == 1:
398
+ env_idx = env_idx[0]
399
+ else:
400
+ break
401
+ idx.append(int(env_idx))
402
+ argsort = torch.argsort(torch.tensor(idx)).tolist()
403
+ results = [results[i] for i in argsort]
404
+ if other_results:
405
+ other_results = [
406
+ [other_results[i] for i in argsort] for other_results in other_results
407
+ ]
408
+ return results, *other_results, idx
409
+ return results, idx
410
+
411
+ def _set_seed(self, seed: int | None) -> None:
412
+ raise NotImplementedError
413
+
414
+ @abc.abstractmethod
415
+ def _setup(self) -> None:
416
+ raise NotImplementedError
417
+
418
+ def _maybe_make_tensordict(self, tensordict, env_index, make_if_none):
419
+ if env_index is None:
420
+ env_idx = tensordict.view(-1)[self._env_idx_key]
421
+ if isinstance(env_idx, torch.Tensor):
422
+ env_idx = env_idx.tolist()
423
+ if isinstance(env_idx, int):
424
+ # If we squeezed a td with shape (1,) and got a NonTensorStack -> NonTensorData, then
425
+ # unsqueezed the NonTensorData, we'd still have a NonTensorData with shape (1,)
426
+ # This will give us an integer now, but we don't want to unsqueeze the full td because then
427
+ # we'd have a td with shape (1, 1)
428
+ if tensordict.shape != (1, *self.env_batch_sizes[env_idx]):
429
+ tensordict = tensordict.unsqueeze(0)
430
+ env_idx = [env_idx]
431
+ elif isinstance(env_index, int):
432
+ if make_if_none:
433
+ if tensordict is None:
434
+ tensordict = TensorDict(
435
+ batch_size=self.env_batch_sizes[env_index], device=self.device
436
+ )
437
+ if self.stack in ("lazy_stack", "maybe_dense"):
438
+ tensordict = tensordict.unsqueeze(0)
439
+ else:
440
+ tensordict = lazy_stack([tensordict])
441
+ tensordict[self._env_idx_key] = NonTensorStack(env_index)
442
+ env_idx = [env_index]
443
+ else:
444
+ if make_if_none and tensordict is None:
445
+ if self.stack in ("lazy_stack", "maybe_dense"):
446
+ tensordict = lazy_stack(
447
+ [TensorDict(device=self.device) for _ in env_index]
448
+ )
449
+ else:
450
+ tensordict = TensorDict(
451
+ batch_size=(len(env_index),), device=self.device
452
+ )
453
+ tensordict[self._env_idx_key] = NonTensorStack(*env_index)
454
+ env_idx = env_index
455
+ return tensordict, env_idx
456
+
457
+ @abc.abstractmethod
458
+ def async_step_send(
459
+ self, tensordict: TensorDictBase, env_index: int | list[int] | None = None
460
+ ) -> None:
461
+ raise NotImplementedError
462
+
463
+ @abc.abstractmethod
464
+ def async_step_recv(self, min_get: int | None = None) -> TensorDictBase:
465
+ raise NotImplementedError
466
+
467
+ @abc.abstractmethod
468
+ def async_step_and_maybe_reset_send(
469
+ self, tensordict: TensorDictBase, env_index: int | list[int] | None = None
470
+ ) -> None:
471
+ raise NotImplementedError
472
+
473
+ @abc.abstractmethod
474
+ def async_step_and_maybe_reset_recv(
475
+ self,
476
+ min_get: int | None = None,
477
+ env_index: int | list[int] | None = None,
478
+ ) -> tuple[TensorDictBase, TensorDictBase]:
479
+ raise NotImplementedError
480
+
481
+ @abc.abstractmethod
482
+ def async_reset_send(
483
+ self,
484
+ tensordict: TensorDictBase | None = None,
485
+ env_index: int | list[int] | None = None,
486
+ ) -> None:
487
+ raise NotImplementedError
488
+
489
+ @abc.abstractmethod
490
+ def async_reset_recv(self, min_get: int | None = None) -> TensorDictBase:
491
+ raise NotImplementedError
492
+
493
+ def __del__(self):
494
+ self._maybe_shutdown()
495
+
496
+ def _maybe_shutdown(self):
497
+ try:
498
+ self.shutdown()
499
+ except Exception:
500
+ pass
501
+
502
+ @abc.abstractmethod
503
+ def shutdown(self):
504
+ raise NotImplementedError
505
+
506
+ def close(self, *, raise_if_closed: bool = True):
507
+ if raise_if_closed:
508
+ self.shutdown()
509
+ else:
510
+ self._maybe_shutdown()
511
+
512
+
513
+ class ProcessorAsyncEnvPool(AsyncEnvPool):
514
+ """An implementation of `AsyncEnvPool` using multiprocessing for parallel execution of environments.
515
+
516
+ This class manages a pool of environments, each running in its own process, and
517
+ provides methods for asynchronous stepping and resetting of environments using
518
+ inter-process communication.
519
+
520
+ .. note:: This class and its subclasses should work when nested in with :class:`~torchrl.envs.TransformedEnv` and
521
+ batched environments, but users won't currently be able to use the async features of the base environment when
522
+ it's nested in these classes. One should prefer nested transformed envs within an `AsyncEnvPool` instead.
523
+ If this is not possible, please raise an issue.
524
+
525
+ Methods:
526
+ _setup(): Initializes the multiprocessing queues and processes for each
527
+ environment.
528
+ async_step_send(tensordict): Sends a step command to the environments.
529
+ async_step_recv(min_get): Receives the results of the step command.
530
+ async_reset_send(tensordict): Sends a reset command to the environments.
531
+ async_reset_recv(min_get): Receives the results of the reset command.
532
+ shutdown(): Shuts down all environment processes.
533
+ """
534
+
535
+ def _setup(self) -> None:
536
+ self.step_queue = Queue(maxsize=self.num_envs)
537
+ self.reset_queue = Queue(maxsize=self.num_envs)
538
+ self.step_reset_queue = Queue(maxsize=self.num_envs)
539
+ self.input_queue = [Queue(maxsize=1) for _ in range(self.num_envs)]
540
+ self.output_queue = [Queue(maxsize=1) for _ in range(self.num_envs)]
541
+ self._current_reset = 0
542
+ self._current_step = 0
543
+ self._current_step_reset = 0
544
+
545
+ num_threads = self.num_envs
546
+ self.threads = []
547
+ for i in range(num_threads):
548
+ # thread = threading.Thread(target=_env_exec, kwargs={"i": i, "env_or_factory": self.env_maker[i], "input_queue": self.input_queue[i], "step_queue": self.step_queue, "reset_queue": self.reset_queue})
549
+ thread = multiprocessing.Process(
550
+ target=self._env_exec,
551
+ kwargs={
552
+ "i": i,
553
+ "env_or_factory": self.env_makers[i],
554
+ "create_env_kwargs": self.create_env_kwargs[i],
555
+ "input_queue": self.input_queue[i],
556
+ "output_queue": self.output_queue[i],
557
+ "step_reset_queue": self.step_reset_queue,
558
+ "step_queue": self.step_queue,
559
+ "reset_queue": self.reset_queue,
560
+ },
561
+ )
562
+ self.threads.append(thread)
563
+ thread.start()
564
+ # Get specs from each worker and cache them for _get_child_specs()
565
+ for i in range(num_threads):
566
+ self.input_queue[i].put(("get_specs", None))
567
+ self._child_specs = []
568
+ for i in range(num_threads):
569
+ self._child_specs.append(self.output_queue[i].get())
570
+ specs = torch.stack(list(self._child_specs))
571
+ output_spec = specs["output_spec"]
572
+ input_spec = specs["input_spec"]
573
+ return output_spec, input_spec
574
+
575
+ def _get_child_specs(self) -> list:
576
+ """Returns the cached specs from each child environment process."""
577
+ return self._child_specs
578
+
579
+ @property
580
+ def env_batch_sizes(self) -> list[torch.Size]:
581
+ batch_sizes = getattr(self, "_env_batch_sizes", [])
582
+ if not batch_sizes:
583
+ for _env_idx in range(self.num_envs):
584
+ self.input_queue[_env_idx].put(("batch_size", None))
585
+ batch_sizes.append(self.output_queue[_env_idx].get())
586
+ self._env_batch_sizes = batch_sizes
587
+ return batch_sizes
588
+
589
+ def async_step_send(
590
+ self, tensordict: TensorDictBase, env_index: int | list[int] | None = None
591
+ ) -> None:
592
+ # puts tds in a queue and ask for env.step
593
+ tensordict, env_idx = self._maybe_make_tensordict(tensordict, env_index, False)
594
+
595
+ if self._busy.intersection(env_idx):
596
+ raise RuntimeError(
597
+ f"Some envs are still processing a step: envs that are busy: {self._busy}, queried: {env_idx}."
598
+ )
599
+ self._busy.update(env_idx)
600
+
601
+ local_tds = tensordict.unbind(0)
602
+ for _env_idx, local_td in _zip_strict(env_idx, local_tds):
603
+ self.input_queue[_env_idx].put(("step", local_td))
604
+ self._current_step = self._current_step + len(env_idx)
605
+
606
+ def async_step_recv(self, min_get: int = 1) -> TensorDictBase:
607
+ # gets step results from the queue
608
+ if min_get is None:
609
+ min_get = self.min_get
610
+ if min_get > self._current_step:
611
+ raise RuntimeError(
612
+ f"Cannot await {min_get} step when only {self._current_step} are being stepped."
613
+ )
614
+ r = self._wait_for_one_and_get(self.step_queue, min_get)
615
+ self._current_step = self._current_step - len(r)
616
+ r, idx = self._sort_results(r)
617
+ self._busy.difference_update(idx)
618
+ return self._stack_func(r)
619
+
620
+ def _async_private_step_send(
621
+ self, tensordict: TensorDictBase, env_index: int | list[int] | None = None
622
+ ) -> None:
623
+ # puts tds in a queue and ask for env.step
624
+ tensordict, env_idx = self._maybe_make_tensordict(tensordict, env_index, False)
625
+
626
+ if self._busy.intersection(env_idx):
627
+ raise RuntimeError(
628
+ f"Some envs are still processing a step: envs that are busy: {self._busy}, queried: {env_idx}."
629
+ )
630
+ self._busy.update(env_idx)
631
+
632
+ local_tds = tensordict.unbind(0)
633
+ for _env_idx, local_td in _zip_strict(env_idx, local_tds):
634
+ self.input_queue[_env_idx].put(("_step", local_td))
635
+ self._current_step = self._current_step + len(env_idx)
636
+
637
+ _async_private_step_recv = async_step_recv
638
+
639
+ def async_step_and_maybe_reset_send(
640
+ self, tensordict: TensorDictBase, env_index: int | list[int] | None = None
641
+ ) -> None:
642
+ # puts tds in a queue and ask for env.step
643
+ tensordict, env_idx = self._maybe_make_tensordict(tensordict, env_index, False)
644
+
645
+ if self._busy.intersection(env_idx):
646
+ raise RuntimeError(
647
+ f"Some envs are still processing a step: envs that are busy: {self._busy}, queried: {env_idx}."
648
+ )
649
+ self._busy.update(env_idx)
650
+ local_tds = tensordict.unbind(0)
651
+ for _env_idx, local_td in _zip_strict(env_idx, local_tds):
652
+ self._current_step_reset = self._current_step_reset + 1
653
+ self.input_queue[_env_idx].put(("step_and_maybe_reset", local_td))
654
+
655
+ def async_step_and_maybe_reset_recv(self, min_get: int = 1) -> TensorDictBase:
656
+ # gets step results from the queue
657
+ if min_get is None:
658
+ min_get = self.min_get
659
+ if min_get > self._current_step_reset:
660
+ raise RuntimeError(
661
+ f"Cannot await {min_get} step_and_maybe_reset when only {self._current_step_reset} are being stepped."
662
+ )
663
+ r = self._wait_for_one_and_get(self.step_reset_queue, min_get)
664
+ self._current_step_reset = self._current_step_reset - len(r)
665
+ r, r_ = zip(*r)
666
+ r, r_, idx = self._sort_results(r, r_)
667
+ self._busy.difference_update(idx)
668
+ return self._stack_func(r), self._stack_func(r_)
669
+
670
+ def async_reset_send(
671
+ self,
672
+ tensordict: TensorDictBase | None = None,
673
+ env_index: int | list[int] | None = None,
674
+ ) -> None:
675
+ # puts tds in a queue and ask for env.reset
676
+ tensordict, env_idx = self._maybe_make_tensordict(tensordict, env_index, True)
677
+
678
+ if self._busy.intersection(env_idx):
679
+ raise RuntimeError(
680
+ f"Some envs are still processing a step: envs that are busy: {self._busy}, queried: {env_idx}."
681
+ )
682
+ self._busy.update(env_idx)
683
+ local_tds = tensordict.unbind(0)
684
+ for _env_idx, local_td in _zip_strict(env_idx, local_tds):
685
+ self._current_reset = self._current_reset + 1
686
+ self.input_queue[_env_idx].put(("reset", local_td))
687
+
688
+ def async_reset_recv(self, min_get: int | None = None) -> TensorDictBase:
689
+ # gets reset results from the queue
690
+ if min_get is None:
691
+ min_get = self.min_get
692
+ if min_get > self._current_reset:
693
+ raise RuntimeError(
694
+ f"Cannot await {min_get} reset when only {self._current_reset} are being reset."
695
+ )
696
+ r = self._wait_for_one_and_get(self.reset_queue, min_get)
697
+ self._current_reset = self._current_reset - len(r)
698
+ r, idx = self._sort_results(r)
699
+ self._busy.difference_update(idx)
700
+ return self._stack_func(r)
701
+
702
+ def _async_private_reset_send(
703
+ self,
704
+ tensordict: TensorDictBase | None = None,
705
+ env_index: int | list[int] | None = None,
706
+ ) -> None:
707
+ # puts tds in a queue and ask for env.reset
708
+ tensordict, env_idx = self._maybe_make_tensordict(tensordict, env_index, True)
709
+
710
+ if self._busy.intersection(env_idx):
711
+ raise RuntimeError(
712
+ f"Some envs are still processing a step: envs that are busy: {self._busy}, queried: {env_idx}."
713
+ )
714
+ self._busy.update(env_idx)
715
+ local_tds = tensordict.unbind(0)
716
+ for _env_idx, local_td in _zip_strict(env_idx, local_tds):
717
+ self._current_reset = self._current_reset + 1
718
+ self.input_queue[_env_idx].put(("_reset", local_td))
719
+
720
+ _async_private_reset_recv = async_reset_recv
721
+
722
+ def _wait_for_one_and_get(self, q, min_get):
723
+ items = [q.get()]
724
+
725
+ try:
726
+ while True:
727
+ item = q.get_nowait()
728
+ items.append(item)
729
+ except Empty:
730
+ pass
731
+
732
+ # Retrieve all other available items
733
+ while len(items) < min_get:
734
+ item = q.get()
735
+ items.append(item)
736
+
737
+ return items
738
+
739
+ def shutdown(self):
740
+ for env_id in range(self.num_envs):
741
+ self.input_queue[env_id].put(("shutdown", None))
742
+
743
+ for thread in self.threads:
744
+ thread.join()
745
+
746
+ @classmethod
747
+ def _env_exec(
748
+ cls,
749
+ i,
750
+ env_or_factory,
751
+ create_env_kwargs,
752
+ input_queue,
753
+ output_queue,
754
+ step_queue,
755
+ step_reset_queue,
756
+ reset_queue,
757
+ ):
758
+ if not isinstance(env_or_factory, EnvBase):
759
+ env = env_or_factory(**create_env_kwargs)
760
+ else:
761
+ env = env_or_factory
762
+
763
+ while True:
764
+ msg_data = input_queue.get()
765
+ msg, data = msg_data
766
+ if msg == "get_specs":
767
+ output_queue.put(env.specs)
768
+ elif msg == "batch_size":
769
+ output_queue.put(env.batch_size)
770
+ elif msg == "reset":
771
+ data = env.reset(data.copy())
772
+ data.set(cls._env_idx_key, NonTensorData(i))
773
+ reset_queue.put(data)
774
+ elif msg == "_reset":
775
+ data = env._reset(data.copy())
776
+ data.set(cls._env_idx_key, NonTensorData(i))
777
+ reset_queue.put(data)
778
+ elif msg == "step_and_maybe_reset":
779
+ data, data_ = env.step_and_maybe_reset(data.copy())
780
+ data.set(cls._env_idx_key, NonTensorData(i))
781
+ data_.set(cls._env_idx_key, NonTensorData(i))
782
+ step_reset_queue.put((data, data_))
783
+ elif msg == "step":
784
+ data = env.step(data.copy())
785
+ data.set(cls._env_idx_key, NonTensorData(i))
786
+ step_queue.put(data)
787
+ elif msg == "_step":
788
+ data = env._step(data.copy())
789
+ data.set(cls._env_idx_key, NonTensorData(i))
790
+ step_queue.put(data)
791
+ elif msg == "shutdown":
792
+ env.close()
793
+ break
794
+ else:
795
+ raise RuntimeError(f"Unknown msg {msg} for worker {i}")
796
+ return
797
+
798
+
799
+ class ThreadingAsyncEnvPool(AsyncEnvPool):
800
+ """An implementation of `AsyncEnvPool` using threading for parallel execution of environments.
801
+
802
+ This class manages a pool of environments, each running in its own thread, and
803
+ provides methods for asynchronous stepping and resetting of environments using
804
+ a thread pool executor.
805
+
806
+ .. note:: This class and its subclasses should work when nested in with :class:`~torchrl.envs.TransformedEnv` and
807
+ batched environments, but users won't currently be able to use the async features of the base environment when
808
+ it's nested in these classes. One should prefer nested transformed envs within an `AsyncEnvPool` instead.
809
+ If this is not possible, please raise an issue.
810
+
811
+ Methods:
812
+ _setup(): Initializes the thread pool and environment instances.
813
+ async_step_send(tensordict): Sends a step command to the environments.
814
+ async_step_recv(min_get): Receives the results of the step command.
815
+ async_reset_send(tensordict): Sends a reset command to the environments.
816
+ async_reset_recv(min_get): Receives the results of the reset command.
817
+ shutdown(): Shuts down the thread pool.
818
+
819
+ """
820
+
821
+ def _setup(self) -> None:
822
+ self._pool = ThreadPoolExecutor(max_workers=self.num_envs)
823
+ self.envs = [
824
+ env_factory(**create_env_kwargs)
825
+ if not isinstance(env_factory, EnvBase)
826
+ else env_factory
827
+ for env_factory, create_env_kwargs in zip(
828
+ self.env_makers, self.create_env_kwargs
829
+ )
830
+ ]
831
+ self._reset_futures = []
832
+ self._private_reset_futures = []
833
+ self._step_futures = []
834
+ self._private_step_futures = []
835
+ self._step_and_maybe_reset_futures = []
836
+ self._current_step = 0
837
+ self._current_step_reset = 0
838
+ self._current_reset = 0
839
+
840
+ # get specs
841
+ specs = torch.stack([env.specs for env in self.envs])
842
+ return specs["output_spec"].clone(), specs["input_spec"].clone()
843
+
844
+ @property
845
+ def env_batch_sizes(self) -> list[torch.Size]:
846
+ return [env.batch_size for env in self.envs]
847
+
848
+ def _get_child_specs(self) -> list:
849
+ """Returns the specs from each child environment."""
850
+ return [env.specs for env in self.envs]
851
+
852
+ @classmethod
853
+ def _get_specs(cls, env: EnvBase):
854
+ return env.specs
855
+
856
+ @classmethod
857
+ def _step_func(cls, env_td: tuple[EnvBase, TensorDictBase, int]):
858
+ env, td, idx = env_td
859
+ return env.step(td).set(cls._env_idx_key, NonTensorData(idx))
860
+
861
+ @classmethod
862
+ def _private_step_func(cls, env_td: tuple[EnvBase, TensorDictBase, int]):
863
+ env, td, idx = env_td
864
+ return env._step(td).set(cls._env_idx_key, NonTensorData(idx))
865
+
866
+ @classmethod
867
+ def _reset_func(cls, env_td: tuple[EnvBase, TensorDictBase]):
868
+ env, td, idx = env_td
869
+ return env.reset(td).set(cls._env_idx_key, NonTensorData(idx))
870
+
871
+ @classmethod
872
+ def _private_reset_func(cls, env_td: tuple[EnvBase, TensorDictBase]):
873
+ env, td, idx = env_td
874
+ return env._reset(td).set(cls._env_idx_key, NonTensorData(idx))
875
+
876
+ @classmethod
877
+ def _step_and_maybe_reset_func(cls, env_td: tuple[EnvBase, TensorDictBase]):
878
+ env, td, idx = env_td
879
+ td, td_ = env.step_and_maybe_reset(td)
880
+ idx = NonTensorData(idx)
881
+ return td.set(cls._env_idx_key, idx), td_.set(cls._env_idx_key, idx)
882
+
883
+ def async_step_send(
884
+ self, tensordict: TensorDictBase, env_index: int | list[int] | None = None
885
+ ) -> None:
886
+ tensordict, env_idx = self._maybe_make_tensordict(tensordict, env_index, False)
887
+ if self._busy.intersection(env_idx):
888
+ raise RuntimeError(
889
+ f"Some envs are still processing a step: envs that are busy: {self._busy}, queried: {env_idx}."
890
+ )
891
+ self._busy.update(env_idx)
892
+
893
+ tds = tensordict.unbind(0)
894
+ envs = [self.envs[idx] for idx in env_idx]
895
+ futures = [
896
+ self._pool.submit(self._step_func, (env, td, idx))
897
+ for env, td, idx in zip(envs, tds, env_idx)
898
+ ]
899
+ self._step_futures.extend(futures)
900
+ self._current_step = self._current_step + len(futures)
901
+
902
+ def async_step_recv(self, min_get: int | None = None) -> TensorDictBase:
903
+ if min_get is None:
904
+ min_get = self.min_get
905
+ if min_get > self._current_step:
906
+ raise RuntimeError(
907
+ f"Cannot await {min_get} step when only {self._current_step_reset} are being stepped."
908
+ )
909
+ results = []
910
+ futures = self._step_futures
911
+ completed_futures = []
912
+ for future in as_completed(futures):
913
+ results.append(future.result())
914
+ completed_futures.append(future)
915
+ self._current_step = self._current_step - 1
916
+ if len(results) >= min_get and sum([f.done() for f in futures]) == 0:
917
+ break
918
+ self._step_futures = [
919
+ f for f in self._step_futures if f not in completed_futures
920
+ ]
921
+ results, idx = self._sort_results(results)
922
+ self._busy.difference_update(idx)
923
+ return self._stack_func(results)
924
+
925
+ def _async_private_step_send(
926
+ self, tensordict: TensorDictBase, env_index: int | list[int] | None = None
927
+ ) -> None:
928
+ tensordict, env_idx = self._maybe_make_tensordict(tensordict, env_index, False)
929
+
930
+ if self._busy.intersection(env_idx):
931
+ raise RuntimeError(
932
+ f"Some envs are still processing a step: envs that are busy: {self._busy}, queried: {env_idx}."
933
+ )
934
+ self._busy.update(env_idx)
935
+
936
+ tds = tensordict.unbind(0)
937
+ envs = [self.envs[idx] for idx in env_idx]
938
+ futures = [
939
+ self._pool.submit(self._private_step_func, (env, td, idx))
940
+ for env, td, idx in zip(envs, tds, env_idx)
941
+ ]
942
+ self._private_step_futures.extend(futures)
943
+ self._current_step = self._current_step + len(futures)
944
+
945
+ def _async_private_step_recv(self, min_get: int | None = None) -> TensorDictBase:
946
+ if min_get is None:
947
+ min_get = self.min_get
948
+ if min_get > self._current_step:
949
+ raise RuntimeError(
950
+ f"Cannot await {min_get} step when only {self._current_step_reset} are being stepped."
951
+ )
952
+ results = []
953
+ futures = self._private_step_futures
954
+ completed_futures = []
955
+ for future in as_completed(futures):
956
+ results.append(future.result())
957
+ completed_futures.append(future)
958
+ self._current_step = self._current_step - 1
959
+ if len(results) >= min_get and sum([f.done() for f in futures]) == 0:
960
+ break
961
+ self._private_step_futures = [
962
+ f for f in self._private_step_futures if f not in completed_futures
963
+ ]
964
+ results, idx = self._sort_results(results)
965
+ self._busy.difference_update(idx)
966
+ return self._stack_func(results)
967
+
968
+ def async_step_and_maybe_reset_send(
969
+ self, tensordict: TensorDictBase, env_index: int | list[int] | None = None
970
+ ) -> None:
971
+ tensordict, env_idx = self._maybe_make_tensordict(tensordict, env_index, False)
972
+
973
+ if self._busy.intersection(env_idx):
974
+ raise RuntimeError(
975
+ f"Some envs are still processing a step: envs that are busy: {self._busy}, queried: {env_idx}."
976
+ )
977
+ self._busy.update(env_idx)
978
+
979
+ tds = tensordict.unbind(0)
980
+ envs = [self.envs[idx] for idx in env_idx]
981
+ futures = [
982
+ self._pool.submit(self._step_and_maybe_reset_func, (env, td, idx))
983
+ for env, td, idx in zip(envs, tds, env_idx)
984
+ ]
985
+ self._step_and_maybe_reset_futures.extend(futures)
986
+ self._current_step_reset = self._current_step_reset + len(futures)
987
+
988
+ def async_step_and_maybe_reset_recv(
989
+ self, min_get: int | None = None
990
+ ) -> TensorDictBase:
991
+ if min_get is None:
992
+ min_get = self.min_get
993
+ if min_get > self._current_step_reset:
994
+ raise RuntimeError(
995
+ f"Cannot await {min_get} step_and_maybe_reset when only {self._current_step_reset} are being stepped."
996
+ )
997
+ results = []
998
+ futures = self._step_and_maybe_reset_futures
999
+ completed_futures = []
1000
+ for future in as_completed(futures):
1001
+ results.append(future.result())
1002
+ completed_futures.append(future)
1003
+ self._current_step_reset = self._current_step_reset - 1
1004
+ if len(results) >= min_get and sum([f.done() for f in futures]) == 0:
1005
+ break
1006
+ self._step_and_maybe_reset_futures = [
1007
+ f for f in self._step_and_maybe_reset_futures if f not in completed_futures
1008
+ ]
1009
+ results, results_ = zip(*results)
1010
+ results, results_, idx = self._sort_results(results, results_)
1011
+ self._busy.difference_update(idx)
1012
+ return self._stack_func(results), self._stack_func(results_)
1013
+
1014
+ def async_reset_send(
1015
+ self,
1016
+ tensordict: TensorDictBase | None = None,
1017
+ env_index: int | list[int] | None = None,
1018
+ ) -> None:
1019
+ tensordict, env_idx = self._maybe_make_tensordict(tensordict, env_index, True)
1020
+
1021
+ if self._busy.intersection(env_idx):
1022
+ raise RuntimeError(
1023
+ f"Some envs are still processing a step: envs that are busy: {self._busy}, queried: {env_idx}."
1024
+ )
1025
+ self._busy.update(env_idx)
1026
+
1027
+ tds = tensordict.unbind(0)
1028
+ envs = [self.envs[idx] for idx in env_idx]
1029
+ futures = [
1030
+ self._pool.submit(self._reset_func, (env, td, idx))
1031
+ for env, td, idx in zip(envs, tds, env_idx)
1032
+ ]
1033
+ self._current_reset = self._current_reset + len(futures)
1034
+ self._reset_futures.extend(futures)
1035
+
1036
+ def async_reset_recv(self, min_get: int | None = None) -> TensorDictBase:
1037
+ if min_get is None:
1038
+ min_get = self.min_get
1039
+ if min_get > self._current_reset:
1040
+ raise RuntimeError(
1041
+ f"Cannot await {min_get} reset when only {self._current_step_reset} are being reset."
1042
+ )
1043
+ results = []
1044
+ futures = self._reset_futures
1045
+ completed_futures = []
1046
+ for future in as_completed(futures):
1047
+ results.append(future.result())
1048
+ completed_futures.append(future)
1049
+ self._current_reset = self._current_reset - 1
1050
+ if len(results) >= min_get and sum([f.done() for f in futures]) == 0:
1051
+ break
1052
+ self._reset_futures = [
1053
+ f for f in self._reset_futures if f not in completed_futures
1054
+ ]
1055
+ results, idx = self._sort_results(results)
1056
+ self._busy.difference_update(idx)
1057
+ return self._stack_func(results)
1058
+
1059
+ def _async_private_reset_send(
1060
+ self,
1061
+ tensordict: TensorDictBase | None = None,
1062
+ env_index: int | list[int] | None = None,
1063
+ ) -> None:
1064
+ tensordict, env_idx = self._maybe_make_tensordict(tensordict, env_index, True)
1065
+
1066
+ if self._busy.intersection(env_idx):
1067
+ raise RuntimeError(
1068
+ f"Some envs are still processing a step: envs that are busy: {self._busy}, queried: {env_idx}."
1069
+ )
1070
+ self._busy.update(env_idx)
1071
+
1072
+ tds = tensordict.unbind(0)
1073
+ envs = [self.envs[idx] for idx in env_idx]
1074
+ futures = [
1075
+ self._pool.submit(self._private_reset_func, (env, td, idx))
1076
+ for env, td, idx in zip(envs, tds, env_idx)
1077
+ ]
1078
+ self._current_reset = self._current_reset + len(futures)
1079
+ self._private_reset_futures.extend(futures)
1080
+
1081
+ def _async_private_reset_recv(self, min_get: int | None = None) -> TensorDictBase:
1082
+ if min_get is None:
1083
+ min_get = self.min_get
1084
+ if min_get > self._current_reset:
1085
+ raise RuntimeError(
1086
+ f"Cannot await {min_get} reset when only {self._current_step_reset} are being reset."
1087
+ )
1088
+ results = []
1089
+ futures = self._private_reset_futures
1090
+ completed_futures = []
1091
+ for future in as_completed(futures):
1092
+ results.append(future.result())
1093
+ completed_futures.append(future)
1094
+ self._current_reset = self._current_reset - 1
1095
+ if len(results) >= min_get and sum([f.done() for f in futures]) == 0:
1096
+ break
1097
+ self._private_reset_futures = [
1098
+ f for f in self._private_reset_futures if f not in completed_futures
1099
+ ]
1100
+ results, idx = self._sort_results(results)
1101
+ self._busy.difference_update(idx)
1102
+ return self._stack_func(results)
1103
+
1104
+ def shutdown(self):
1105
+ self._pool.shutdown()