torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,22 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from torchrl._utils import logger
9
+
10
+
11
+ def register_fp32_overrides() -> None:
12
+ """Register FP32 overrides for vLLM models."""
13
+ from vllm.model_executor.models.registry import ModelRegistry
14
+
15
+ # ======= Register models here =======
16
+ # Register Qwen3 models with FP32 override
17
+ ModelRegistry.register_model(
18
+ "Qwen3ForCausalLM",
19
+ "torchrl.modules.llm.backends._models:Qwen3ForCausalLMFP32",
20
+ )
21
+
22
+ logger.info("Registered Qwen3 FP32 model overrides")
@@ -0,0 +1,446 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Synchronous vLLM backend for TorchRL.
7
+
8
+ From https://docs.vllm.ai/en/v0.7.0/getting_started/examples/rlhf.html
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ from collections.abc import Iterator
15
+ from contextlib import nullcontext
16
+
17
+ import torch
18
+ from torchrl._utils import logger as torchrl_logger
19
+ from torchrl.modules.llm.utils import _cuda_visible_devices
20
+
21
+ from .base import RLvLLMEngine
22
+ from .vllm_utils import stateless_init_process_group
23
+
24
+ try:
25
+ from vllm import LLM
26
+ from vllm.worker.worker import Worker
27
+
28
+ _has_vllm = True
29
+ except ImportError:
30
+
31
+ class LLM:
32
+ """Placeholder for LLM class when vLLM is not installed."""
33
+
34
+ class Worker:
35
+ """Placeholder for Worker class when vLLM is not installed."""
36
+
37
+ _has_vllm = False
38
+
39
+ # get_open_port may not be available in all vLLM versions
40
+ try:
41
+ from vllm.utils import get_open_port
42
+ except ImportError:
43
+
44
+ def get_open_port():
45
+ """Fallback get_open_port using standard library."""
46
+ import socket
47
+
48
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
49
+ s.bind(("", 0))
50
+ return s.getsockname()[1]
51
+
52
+
53
+ class _vLLMWorker(Worker):
54
+ """Private vLLM worker for Ray.
55
+
56
+ vLLMParameterServer will always take rank 0 in the stateless process group
57
+ initialized by this worker. And the tp ranks associated with the LLM class
58
+ will be in the range [1, tp_size].
59
+ """
60
+
61
+ def __init__(self, *args, **kwargs):
62
+ if not _has_vllm:
63
+ raise ImportError(
64
+ "vllm is not installed. Please install it with `pip install vllm`."
65
+ )
66
+
67
+ torchrl_logger.info(f"=> in {type(self).__name__}.__init__")
68
+ torchrl_logger.info(f"visible devices {os.getenv('CUDA_VISIBLE_DEVICES')}")
69
+ torchrl_logger.info(f"device count {torch.cuda.device_count()}")
70
+ super().__init__(*args, **kwargs)
71
+
72
+ def init_weight_update_group(
73
+ self, master_address, master_port, rank_offset, world_size
74
+ ):
75
+ from vllm.distributed.parallel_state import get_world_group
76
+
77
+ torchrl_logger.info(f"=> in {type(self).__name__}.init_weight_update_group")
78
+
79
+ # Get the local rank within the tensor parallel group
80
+ tp_group = get_world_group()
81
+ local_rank = tp_group.rank
82
+ torchrl_logger.info(f"Local rank in tensor parallel group: {local_rank}")
83
+
84
+ # Calculate the global rank for weight update group
85
+ # rank_offset is 1, so ranks will be [1, 2] for tp_size=2
86
+ rank = local_rank + rank_offset
87
+ torchrl_logger.info(
88
+ f"Initializing {type(self).__name__} weight update group with "
89
+ f"{master_address=}, {master_port=}, {rank=}, {world_size=}, device={self.device}"
90
+ )
91
+
92
+ self.model_update_group = stateless_init_process_group(
93
+ master_address,
94
+ master_port,
95
+ rank,
96
+ world_size,
97
+ self.device,
98
+ )
99
+
100
+ torchrl_logger.info(f"{type(self).__name__}.init_weight_update_group success")
101
+
102
+ def update_weight_broadcast(self, name, dtype, shape):
103
+ weight = torch.empty(shape, dtype=dtype, device="cuda")
104
+ self.model_update_group.broadcast(
105
+ weight, src=0, stream=torch.cuda.current_stream()
106
+ )
107
+
108
+ self.model_runner.model.load_weights(weights=[(name, weight)])
109
+ del weight
110
+
111
+ def update_weight(self, name, weight):
112
+ self.model_runner.model.load_weights(weights=[(name, weight)])
113
+ del weight
114
+
115
+ def check_weights_changed(self):
116
+ """Check if the weights are updated to 0."""
117
+ # TODO: This is a test and should be treated as such
118
+ weights_updated = True
119
+ for p in self.model_runner.model.parameters():
120
+ weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
121
+ return weights_updated
122
+
123
+
124
+ class _LLMOnDevice(LLM):
125
+ """Private wrapper around `vllm.LLM` to control its placement devices."""
126
+
127
+ def __init__(self, *args, bundle_indices: list | None = None, **kwargs):
128
+ if not _has_vllm:
129
+ raise ImportError(
130
+ "vllm is not installed. Please install it with `pip install vllm`."
131
+ )
132
+
133
+ # Stop Ray from manipulating CUDA_VISIBLE_DEVICES at the top-level
134
+ os.environ.pop("CUDA_VISIBLE_DEVICES", None)
135
+
136
+ # Configure GPU utilization for Ray workers
137
+ if bundle_indices is not None:
138
+ os.environ[
139
+ "VLLM_RAY_PER_WORKER_GPUS"
140
+ ] = "0.4" # Allow multiple workers per GPU
141
+ os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
142
+ torchrl_logger.info(
143
+ f"Initializing LLM with bundle_indices={bundle_indices}"
144
+ )
145
+
146
+ self.args = args
147
+ self.kwargs = kwargs
148
+
149
+ def initialize(self):
150
+ # Let vLLM handle device placement
151
+ super().__init__(*self.args, **self.kwargs)
152
+ return True
153
+
154
+
155
+ class RayLLMWorker(RLvLLMEngine):
156
+ """A wrapper for Ray-based vLLM workers that implements the RLvLLMEngine interface.
157
+
158
+ This class wraps a Ray actor handle for a vLLM worker and provides the
159
+ standardized interface for weight updates and configuration access.
160
+ """
161
+
162
+ def __init__(self, ray_actor, tensor_parallel_size: int, model_name: str):
163
+ self.ray_actor = ray_actor
164
+ self._tensor_parallel_size = tensor_parallel_size
165
+ self._model_name = model_name
166
+ self._master_address = None
167
+ self._master_port = None
168
+
169
+ def get_tp_size(self) -> int:
170
+ """Get the tensor parallel size."""
171
+ return self._tensor_parallel_size
172
+
173
+ def get_model_metadata(self) -> dict[str, tuple[torch.dtype, torch.Size]]:
174
+ """Get model parameter metadata.
175
+
176
+ For Ray workers, this requires loading the model to inspect parameters.
177
+ Currently returns empty dict - should be implemented when needed.
178
+ """
179
+ # TODO: Implement metadata extraction from Ray worker
180
+ torchrl_logger.warning(
181
+ "RayLLMWorker.get_model_metadata() not implemented - returning empty dict"
182
+ )
183
+ return {}
184
+
185
+ def get_master_address(self) -> str:
186
+ """Get the master address for weight synchronization."""
187
+ if self._master_address is None:
188
+ self._master_address = "localhost"
189
+ return self._master_address
190
+
191
+ def get_master_port(self) -> int:
192
+ """Get the master port for weight synchronization."""
193
+ if self._master_port is None:
194
+ self._master_port = get_open_port() if callable(get_open_port) else 29500
195
+ return self._master_port
196
+
197
+ def init_weight_update_group(self) -> None:
198
+ """Initialize the weight update communication group."""
199
+ weight_sync_world_size = self._tensor_parallel_size + 1
200
+
201
+ try:
202
+ import ray
203
+
204
+ # Initialize weight update group on the Ray worker
205
+ ray.get(
206
+ self.ray_actor.collective_rpc.remote(
207
+ "init_weight_update_group",
208
+ args=(
209
+ self.get_master_address(),
210
+ self.get_master_port(),
211
+ 1,
212
+ weight_sync_world_size,
213
+ ),
214
+ )
215
+ )
216
+ torchrl_logger.info("Ray worker weight update group initialized")
217
+ except ImportError:
218
+ raise ImportError(
219
+ "Ray not available for weight update group initialization"
220
+ )
221
+
222
+ def update_weights(self, weights: Iterator[tuple[str, torch.Tensor]]) -> None:
223
+ """Update model weights via the Ray worker.
224
+
225
+ Args:
226
+ weights: Iterator yielding (parameter_name, tensor) tuples
227
+ """
228
+ try:
229
+ import ray
230
+
231
+ # Convert iterator to list for Ray serialization
232
+ weights_list = list(weights)
233
+
234
+ if not weights_list:
235
+ torchrl_logger.warning("No weights provided for update")
236
+ return
237
+
238
+ torchrl_logger.info(
239
+ f"Updating {len(weights_list)} parameters on Ray worker"
240
+ )
241
+
242
+ # Send weights to the Ray worker
243
+ remotes = []
244
+ for name, weight in weights_list:
245
+ remotes.append(
246
+ self.ray_actor.collective_rpc.remote(
247
+ "update_weight", args=(name, weight.to("cuda"))
248
+ )
249
+ )
250
+
251
+ ray.get(remotes)
252
+ torchrl_logger.info("Ray worker weight update completed")
253
+
254
+ except ImportError:
255
+ raise ImportError("Ray not available for weight updates")
256
+
257
+ # Delegate generation methods to the Ray actor
258
+ def generate(self, *args, **kwargs):
259
+ """Generate text using the Ray worker."""
260
+ try:
261
+ import ray
262
+
263
+ return ray.get(self.ray_actor.generate.remote(*args, **kwargs))
264
+ except ImportError:
265
+ raise ImportError("Ray not available for generation")
266
+
267
+
268
+ class LocalLLMWrapper(RLvLLMEngine):
269
+ """A wrapper for local vLLM.LLM instances that implements the RLvLLMEngine interface.
270
+
271
+ This wrapper provides the standardized interface for local vLLM instances,
272
+ though weight updates are not applicable since the model is in the same process.
273
+ """
274
+
275
+ def __init__(self, llm_instance, tensor_parallel_size: int, model_name: str):
276
+ self.llm_instance = llm_instance
277
+ self._tensor_parallel_size = tensor_parallel_size
278
+ self._model_name = model_name
279
+ self._master_address = None
280
+ self._master_port = None
281
+
282
+ def get_tp_size(self) -> int:
283
+ """Get the tensor parallel size."""
284
+ return self._tensor_parallel_size
285
+
286
+ def get_model_metadata(self) -> dict[str, tuple[torch.dtype, torch.Size]]:
287
+ """Get model parameter metadata.
288
+
289
+ For local LLM instances, this would require accessing the model directly.
290
+ Currently returns empty dict.
291
+ """
292
+ # TODO: Implement metadata extraction from local LLM
293
+ torchrl_logger.warning(
294
+ "LocalLLMWrapper.get_model_metadata() not implemented - returning empty dict"
295
+ )
296
+ return {}
297
+
298
+ def get_master_address(self) -> str:
299
+ """Get the master address for weight synchronization."""
300
+ if self._master_address is None:
301
+ self._master_address = "localhost"
302
+ return self._master_address
303
+
304
+ def get_master_port(self) -> int:
305
+ """Get the master port for weight synchronization."""
306
+ if self._master_port is None:
307
+ self._master_port = get_open_port() if callable(get_open_port) else 29500
308
+ return self._master_port
309
+
310
+ def init_weight_update_group(self) -> None:
311
+ """Initialize the weight update communication group."""
312
+ torchrl_logger.info("Local LLM weight update group initialized (no-op)")
313
+
314
+ def update_weights(self, weights: Iterator[tuple[str, torch.Tensor]]) -> None:
315
+ """Update model weights.
316
+
317
+ For local LLM instances, weight updates are not applicable since
318
+ the model is in the same process space.
319
+ """
320
+ weights_list = list(weights)
321
+ torchrl_logger.info(
322
+ f"Local LLM weight update (no-op) for {len(weights_list)} parameters"
323
+ )
324
+
325
+ # Delegate generation methods to the local LLM
326
+ def generate(self, *args, **kwargs):
327
+ """Generate text using the local LLM."""
328
+ return self.llm_instance.generate(*args, **kwargs)
329
+
330
+
331
+ def make_vllm_worker(
332
+ *,
333
+ model_name: str,
334
+ devices: list[torch.device | int] | None = None,
335
+ num_devices: int | None = None,
336
+ make_ray_worker: bool = True,
337
+ enforce_eager: bool = False,
338
+ enable_fp32_output: bool = False,
339
+ **kwargs,
340
+ ) -> RayLLMWorker | LocalLLMWrapper:
341
+ """Creates a vLLM inference engine with tensor parallelism support.
342
+
343
+ Args:
344
+ model_name (str): The model name to pass to vLLM.LLM.
345
+ devices (list[torch.device | int], optional): List of devices to use. Exclusive with num_devices.
346
+ num_devices (int, optional): Number of devices to use. Exclusive with devices.
347
+ make_ray_worker (bool, optional): Whether to create a Ray actor. Defaults to True.
348
+ enforce_eager (bool, optional): Whether to enforce eager execution. Defaults to `False`.
349
+ enable_fp32_output (bool, optional): Whether to enable FP32 output for the final layer. Defaults to False.
350
+ This can help with numerical stability for certain models. Requires model-specific support in
351
+ torchrl.modules.llm.backends._models.
352
+ **kwargs: Additional arguments passed to vLLM.LLM.__init__.
353
+
354
+ Returns:
355
+ RayLLMWorker | LocalLLMWrapper: Either a Ray worker wrapper or a local LLM wrapper, both implementing RLvLLMEngine.
356
+
357
+ Example:
358
+ >>> # Create a 2-GPU tensor parallel worker with Ray
359
+ >>> worker = make_vllm_worker("Qwen/Qwen2.5-3B", num_devices=2)
360
+ >>> # Create a local LLM instance on GPU 1
361
+ >>> llm = make_vllm_worker("Qwen/Qwen2.5-3B", devices=[1], make_ray_worker=False)
362
+ >>> # Create with FP32 output enabled
363
+ >>> worker = make_vllm_worker("Qwen/Qwen2.5-3B", num_devices=2, enable_fp32_output=True)
364
+ """
365
+ if not _has_vllm:
366
+ raise ImportError(
367
+ "vllm is not installed. Please install it with `pip install vllm`."
368
+ )
369
+
370
+ # Set FP32 output environment variable if requested
371
+ if enable_fp32_output:
372
+ os.environ["VLLM_ENABLE_FP32_OUTPUT"] = "1"
373
+ torchrl_logger.info(
374
+ "Enabled FP32 output for vLLM (VLLM_ENABLE_FP32_OUTPUT=1). "
375
+ "This will use FP32 for the final output layer if the model supports it."
376
+ )
377
+
378
+ # Handle device specification
379
+ if num_devices is not None and devices is not None:
380
+ raise ValueError("Cannot specify both num_devices and devices")
381
+ if num_devices is not None:
382
+ devices = None
383
+ elif devices is None:
384
+ devices = [0] # Default to first GPU
385
+ num_devices = 1
386
+ elif len(devices) > 1:
387
+ # Convert devices to indices
388
+ devices = [
389
+ torch.device(device).index if not isinstance(device, int) else device
390
+ for device in devices
391
+ ]
392
+ num_devices = len(devices)
393
+
394
+ # Validate devices
395
+ if devices is not None:
396
+ for d in devices:
397
+ if not isinstance(d, int) or d < 0 or d >= torch.cuda.device_count():
398
+ raise ValueError(f"Invalid device index: {d}")
399
+
400
+ if make_ray_worker:
401
+ import ray
402
+
403
+ if not ray.is_initialized():
404
+ raise RuntimeError("Ray is not initialized")
405
+
406
+ torchrl_logger.info(
407
+ f"Creating vLLM Ray worker with tensor_parallel_size={num_devices}"
408
+ )
409
+
410
+ # Configure Ray remote class with minimal resources
411
+ # Let vLLM handle GPU allocation through environment variables
412
+ worker_cls = ray.remote(
413
+ num_cpus=4, # Minimal CPU request
414
+ num_gpus=0, # Let vLLM handle GPU allocation
415
+ )(_LLMOnDevice)
416
+
417
+ # Create worker with tensor parallelism config
418
+ worker = worker_cls.remote(
419
+ model=model_name,
420
+ bundle_indices=devices, # Pass device indices to _LLMOnDevice
421
+ tensor_parallel_size=num_devices,
422
+ distributed_executor_backend="ray",
423
+ enforce_eager=enforce_eager,
424
+ worker_cls="torchrl.modules.llm.backends.vllm.vllm_sync._vLLMWorker",
425
+ **kwargs,
426
+ )
427
+ ray.get(worker.initialize.remote())
428
+
429
+ # Wrap the Ray actor in RayLLMWorker to provide RLvLLMEngine interface
430
+ return RayLLMWorker(worker, num_devices or 1, model_name)
431
+
432
+ else:
433
+ # Local non-Ray mode - use LLM directly
434
+ with _cuda_visible_devices(devices) if devices is not None else nullcontext():
435
+ torchrl_logger.info(
436
+ f"Creating local vLLM LLM with tensor_parallel_size={num_devices}, devices={devices}"
437
+ )
438
+ llm_instance = LLM(
439
+ model=model_name,
440
+ tensor_parallel_size=num_devices,
441
+ enforce_eager=True,
442
+ **kwargs,
443
+ )
444
+
445
+ # Wrap the local LLM to provide RLvLLMEngine interface
446
+ return LocalLLMWrapper(llm_instance, num_devices or 1, model_name)
@@ -0,0 +1,129 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Shared utilities for vLLM backends."""
7
+
8
+ from __future__ import annotations
9
+
10
+ import torch
11
+
12
+ from torchrl._utils import logger as torchrl_logger
13
+
14
+ try:
15
+ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
16
+ from vllm.distributed.utils import StatelessProcessGroup
17
+
18
+ _has_vllm = True
19
+ except ImportError:
20
+ PyNcclCommunicator = None
21
+ StatelessProcessGroup = None
22
+ _has_vllm = False
23
+
24
+ # get_open_port may not be available in all vLLM versions
25
+ try:
26
+ from vllm.utils import get_open_port
27
+ except ImportError:
28
+
29
+ def get_open_port():
30
+ """Fallback get_open_port using standard library."""
31
+ import socket
32
+
33
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
34
+ s.bind(("", 0))
35
+ return s.getsockname()[1]
36
+
37
+
38
+ def stateless_init_process_group(
39
+ master_address: str | None, master_port: str | None, rank, world_size, device=None
40
+ ):
41
+ """Initializes a stateless process group for distributed communication.
42
+
43
+ Creates a `StatelessProcessGroup` instance without relying on the global
44
+ process group in `torch.distributed`. This approach is recommended for
45
+ initializing data-plane communication (NCCL) between external processes
46
+ (e.g., training processes) and vLLM workers.
47
+
48
+ Args:
49
+ master_address (str | None): The address of the master node. Defaults to "localhost" if not specified.
50
+ master_port (str | None): The port used by the master node. Automatically assigns an open port if not specified.
51
+ rank (int): The rank of the current process.
52
+ world_size (int): The total number of processes in the distributed group.
53
+ device: The device to use for communication. Defaults to None.
54
+
55
+ Returns:
56
+ PyNcclCommunicator: A PyNcclCommunicator instance initialized with the created StatelessProcessGroup.
57
+ """
58
+ if not _has_vllm:
59
+ raise ImportError(
60
+ "vllm is not installed. Please install it with `pip install vllm`."
61
+ )
62
+
63
+ if StatelessProcessGroup is None or PyNcclCommunicator is None:
64
+ raise ImportError(
65
+ "vllm is not installed. Please install it with `pip install vllm`."
66
+ )
67
+
68
+ if master_address is None:
69
+ master_address = "localhost" # get_ip()
70
+ if master_port is None:
71
+ master_port = get_open_port() if callable(get_open_port) else 29500
72
+
73
+ torchrl_logger.info(
74
+ f"Initializing stateless process group: rank={rank}, world_size={world_size}, master_address={master_address}, master_port={master_port}"
75
+ )
76
+ pg = StatelessProcessGroup.create(
77
+ host=master_address, port=int(master_port), rank=rank, world_size=world_size
78
+ )
79
+ if device is None:
80
+ device = torch.device("cuda:0")
81
+ pynccl = PyNcclCommunicator(pg, device=device)
82
+ return pynccl
83
+
84
+
85
+ async def stateless_init_process_group_async(
86
+ master_address: str | None,
87
+ master_port: str | None,
88
+ rank: int,
89
+ world_size: int,
90
+ device,
91
+ ):
92
+ """Initializes a stateless process group for distributed communication (async version).
93
+
94
+ Creates a `StatelessProcessGroup` instance without relying on the global
95
+ process group in `torch.distributed`. This approach is recommended for
96
+ initializing data-plane communication (NCCL) between external processes
97
+ (e.g., training processes) and vLLM workers.
98
+
99
+ Args:
100
+ master_address (str | None): The address of the master node. Defaults to "localhost" if not specified.
101
+ master_port (str | None): The port used by the master node. Automatically assigns an open port if not specified.
102
+ rank (int): The rank of the current process.
103
+ world_size (int): The total number of processes in the distributed group.
104
+ device: The device to use for communication.
105
+
106
+ Returns:
107
+ PyNcclCommunicator: A PyNcclCommunicator instance initialized with the created StatelessProcessGroup.
108
+ """
109
+ if not _has_vllm:
110
+ raise ImportError(
111
+ "vllm is not installed. Please install it with `pip install vllm`."
112
+ )
113
+
114
+ if StatelessProcessGroup is None or PyNcclCommunicator is None:
115
+ raise ImportError(
116
+ "vllm is not installed. Please install it with `pip install vllm`."
117
+ )
118
+
119
+ if master_address is None:
120
+ master_address = "localhost"
121
+ if master_port is None:
122
+ master_port = get_open_port() if callable(get_open_port) else 29500
123
+
124
+ master_port_int = int(master_port) if master_port is not None else 0
125
+ pg = StatelessProcessGroup.create(
126
+ host=master_address, port=master_port_int, rank=rank, world_size=world_size
127
+ )
128
+ pynccl = PyNcclCommunicator(pg, device=device)
129
+ return pynccl
@@ -0,0 +1,28 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """LLM policy wrappers.
6
+
7
+ This subpackage includes optional wrappers that may rely on native extensions
8
+ (e.g. vLLM). To avoid importing optional dependencies at module import time,
9
+ we avoid importing those dependencies at module import time.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from .common import ChatHistory, LLMWrapperBase, LogProbs, Masks, Text, Tokens
15
+ from .transformers_wrapper import RemoteTransformersWrapper, TransformersWrapper
16
+ from .vllm_wrapper import vLLMWrapper
17
+
18
+ __all__ = [
19
+ "TransformersWrapper",
20
+ "RemoteTransformersWrapper",
21
+ "vLLMWrapper",
22
+ "LLMWrapperBase",
23
+ "Text",
24
+ "LogProbs",
25
+ "Masks",
26
+ "Tokens",
27
+ "ChatHistory",
28
+ ]