torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,216 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass, field
9
+ from functools import partial
10
+ from typing import Any
11
+
12
+ from omegaconf import MISSING
13
+
14
+ from torchrl.trainers.algorithms.configs.common import ConfigBase
15
+ from torchrl.trainers.algorithms.configs.envs import EnvConfig
16
+
17
+
18
+ @dataclass
19
+ class BaseCollectorConfig(ConfigBase):
20
+ """Parent class to configure a data collector."""
21
+
22
+
23
+ @dataclass
24
+ class CollectorConfig(BaseCollectorConfig):
25
+ """A class to configure a synchronous data collector (Collector)."""
26
+
27
+ create_env_fn: ConfigBase = MISSING
28
+ policy: Any = None
29
+ policy_factory: Any = None
30
+ frames_per_batch: int | None = None
31
+ total_frames: int = -1
32
+ init_random_frames: int | None = 0
33
+ device: str | None = None
34
+ storing_device: str | None = None
35
+ policy_device: str | None = None
36
+ env_device: str | None = None
37
+ create_env_kwargs: dict | None = None
38
+ max_frames_per_traj: int | None = None
39
+ reset_at_each_iter: bool = False
40
+ postproc: Any = None
41
+ split_trajs: bool = False
42
+ exploration_type: str = "RANDOM"
43
+ return_same_td: bool = False
44
+ interruptor: Any = None
45
+ set_truncated: bool = False
46
+ use_buffers: bool = False
47
+ replay_buffer: Any = None
48
+ extend_buffer: bool = False
49
+ trust_policy: bool = True
50
+ compile_policy: Any = None
51
+ cudagraph_policy: Any = None
52
+ no_cuda_sync: bool = False
53
+ weight_updater: Any = None
54
+ weight_sync_schemes: Any = None
55
+ track_policy_version: bool = False
56
+ local_init_rb: bool = False
57
+ _target_: str = "torchrl.collectors.Collector"
58
+ _partial_: bool = False
59
+
60
+ def __post_init__(self):
61
+ self.create_env_fn._partial_ = True
62
+ if self.policy_factory is not None:
63
+ self.policy_factory._partial_ = True
64
+ if self.weight_updater is not None:
65
+ self.weight_updater._partial_ = True
66
+
67
+
68
+ # Legacy alias
69
+ SyncDataCollectorConfig = CollectorConfig
70
+
71
+
72
+ @dataclass
73
+ class AsyncCollectorConfig(BaseCollectorConfig):
74
+ """Configuration for asynchronous data collector (AsyncCollector)."""
75
+
76
+ create_env_fn: ConfigBase = field(
77
+ default_factory=partial(EnvConfig, _partial_=True)
78
+ )
79
+ policy: Any = None
80
+ policy_factory: Any = None
81
+ frames_per_batch: int | None = None
82
+ init_random_frames: int | None = 0
83
+ total_frames: int = -1
84
+ device: str | None = None
85
+ storing_device: str | None = None
86
+ policy_device: str | None = None
87
+ env_device: str | None = None
88
+ create_env_kwargs: dict | None = None
89
+ max_frames_per_traj: int | None = None
90
+ reset_at_each_iter: bool = False
91
+ postproc: ConfigBase | None = None
92
+ split_trajs: bool = False
93
+ exploration_type: str = "RANDOM"
94
+ set_truncated: bool = False
95
+ use_buffers: bool = False
96
+ replay_buffer: ConfigBase | None = None
97
+ extend_buffer: bool = False
98
+ trust_policy: bool = True
99
+ compile_policy: Any = None
100
+ cudagraph_policy: Any = None
101
+ no_cuda_sync: bool = False
102
+ weight_updater: Any = None
103
+ weight_sync_schemes: Any = None
104
+ track_policy_version: bool = False
105
+ local_init_rb: bool = False
106
+ _target_: str = "torchrl.collectors.AsyncCollector"
107
+ _partial_: bool = False
108
+
109
+ def __post_init__(self):
110
+ self.create_env_fn._partial_ = True
111
+ if self.policy_factory is not None:
112
+ self.policy_factory._partial_ = True
113
+ if self.weight_updater is not None:
114
+ self.weight_updater._partial_ = True
115
+
116
+
117
+ # Legacy alias
118
+ AsyncDataCollectorConfig = AsyncCollectorConfig
119
+
120
+
121
+ @dataclass
122
+ class MultiSyncCollectorConfig(BaseCollectorConfig):
123
+ """Configuration for multi-synchronous data collector (MultiSyncCollector)."""
124
+
125
+ create_env_fn: Any = MISSING
126
+ num_workers: int | None = None
127
+ policy: Any = None
128
+ policy_factory: Any = None
129
+ frames_per_batch: int | None = None
130
+ init_random_frames: int | None = 0
131
+ total_frames: int = -1
132
+ device: str | None = None
133
+ storing_device: str | None = None
134
+ policy_device: str | None = None
135
+ env_device: str | None = None
136
+ create_env_kwargs: dict | None = None
137
+ max_frames_per_traj: int | None = None
138
+ reset_at_each_iter: bool = False
139
+ postproc: ConfigBase | None = None
140
+ split_trajs: bool = False
141
+ exploration_type: str = "RANDOM"
142
+ set_truncated: bool = False
143
+ use_buffers: bool = False
144
+ replay_buffer: ConfigBase | None = None
145
+ extend_buffer: bool = False
146
+ trust_policy: bool = True
147
+ compile_policy: Any = None
148
+ cudagraph_policy: Any = None
149
+ no_cuda_sync: bool = False
150
+ weight_updater: Any = None
151
+ weight_sync_schemes: Any = None
152
+ track_policy_version: bool = False
153
+ local_init_rb: bool = False
154
+ _target_: str = "torchrl.collectors.MultiSyncCollector"
155
+ _partial_: bool = False
156
+
157
+ def __post_init__(self):
158
+ for env_cfg in self.create_env_fn:
159
+ env_cfg._partial_ = True
160
+ if self.policy_factory is not None:
161
+ self.policy_factory._partial_ = True
162
+ if self.weight_updater is not None:
163
+ self.weight_updater._partial_ = True
164
+
165
+
166
+ # Legacy alias
167
+ MultiSyncCollectorConfig = MultiSyncCollectorConfig
168
+
169
+
170
+ @dataclass
171
+ class MultiAsyncCollectorConfig(BaseCollectorConfig):
172
+ """Configuration for multi-asynchronous data collector (MultiAsyncCollector)."""
173
+
174
+ create_env_fn: Any = MISSING
175
+ num_workers: int | None = None
176
+ policy: Any = None
177
+ policy_factory: Any = None
178
+ frames_per_batch: int | None = None
179
+ init_random_frames: int | None = 0
180
+ total_frames: int = -1
181
+ device: str | None = None
182
+ storing_device: str | None = None
183
+ policy_device: str | None = None
184
+ env_device: str | None = None
185
+ create_env_kwargs: dict | None = None
186
+ max_frames_per_traj: int | None = None
187
+ reset_at_each_iter: bool = False
188
+ postproc: ConfigBase | None = None
189
+ split_trajs: bool = False
190
+ exploration_type: str = "RANDOM"
191
+ set_truncated: bool = False
192
+ use_buffers: bool = False
193
+ replay_buffer: ConfigBase | None = None
194
+ extend_buffer: bool = False
195
+ trust_policy: bool = True
196
+ compile_policy: Any = None
197
+ cudagraph_policy: Any = None
198
+ no_cuda_sync: bool = False
199
+ weight_updater: Any = None
200
+ weight_sync_schemes: Any = None
201
+ track_policy_version: bool = False
202
+ local_init_rb: bool = False
203
+ _target_: str = "torchrl.collectors.MultiAsyncCollector"
204
+ _partial_: bool = False
205
+
206
+ def __post_init__(self):
207
+ for env_cfg in self.create_env_fn:
208
+ env_cfg._partial_ = True
209
+ if self.policy_factory is not None:
210
+ self.policy_factory._partial_ = True
211
+ if self.weight_updater is not None:
212
+ self.weight_updater._partial_ = True
213
+
214
+
215
+ # Legacy alias
216
+ MultiAsyncCollectorConfig = MultiAsyncCollectorConfig
@@ -0,0 +1,41 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from abc import ABC, abstractmethod
9
+ from dataclasses import dataclass
10
+
11
+ from omegaconf import DictConfig
12
+
13
+
14
+ @dataclass
15
+ class ConfigBase(ABC):
16
+ """Abstract base class for all configuration classes.
17
+
18
+ This class serves as the foundation for all configuration classes in the
19
+ configurable configuration system, providing a common interface and structure.
20
+ """
21
+
22
+ @abstractmethod
23
+ def __post_init__(self) -> None:
24
+ """Post-initialization hook for configuration classes."""
25
+
26
+
27
+ @dataclass
28
+ class Config:
29
+ """A flexible config that allows arbitrary fields."""
30
+
31
+ def __init__(self, **kwargs):
32
+ self._config = DictConfig(kwargs)
33
+
34
+ def __getattr__(self, name):
35
+ return getattr(self._config, name)
36
+
37
+ def __setattr__(self, name, value):
38
+ if name == "_config":
39
+ super().__setattr__(name, value)
40
+ else:
41
+ setattr(self._config, name, value)
@@ -0,0 +1,308 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass, field
9
+ from typing import Any
10
+
11
+ from omegaconf import MISSING
12
+
13
+ from torchrl.trainers.algorithms.configs.common import ConfigBase
14
+
15
+
16
+ @dataclass
17
+ class WriterConfig(ConfigBase):
18
+ """Base configuration class for replay buffer writers."""
19
+
20
+ _target_: str = "torchrl.data.replay_buffers.Writer"
21
+
22
+ def __post_init__(self) -> None:
23
+ """Post-initialization hook for writer configurations."""
24
+
25
+
26
+ @dataclass
27
+ class RoundRobinWriterConfig(WriterConfig):
28
+ """Configuration for round-robin writer that distributes data across multiple storages."""
29
+
30
+ _target_: str = "torchrl.data.replay_buffers.RoundRobinWriter"
31
+ compilable: bool = False
32
+
33
+ def __post_init__(self) -> None:
34
+ """Post-initialization hook for round-robin writer configurations."""
35
+ super().__post_init__()
36
+
37
+
38
+ @dataclass
39
+ class SamplerConfig(ConfigBase):
40
+ """Base configuration class for replay buffer samplers."""
41
+
42
+ _target_: str = "torchrl.data.replay_buffers.Sampler"
43
+
44
+ def __post_init__(self) -> None:
45
+ """Post-initialization hook for sampler configurations."""
46
+
47
+
48
+ @dataclass
49
+ class RandomSamplerConfig(SamplerConfig):
50
+ """Configuration for random sampling from replay buffer."""
51
+
52
+ _target_: str = "torchrl.data.replay_buffers.RandomSampler"
53
+
54
+ def __post_init__(self) -> None:
55
+ """Post-initialization hook for random sampler configurations."""
56
+ super().__post_init__()
57
+
58
+
59
+ @dataclass
60
+ class WriterEnsembleConfig(WriterConfig):
61
+ """Configuration for ensemble writer that combines multiple writers."""
62
+
63
+ _target_: str = "torchrl.data.replay_buffers.WriterEnsemble"
64
+ writers: list[Any] = field(default_factory=list)
65
+ p: Any = None
66
+
67
+
68
+ @dataclass
69
+ class TensorDictMaxValueWriterConfig(WriterConfig):
70
+ """Configuration for TensorDict max value writer."""
71
+
72
+ _target_: str = "torchrl.data.replay_buffers.TensorDictMaxValueWriter"
73
+ rank_key: Any = None
74
+ reduction: str = "sum"
75
+
76
+
77
+ @dataclass
78
+ class TensorDictRoundRobinWriterConfig(WriterConfig):
79
+ """Configuration for TensorDict round-robin writer."""
80
+
81
+ _target_: str = "torchrl.data.replay_buffers.TensorDictRoundRobinWriter"
82
+ compilable: bool = False
83
+
84
+
85
+ @dataclass
86
+ class ImmutableDatasetWriterConfig(WriterConfig):
87
+ """Configuration for immutable dataset writer."""
88
+
89
+ _target_: str = "torchrl.data.replay_buffers.ImmutableDatasetWriter"
90
+
91
+
92
+ @dataclass
93
+ class SamplerEnsembleConfig(SamplerConfig):
94
+ """Configuration for ensemble sampler that combines multiple samplers."""
95
+
96
+ _target_: str = "torchrl.data.replay_buffers.SamplerEnsemble"
97
+ samplers: list[Any] = field(default_factory=list)
98
+ p: Any = None
99
+
100
+
101
+ @dataclass
102
+ class PrioritizedSliceSamplerConfig(SamplerConfig):
103
+ """Configuration for prioritized slice sampling from replay buffer."""
104
+
105
+ num_slices: int | None = None
106
+ slice_len: int | None = None
107
+ end_key: Any = None
108
+ traj_key: Any = None
109
+ ends: Any = None
110
+ trajectories: Any = None
111
+ cache_values: bool = False
112
+ truncated_key: Any = ("next", "truncated")
113
+ strict_length: bool = True
114
+ compile: Any = False
115
+ span: Any = False
116
+ use_gpu: Any = False
117
+ max_capacity: int | None = None
118
+ alpha: float | None = None
119
+ beta: float | None = None
120
+ eps: float | None = None
121
+ reduction: str | None = None
122
+ _target_: str = "torchrl.data.replay_buffers.PrioritizedSliceSampler"
123
+
124
+
125
+ @dataclass
126
+ class SliceSamplerWithoutReplacementConfig(SamplerConfig):
127
+ """Configuration for slice sampling without replacement."""
128
+
129
+ _target_: str = "torchrl.data.replay_buffers.SliceSamplerWithoutReplacement"
130
+ num_slices: int | None = None
131
+ slice_len: int | None = None
132
+ end_key: Any = None
133
+ traj_key: Any = None
134
+ ends: Any = None
135
+ trajectories: Any = None
136
+ cache_values: bool = False
137
+ truncated_key: Any = ("next", "truncated")
138
+ strict_length: bool = True
139
+ compile: Any = False
140
+ span: Any = False
141
+ use_gpu: Any = False
142
+
143
+
144
+ @dataclass
145
+ class SliceSamplerConfig(SamplerConfig):
146
+ """Configuration for slice sampling from replay buffer."""
147
+
148
+ _target_: str = "torchrl.data.replay_buffers.SliceSampler"
149
+ num_slices: int | None = None
150
+ slice_len: int | None = None
151
+ end_key: Any = None
152
+ traj_key: Any = None
153
+ ends: Any = None
154
+ trajectories: Any = None
155
+ cache_values: bool = False
156
+ truncated_key: Any = ("next", "truncated")
157
+ strict_length: bool = True
158
+ compile: Any = False
159
+ span: Any = False
160
+ use_gpu: Any = False
161
+
162
+
163
+ @dataclass
164
+ class PrioritizedSamplerConfig(SamplerConfig):
165
+ """Configuration for prioritized sampling from replay buffer."""
166
+
167
+ max_capacity: int | None = None
168
+ alpha: float | None = None
169
+ beta: float | None = None
170
+ eps: float | None = None
171
+ reduction: str | None = None
172
+ _target_: str = "torchrl.data.replay_buffers.PrioritizedSampler"
173
+
174
+
175
+ @dataclass
176
+ class SamplerWithoutReplacementConfig(SamplerConfig):
177
+ """Configuration for sampling without replacement."""
178
+
179
+ _target_: str = "torchrl.data.replay_buffers.SamplerWithoutReplacement"
180
+ drop_last: bool = False
181
+ shuffle: bool = True
182
+
183
+
184
+ @dataclass
185
+ class StorageConfig(ConfigBase):
186
+ """Base configuration class for replay buffer storage."""
187
+
188
+ _partial_: bool = False
189
+ _target_: str = "torchrl.data.replay_buffers.Storage"
190
+
191
+ def __post_init__(self) -> None:
192
+ """Post-initialization hook for storage configurations."""
193
+
194
+
195
+ @dataclass
196
+ class TensorStorageConfig(StorageConfig):
197
+ """Configuration for tensor-based storage in replay buffer."""
198
+
199
+ _target_: str = "torchrl.data.replay_buffers.TensorStorage"
200
+ max_size: int | None = None
201
+ storage: Any = None
202
+ device: Any = None
203
+ ndim: int | None = None
204
+ compilable: bool = False
205
+
206
+ def __post_init__(self) -> None:
207
+ """Post-initialization hook for tensor storage configurations."""
208
+ super().__post_init__()
209
+
210
+
211
+ @dataclass
212
+ class ListStorageConfig(StorageConfig):
213
+ """Configuration for list-based storage in replay buffer."""
214
+
215
+ _target_: str = "torchrl.data.replay_buffers.ListStorage"
216
+ max_size: int | None = None
217
+ compilable: bool = False
218
+
219
+
220
+ @dataclass
221
+ class StorageEnsembleWriterConfig(StorageConfig):
222
+ """Configuration for storage ensemble writer."""
223
+
224
+ _target_: str = "torchrl.data.replay_buffers.StorageEnsembleWriter"
225
+ writers: list[Any] = MISSING
226
+ transforms: list[Any] = MISSING
227
+
228
+
229
+ @dataclass
230
+ class LazyStackStorageConfig(StorageConfig):
231
+ """Configuration for lazy stack storage."""
232
+
233
+ _target_: str = "torchrl.data.replay_buffers.LazyStackStorage"
234
+ max_size: int | None = None
235
+ compilable: bool = False
236
+ stack_dim: int = 0
237
+
238
+
239
+ @dataclass
240
+ class StorageEnsembleConfig(StorageConfig):
241
+ """Configuration for storage ensemble."""
242
+
243
+ _target_: str = "torchrl.data.replay_buffers.StorageEnsemble"
244
+ storages: list[Any] = MISSING
245
+ transforms: list[Any] = MISSING
246
+
247
+
248
+ @dataclass
249
+ class LazyMemmapStorageConfig(StorageConfig):
250
+ """Configuration for lazy memory-mapped storage."""
251
+
252
+ _target_: str = "torchrl.data.replay_buffers.LazyMemmapStorage"
253
+ max_size: int | None = None
254
+ device: Any = None
255
+ ndim: int = 1
256
+ compilable: bool = False
257
+ shared_init: bool = False
258
+
259
+
260
+ @dataclass
261
+ class LazyTensorStorageConfig(StorageConfig):
262
+ """Configuration for lazy tensor storage."""
263
+
264
+ _target_: str = "torchrl.data.replay_buffers.LazyTensorStorage"
265
+ max_size: int | None = None
266
+ device: Any = None
267
+ ndim: int = 1
268
+ compilable: bool = False
269
+ shared_init: bool = False
270
+
271
+
272
+ @dataclass
273
+ class ReplayBufferBaseConfig(ConfigBase):
274
+ """Base configuration class for replay buffers."""
275
+
276
+ _partial_: bool = False
277
+
278
+ def __post_init__(self) -> None:
279
+ """Post-initialization hook for replay buffer configurations."""
280
+
281
+
282
+ @dataclass
283
+ class TensorDictReplayBufferConfig(ReplayBufferBaseConfig):
284
+ """Configuration for TensorDict-based replay buffer."""
285
+
286
+ _target_: str = "torchrl.data.replay_buffers.TensorDictReplayBuffer"
287
+ sampler: Any = None
288
+ storage: Any = None
289
+ writer: Any = None
290
+ transform: Any = None
291
+ batch_size: int | None = None
292
+
293
+ def __post_init__(self) -> None:
294
+ """Post-initialization hook for TensorDict replay buffer configurations."""
295
+ super().__post_init__()
296
+
297
+
298
+ @dataclass
299
+ class ReplayBufferConfig(ReplayBufferBaseConfig):
300
+ """Configuration for generic replay buffer."""
301
+
302
+ _target_: str = "torchrl.data.replay_buffers.ReplayBuffer"
303
+ sampler: Any = None
304
+ storage: Any = None
305
+ writer: Any = None
306
+ transform: Any = None
307
+ batch_size: int | None = None
308
+ shared: bool = False
@@ -0,0 +1,104 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass, field
9
+ from typing import Any
10
+
11
+ from omegaconf import MISSING
12
+
13
+ from torchrl.envs.common import EnvBase
14
+ from torchrl.trainers.algorithms.configs.common import ConfigBase
15
+
16
+
17
+ @dataclass
18
+ class EnvConfig(ConfigBase):
19
+ """Base configuration class for environments."""
20
+
21
+ _partial_: bool = False
22
+
23
+ def __post_init__(self) -> None:
24
+ """Post-initialization hook for environment configurations."""
25
+ self._partial_ = False
26
+
27
+
28
+ @dataclass
29
+ class BatchedEnvConfig(EnvConfig):
30
+ """Configuration for batched environments."""
31
+
32
+ create_env_fn: Any = MISSING
33
+ num_workers: int = 1
34
+ create_env_kwargs: dict = field(default_factory=dict)
35
+ batched_env_type: str = "parallel"
36
+ device: str | None = None
37
+ # batched_env_type: Literal["parallel", "serial", "async"] = "parallel"
38
+ _target_: str = "torchrl.trainers.algorithms.configs.envs.make_batched_env"
39
+
40
+ def __post_init__(self) -> None:
41
+ """Post-initialization hook for batched environment configurations."""
42
+ super().__post_init__()
43
+ if hasattr(self.create_env_fn, "_partial_"):
44
+ self.create_env_fn._partial_ = True
45
+
46
+
47
+ @dataclass
48
+ class TransformedEnvConfig(EnvConfig):
49
+ """Configuration for transformed environments."""
50
+
51
+ base_env: Any = MISSING
52
+ transform: Any = None
53
+ cache_specs: bool = True
54
+ auto_unwrap: bool | None = None
55
+ _target_: str = "torchrl.envs.TransformedEnv"
56
+
57
+
58
+ def make_batched_env(
59
+ create_env_fn, num_workers, batched_env_type="parallel", device=None, **kwargs
60
+ ):
61
+ """Create a batched environment.
62
+
63
+ Args:
64
+ create_env_fn: Function to create individual environments or environment instance.
65
+ num_workers: Number of worker environments.
66
+ batched_env_type: Type of batched environment (parallel, serial, async).
67
+ device: Device to place the batched environment on.
68
+ **kwargs: Additional keyword arguments.
69
+
70
+ Returns:
71
+ The created batched environment instance.
72
+ """
73
+ from torchrl.envs import AsyncEnvPool, ParallelEnv, SerialEnv
74
+
75
+ if create_env_fn is None:
76
+ raise ValueError("create_env_fn must be provided")
77
+
78
+ if num_workers is None:
79
+ raise ValueError("num_workers must be provided")
80
+
81
+ # If create_env_fn is a config object, create a lambda that instantiates it each time
82
+ if isinstance(create_env_fn, EnvBase):
83
+ # Already an instance (either instantiated config or actual env), wrap in lambda
84
+ env_instance = create_env_fn
85
+
86
+ def env_fn(env_instance=env_instance):
87
+ return env_instance
88
+
89
+ else:
90
+ env_fn = create_env_fn
91
+ assert callable(env_fn), env_fn
92
+
93
+ # Add device to kwargs if provided
94
+ if device is not None:
95
+ kwargs["device"] = device
96
+
97
+ if batched_env_type == "parallel":
98
+ return ParallelEnv(num_workers, env_fn, **kwargs)
99
+ elif batched_env_type == "serial":
100
+ return SerialEnv(num_workers, env_fn, **kwargs)
101
+ elif batched_env_type == "async":
102
+ return AsyncEnvPool([env_fn] * num_workers, **kwargs)
103
+ else:
104
+ raise ValueError(f"Unknown batched_env_type: {batched_env_type}")