sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,109 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ """Inference-only Qwen3Next MTP Speculative Decoding."""
16
+ import logging
17
+ from typing import Iterable, Optional, Tuple
18
+
19
+ import torch
20
+ from torch import nn
21
+ from transformers import PretrainedConfig
22
+
23
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
24
+ from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
25
+ from sglang.srt.layers.logits_processor import LogitsProcessor
26
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
27
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
28
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
29
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
+ from sglang.srt.models.qwen3_moe import Qwen3MoeModel
31
+ from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel
32
+ from sglang.srt.utils import add_prefix
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
38
+
39
+ def __init__(
40
+ self,
41
+ config: PretrainedConfig,
42
+ quant_config: Optional[QuantizationConfig] = None,
43
+ prefix: str = "",
44
+ ) -> None:
45
+ nn.Module.__init__(self)
46
+ self.config = config
47
+ self.tp_size = get_tensor_model_parallel_world_size()
48
+ self.quant_config = quant_config
49
+ # if not set, model load will be broken in Qwen3NextForCausalLM load_weights()
50
+ self.pp_group = get_pp_group()
51
+ # self.determine_num_fused_shared_experts("Qwen3NextForCausalLMMTP")
52
+
53
+ # currently based on the provided ckpt, we:
54
+ # (1) do not use_dedicated_mtp_embeddings provided in ckpt since not provided and directly use the target model embeddings
55
+ # (2) hardcode bias=False since not provided
56
+ self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
57
+ RMSNorm_cls = GemmaRMSNorm
58
+ self.pre_fc_norm_embedding = RMSNorm_cls(
59
+ config.hidden_size, config.rms_norm_eps
60
+ )
61
+ self.pre_fc_norm_hidden = RMSNorm_cls(config.hidden_size, config.rms_norm_eps)
62
+ config.num_hidden_layers = 1
63
+ config.full_attention_interval = 1
64
+ self.model = Qwen3NextModel(
65
+ config, quant_config, prefix=add_prefix("model", prefix)
66
+ )
67
+ self.lm_head = ParallelLMHead(
68
+ config.vocab_size,
69
+ config.hidden_size,
70
+ quant_config=quant_config,
71
+ prefix=add_prefix("model.shared_head.head", prefix),
72
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
73
+ )
74
+ self.logits_processor = LogitsProcessor(config)
75
+
76
+ @torch.no_grad()
77
+ def forward(
78
+ self,
79
+ input_ids: torch.Tensor,
80
+ positions: torch.Tensor,
81
+ forward_batch: ForwardBatch,
82
+ input_embeds: Optional[torch.Tensor] = None,
83
+ **kwargs,
84
+ ):
85
+ if input_embeds is None:
86
+ input_embeds = self.model.embed_tokens(input_ids)
87
+
88
+ input_embeds = self.pre_fc_norm_embedding(input_embeds)
89
+ hidden_states = self.pre_fc_norm_hidden(forward_batch.spec_info.hidden_states)
90
+ hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1))
91
+
92
+ hidden_states = self.model(
93
+ input_ids,
94
+ positions,
95
+ forward_batch,
96
+ hidden_states,
97
+ )
98
+
99
+ return self.logits_processor(
100
+ input_ids, hidden_states, self.lm_head, forward_batch
101
+ )
102
+
103
+ def load_weights(
104
+ self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
105
+ ):
106
+ super().load_weights(weights, is_mtp=True)
107
+
108
+
109
+ EntryClass = [Qwen3NextForCausalLMMTP]
@@ -22,7 +22,7 @@ Reference: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
22
22
 
23
23
  Here is a quick example to enable TP:
24
24
  ```python
25
- from sglang.srt.model_parallel import tensor_parallel
25
+ from sglang.srt.layers.model_parallel import tensor_parallel
26
26
 
27
27
  device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
28
28
  tensor_parallel(model, device_mesh)
@@ -213,7 +213,7 @@ class TransformersForCausalLM(nn.Module):
213
213
  """
214
214
  tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
215
215
 
216
- if not tp_plan and self.tp_size > 1:
216
+ if not tp_plan and tp_size > 1:
217
217
  raise ValueError(
218
218
  f"{type(self.model)} does not support tensor parallel yet!"
219
219
  )
@@ -13,7 +13,9 @@ from PIL import Image
13
13
  from transformers import BaseImageProcessorFast
14
14
 
15
15
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
16
- from sglang.srt.utils import load_audio, load_image, load_video, logger
16
+ from sglang.srt.utils import is_npu, load_audio, load_image, load_video, logger
17
+
18
+ _is_npu = is_npu()
17
19
 
18
20
 
19
21
  @dataclasses.dataclass
@@ -232,7 +234,7 @@ class BaseMultimodalProcessor(ABC):
232
234
  and isinstance(processor.image_processor, BaseImageProcessorFast)
233
235
  and not self.server_args.disable_fast_image_processor
234
236
  ):
235
- kwargs["device"] = "cuda"
237
+ kwargs["device"] = "cuda" if not _is_npu else "npu"
236
238
  result = processor.__call__(
237
239
  text=[input_text],
238
240
  padding=True,
@@ -2,7 +2,6 @@ import re
2
2
  from typing import List, Union
3
3
 
4
4
  from decord import VideoReader
5
- from transformers.video_utils import VideoMetadata
6
5
 
7
6
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
8
7
  from sglang.srt.models.glm4v import Glm4vForConditionalGeneration
@@ -66,17 +65,18 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
66
65
  total_num_frames = len(vr)
67
66
  duration = total_num_frames / video_fps if video_fps else 0
68
67
 
69
- metadata = VideoMetadata(
70
- total_num_frames=int(total_num_frames),
71
- fps=float(video_fps),
72
- duration=float(duration),
73
- video_backend="decord",
74
- )
75
-
76
68
  # Extract all frames
77
69
  indices = list(range(total_num_frames))
78
70
  frames = vr.get_batch(indices).asnumpy()
79
- metadata.frames_indices = indices
71
+
72
+ # Return metadata as dict so transformers can properly create VideoMetadata objects
73
+ metadata = {
74
+ "total_num_frames": int(total_num_frames),
75
+ "fps": float(video_fps),
76
+ "duration": float(duration),
77
+ "video_backend": "decord",
78
+ "frames_indices": indices,
79
+ }
80
80
 
81
81
  return frames, metadata
82
82
 
@@ -2,8 +2,10 @@
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
- from decord import VideoReader, cpu
5
+ import torchvision.transforms as T
6
+ from decord import VideoReader, cpu, gpu
6
7
  from PIL import Image
8
+ from torchvision.transforms import InterpolationMode
7
9
 
8
10
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
9
11
  from sglang.srt.models.interns1 import InternS1ForConditionalGeneration
@@ -48,99 +50,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
48
50
  image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN),
49
51
  ).build(_image_processor)
50
52
 
51
- @staticmethod
52
- def build_transform(input_size):
53
- IMAGENET_MEAN = (0.485, 0.456, 0.406)
54
- IMAGENET_STD = (0.229, 0.224, 0.225)
55
-
56
- def resize_image(img, size):
57
- return img.resize((size, size), Image.Resampling.BICUBIC)
58
-
59
- def to_tensor(img):
60
- # Convert PIL Image to numpy array
61
- img_array = np.array(img).astype(np.float32) / 255.0
62
- # Convert HWC to CHW format
63
- img_array = img_array.transpose(2, 0, 1)
64
- return torch.from_numpy(img_array)
65
-
66
- def normalize(tensor, mean, std):
67
- mean = torch.tensor(mean).view(-1, 1, 1)
68
- std = torch.tensor(std).view(-1, 1, 1)
69
- return (tensor - mean) / std
70
-
71
- def transform(img):
72
- img = img.convert("RGB") if img.mode != "RGB" else img
73
- img = resize_image(img, input_size)
74
- tensor = to_tensor(img)
75
- tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
76
- return tensor
77
-
78
- return transform
79
-
80
- @staticmethod
81
- def dynamic_preprocess(
82
- image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
83
- ):
84
-
85
- def find_closest_aspect_ratio(
86
- aspect_ratio, target_ratios, width, height, image_size
87
- ):
88
- best_ratio_diff = float("inf")
89
- best_ratio = (1, 1)
90
- area = width * height
91
- for ratio in target_ratios:
92
- target_aspect_ratio = ratio[0] / ratio[1]
93
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
94
- if ratio_diff < best_ratio_diff:
95
- best_ratio_diff = ratio_diff
96
- best_ratio = ratio
97
- elif ratio_diff == best_ratio_diff:
98
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
99
- best_ratio = ratio
100
- return best_ratio
101
-
102
- orig_width, orig_height = image.size
103
- aspect_ratio = orig_width / orig_height
104
-
105
- # calculate the existing image aspect ratio
106
- target_ratios = set(
107
- (i, j)
108
- for n in range(min_num, max_num + 1)
109
- for i in range(1, n + 1)
110
- for j in range(1, n + 1)
111
- if i * j <= max_num and i * j >= min_num
112
- )
113
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
114
-
115
- # find the closest aspect ratio to the target
116
- target_aspect_ratio = find_closest_aspect_ratio(
117
- aspect_ratio, target_ratios, orig_width, orig_height, image_size
118
- )
119
-
120
- # calculate the target width and height
121
- target_width = image_size * target_aspect_ratio[0]
122
- target_height = image_size * target_aspect_ratio[1]
123
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
124
-
125
- # resize the image
126
- resized_img = image.resize((target_width, target_height))
127
- processed_images = []
128
- for i in range(blocks):
129
- box = (
130
- (i % (target_width // image_size)) * image_size,
131
- (i // (target_width // image_size)) * image_size,
132
- ((i % (target_width // image_size)) + 1) * image_size,
133
- ((i // (target_width // image_size)) + 1) * image_size,
134
- )
135
- # split the image
136
- split_img = resized_img.crop(box)
137
- processed_images.append(split_img)
138
- assert len(processed_images) == blocks
139
- if use_thumbnail and len(processed_images) != 1:
140
- thumbnail_img = image.resize((image_size, image_size))
141
- processed_images.append(thumbnail_img)
142
- return processed_images
143
-
144
53
  @staticmethod
145
54
  def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
146
55
  if bound:
@@ -160,27 +69,112 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
160
69
 
161
70
  @staticmethod
162
71
  def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
163
- vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
72
+ try:
73
+ vr = VideoReader(video_path, ctx=gpu(0), num_threads=1)
74
+ use_gpu = True
75
+ except (RuntimeError, OSError) as e:
76
+ print(
77
+ f"[WARNING] Load video on gpu decoding failed: {e}. Falling back to CPU."
78
+ )
79
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
80
+ use_gpu = False
81
+
164
82
  max_frame = len(vr) - 1
165
83
  fps = float(vr.get_avg_fps())
166
84
 
167
- pixel_values_list, num_patches_list = [], []
168
- transform = InternVLImageProcessor.build_transform(input_size=input_size)
85
+ pixel_values_list = []
86
+ num_patches_list = []
169
87
  frame_indices = InternVLImageProcessor.get_index(
170
88
  bound, fps, max_frame, first_idx=0, num_segments=num_segments
171
89
  )
90
+
172
91
  for frame_index in frame_indices:
173
- img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
174
- img = InternVLImageProcessor.dynamic_preprocess(
175
- img, image_size=input_size, use_thumbnail=True, max_num=max_num
92
+ # Load frame
93
+ frame = vr[frame_index]
94
+ if use_gpu:
95
+ img = frame.cuda().permute(2, 0, 1).float() / 255.0
96
+ else:
97
+ img_np = frame.asnumpy()
98
+ img = torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
99
+
100
+ # Using the mean and variance of the ImageNet dataset for all input images can lead to accuracy issues, while using the mean and variance of each input image is a more accurate choice.
101
+ mean = img.mean(dim=[1, 2], keepdim=True)
102
+ # Prevent division by zero; clamp to minimum value of 1e-6
103
+ std = img.std(dim=[1, 2], keepdim=True).clamp(min=1e-6)
104
+ img = (img - mean) / std
105
+
106
+ tiles = InternVLImageProcessor.dynamic_preprocess(
107
+ img, image_size=input_size, max_num=max_num, use_thumbnail=True
176
108
  )
177
- pixel_values = [transform(tile) for tile in img]
178
- pixel_values = torch.stack(pixel_values)
179
- num_patches_list.append(pixel_values.shape[0])
180
- pixel_values_list.append(pixel_values)
181
- pixel_values = torch.cat(pixel_values_list)
109
+
110
+ pixel_values_list.append(tiles)
111
+ num_patches_list.append(tiles.shape[0])
112
+
113
+ pixel_values = torch.cat(pixel_values_list, dim=0)
182
114
  return pixel_values, num_patches_list
183
115
 
116
+ @staticmethod
117
+ def dynamic_preprocess(tensor, image_size=448, max_num=12, use_thumbnail=False):
118
+ C, H, W = tensor.shape
119
+ aspect_ratio = W / H
120
+
121
+ # Generate all possible aspect ratios
122
+ target_ratios = set(
123
+ (i, j)
124
+ for n in range(1, max_num + 1)
125
+ for i in range(1, n + 1)
126
+ for j in range(1, n + 1)
127
+ if i * j <= max_num
128
+ )
129
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
130
+
131
+ # Find closest ratio
132
+ best_ratio_diff = float("inf")
133
+ best_ratio = (1, 1)
134
+
135
+ for x, y in target_ratios:
136
+ target_ar = x / y
137
+ diff = abs(aspect_ratio - target_ar)
138
+ blocks = x * y
139
+ best_blocks = best_ratio[0] * best_ratio[1]
140
+
141
+ if diff < best_ratio_diff:
142
+ best_ratio_diff = diff
143
+ best_ratio = (x, y)
144
+ elif diff == best_ratio_diff and blocks > best_blocks:
145
+ best_ratio = (x, y)
146
+
147
+ target_w, target_h = image_size * best_ratio[0], image_size * best_ratio[1]
148
+ blocks = best_ratio[0] * best_ratio[1]
149
+
150
+ # Resize on GPU
151
+ resized = torch.nn.functional.interpolate(
152
+ tensor.unsqueeze(0),
153
+ size=(target_h, target_w),
154
+ mode="bicubic",
155
+ align_corners=False,
156
+ ).squeeze(0)
157
+
158
+ # Split into tiles
159
+ tiles = []
160
+ for i in range(blocks):
161
+ x = (i % best_ratio[0]) * image_size
162
+ y = (i // best_ratio[0]) * image_size
163
+ tile = resized[:, y : y + image_size, x : x + image_size]
164
+ tiles.append(tile)
165
+
166
+ # Add thumbnail if needed
167
+ if use_thumbnail and len(tiles) > 1:
168
+ thumb = torch.nn.functional.interpolate(
169
+ tensor.unsqueeze(0),
170
+ size=(image_size, image_size),
171
+ mode="bicubic",
172
+ align_corners=False,
173
+ ).squeeze(0)
174
+ tiles.append(thumb)
175
+
176
+ return torch.stack(tiles).to(torch.bfloat16)
177
+
184
178
  async def process_mm_data_async(
185
179
  self, image_data, input_text, request_obj, **kwargs
186
180
  ):
@@ -191,53 +185,71 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
191
185
  discard_alpha_channel=True,
192
186
  )
193
187
 
194
- def process_image_internvl(image, input_size=448, max_num=12):
195
- transform = InternVLImageProcessor.build_transform(input_size=input_size)
196
- images = InternVLImageProcessor.dynamic_preprocess(
197
- image, image_size=input_size, use_thumbnail=True, max_num=max_num
198
- )
199
- pixel_values = [transform(image) for image in images]
200
- pixel_values = torch.stack(pixel_values)
201
- return pixel_values
202
-
203
188
  num_patches_list = []
204
189
  pixel_values = []
190
+
205
191
  # Process each input with allocated frames
206
- for image_index, (image) in enumerate(base_output.images):
192
+ for image_index, image in enumerate(base_output.images):
207
193
  try:
208
194
  # TODO: video input
209
- raw_image = process_image_internvl(image)
210
- pixel_value = [raw_image.to(torch.bfloat16)]
211
- pixel_values += pixel_value
212
- num_patches = raw_image.shape[0]
213
- num_patches_list += [num_patches]
214
-
215
- except FileNotFoundError as e:
216
- print(e)
195
+ # Convert PIL to GPU tensor
196
+ if isinstance(image, Image.Image):
197
+ img_np = np.array(image.convert("RGB"))
198
+ tensor = (
199
+ torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
200
+ )
201
+ else:
202
+ tensor = image.cuda() # assume already tensor
203
+
204
+ # Using the mean and variance of the ImageNet dataset for all input images can lead to accuracy issues, while using the mean and variance of each input image is a more accurate choice.
205
+ mean = tensor.mean(dim=[1, 2], keepdim=True)
206
+ # Prevent division by zero; clamp to minimum value of 1e-6
207
+ std = tensor.std(dim=[1, 2], keepdim=True).clamp(min=1e-6)
208
+ tensor = (tensor - mean) / std
209
+ tiles = self.dynamic_preprocess(
210
+ tensor, image_size=448, max_num=12, use_thumbnail=True
211
+ )
212
+
213
+ pixel_values.append(tiles)
214
+ num_patches_list.append(tiles.shape[0])
215
+
216
+ except Exception as e:
217
+ print(f"[Error] Failed to process image {image_index}: {e}")
217
218
  return None
218
219
 
220
+ # Concatenate all
219
221
  pixel_values = torch.cat(pixel_values, dim=0)
220
222
 
221
223
  original_placeholder = "<<<__IMG_CONTEXT_PLACEHOLDER__>>>"
222
224
  input_text = input_text.replace(self.IMG_CONTEXT_TOKEN, original_placeholder)
223
225
 
224
- for idx, num_patches in enumerate(num_patches_list):
226
+ input_text_updated = input_text
227
+ for num_patches in num_patches_list:
225
228
  image_tokens = (
226
229
  self.IMG_START_TOKEN
227
230
  + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
228
231
  + self.IMG_END_TOKEN
229
232
  )
230
- input_text = input_text.replace(original_placeholder, image_tokens, 1)
233
+ input_text_updated = input_text_updated.replace(
234
+ original_placeholder, image_tokens, 1
235
+ )
231
236
 
232
- input_text = input_text.replace(original_placeholder, self.IMG_CONTEXT_TOKEN)
237
+ input_text_updated = input_text_updated.replace(
238
+ original_placeholder, self.IMG_CONTEXT_TOKEN
239
+ )
233
240
 
234
- input_ids = self.tokenizer(input_text, return_tensors="pt")[
241
+ # Tokenize
242
+ input_ids_tensor = self.tokenizer(input_text_updated, return_tensors="pt")[
235
243
  "input_ids"
236
244
  ].flatten()
245
+ input_ids = input_ids_tensor.tolist()
246
+
247
+ # Get image token offsets
237
248
  image_offsets = self.get_mm_items_offset(
238
- input_ids=input_ids,
249
+ input_ids=input_ids_tensor.to("cuda"),
239
250
  mm_token_id=self.mm_tokens.image_token_id,
240
251
  )
252
+
241
253
  items = [
242
254
  MultimodalDataItem(
243
255
  feature=pixel_values,
@@ -247,7 +259,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
247
259
  ]
248
260
 
249
261
  return {
250
- "input_ids": input_ids.tolist(),
262
+ "input_ids": input_ids,
251
263
  "mm_items": items,
252
264
  "im_start_id": self.img_start_token_id,
253
265
  "im_end_id": self.img_end_token_id,
@@ -26,6 +26,8 @@ Key components:
26
26
  # Adapted from
27
27
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
28
28
  import dataclasses
29
+ import json
30
+ import os
29
31
  import re
30
32
  from enum import IntEnum, auto
31
33
  from typing import Callable, Dict, List, Optional, Tuple, Union
@@ -959,16 +961,42 @@ register_conv_template(
959
961
  )
960
962
 
961
963
 
964
+ MODEL_TYPE_TO_TEMPLATE = {
965
+ "internvl_chat": "internvl-2-5",
966
+ "deepseek_vl_v2": "deepseek-vl2",
967
+ "multi_modality": "janus-pro",
968
+ "phi4mm": "phi-4-mm",
969
+ "minicpmv": "minicpmv",
970
+ "minicpmo": "minicpmo",
971
+ }
972
+
973
+
974
+ def get_model_type(model_path: str) -> Optional[str]:
975
+ config_path = os.path.join(model_path, "config.json")
976
+ if not os.path.exists(config_path):
977
+ return None
978
+ try:
979
+ with open(config_path, "r", encoding="utf-8") as f:
980
+ config = json.load(f)
981
+ return config.get("model_type")
982
+ except (IOError, json.JSONDecodeError):
983
+ return None
984
+
985
+
962
986
  @register_conv_template_matching_function
963
987
  def match_internvl(model_path: str):
964
988
  if re.search(r"internvl", model_path, re.IGNORECASE):
965
989
  return "internvl-2-5"
990
+ model_type = get_model_type(model_path)
991
+ return MODEL_TYPE_TO_TEMPLATE.get(model_type)
966
992
 
967
993
 
968
994
  @register_conv_template_matching_function
969
995
  def match_deepseek_janus_pro(model_path: str):
970
996
  if re.search(r"janus", model_path, re.IGNORECASE):
971
997
  return "janus-pro"
998
+ model_type = get_model_type(model_path)
999
+ return MODEL_TYPE_TO_TEMPLATE.get(model_type)
972
1000
 
973
1001
 
974
1002
  @register_conv_template_matching_function
@@ -981,6 +1009,8 @@ def match_vicuna(model_path: str):
981
1009
  def match_deepseek_vl(model_path: str):
982
1010
  if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
983
1011
  return "deepseek-vl2"
1012
+ model_type = get_model_type(model_path)
1013
+ return MODEL_TYPE_TO_TEMPLATE.get(model_type)
984
1014
 
985
1015
 
986
1016
  @register_conv_template_matching_function
@@ -994,14 +1024,17 @@ def match_qwen_chat_ml(model_path: str):
994
1024
 
995
1025
 
996
1026
  @register_conv_template_matching_function
997
- def match_openbmb_minicpm(model_path: str):
998
- if re.search(r"minicpm-v", model_path, re.IGNORECASE):
999
- return "minicpmv"
1000
- elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
1001
- return "minicpmo"
1027
+ def match_minicpm(model_path: str):
1028
+ match = re.search(r"minicpm-(v|o)", model_path, re.IGNORECASE)
1029
+ if match:
1030
+ return f"minicpm{match.group(1).lower()}"
1031
+ model_type = get_model_type(model_path)
1032
+ return MODEL_TYPE_TO_TEMPLATE.get(model_type)
1002
1033
 
1003
1034
 
1004
1035
  @register_conv_template_matching_function
1005
1036
  def match_phi_4_mm(model_path: str):
1006
1037
  if "phi-4-multimodal" in model_path.lower():
1007
1038
  return "phi-4-mm"
1039
+ model_type = get_model_type(model_path)
1040
+ return MODEL_TYPE_TO_TEMPLATE.get(model_type)