sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,112 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ """Inference-only Qwen3Next MTP Speculative Decoding."""
16
+ import logging
17
+ from typing import Iterable, Optional, Tuple
18
+
19
+ import torch
20
+ from torch import nn
21
+ from transformers import PretrainedConfig
22
+
23
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
24
+ from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
25
+ from sglang.srt.layers.logits_processor import LogitsProcessor
26
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
27
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
28
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
29
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
+ from sglang.srt.models.qwen3_moe import Qwen3MoeModel
31
+ from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel
32
+ from sglang.srt.utils import add_prefix
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
38
+
39
+ def __init__(
40
+ self,
41
+ config: PretrainedConfig,
42
+ quant_config: Optional[QuantizationConfig] = None,
43
+ prefix: str = "",
44
+ ) -> None:
45
+ nn.Module.__init__(self)
46
+ self.config = config
47
+ self.tp_size = get_tensor_model_parallel_world_size()
48
+ self.quant_config = quant_config
49
+ # if not set, model load will be broken in Qwen3NextForCausalLM load_weights()
50
+ self.pp_group = get_pp_group()
51
+ # self.determine_num_fused_shared_experts("Qwen3NextForCausalLMMTP")
52
+
53
+ # currently based on the provided ckpt, we:
54
+ # (1) do not use_dedicated_mtp_embeddings provided in ckpt since not provided and directly use the target model embeddings
55
+ # (2) hardcode bias=False since not provided
56
+ self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
57
+ RMSNorm_cls = GemmaRMSNorm
58
+ self.pre_fc_norm_embedding = RMSNorm_cls(
59
+ config.hidden_size, config.rms_norm_eps
60
+ )
61
+ self.pre_fc_norm_hidden = RMSNorm_cls(config.hidden_size, config.rms_norm_eps)
62
+ config.num_hidden_layers = 1
63
+ config.full_attention_interval = 1
64
+ self.model = Qwen3NextModel(
65
+ config, quant_config, prefix=add_prefix("model", prefix)
66
+ )
67
+ self.lm_head = ParallelLMHead(
68
+ config.vocab_size,
69
+ config.hidden_size,
70
+ quant_config=quant_config,
71
+ prefix=add_prefix("model.shared_head.head", prefix),
72
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
73
+ )
74
+ self.logits_processor = LogitsProcessor(config)
75
+
76
+ @torch.no_grad()
77
+ def forward(
78
+ self,
79
+ input_ids: torch.Tensor,
80
+ positions: torch.Tensor,
81
+ forward_batch: ForwardBatch,
82
+ input_embeds: Optional[torch.Tensor] = None,
83
+ **kwargs,
84
+ ):
85
+ if input_embeds is None:
86
+ input_embeds = self.model.embed_tokens(input_ids)
87
+
88
+ hidden_states = forward_batch.spec_info.hidden_states
89
+ # Some idle batch has 0 batch size. GemmaRMSNorm.forward would fail due to bs=0.
90
+ if not forward_batch.forward_mode.is_idle():
91
+ input_embeds = self.pre_fc_norm_embedding(input_embeds)
92
+ hidden_states = self.pre_fc_norm_hidden(hidden_states)
93
+ hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1))
94
+
95
+ hidden_states = self.model(
96
+ input_ids,
97
+ positions,
98
+ forward_batch,
99
+ hidden_states,
100
+ )
101
+
102
+ return self.logits_processor(
103
+ input_ids, hidden_states, self.lm_head, forward_batch
104
+ )
105
+
106
+ def load_weights(
107
+ self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
108
+ ):
109
+ super().load_weights(weights, is_mtp=True)
110
+
111
+
112
+ EntryClass = [Qwen3NextForCausalLMMTP]
@@ -133,7 +133,7 @@ class Step3TextMoEMLP(nn.Module):
133
133
  use_grouped_topk=False,
134
134
  )
135
135
 
