sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,109 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ """Inference-only Qwen3Next MTP Speculative Decoding."""
16
+ import logging
17
+ from typing import Iterable, Optional, Tuple
18
+
19
+ import torch
20
+ from torch import nn
21
+ from transformers import PretrainedConfig
22
+
23
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
24
+ from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
25
+ from sglang.srt.layers.logits_processor import LogitsProcessor
26
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
27
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
28
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
29
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
+ from sglang.srt.models.qwen3_moe import Qwen3MoeModel
31
+ from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel
32
+ from sglang.srt.utils import add_prefix
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
38
+
39
+ def __init__(
40
+ self,
41
+ config: PretrainedConfig,
42
+ quant_config: Optional[QuantizationConfig] = None,
43
+ prefix: str = "",
44
+ ) -> None:
45
+ nn.Module.__init__(self)
46
+ self.config = config
47
+ self.tp_size = get_tensor_model_parallel_world_size()
48
+ self.quant_config = quant_config
49
+ # if not set, model load will be broken in Qwen3NextForCausalLM load_weights()
50
+ self.pp_group = get_pp_group()
51
+ # self.determine_num_fused_shared_experts("Qwen3NextForCausalLMMTP")
52
+
53
+ # currently based on the provided ckpt, we:
54
+ # (1) do not use_dedicated_mtp_embeddings provided in ckpt since not provided and directly use the target model embeddings
55
+ # (2) hardcode bias=False since not provided
56
+ self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
57
+ RMSNorm_cls = GemmaRMSNorm
58
+ self.pre_fc_norm_embedding = RMSNorm_cls(
59
+ config.hidden_size, config.rms_norm_eps
60
+ )
61
+ self.pre_fc_norm_hidden = RMSNorm_cls(config.hidden_size, config.rms_norm_eps)
62
+ config.num_hidden_layers = 1
63
+ config.full_attention_interval = 1
64
+ self.model = Qwen3NextModel(
65
+ config, quant_config, prefix=add_prefix("model", prefix)
66
+ )
67
+ self.lm_head = ParallelLMHead(
68
+ config.vocab_size,
69
+ config.hidden_size,
70
+ quant_config=quant_config,
71
+ prefix=add_prefix("model.shared_head.head", prefix),
72
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
73
+ )
74
+ self.logits_processor = LogitsProcessor(config)
75
+
76
+ @torch.no_grad()
77
+ def forward(
78
+ self,
79
+ input_ids: torch.Tensor,
80
+ positions: torch.Tensor,
81
+ forward_batch: ForwardBatch,
82
+ input_embeds: Optional[torch.Tensor] = None,
83
+ **kwargs,
84
+ ):
85
+ if input_embeds is None:
86
+ input_embeds = self.model.embed_tokens(input_ids)
87
+
88
+ input_embeds = self.pre_fc_norm_embedding(input_embeds)
89
+ hidden_states = self.pre_fc_norm_hidden(forward_batch.spec_info.hidden_states)
90
+ hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1))
91
+
92
+ hidden_states = self.model(
93
+ input_ids,
94
+ positions,
95
+ forward_batch,
96
+ hidden_states,
97
+ )
98
+
99
+ return self.logits_processor(
100
+ input_ids, hidden_states, self.lm_head, forward_batch
101
+ )
102
+
103
+ def load_weights(
104
+ self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
105
+ ):
106
+ super().load_weights(weights, is_mtp=True)
107
+
108
+
109
+ EntryClass = [Qwen3NextForCausalLMMTP]
@@ -22,7 +22,7 @@ Reference: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
22
22
 
23
23
  Here is a quick example to enable TP:
