torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,447 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import importlib
9
+ from typing import Any
10
+
11
+ import numpy as np
12
+ import torch
13
+ from tensordict import TensorDict, TensorDictBase
14
+ from torchrl._utils import logger as torchrl_logger
15
+ from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded
16
+ from torchrl.envs.common import _EnvWrapper
17
+ from torchrl.envs.utils import _classproperty
18
+
19
+ _has_envpool = importlib.util.find_spec("envpool") is not None
20
+
21
+
22
+ class MultiThreadedEnvWrapper(_EnvWrapper):
23
+ """Wrapper for envpool-based multithreaded environments.
24
+
25
+ GitHub: https://github.com/sail-sg/envpool
26
+
27
+ Paper: https://arxiv.org/abs/2206.10558
28
+
29
+ EnvPool environments auto-reset internally when episodes end. This wrapper
30
+ handles that behavior by caching the auto-reset observations and returning
31
+ them appropriately in step_and_maybe_reset.
32
+
33
+ Args:
34
+ env (envpool.python.envpool.EnvPoolMixin): the envpool to wrap.
35
+ categorical_action_encoding (bool, optional): if ``True``, categorical
36
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
37
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
38
+ Defaults to ``False``.
39
+
40
+ Keyword Args:
41
+ disable_env_checker (bool, optional): for gym > 0.24 only. If ``True`` (default
42
+ for these versions), the environment checker won't be run.
43
+ frame_skip (int, optional): if provided, indicates for how many steps the
44
+ same action is to be repeated. The observation returned will be the
45
+ last observation of the sequence, whereas the reward will be the sum
46
+ of rewards across steps.
47
+ device (torch.device, optional): if provided, the device on which the data
48
+ is to be cast. Defaults to ``torch.device("cpu")``.
49
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
50
+ for envs to be ``done`` just after :meth:`reset` is called.
51
+ Defaults to ``False``.
52
+
53
+ Attributes:
54
+ batch_size: The number of envs run simultaneously.
55
+
56
+ Examples:
57
+ >>> import envpool
58
+ >>> from torchrl.envs import MultiThreadedEnvWrapper
59
+ >>> env_base = envpool.make(
60
+ ... task_id="Pong-v5", env_type="gym", num_envs=4, gym_reset_return_info=True
61
+ ... )
62
+ >>> env = MultiThreadedEnvWrapper(envpool_env)
63
+ >>> env.reset()
64
+ >>> env.rand_step()
65
+
66
+ """
67
+
68
+ _verbose: bool = False
69
+
70
+ @_classproperty
71
+ def lib(cls):
72
+ import envpool
73
+
74
+ return envpool
75
+
76
+ def __init__(
77
+ self,
78
+ env: envpool.python.envpool.EnvPoolMixin | None = None, # noqa: F821
79
+ **kwargs,
80
+ ):
81
+ if not _has_envpool:
82
+ raise ImportError(
83
+ "envpool python package or one of its dependencies (gym, treevalue) were not found. Please install these dependencies."
84
+ )
85
+ if env is not None:
86
+ kwargs["env"] = env
87
+ self.num_workers = env.config["num_envs"]
88
+ # For synchronous mode batch size is equal to the number of workers
89
+ self.batch_size = torch.Size([self.num_workers])
90
+ super().__init__(**kwargs)
91
+
92
+ # Buffer to keep the latest observation for each worker
93
+ # It's a TensorDict when the observation consists of several variables, e.g. "position" and "velocity"
94
+ self.obs: torch.tensor | TensorDict = self.observation_spec.zero()
95
+
96
+ def _check_kwargs(self, kwargs: dict):
97
+ if "env" not in kwargs:
98
+ raise TypeError("Could not find environment key 'env' in kwargs.")
99
+ env = kwargs["env"]
100
+ import envpool
101
+
102
+ if not isinstance(env, (envpool.python.envpool.EnvPoolMixin,)):
103
+ raise TypeError("env is not of type 'envpool.python.envpool.EnvPoolMixin'.")
104
+
105
+ def _build_env(self, env: envpool.python.envpool.EnvPoolMixin): # noqa: F821
106
+ return env
107
+
108
+ def _make_specs(
109
+ self, env: envpool.python.envpool.EnvPoolMixin # noqa: F821
110
+ ) -> None: # noqa: F821
111
+ from torchrl.envs.libs.gym import set_gym_backend
112
+
113
+ with set_gym_backend("gym"):
114
+ self.action_spec = self._get_action_spec()
115
+ output_spec = self._get_output_spec()
116
+ self.observation_spec = output_spec["full_observation_spec"]
117
+ self.reward_spec = output_spec["full_reward_spec"]
118
+ self.done_spec = output_spec["full_done_spec"]
119
+
120
+ def _init_env(self) -> int | None:
121
+ pass
122
+
123
+ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
124
+ if tensordict is not None:
125
+ reset_workers = tensordict.get("_reset", None)
126
+ else:
127
+ reset_workers = None
128
+ if reset_workers is not None:
129
+ reset_data = self._env.reset(np.where(reset_workers.cpu().numpy())[0])
130
+ else:
131
+ reset_data = self._env.reset()
132
+ tensordict_out = self._transform_reset_output(reset_data, reset_workers)
133
+ self.is_closed = False
134
+ return tensordict_out
135
+
136
+ @torch.no_grad()
137
+ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
138
+ action = tensordict.get(self.action_key)
139
+ # Action needs to be moved to CPU and converted to numpy before being passed to envpool
140
+ action = action.to(torch.device("cpu"))
141
+ step_output = self._env.step(action.numpy())
142
+ tensordict_out = self._transform_step_output(step_output)
143
+ return tensordict_out
144
+
145
+ def step_and_maybe_reset(
146
+ self, tensordict: TensorDictBase
147
+ ) -> tuple[TensorDictBase, TensorDictBase]:
148
+ """Runs a step and handles envpool's internal auto-reset.
149
+
150
+ EnvPool auto-resets internally when episodes end. When done=True:
151
+ - The observation returned is the final observation of the ending episode
152
+ - The NEXT call to step() returns the first observation of a new episode
153
+
154
+ This method handles this by skipping explicit reset() calls for done
155
+ environments. EnvPool maintains its own internal state, so the next
156
+ step() will automatically return the reset observation.
157
+
158
+ Note: The observation in tensordict_ for done envs will be the final
159
+ observation (not the reset observation). This is acceptable because
160
+ envpool ignores the input observation and uses its internal state.
161
+ """
162
+ # Perform the step
163
+ tensordict = self.step(tensordict)
164
+
165
+ # Move data from "next" to root for the next iteration
166
+ tensordict_ = self._step_mdp(tensordict)
167
+
168
+ # EnvPool auto-resets internally, so we skip calling reset().
169
+ # However, we need to clear the done flags in tensordict_ since envpool
170
+ # has already reset those environments. The next step() will return
171
+ # the reset observations automatically.
172
+ for key in self.done_keys:
173
+ if key in tensordict_.keys(True):
174
+ tensordict_.set(key, torch.zeros_like(tensordict_.get(key)))
175
+
176
+ return tensordict, tensordict_
177
+
178
+ def _get_action_spec(self) -> TensorSpec:
179
+ # local import to avoid importing gym in the script
180
+ from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
181
+
182
+ # Envpool provides Gym-compatible specs as env.spec.action_space and
183
+ # DM_Control-compatible specs as env.spec.action_spec(). We use the Gym ones.
184
+
185
+ # Gym specs produced by EnvPool don't contain batch_size, we add it to satisfy checks in EnvBase
186
+ action_spec = _gym_to_torchrl_spec_transform(
187
+ self._env.spec.action_space,
188
+ device=self.device,
189
+ categorical_action_encoding=True,
190
+ )
191
+ action_spec = self._add_shape_to_spec(action_spec)
192
+ return action_spec
193
+
194
+ def _get_output_spec(self) -> TensorSpec:
195
+ return Composite(
196
+ full_observation_spec=self._get_observation_spec(),
197
+ full_reward_spec=self._get_reward_spec(),
198
+ full_done_spec=self._get_done_spec(),
199
+ shape=(self.num_workers,),
200
+ device=self.device,
201
+ )
202
+
203
+ def _get_observation_spec(self) -> TensorSpec:
204
+ # local import to avoid importing gym in the script
205
+ from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
206
+
207
+ # Gym specs produced by EnvPool don't contain batch_size, we add it to satisfy checks in EnvBase
208
+ observation_spec = _gym_to_torchrl_spec_transform(
209
+ self._env.spec.observation_space,
210
+ device=self.device,
211
+ categorical_action_encoding=True,
212
+ )
213
+ observation_spec = self._add_shape_to_spec(observation_spec)
214
+ if isinstance(observation_spec, Composite):
215
+ return observation_spec
216
+ return Composite(
217
+ observation=observation_spec,
218
+ shape=(self.num_workers,),
219
+ device=self.device,
220
+ )
221
+
222
+ def _add_shape_to_spec(self, spec: TensorSpec) -> TensorSpec:
223
+ return spec.expand((self.num_workers, *spec.shape))
224
+
225
+ def _get_reward_spec(self) -> TensorSpec:
226
+ return Unbounded(
227
+ device=self.device,
228
+ shape=self.batch_size,
229
+ )
230
+
231
+ def _get_done_spec(self) -> TensorSpec:
232
+ spec = Categorical(
233
+ 2,
234
+ device=self.device,
235
+ shape=self.batch_size,
236
+ dtype=torch.bool,
237
+ )
238
+ return Composite(
239
+ done=spec,
240
+ truncated=spec.clone(),
241
+ terminated=spec.clone(),
242
+ shape=self.batch_size,
243
+ device=self.device,
244
+ )
245
+
246
+ def __repr__(self) -> str:
247
+ return f"{self.__class__.__name__}(num_workers={self.num_workers}, device={self.device})"
248
+
249
+ def _transform_reset_output(
250
+ self,
251
+ envpool_output: tuple[treevalue.TreeValue | np.ndarray, Any], # noqa: F821
252
+ reset_workers: torch.Tensor | None,
253
+ ):
254
+ """Process output of envpool env.reset."""
255
+ import treevalue
256
+
257
+ observation, _ = envpool_output
258
+ if reset_workers is not None:
259
+ # Only specified workers were reset - need to set observation buffer values only for them
260
+ if isinstance(observation, treevalue.TreeValue):
261
+ # If observation contain several fields, it will be returned as treevalue.TreeValue.
262
+ # Convert to treevalue.FastTreeValue to allow indexing
263
+ observation = treevalue.FastTreeValue(observation)
264
+ self.obs[reset_workers] = self._treevalue_or_numpy_to_tensor_or_dict(
265
+ observation
266
+ )
267
+ else:
268
+ # All workers were reset - rewrite the whole observation buffer
269
+ self.obs = TensorDict(
270
+ self._treevalue_or_numpy_to_tensor_or_dict(observation),
271
+ self.batch_size,
272
+ device=self.device,
273
+ )
274
+
275
+ obs = self.obs.clone(False)
276
+ obs.update(self.full_done_spec.zero())
277
+ return obs
278
+
279
+ def _transform_step_output(
280
+ self, envpool_output: tuple[Any, Any, Any, ...]
281
+ ) -> TensorDict:
282
+ """Process output of envpool env.step."""
283
+ out = envpool_output
284
+ if len(out) == 4:
285
+ obs, reward, done, info = out
286
+ terminated = done
287
+ truncated = info.get("TimeLimit.truncated", done * 0)
288
+ elif len(out) == 5:
289
+ obs, reward, terminated, truncated, info = out
290
+ done = terminated | truncated
291
+ else:
292
+ raise TypeError(
293
+ f"The output of step was had {len(out)} elements, but only 4 or 5 are supported."
294
+ )
295
+ obs = self._treevalue_or_numpy_to_tensor_or_dict(obs)
296
+ reward_and_done = {self.reward_key: torch.as_tensor(reward)}
297
+ reward_and_done["done"] = done
298
+ reward_and_done["terminated"] = terminated
299
+ reward_and_done["truncated"] = truncated
300
+ obs.update(reward_and_done)
301
+ self.obs = tensordict_out = TensorDict(
302
+ obs,
303
+ batch_size=self.batch_size,
304
+ device=self.device,
305
+ )
306
+ return tensordict_out
307
+
308
+ def _treevalue_or_numpy_to_tensor_or_dict(
309
+ self, x: treevalue.TreeValue | np.ndarray # noqa: F821
310
+ ) -> torch.Tensor | dict[str, torch.Tensor]:
311
+ """Converts observation returned by EnvPool.
312
+
313
+ EnvPool step and reset return observation as a numpy array or a TreeValue of numpy arrays, which we convert
314
+ to a tensor or a dictionary of tensors. Currently only supports depth 1 trees, but can easily be extended to
315
+ arbitrary depth if necessary.
316
+ """
317
+ import treevalue
318
+
319
+ if isinstance(x, treevalue.TreeValue):
320
+ ret = self._treevalue_to_dict(x)
321
+ elif not isinstance(x, dict):
322
+ ret = {"observation": torch.as_tensor(x)}
323
+ else:
324
+ ret = x
325
+ return ret
326
+
327
+ def _treevalue_to_dict(
328
+ self, tv: treevalue.TreeValue # noqa: F821
329
+ ) -> dict[str, Any]:
330
+ """Converts TreeValue to a dictionary.
331
+
332
+ Currently only supports depth 1 trees, but can easily be extended to arbitrary depth if necessary.
333
+ """
334
+ import treevalue
335
+
336
+ return {k[0]: torch.as_tensor(v) for k, v in treevalue.flatten(tv)}
337
+
338
+ def _set_seed(self, seed: int | None) -> None:
339
+ if seed is not None:
340
+ torchrl_logger.info(
341
+ "MultiThreadedEnvWrapper._set_seed ignored, as setting seed in an existing envorinment is not\
342
+ supported by envpool. Please create a new environment, passing the seed to the constructor."
343
+ )
344
+
345
+
346
+ class MultiThreadedEnv(MultiThreadedEnvWrapper):
347
+ """Multithreaded execution of environments based on EnvPool.
348
+
349
+ GitHub: https://github.com/sail-sg/envpool
350
+
351
+ Paper: https://arxiv.org/abs/2206.10558
352
+
353
+ An alternative to ParallelEnv based on multithreading. It's faster, as it doesn't require new process spawning, but
354
+ less flexible, as it only supports environments implemented in EnvPool library.
355
+ Currently, only supports synchronous execution mode, when the batch size is equal to the number of workers, see
356
+ https://envpool.readthedocs.io/en/latest/content/python_interface.html#batch-size.
357
+
358
+ Args:
359
+ num_workers (int): The number of envs to run simultaneously. Will be
360
+ identical to the content of `~.batch_size`.
361
+ env_name (str): name of the environment to build.
362
+
363
+ Keyword Args:
364
+ create_env_kwargs (Dict[str, Any], optional): kwargs to be passed to envpool
365
+ environment constructor.
366
+ categorical_action_encoding (bool, optional): if ``True``, categorical
367
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
368
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
369
+ Defaults to ``False``.
370
+ disable_env_checker (bool, optional): for gym > 0.24 only. If ``True`` (default
371
+ for these versions), the environment checker won't be run.
372
+ frame_skip (int, optional): if provided, indicates for how many steps the
373
+ same action is to be repeated. The observation returned will be the
374
+ last observation of the sequence, whereas the reward will be the sum
375
+ of rewards across steps.
376
+ device (torch.device, optional): if provided, the device on which the data
377
+ is to be cast. Defaults to ``torch.device("cpu")``.
378
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
379
+ for envs to be ``done`` just after :meth:`reset` is called.
380
+ Defaults to ``False``.
381
+
382
+ Examples:
383
+ >>> env = MultiThreadedEnv(num_workers=3, env_name="Pendulum-v1")
384
+ >>> env.reset()
385
+ >>> env.rand_step()
386
+ >>> env.rollout(5)
387
+ >>> env.close()
388
+
389
+ """
390
+
391
+ def __init__(
392
+ self,
393
+ num_workers: int,
394
+ env_name: str,
395
+ *,
396
+ create_env_kwargs: dict[str, Any] | None = None,
397
+ **kwargs,
398
+ ):
399
+ self.env_name = env_name.replace("ALE/", "") # Naming convention of EnvPool
400
+ self.num_workers = num_workers
401
+ self.batch_size = torch.Size([num_workers])
402
+ self.create_env_kwargs = create_env_kwargs or {}
403
+
404
+ kwargs["num_workers"] = num_workers
405
+ kwargs["env_name"] = self.env_name
406
+ kwargs["create_env_kwargs"] = create_env_kwargs
407
+ super().__init__(**kwargs)
408
+
409
+ def _build_env(
410
+ self,
411
+ env_name: str,
412
+ num_workers: int,
413
+ create_env_kwargs: dict[str, Any] | None,
414
+ ) -> Any:
415
+ import envpool
416
+
417
+ create_env_kwargs = create_env_kwargs or {}
418
+ # EnvPool requires max_num_players to be set for single-player environments
419
+ if "max_num_players" not in create_env_kwargs:
420
+ create_env_kwargs["max_num_players"] = 1
421
+ env = envpool.make(
422
+ task_id=env_name,
423
+ env_type="gym",
424
+ num_envs=num_workers,
425
+ gym_reset_return_info=True,
426
+ **create_env_kwargs,
427
+ )
428
+ return super()._build_env(env)
429
+
430
+ def _set_seed(self, seed: int | None) -> None:
431
+ """Library EnvPool only supports setting a seed by recreating the environment."""
432
+ if seed is not None:
433
+ torchrl_logger.debug("Recreating EnvPool environment to set seed.")
434
+ self.create_env_kwargs["seed"] = seed
435
+ self._env = self._build_env(
436
+ env_name=self.env_name,
437
+ num_workers=self.num_workers,
438
+ create_env_kwargs=self.create_env_kwargs,
439
+ )
440
+
441
+ def _check_kwargs(self, kwargs: dict):
442
+ for arg in ["num_workers", "env_name", "create_env_kwargs"]:
443
+ if arg not in kwargs:
444
+ raise TypeError(f"Expected '{arg}' to be part of kwargs")
445
+
446
+ def __repr__(self) -> str:
447
+ return f"{self.__class__.__name__}(env={self.env_name}, num_workers={self.num_workers}, device={self.device})"