136
- self.experts = get_moe_impl_class()(
136
+ self.experts = get_moe_impl_class(quant_config)(
137
137
  num_experts=config.moe_num_experts,
138
138
  top_k=config.moe_top_k,
139
139
  hidden_size=config.hidden_size,
@@ -0,0 +1,99 @@
1
+ import asyncio
2
+ import math
3
+ import re
4
+ from typing import Dict, List, Union
5
+
6
+ from PIL import Image
7
+
8
+ from sglang.srt.models.dots_vlm import DotsVLMForCausalLM
9
+ from sglang.srt.multimodal.processors.base_processor import (
10
+ BaseMultimodalProcessor,
11
+ MultimodalSpecialTokens,
12
+ )
13
+ from sglang.srt.multimodal.processors.qwen_vl import resize_image_async
14
+
15
+
16
+ class DotsVLMImageProcessor(BaseMultimodalProcessor):
17
+ models = [DotsVLMForCausalLM]
18
+
19
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
20
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
21
+ # The single, pre-expanded image token.
22
+ self.IMAGE_TOKEN = "<|img|><|imgpad|><|endofimg|>"
23
+ # The regex that matches expanded image tokens.
24
+ self.IMAGE_TOKEN_REGEX = re.compile(r"<\|img\|>(?:<\|imgpad\|>)+<\|endofimg\|>")
25
+
26
+ assert len(_processor.tokenizer.encode("<|img|>")) == 1
27
+ self.im_start_id = _processor.tokenizer.encode("<|img|>")[0]
28
+ self.im_end_id = _processor.tokenizer.encode("<|endofimg|>")[0]
29
+ self.image_token_id = _processor.tokenizer.encode("<|imgpad|>")[0]
30
+ self.IM_TOKEN_ID = self.image_token_id
31
+ self.IM_START_ID = self.im_start_id
32
+ self.IM_END_ID = self.im_end_id
33
+
34
+ vision_config = hf_config.vision_config
35
+ patch_size = vision_config.patch_size
36
+ merge_size = vision_config.spatial_merge_size
37
+
38
+ self.IMAGE_FACTOR = patch_size * merge_size
39
+ self.MIN_PIXELS = _processor.image_processor.min_pixels
40
+ self.MAX_PIXELS = _processor.image_processor.max_pixels
41
+ self.MAX_RATIO = 200
42
+ self.mm_tokens = MultimodalSpecialTokens(
43
+ image_token=self.IMAGE_TOKEN,
44
+ image_token_id=self.image_token_id,
45
+ image_token_regex=self.IMAGE_TOKEN_REGEX,
46
+ ).build(_processor)
47
+
48
+ async def process_mm_data_async(
49
+ self,
50
+ image_data: List[Union[str, bytes, Dict]],
51
+ input_text,
52
+ request_obj,
53
+ max_req_input_len,
54
+ *args,
55
+ **kwargs,
56
+ ):
57
+ if isinstance(image_data, str):
58
+ image_data = [image_data]
59
+
60
+ if (
61
+ isinstance(image_data, list)
62
+ and image_data
63
+ and isinstance(image_data[0], list)
64
+ ):
65
+ image_data = sum(image_data, [])
66
+
67
+ base_output = self.load_mm_data(
68
+ prompt=input_text,
69
+ image_data=image_data,
70
+ multimodal_tokens=self.mm_tokens,
71
+ )
72
+
73
+ # Qwen-specific: resize images if they are raw Image objects
74
+ if base_output.images and isinstance(base_output.images[0], Image.Image):
75
+ resize_tasks = [
76
+ resize_image_async(
77
+ image,
78
+ min_pixels=self.MIN_PIXELS,
79
+ max_pixels=self.MAX_PIXELS,
80
+ size_factor=self.IMAGE_FACTOR,
81
+ )
82
+ for image in base_output.images
83
+ ]
84
+ base_output.images = await asyncio.gather(*resize_tasks)
85
+
86
+ combined_mm_item, input_ids, _ = self.process_and_combine_mm_data(
87
+ base_output, self.mm_tokens
88
+ )
89
+
90
+ if combined_mm_item is None:
91
+ return None
92
+
93
+ return {
94
+ "input_ids": input_ids.tolist(),
95
+ "mm_items": combined_mm_item,
96
+ "im_start_id": self.im_start_id,
97
+ "im_end_id": self.im_end_id,
98
+ "im_token_id": self.image_token_id,
99
+ }
@@ -2,7 +2,6 @@ import re
2
2
  from typing import List, Union
3
3
 
4
4
  from decord import VideoReader
5
- from transformers.video_utils import VideoMetadata
6
5
 
7
6
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
8
7
  from sglang.srt.models.glm4v import Glm4vForConditionalGeneration
@@ -66,17 +65,18 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
66
65
  total_num_frames = len(vr)
67
66
  duration = total_num_frames / video_fps if video_fps else 0
68
67
 
69
- metadata = VideoMetadata(
70
- total_num_frames=int(total_num_frames),
71
- fps=float(video_fps),
72
- duration=float(duration),
73
- video_backend="decord",
74
- )
75
-
76
68
  # Extract all frames
77
69
  indices = list(range(total_num_frames))
78
70
  frames = vr.get_batch(indices).asnumpy()
79
- metadata.frames_indices = indices
71
+
72
+ # Return metadata as dict so transformers can properly create VideoMetadata objects
73
+ metadata = {
74
+ "total_num_frames": int(total_num_frames),
75
+ "fps": float(video_fps),
76
+ "duration": float(duration),
77
+ "video_backend": "decord",
78
+ "frames_indices": indices,
79
+ }
80
80
 
81
81
  return frames, metadata
82
82
 
@@ -2,8 +2,10 @@
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
- from decord import VideoReader, cpu
5
+ import torchvision.transforms as T
6
+ from decord import VideoReader, cpu, gpu
6
7
  from PIL import Image
8
+ from torchvision.transforms import InterpolationMode
7
9
 
8
10
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
9
11
  from sglang.srt.models.interns1 import InternS1ForConditionalGeneration
@@ -48,99 +50,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
48
50
  image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN),
49
51
  ).build(_image_processor)