24
24
  ```python
25
- from sglang.srt.model_parallel import tensor_parallel
25
+ from sglang.srt.layers.model_parallel import tensor_parallel
26
26
 
27
27
  device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
28
28
  tensor_parallel(model, device_mesh)
@@ -213,7 +213,7 @@ class TransformersForCausalLM(nn.Module):
213
213
  """
214
214
  tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
215
215
 
216
- if not tp_plan and self.tp_size > 1:
216
+ if not tp_plan and tp_size > 1:
217
217
  raise ValueError(
218
218
  f"{type(self.model)} does not support tensor parallel yet!"
219
219
  )
@@ -13,7 +13,9 @@ from PIL import Image
13
13
  from transformers import BaseImageProcessorFast
14
14
 
15
15
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
16
- from sglang.srt.utils import load_audio, load_image, load_video, logger
16
+ from sglang.srt.utils import is_npu, load_audio, load_image, load_video, logger
17
+
18
+ _is_npu = is_npu()
17
19
 
18
20
 
19
21
  @dataclasses.dataclass
@@ -232,7 +234,7 @@ class BaseMultimodalProcessor(ABC):
232
234
  and isinstance(processor.image_processor, BaseImageProcessorFast)
233
235
  and not self.server_args.disable_fast_image_processor
234
236
  ):
235
- kwargs["device"] = "cuda"
237
+ kwargs["device"] = "cuda" if not _is_npu else "npu"
236
238
  result = processor.__call__(
237
239
  text=[input_text],
238
240
  padding=True,
@@ -2,7 +2,6 @@ import re
2
2
  from typing import List, Union
3
3
 
4
4
  from decord import VideoReader
5
- from transformers.video_utils import VideoMetadata
6
5
 
7
6
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
8
7
  from sglang.srt.models.glm4v import Glm4vForConditionalGeneration
@@ -66,17 +65,18 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
66
65
  total_num_frames = len(vr)
67
66
  duration = total_num_frames / video_fps if video_fps else 0
68
67
 
69
- metadata = VideoMetadata(
70
- total_num_frames=int(total_num_frames),
71
- fps=float(video_fps),
72
- duration=float(duration),
73
- video_backend="decord",
74
- )
75
-
76
68
  # Extract all frames
77
69
  indices = list(range(total_num_frames))
78
70
  frames = vr.get_batch(indices).asnumpy()
79
- metadata.frames_indices = indices
71
+
72
+ # Return metadata as dict so transformers can properly create VideoMetadata objects
73
+ metadata = {
74
+ "total_num_frames": int(total_num_frames),
75
+ "fps": float(video_fps),
76
+ "duration": float(duration),
77
+ "video_backend": "decord",
78
+ "frames_indices": indices,
79
+ }
80
80
 
81
81
  return frames, metadata
82
82
 
@@ -2,8 +2,10 @@
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
- from decord import VideoReader, cpu
5
+ import torchvision.transforms as T
6
+ from decord import VideoReader, cpu, gpu
6
7
  from PIL import Image
8
+ from torchvision.transforms import InterpolationMode
7
9
 
8
10
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
9
11
  from sglang.srt.models.interns1 import InternS1ForConditionalGeneration
@@ -48,99 +50,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
48
50
  image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN),
49
51
  ).build(_image_processor)
50
52
 
51
- @staticmethod
52
- def build_transform(input_size):
53
- IMAGENET_MEAN = (0.485, 0.456, 0.406)
54
- IMAGENET_STD = (0.229, 0.224, 0.225)
55
-
56
- def resize_image(img, size):
57
- return img.resize((size, size), Image.Resampling.BICUBIC)
58
-
59
- def to_tensor(img):
60
- # Convert PIL Image to numpy array
61
- img_array = np.array(img).astype(np.float32) / 255.0
62
- # Convert HWC to CHW format
63
- img_array = img_array.transpose(2, 0, 1)
64
- return torch.from_numpy(img_array)
65
-
66
- def normalize(tensor, mean, std):
67
- mean = torch.tensor(mean).view(-1, 1, 1)
68
- std = torch.tensor(std).view(-1, 1, 1)
69
- return (tensor - mean) / std
70
-
71
- def transform(img):
72
- img = img.convert("RGB") if img.mode != "RGB" else img
73
- img = resize_image(img, input_size)
74
- tensor = to_tensor(img)
75
- tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
76
- return tensor
77
-
78
- return transform
79
-
80
- @staticmethod
81
- def dynamic_preprocess(
82
- image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
83
- ):
84
-
85
- def find_closest_aspect_ratio(
86
- aspect_ratio, target_ratios, width, height, image_size
87
- ):
88
- best_ratio_diff = float("inf")
89
- best_ratio = (1, 1)
90
- area = width * height
91
- for ratio in target_ratios:
92
- target_aspect_ratio = ratio[0] / ratio[1]
93
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
94
- if ratio_diff < best_ratio_diff:
95
- best_ratio_diff = ratio_diff
96
- best_ratio = ratio
97
- elif ratio_diff == best_ratio_diff:
98
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
99
- best_ratio = ratio
100
- return best_ratio
101
-
102
- orig_width, orig_height = image.size
103
- aspect_ratio = orig_width / orig_height
104
-
105
- # calculate the existing image aspect ratio
106
- target_ratios = set(
107
- (i, j)
108
- for n in range(min_num, max_num + 1)
109
- for i in range(1, n + 1)
110
- for j in range(1, n + 1)
111
- if i * j <= max_num and i * j >= min_num
112
- )
113
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
114
-
115
- # find the closest aspect ratio to the target
116
- target_aspect_ratio = find_closest_aspect_ratio(
117
- aspect_ratio, target_ratios, orig_width, orig_height, image_size
118
- )
119
-
120
- # calculate the target width and height
121
- target_width = image_size * target_aspect_ratio[0]
122
- target_height = image_size * target_aspect_ratio[1]
123
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
124
-
125
- # resize the image
126
- resized_img = image.resize((target_width, target_height))
127
- processed_images = []
128
- for i in range(blocks):
129
- box = (
130
- (i % (target_width // image_size)) * image_size,
131
- (i // (target_width // image_size)) * image_size,
132
- ((i % (target_width // image_size)) + 1) * image_size,
133
- ((i // (target_width // image_size)) + 1) * image_size,
134
- )
135
- # split the image
136
- split_img = resized_img.crop(box)
137
- processed_images.append(split_img)
138
- assert len(processed_images) == blocks
139
- if use_thumbnail and len(processed_images) != 1:
140
- thumbnail_img = image.resize((image_size, image_size))
141
- processed_images.append(thumbnail_img)
142
- return processed_images
143
-
144
53
  @staticmethod
145
54
  def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
146
55
  if bound:
@@ -160,27 +69,112 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
160
69
 
161
70
  @staticmethod
162
71
  def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
163
- vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
72
+ try:
73
+ vr = VideoReader(video_path, ctx=gpu(0), num_threads=1)
74
+ use_gpu = True
75
+ except (RuntimeError, OSError) as e:
76
+ print(
77
+ f"[WARNING] Load video on gpu decoding failed: {e}. Falling back to CPU."
78
+ )
79
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
80
+ use_gpu = False
81
+
164
82
  max_frame = len(vr) - 1
165
83
  fps = float(vr.get_avg_fps())
166
84
 
167
- pixel_values_list, num_patches_list = [], []
168
- transform = InternVLImageProcessor.build_transform(input_size=input_size)
85
+ pixel_values_list = []
86
+ num_patches_list = []
169
87
  frame_indices = InternVLImageProcessor.get_index(
170
88
  bound, fps, max_frame, first_idx=0, num_segments=num_segments
171
89
  )
90
+
172
91
  for frame_index in frame_indices:
173
- img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
174
- img = InternVLImageProcessor.dynamic_preprocess(
175
- img, image_size=input_size, use_thumbnail=True, max_num=max_num
92
+ # Load frame
93
+ frame = vr[frame_index]
94
+ if use_gpu:
95
+ img = frame.cuda().permute(2, 0, 1).float() / 255.0
96
+ else:
97
+ img_np = frame.asnumpy()
98
+ img = torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
99
+
100
+ # Using the mean and variance of the ImageNet dataset for all input images can lead to accuracy issues, while using the mean and variance of each input image is a more accurate choice.
101
+ mean = img.mean(dim=[1, 2], keepdim=True)
102
+ # Prevent division by zero; clamp to minimum value of 1e-6
103
+ std = img.std(dim=[1, 2], keepdim=True).clamp(min=1e-6)
104
+ img = (img - mean) / std
105
+
106
+ tiles = InternVLImageProcessor.dynamic_preprocess(
107
+ img, image_size=input_size, max_num=max_num, use_thumbnail=True
176
108
  )
177
- pixel_values = [transform(tile) for tile in img]
178
- pixel_values = torch.stack(pixel_values)
179
- num_patches_list.append(pixel_values.shape[0])
180
- pixel_values_list.append(pixel_values)
181
- pixel_values = torch.cat(pixel_values_list)
109
+
110
+ pixel_values_list.append(tiles)
111
+ num_patches_list.append(tiles.shape[0])
112
+
113
+ pixel_values = torch.cat(pixel_values_list, dim=0)
182
114
  return pixel_values, num_patches_list
183
115
 
116
+ @staticmethod
117
+ def dynamic_preprocess(tensor, image_size=448, max_num=12, use_thumbnail=False):
118
+ C, H, W = tensor.shape
119
+ aspect_ratio = W / H
120
+
121
+ # Generate all possible aspect ratios
122
+ target_ratios = set(
123
+ (i, j)
124
+ for n in range(1, max_num + 1)
125
+ for i in range(1, n + 1)
126
+ for j in range(1, n + 1)
127
+ if i * j <= max_num
128
+ )
129
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
130
+
131
+ # Find closest ratio
132
+ best_ratio_diff = float("inf")
133
+ best_ratio = (1, 1)
134
+
135
+ for x, y in target_ratios:
136
+ target_ar = x / y
137
+ diff = abs(aspect_ratio - target_ar)
138
+ blocks = x * y
139
+ best_blocks = best_ratio[0] * best_ratio[1]
140
+
141
+ if diff < best_ratio_diff:
142
+ best_ratio_diff = diff
143
+ best_ratio = (x, y)
144
+ elif diff == best_ratio_diff and blocks > best_blocks:
145
+ best_ratio = (x, y)
146
+
147
+ target_w, target_h = image_size * best_ratio[0], image_size * best_ratio[1]
148
+ blocks = best_ratio[0] * best_ratio[1]
149
+
150
+ # Resize on GPU
151
+ resized = torch.nn.functional.interpolate(
152
+ tensor.unsqueeze(0),
153
+ size=(target_h, target_w),
154
+ mode="bicubic",
155
+ align_corners=False,
156
+ ).squeeze(0)
157
+
158
+ # Split into tiles
159
+ tiles = []
160
+ for i in range(blocks):
161
+ x = (i % best_ratio[0]) * image_size
162
+ y = (i // best_ratio[0]) * image_size
163
+ tile = resized[:, y : y + image_size, x : x + image_size]
164
+ tiles.append(tile)
165
+
166
+ # Add thumbnail if needed
167
+ if use_thumbnail and len(tiles) > 1:
168
+ thumb = torch.nn.functional.interpolate(
169
+ tensor.unsqueeze(0),
170
+ size=(image_size, image_size),
171
+ mode="bicubic",
172
+ align_corners=False,
173
+ ).squeeze(0)
174
+ tiles.append(thumb)
175
+
176
+ return torch.stack(tiles).to(torch.bfloat16)
177
+
184
178
  async def process_mm_data_async(
185
179
  self, image_data, input_text, request_obj, **kwargs
186
180
  ):
@@ -191,53 +185,71 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
191
185
  discard_alpha_channel=True,
192
186
  )
193
187
 
194
- def process_image_internvl(image, input_size=448, max_num=12):
195
- transform = InternVLImageProcessor.build_transform(input_size=input_size)
196
- images = InternVLImageProcessor.dynamic_preprocess(
197
- image, image_size=input_size, use_thumbnail=True, max_num=max_num
198
- )
199
- pixel_values = [transform(image) for image in images]
200
- pixel_values = torch.stack(pixel_values)
201
- return pixel_values
202
-
203
188
  num_patches_list = []
204
189
  pixel_values = []
190
+
205
191
  # Process each input with allocated frames
206
- for image_index, (image) in enumerate(base_output.images):
192
+ for image_index, image in enumerate(base_output.images):
207
193
  try:
208
194
  # TODO: video input
209
- raw_image = process_image_internvl(image)
210
- pixel_value = [raw_image.to(torch.bfloat16)]
211
- pixel_values += pixel_value
212
- num_patches = raw_image.shape[0]
213
- num_patches_list += [num_patches]
214
-
215
- except FileNotFoundError as e:
216
- print(e)
195
+ # Convert PIL to GPU tensor
196
+ if isinstance(image, Image.Image):
197
+ img_np = np.array(image.convert("RGB"))
198
+ tensor = (
199
+ torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
200
+ )
201
+ else:
202
+ tensor = image.cuda() # assume already tensor
203
+
204
+ # Using the mean and variance of the ImageNet dataset for all input images can lead to accuracy issues, while using the mean and variance of each input image is a more accurate choice.
205
+ mean = tensor.mean(dim=[1, 2], keepdim=True)
206
+ # Prevent division by zero; clamp to minimum value of 1e-6
207
+ std = tensor.std(dim=[1, 2], keepdim=True).clamp(min=1e-6)
208
+ tensor = (tensor - mean) / std
209
+ tiles = self.dynamic_preprocess(
210
+ tensor, image_size=448, max_num=12, use_thumbnail=True
211
+ )
212
+
213
+ pixel_values.append(tiles)
214
+ num_patches_list.append(tiles.shape[0])
215
+
216
+ except Exception as e:
217
+ print(f"[Error] Failed to process image {image_index}: {e}")
217
218
  return None
218
219
 
220
+ # Concatenate all
219
221
  pixel_values = torch.cat(pixel_values, dim=0)
220
222
 
221
223
  original_placeholder = "<<<__IMG_CONTEXT_PLACEHOLDER__>>>"
222
224
  input_text = input_text.replace(self.IMG_CONTEXT_TOKEN, original_placeholder)
223
225
 
224
- for idx, num_patches in enumerate(num_patches_list):
226
+ input_text_updated = input_text
227
+ for num_patches in num_patches_list:
225
228
  image_tokens = (
226
229
  self.IMG_START_TOKEN
227
230
  + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
228
231
  + self.IMG_END_TOKEN
229
232
  )
230
- input_text = input_text.replace(original_placeholder, image_tokens, 1)
233
+ input_text_updated = input_text_updated.replace(
234
+ original_placeholder, image_tokens, 1
235
+ )
231
236
 
232
- input_text = input_text.replace(original_placeholder, self.IMG_CONTEXT_TOKEN)
237
+ input_text_updated = input_text_updated.replace(
238
+ original_placeholder, self.IMG_CONTEXT_TOKEN
239
+ )
233
240
 
234
- input_ids = self.tokenizer(input_text, return_tensors="pt")[
241
+ # Tokenize
242
+ input_ids_tensor = self.tokenizer(input_text_updated, return_tensors="pt")[
235
243
  "input_ids"
236
244
  ].flatten()
245
+ input_ids = input_ids_tensor.tolist()
246
+
247
+ # Get image token offsets
237
248
  image_offsets = self.get_mm_items_offset(
238
- input_ids=input_ids,
249
+ input_ids=input_ids_tensor.to("cuda"),
239
250
  mm_token_id=self.mm_tokens.image_token_id,
240
251
  )
252
+
241
253
  items = [
242
254
  MultimodalDataItem(
243
255
  feature=pixel_values,
@@ -247,7 +259,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
247
259
  ]
248
260
 
249
261
  return {
250
- "input_ids": input_ids.tolist(),
262
+ "input_ids": input_ids,
251
263
  "mm_items": items,
252
264
  "im_start_id": self.img_start_token_id,
253
265
  "im_end_id": self.img_end_token_id,
@@ -1,7 +1,7 @@
1
1
  import re
2
2
  from typing import Dict, Optional, Tuple, Type
3
3
 
4
- from sglang.srt.harmony_parser import HarmonyParser
4
+ from sglang.srt.parser.harmony_parser import HarmonyParser
5
5
 
6
6
 
7
7
  class StreamingParseResult:
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import abc
4
- from typing import TYPE_CHECKING, Set, Type
4
+ import weakref
5
+ from typing import TYPE_CHECKING, Optional, Set, Type
5
6
 
6
7
  import torch
7
8
 
@@ -17,7 +18,7 @@ class BatchedPenalizerOrchestrator:
17
18
  penalizers: Set[Type["_BatchedPenalizer"]],
18
19
  ):
19
20
  self.vocab_size = vocab_size
20
- self.batch = batch
21
+ self._batch_ref = weakref.ref(batch)
21
22
  self.device = batch.device
22
23
  self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
23
24
 
@@ -27,6 +28,17 @@ class BatchedPenalizerOrchestrator:
27
28
  is_required |= pen_is_required
28
29
  self.is_required = is_required
29
30
 
31
+ @property
32
+ def batch(self) -> ScheduleBatch | None:
33
+ return self._batch_ref()
34
+
35
+ @batch.setter
36
+ def batch(self, value: Optional[ScheduleBatch]):
37
+ if value is None:
38
+ self._batch_ref = lambda: None
39
+ else:
40
+ self._batch_ref = weakref.ref(value)
41
+
30
42
  def reqs(self):
31
43
  return self.batch.reqs
32
44
 
@@ -67,28 +67,31 @@ class SamplingBatchInfo:
67
67
  logit_bias: Optional[torch.Tensor] = None
68
68
 
69
69
  @classmethod
70
- def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
70
+ def _get_global_server_args_dict(cls):
71
71
  from sglang.srt.managers.schedule_batch import global_server_args_dict
72
72
 
73
+ return global_server_args_dict
74
+
75
+ @classmethod
76
+ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
77
+ global_server_args_dict = cls._get_global_server_args_dict()
78
+
73
79
  reqs = batch.reqs
74
80
  device = batch.device
75
- temperatures = (
76
- torch.tensor(
77
- [r.sampling_params.temperature for r in reqs],
78
- dtype=torch.float,
79
- )
80
- .view(-1, 1)
81
- .to(device, non_blocking=True)
82
- )
81
+ temperatures = torch.tensor(
82
+ [r.sampling_params.temperature for r in reqs],
83
+ dtype=torch.float,
84
+ device=device,
85
+ ).view(-1, 1)
83
86
  top_ps = torch.tensor(
84
- [r.sampling_params.top_p for r in reqs], dtype=torch.float
85
- ).to(device, non_blocking=True)
87
+ [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
88
+ )
86
89
  top_ks = torch.tensor(
87
- [r.sampling_params.top_k for r in reqs], dtype=torch.int32
88
- ).to(device, non_blocking=True)
90
+ [r.sampling_params.top_k for r in reqs], dtype=torch.int32, device=device
91
+ )
89
92
  min_ps = torch.tensor(
90
- [r.sampling_params.min_p for r in reqs], dtype=torch.float
91
- ).to(device, non_blocking=True)
93
+ [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
94
+ )
92
95
 
93
96
  logit_bias = None
94
97
  if any(r.sampling_params.logit_bias is not None for r in reqs):