50
52
 
51
- @staticmethod
52
- def build_transform(input_size):
53
- IMAGENET_MEAN = (0.485, 0.456, 0.406)
54
- IMAGENET_STD = (0.229, 0.224, 0.225)
55
-
56
- def resize_image(img, size):
57
- return img.resize((size, size), Image.Resampling.BICUBIC)
58
-
59
- def to_tensor(img):
60
- # Convert PIL Image to numpy array
61
- img_array = np.array(img).astype(np.float32) / 255.0
62
- # Convert HWC to CHW format
63
- img_array = img_array.transpose(2, 0, 1)
64
- return torch.from_numpy(img_array)
65
-
66
- def normalize(tensor, mean, std):
67
- mean = torch.tensor(mean).view(-1, 1, 1)
68
- std = torch.tensor(std).view(-1, 1, 1)
69
- return (tensor - mean) / std
70
-
71
- def transform(img):
72
- img = img.convert("RGB") if img.mode != "RGB" else img
73
- img = resize_image(img, input_size)
74
- tensor = to_tensor(img)
75
- tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
76
- return tensor
77
-
78
- return transform
79
-
80
- @staticmethod
81
- def dynamic_preprocess(
82
- image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
83
- ):
84
-
85
- def find_closest_aspect_ratio(
86
- aspect_ratio, target_ratios, width, height, image_size
87
- ):
88
- best_ratio_diff = float("inf")
89
- best_ratio = (1, 1)
90
- area = width * height
91
- for ratio in target_ratios:
92
- target_aspect_ratio = ratio[0] / ratio[1]
93
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
94
- if ratio_diff < best_ratio_diff:
95
- best_ratio_diff = ratio_diff
96
- best_ratio = ratio
97
- elif ratio_diff == best_ratio_diff:
98
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
99
- best_ratio = ratio
100
- return best_ratio
101
-
102
- orig_width, orig_height = image.size
103
- aspect_ratio = orig_width / orig_height
104
-
105
- # calculate the existing image aspect ratio
106
- target_ratios = set(
107
- (i, j)
108
- for n in range(min_num, max_num + 1)
109
- for i in range(1, n + 1)
110
- for j in range(1, n + 1)
111
- if i * j <= max_num and i * j >= min_num
112
- )
113
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
114
-
115
- # find the closest aspect ratio to the target
116
- target_aspect_ratio = find_closest_aspect_ratio(
117
- aspect_ratio, target_ratios, orig_width, orig_height, image_size
118
- )
119
-
120
- # calculate the target width and height
121
- target_width = image_size * target_aspect_ratio[0]
122
- target_height = image_size * target_aspect_ratio[1]
123
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
124
-
125
- # resize the image
126
- resized_img = image.resize((target_width, target_height))
127
- processed_images = []
128
- for i in range(blocks):
129
- box = (
130
- (i % (target_width // image_size)) * image_size,
131
- (i // (target_width // image_size)) * image_size,
132
- ((i % (target_width // image_size)) + 1) * image_size,
133
- ((i // (target_width // image_size)) + 1) * image_size,
134
- )
135
- # split the image
136
- split_img = resized_img.crop(box)
137
- processed_images.append(split_img)
138
- assert len(processed_images) == blocks
139
- if use_thumbnail and len(processed_images) != 1:
140
- thumbnail_img = image.resize((image_size, image_size))
141
- processed_images.append(thumbnail_img)
142
- return processed_images
143
-
144
53
  @staticmethod
145
54
  def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
146
55
  if bound:
@@ -160,27 +69,112 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
160
69
 
161
70
  @staticmethod
162
71
  def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
163
- vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
72
+ try:
73
+ vr = VideoReader(video_path, ctx=gpu(0), num_threads=1)
74
+ use_gpu = True
75
+ except (RuntimeError, OSError) as e:
76
+ print(
77
+ f"[WARNING] Load video on gpu decoding failed: {e}. Falling back to CPU."
78
+ )
79
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
80
+ use_gpu = False
81
+
164
82
  max_frame = len(vr) - 1
165
83
  fps = float(vr.get_avg_fps())
166
84
 
167
- pixel_values_list, num_patches_list = [], []
168
- transform = InternVLImageProcessor.build_transform(input_size=input_size)
85
+ pixel_values_list = []
86
+ num_patches_list = []
169
87
  frame_indices = InternVLImageProcessor.get_index(
170
88
  bound, fps, max_frame, first_idx=0, num_segments=num_segments
171
89
  )
90
+
172
91
  for frame_index in frame_indices:
173
- img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
174
- img = InternVLImageProcessor.dynamic_preprocess(
175
- img, image_size=input_size, use_thumbnail=True, max_num=max_num
92
+ # Load frame
93
+ frame = vr[frame_index]
94
+ if use_gpu:
95
+ img = frame.cuda().permute(2, 0, 1).float() / 255.0
96
+ else:
97
+ img_np = frame.asnumpy()
98
+ img = torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
99
+
100
+ # Using the mean and variance of the ImageNet dataset for all input images can lead to accuracy issues, while using the mean and variance of each input image is a more accurate choice.
101
+ mean = img.mean(dim=[1, 2], keepdim=True)
102
+ # Prevent division by zero; clamp to minimum value of 1e-6
103
+ std = img.std(dim=[1, 2], keepdim=True).clamp(min=1e-6)
104
+ img = (img - mean) / std
105
+
106
+ tiles = InternVLImageProcessor.dynamic_preprocess(
107
+ img, image_size=input_size, max_num=max_num, use_thumbnail=True
176
108
  )
177
- pixel_values = [transform(tile) for tile in img]
178
- pixel_values = torch.stack(pixel_values)
179
- num_patches_list.append(pixel_values.shape[0])
180
- pixel_values_list.append(pixel_values)
181
- pixel_values = torch.cat(pixel_values_list)
109
+
110
+ pixel_values_list.append(tiles)
111
+ num_patches_list.append(tiles.shape[0])
112
+
113
+ pixel_values = torch.cat(pixel_values_list, dim=0)
182
114
  return pixel_values, num_patches_list
183
115
 
116
+ @staticmethod
117
+ def dynamic_preprocess(tensor, image_size=448, max_num=12, use_thumbnail=False):
118
+ C, H, W = tensor.shape
119
+ aspect_ratio = W / H
120
+
121
+ # Generate all possible aspect ratios
122
+ target_ratios = set(
123
+ (i, j)
124
+ for n in range(1, max_num + 1)
125
+ for i in range(1, n + 1)
126
+ for j in range(1, n + 1)
127
+ if i * j <= max_num
128
+ )
129
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
130
+
131
+ # Find closest ratio
132
+ best_ratio_diff = float("inf")
133
+ best_ratio = (1, 1)
134
+
135
+ for x, y in target_ratios:
136
+ target_ar = x / y
137
+ diff = abs(aspect_ratio - target_ar)
138
+ blocks = x * y
139
+ best_blocks = best_ratio[0] * best_ratio[1]
140
+
141
+ if diff < best_ratio_diff:
142
+ best_ratio_diff = diff
143
+ best_ratio = (x, y)
144
+ elif diff == best_ratio_diff and blocks > best_blocks:
145
+ best_ratio = (x, y)
146
+
147
+ target_w, target_h = image_size * best_ratio[0], image_size * best_ratio[1]
148
+ blocks = best_ratio[0] * best_ratio[1]
149
+
150
+ # Resize on GPU
151
+ resized = torch.nn.functional.interpolate(
152
+ tensor.unsqueeze(0),
153
+ size=(target_h, target_w),
154
+ mode="bicubic",
155
+ align_corners=False,
156
+ ).squeeze(0)
157
+
158
+ # Split into tiles
159
+ tiles = []
160
+ for i in range(blocks):
161
+ x = (i % best_ratio[0]) * image_size
162
+ y = (i // best_ratio[0]) * image_size
163
+ tile = resized[:, y : y + image_size, x : x + image_size]
164
+ tiles.append(tile)
165
+
166
+ # Add thumbnail if needed
167
+ if use_thumbnail and len(tiles) > 1:
168
+ thumb = torch.nn.functional.interpolate(
169
+ tensor.unsqueeze(0),
170
+ size=(image_size, image_size),
171
+ mode="bicubic",
172
+ align_corners=False,
173
+ ).squeeze(0)
174
+ tiles.append(thumb)
175
+
176
+ return torch.stack(tiles).to(torch.bfloat16)
177
+
184
178
  async def process_mm_data_async(
185
179
  self, image_data, input_text, request_obj, **kwargs
186
180
  ):
@@ -191,53 +185,71 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
191
185
  discard_alpha_channel=True,
192
186
  )
193
187
 
194
- def process_image_internvl(image, input_size=448, max_num=12):
195
- transform = InternVLImageProcessor.build_transform(input_size=input_size)
196
- images = InternVLImageProcessor.dynamic_preprocess(
197
- image, image_size=input_size, use_thumbnail=True, max_num=max_num
198
- )
199
- pixel_values = [transform(image) for image in images]
200
- pixel_values = torch.stack(pixel_values)
201
- return pixel_values
202
-
203
188
  num_patches_list = []
204
189
  pixel_values = []
190
+
205
191
  # Process each input with allocated frames
206
- for image_index, (image) in enumerate(base_output.images):
192
+ for image_index, image in enumerate(base_output.images):
207
193
  try:
208
194
  # TODO: video input
209
- raw_image = process_image_internvl(image)
210
- pixel_value = [raw_image.to(torch.bfloat16)]
211
- pixel_values += pixel_value
212
- num_patches = raw_image.shape[0]
213
- num_patches_list += [num_patches]
214
-
215
- except FileNotFoundError as e:
216
- print(e)
195
+ # Convert PIL to GPU tensor
196
+ if isinstance(image, Image.Image):
197
+ img_np = np.array(image.convert("RGB"))
198
+ tensor = (
199
+ torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
200
+ )
201
+ else:
202
+ tensor = image.cuda() # assume already tensor
203
+
204
+ # Using the mean and variance of the ImageNet dataset for all input images can lead to accuracy issues, while using the mean and variance of each input image is a more accurate choice.
205
+ mean = tensor.mean(dim=[1, 2], keepdim=True)
206
+ # Prevent division by zero; clamp to minimum value of 1e-6
207
+ std = tensor.std(dim=[1, 2], keepdim=True).clamp(min=1e-6)
208
+ tensor = (tensor - mean) / std
209
+ tiles = self.dynamic_preprocess(
210
+ tensor, image_size=448, max_num=12, use_thumbnail=True
211
+ )
212
+
213
+ pixel_values.append(tiles)
214
+ num_patches_list.append(tiles.shape[0])
215
+
216
+ except Exception as e:
217
+ print(f"[Error] Failed to process image {image_index}: {e}")
217
218
  return None
218
219
 
220
+ # Concatenate all
219
221
  pixel_values = torch.cat(pixel_values, dim=0)
220
222
 
221
223
  original_placeholder = "<<<__IMG_CONTEXT_PLACEHOLDER__>>>"
222
224
  input_text = input_text.replace(self.IMG_CONTEXT_TOKEN, original_placeholder)
223
225
 
224
- for idx, num_patches in enumerate(num_patches_list):
226
+ input_text_updated = input_text
227
+ for num_patches in num_patches_list:
225
228
  image_tokens = (
226
229
  self.IMG_START_TOKEN
227
230
  + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
228
231
  + self.IMG_END_TOKEN
229
232
  )
230
- input_text = input_text.replace(original_placeholder, image_tokens, 1)
233
+ input_text_updated = input_text_updated.replace(
234
+ original_placeholder, image_tokens, 1
235
+ )
231
236
 
232
- input_text = input_text.replace(original_placeholder, self.IMG_CONTEXT_TOKEN)
237
+ input_text_updated = input_text_updated.replace(
238
+ original_placeholder, self.IMG_CONTEXT_TOKEN
239
+ )
233
240
 
234
- input_ids = self.tokenizer(input_text, return_tensors="pt")[
241
+ # Tokenize
242
+ input_ids_tensor = self.tokenizer(input_text_updated, return_tensors="pt")[
235
243
  "input_ids"
236
244
  ].flatten()
245
+ input_ids = input_ids_tensor.tolist()
246
+
247
+ # Get image token offsets
237
248
  image_offsets = self.get_mm_items_offset(
238
- input_ids=input_ids,
249
+ input_ids=input_ids_tensor.to("cuda"),
239
250
  mm_token_id=self.mm_tokens.image_token_id,
240
251
  )
252
+
241
253
  items = [
242
254
  MultimodalDataItem(
243
255
  feature=pixel_values,
@@ -247,7 +259,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
247
259
  ]
248
260
 
249
261
  return {
250
- "input_ids": input_ids.tolist(),
262
+ "input_ids": input_ids,
251
263
  "mm_items": items,
252
264
  "im_start_id": self.img_start_token_id,
253
265
  "im_end_id": self.img_end_token_id,
@@ -67,10 +67,15 @@ def smart_resize(
67
67
  return h_bar, w_bar
68
68
 
69
69
 
70
- def resize_image(image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
70
+ def resize_image(
71
+ image,
72
+ min_pixels: int = MIN_PIXELS,
73
+ max_pixels: int = MAX_PIXELS,
74
+ size_factor: int = IMAGE_FACTOR,
75
+ ) -> Image.Image:
71
76
  width, height = image.size
72
- min_pixels = MIN_PIXELS
73
- max_pixels = MAX_PIXELS
77
+ min_pixels = min_pixels
78
+ max_pixels = max_pixels
74
79
  resized_height, resized_width = smart_resize(
75
80
  height,
76
81
  width,
@@ -97,8 +102,13 @@ def floor_by_factor(number: int, factor: int) -> int:
97
102
  return math.floor(number / factor) * factor
98
103
 
99
104
 
100
- async def resize_image_async(image):
101
- return resize_image(image)
105
+ async def resize_image_async(
106
+ image,
107
+ min_pixels: int = MIN_PIXELS,
108
+ max_pixels: int = MAX_PIXELS,
109
+ size_factor: int = IMAGE_FACTOR,
110
+ ):
111
+ return resize_image(image, min_pixels, max_pixels, size_factor)
102
112
 
103
113
 
104
114
  def smart_nframes(
sglang/srt/offloader.py CHANGED
@@ -38,6 +38,10 @@ class BaseOffloader(ABC):
38
38
  def post_init(self):
39
39
  pass
40
40
 
41
+ @property
42
+ def forbid_copy_engine_usage(self):
43
+ return False
44
+
41
45
 
42
46
  class NoopOffloader(BaseOffloader):
43
47
  pass
@@ -233,6 +237,10 @@ class OffloaderV2(BaseOffloader):
233
237
  for i in range(self.prefetch_step):
234
238
  self.offloaders[i].start_onload()
235
239
 
240
+ @property
241
+ def forbid_copy_engine_usage(self):
242
+ return self.mode == "cpu"
243
+
236
244
 
237
245
  def _hook_module_forward_for_offloader(index, module, offloaders, prefetch_step):
238
246
  def _on_forward_end():
@@ -398,14 +406,30 @@ class _ShmCpuParamOffloader(_BaseParamOffloader):
398
406
  return self.shm_cpu_data.to("cuda", non_blocking=True)
399
407
 
400
408
 
409
+ def update_param(param, new_tensor):
410
+ """Update parameter while keeping properties needed by Offloader (e.g. pinned host memory)."""
411
+
412
+ if param.device == new_tensor.device:
413
+ param.data = new_tensor
414
+ else:
415
+ assert param.device == torch.device(
416
+ "cpu"
417
+ ), f"{param.device=} {new_tensor.device=}"
418
+ param.data = _create_cpu_data(new_tensor, pin_memory=True)
419
+
420
+
401
421
  def _move_param_to_cpu(param, pin_memory: bool):
422
+ param.data = _create_cpu_data(param.data, pin_memory=pin_memory)
423
+
424
+
425
+ def _create_cpu_data(data, pin_memory: bool):
402
426
  cpu_data = _empty_strided_like(
403
- param.data,
427
+ data,
404
428
  device="cpu",
405
429
  pin_memory=pin_memory,
406
430
  )
407
- cpu_data.copy_(param.data)
408
- param.data = cpu_data
431
+ cpu_data.copy_(data)
432
+ return cpu_data
409
433
 
410
434
 
411
435
  def _move_param_to_meta(module, param_name):