sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,357 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ # Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
7
+ # and OPT implementations in this library. It has been modified from its
8
+ # original forms to accommodate minor architectural differences compared
9
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/starcoder2.py
23
+ """ PyTorch Starcoder2 model."""
24
+ from collections.abc import Iterable
25
+ from typing import Optional, Tuple
26
+
27
+ import torch
28
+ from torch import nn
29
+ from transformers import Starcoder2Config
30
+
31
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
32
+ from sglang.srt.layers.activation import get_act_fn
33
+ from sglang.srt.layers.linear import (
34
+ ColumnParallelLinear,
35
+ QKVParallelLinear,
36
+ RowParallelLinear,
37
+ )
38
+ from sglang.srt.layers.logits_processor import LogitsProcessor
39
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
+ from sglang.srt.layers.radix_attention import RadixAttention
41
+ from sglang.srt.layers.rotary_embedding import get_rope
42
+ from sglang.srt.layers.vocab_parallel_embedding import (
43
+ DEFAULT_VOCAB_PADDING_SIZE,
44
+ ParallelLMHead,
45
+ VocabParallelEmbedding,
46
+ )
47
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
49
+ from sglang.srt.utils import add_prefix, make_layers
50
+
51
+
52
+ class Starcoder2Attention(nn.Module):
53
+
54
+ def __init__(
55
+ self,
56
+ config: Starcoder2Config,
57
+ quant_config: Optional[QuantizationConfig] = None,
58
+ prefix: str = "",
59
+ layer_id: int = 0,
60
+ ):
61
+ super().__init__()
62
+ self.config = config
63
+
64
+ self.hidden_size = config.hidden_size
65
+ tp_size = get_tensor_model_parallel_world_size()
66
+ self.total_num_heads = config.num_attention_heads
67
+ assert self.total_num_heads % tp_size == 0
68
+ self.num_heads = self.total_num_heads // tp_size
69
+ self.total_num_kv_heads = config.num_key_value_heads
70
+ if self.total_num_kv_heads >= tp_size:
71
+ # Number of KV heads is greater than TP size, so we partition
72
+ # the KV heads across multiple tensor parallel GPUs.
73
+ assert self.total_num_kv_heads % tp_size == 0
74
+ else:
75
+ # Number of KV heads is less than TP size, so we replicate
76
+ # the KV heads across multiple tensor parallel GPUs.
77
+ assert tp_size % self.total_num_kv_heads == 0
78
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
79
+ self.head_dim = self.hidden_size // self.total_num_heads
80
+ self.q_size = self.num_heads * self.head_dim
81
+ self.kv_size = self.num_kv_heads * self.head_dim
82
+ self.scaling = self.head_dim**-0.5
83
+ self.rope_theta = config.rope_theta
84
+ self.max_position_embeddings = config.max_position_embeddings
85
+ self.use_bias = config.use_bias
86
+
87
+ self.qkv_proj = QKVParallelLinear(
88
+ self.hidden_size,
89
+ self.head_dim,
90
+ self.total_num_heads,
91
+ self.total_num_kv_heads,
92
+ bias=self.use_bias,
93
+ quant_config=quant_config,
94
+ prefix=f"{prefix}.qkv_proj",
95
+ )
96
+ self.o_proj = RowParallelLinear(
97
+ self.total_num_heads * self.head_dim,
98
+ self.hidden_size,
99
+ bias=self.use_bias,
100
+ quant_config=quant_config,
101
+ prefix=f"{prefix}.o_proj",
102
+ )
103
+ self.rotary_emb = get_rope(
104
+ self.head_dim,
105
+ rotary_dim=self.head_dim,
106
+ max_position=self.max_position_embeddings,
107
+ base=int(self.rope_theta),
108
+ is_neox_style=True,
109
+ )
110
+ self.attn = RadixAttention(
111
+ self.num_heads,
112
+ self.head_dim,
113
+ self.scaling,
114
+ num_kv_heads=self.num_kv_heads,
115
+ layer_id=layer_id,
116
+ quant_config=quant_config,
117
+ prefix=f"{prefix}.attn",
118
+ )
119
+
120
+ def forward(
121
+ self,
122
+ positions: torch.Tensor,
123
+ hidden_states: torch.Tensor,
124
+ forward_batch: ForwardBatch,
125
+ ) -> torch.Tensor:
126
+ qkv, _ = self.qkv_proj(hidden_states)
127
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
128
+ q, k = self.rotary_emb(positions, q, k)
129
+ attn_output = self.attn(q, k, v, forward_batch)
130
+ output, _ = self.o_proj(attn_output)
131
+ return output
132
+
133
+
134
+ class Starcoder2MLP(nn.Module):
135
+
136
+ def __init__(
137
+ self,
138
+ config: Starcoder2Config,
139
+ quant_config: Optional[QuantizationConfig] = None,
140
+ prefix: str = "",
141
+ ):
142
+ super().__init__()
143
+ self.c_fc = ColumnParallelLinear(
144
+ config.hidden_size,
145
+ config.intermediate_size,
146
+ bias=config.use_bias,
147
+ quant_config=quant_config,
148
+ prefix=f"{prefix}.c_fc",
149
+ )
150
+ self.c_proj = RowParallelLinear(
151
+ config.intermediate_size,
152
+ config.hidden_size,
153
+ bias=config.use_bias,
154
+ quant_config=quant_config,
155
+ prefix=f"{prefix}.c_proj",
156
+ )
157
+ self.act = get_act_fn(config.hidden_act)
158
+
159
+ def forward(
160
+ self,
161
+ hidden_states: torch.Tensor,
162
+ ) -> torch.Tensor:
163
+ hidden_states, _ = self.c_fc(hidden_states)
164
+ hidden_states = self.act(hidden_states)
165
+ hidden_states, _ = self.c_proj(hidden_states)
166
+ return hidden_states
167
+
168
+
169
+ class Starcoder2DecoderLayer(nn.Module):
170
+
171
+ def __init__(
172
+ self,
173
+ config: Starcoder2Config,
174
+ layer_id: int,
175
+ quant_config: Optional[QuantizationConfig] = None,
176
+ prefix: str = "",
177
+ ):
178
+ super().__init__()
179
+ self.hidden_size = config.hidden_size
180
+ self.self_attn = Starcoder2Attention(
181
+ config=config,
182
+ layer_id=layer_id,
183
+ quant_config=quant_config,
184
+ prefix=f"{prefix}.self_attn",
185
+ )
186
+ self.mlp = Starcoder2MLP(
187
+ config, quant_config=quant_config, prefix=f"{prefix}.mlp"
188
+ )
189
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
190
+ self.post_attention_layernorm = nn.LayerNorm(
191
+ config.hidden_size, eps=config.norm_epsilon
192
+ )
193
+
194
+ def forward(
195
+ self,
196
+ positions: torch.Tensor,
197
+ hidden_states: torch.Tensor,
198
+ forward_batch: ForwardBatch,
199
+ ) -> torch.Tensor:
200
+ # Self Attention
201
+ residual = hidden_states
202
+ hidden_states = self.input_layernorm(hidden_states)
203
+ hidden_states = self.self_attn(
204
+ positions=positions,
205
+ hidden_states=hidden_states,
206
+ forward_batch=forward_batch,
207
+ )
208
+ hidden_states = residual + hidden_states
209
+
210
+ # Fully Connected
211
+ residual = hidden_states
212
+ hidden_states = self.post_attention_layernorm(hidden_states)
213
+ hidden_states = self.mlp(hidden_states)
214
+ hidden_states = residual + hidden_states
215
+
216
+ return hidden_states
217
+
218
+
219
+ class Starcoder2Model(nn.Module):
220
+
221
+ def __init__(
222
+ self,
223
+ config: Starcoder2Config,
224
+ quant_config: Optional[QuantizationConfig] = None,
225
+ prefix: str = "",
226
+ ):
227
+ super().__init__()
228
+
229
+ self.config = config
230
+ self.vocab_size = config.vocab_size
231
+
232
+ self.embed_tokens = VocabParallelEmbedding(
233
+ config.vocab_size,
234
+ config.hidden_size,
235
+ quant_config=quant_config,
236
+ prefix=f"{prefix}.embed_tokens",
237
+ )
238
+
239
+ pp_group = get_pp_group()
240
+ pp_size = pp_group.world_size
241
+ pp_rank = pp_group.rank
242
+ self.start_layer = pp_rank * config.num_hidden_layers // pp_size
243
+ self.end_layer = (pp_rank + 1) * config.num_hidden_layers // pp_size
244
+
245
+ self.layers = make_layers(
246
+ config.num_hidden_layers,
247
+ lambda idx, prefix: Starcoder2DecoderLayer(
248
+ config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
249
+ ),
250
+ prefix=f"{prefix}.layers",
251
+ )
252
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
253
+
254
+ def forward(
255
+ self,
256
+ input_ids: torch.Tensor,
257
+ positions: torch.Tensor,
258
+ forward_batch: ForwardBatch,
259
+ inputs_embeds: Optional[torch.Tensor] = None,
260
+ ) -> torch.Tensor:
261
+ if inputs_embeds is None:
262
+ hidden_states = self.embed_tokens(input_ids)
263
+ else:
264
+ hidden_states = inputs_embeds
265
+ for i in range(self.start_layer, self.end_layer):
266
+ layer = self.layers[i]
267
+ hidden_states = layer(
268
+ positions,
269
+ hidden_states,
270
+ forward_batch,
271
+ )
272
+ hidden_states = self.norm(hidden_states)
273
+ return hidden_states
274
+
275
+
276
+ class Starcoder2ForCausalLM(nn.Module):
277
+
278
+ def __init__(
279
+ self,
280
+ config: Starcoder2Config,
281
+ quant_config: Optional[QuantizationConfig] = None,
282
+ prefix: str = "",
283
+ ):
284
+ super().__init__()
285
+ self.config = config
286
+ self.model = Starcoder2Model(
287
+ config, quant_config, prefix=add_prefix("model", prefix)
288
+ )
289
+ self.vocab_size = config.vocab_size
290
+ self.unpadded_vocab_size = config.vocab_size
291
+ if config.tie_word_embeddings:
292
+ self.lm_head = self.model.embed_tokens
293
+ else:
294
+ self.unpadded_vocab_size = config.vocab_size
295
+ self.lm_head = ParallelLMHead(
296
+ self.unpadded_vocab_size,
297
+ config.hidden_size,
298
+ org_num_embeddings=config.vocab_size,
299
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE,
300
+ quant_config=quant_config,
301
+ prefix=f"{prefix}.lm_head",
302
+ )
303
+ self.logits_processor = LogitsProcessor(config=config)
304
+
305
+ def forward(
306
+ self,
307
+ input_ids: torch.Tensor,
308
+ positions: torch.Tensor,
309
+ forward_batch: ForwardBatch,
310
+ inputs_embeds: Optional[torch.Tensor] = None,
311
+ ) -> torch.Tensor:
312
+ hidden_states = self.model(
313
+ input_ids=input_ids,
314
+ positions=positions,
315
+ forward_batch=forward_batch,
316
+ inputs_embeds=inputs_embeds,
317
+ )
318
+ return self.logits_processor(
319
+ input_ids, hidden_states, self.lm_head, forward_batch
320
+ )
321
+
322
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
323
+ stacked_params_mapping = [
324
+ # (param_name, shard_name, shard_id)
325
+ ("qkv_proj", "q_proj", "q"),
326
+ ("qkv_proj", "k_proj", "k"),
327
+ ("qkv_proj", "v_proj", "v"),
328
+ ]
329
+ params_dict = dict(self.named_parameters())
330
+
331
+ for name, loaded_weight in weights:
332
+ if "rotary_emb.inv_freqs" in name:
333
+ continue
334
+
335
+ is_stacked = False
336
+ for param_name, weight_name, shard_id in stacked_params_mapping:
337
+ if weight_name in name:
338
+ name = name.replace(weight_name, param_name)
339
+ param = params_dict[name]
340
+ weight_loader = getattr(
341
+ param, "weight_loader", default_weight_loader
342
+ )
343
+ weight_loader(param, loaded_weight, shard_id)
344
+ is_stacked = True
345
+ break
346
+ if is_stacked:
347
+ continue
348
+
349
+ param = params_dict.get(name)
350
+ if param is None:
351
+ continue
352
+
353
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
354
+ weight_loader(param, loaded_weight)
355
+
356
+
357
+ EntryClass = Starcoder2ForCausalLM
@@ -66,8 +66,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
66
66
  from sglang.srt.model_loader.weight_utils import default_weight_loader
67
67
  from sglang.srt.utils import add_prefix
68
68
 
69
- tp_size = get_tensor_model_parallel_world_size()
70
- tp_rank = get_tensor_model_parallel_rank()
69
+ tp_size: Optional[int] = None
70
+ tp_rank: Optional[int] = None
71
71
 
72
72
 
73
73
  def gate_up_proj_weight_loader(
@@ -341,6 +341,13 @@ class LlamaModel(nn.Module):
341
341
  quant_config: Optional[QuantizationConfig] = None,
342
342
  ) -> None:
343
343
  super().__init__()
344
+
345
+ global tp_size, tp_rank
346
+ if tp_size is None:
347
+ tp_size = get_tensor_model_parallel_world_size()
348
+ if tp_rank is None:
349
+ tp_rank = get_tensor_model_parallel_rank()
350
+
344
351
  self.config = config
345
352
  self.padding_idx = config.pad_token_id
346
353
  self.vocab_size = config.vocab_size
@@ -0,0 +1,51 @@
1
+ # Copyright 2023-2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ import torch
16
+
17
+ from sglang.srt.layers.radix_attention import RadixAttention
18
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
19
+ from sglang.srt.utils import is_cuda
20
+
21
+ _is_cuda = is_cuda()
22
+
23
+
24
+ if _is_cuda:
25
+ from sgl_kernel import FusedSetKVBufferArg
26
+
27
+
28
+ def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
29
+ """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
30
+ return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
31
+
32
+
33
+ def create_fused_set_kv_buffer_arg(
34
+ value: torch.Tensor,
35
+ layer: RadixAttention,
36
+ forward_batch: ForwardBatch,
37
+ ):
38
+ layer_id = layer.layer_id
39
+ token_to_kv_pool = forward_batch.token_to_kv_pool
40
+
41
+ k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
42
+ v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
43
+
44
+ return FusedSetKVBufferArg(
45
+ value=value,
46
+ k_buffer=k_buffer.view(k_buffer.shape[0], -1),
47
+ v_buffer=v_buffer.view(v_buffer.shape[0], -1),
48
+ k_scale=layer.k_scale,
49
+ v_scale=layer.v_scale,
50
+ cache_loc=forward_batch.out_cache_loc,
51
+ )
@@ -234,19 +234,27 @@ class BaseMultimodalProcessor(ABC):
234
234
  and isinstance(processor.image_processor, BaseImageProcessorFast)
235
235
  and not self.server_args.disable_fast_image_processor
236
236
  ):
237
- kwargs["device"] = "cuda" if not _is_npu else "npu"
237
+ if not _is_npu:
238
+ kwargs["device"] = "cuda"
239
+ elif processor.__class__.__name__ not in {
240
+ "Qwen2_5_VLProcessor",
241
+ "Qwen3VLProcessor",
242
+ }:
243
+ # Note: for qwen-vl, processor has some reshape issue because of dims restriction on Ascend.
244
+ kwargs["device"] = "npu"
238
245
  result = processor.__call__(
239
246
  text=[input_text],
240
247
  padding=True,
241
248
  return_tensors="pt",
242
249
  **kwargs,
243
250
  )
244
- # move feature tensors to cpu
245
- for feature_name in self.FEATURE_NAMES:
246
- if feature_name in result and isinstance(
247
- result[feature_name], torch.Tensor
248
- ):
249
- result[feature_name] = result[feature_name].to("cpu")
251
+ if not self.server_args.keep_mm_feature_on_device:
252
+ # move feature tensors to cpu
253
+ for feature_name in self.FEATURE_NAMES:
254
+ if feature_name in result and isinstance(
255
+ result[feature_name], torch.Tensor
256
+ ):
257
+ result[feature_name] = result[feature_name].to("cpu")
250
258
 
251
259
  return result
252
260
 
@@ -5,6 +5,7 @@ from typing import Dict, List, Union
5
5
 
6
6
  from PIL import Image
7
7
 
8
+ from sglang.srt.models.dots_ocr import DotsOCRForCausalLM
8
9
  from sglang.srt.models.dots_vlm import DotsVLMForCausalLM
9
10
  from sglang.srt.multimodal.processors.base_processor import (
10
11
  BaseMultimodalProcessor,
@@ -14,7 +15,7 @@ from sglang.srt.multimodal.processors.qwen_vl import resize_image_async
14
15
 
15
16
 
16
17
  class DotsVLMImageProcessor(BaseMultimodalProcessor):
17
- models = [DotsVLMForCausalLM]
18
+ models = [DotsVLMForCausalLM, DotsOCRForCausalLM]
18
19
 
19
20
  def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
20
21
  super().__init__(hf_config, server_args, _processor, *args, **kwargs)
@@ -82,11 +83,9 @@ class DotsVLMImageProcessor(BaseMultimodalProcessor):
82
83
  for image in base_output.images
83
84
  ]
84
85
  base_output.images = await asyncio.gather(*resize_tasks)
85
-
86
86
  combined_mm_item, input_ids, _ = self.process_and_combine_mm_data(
87
87
  base_output, self.mm_tokens
88
88
  )
89
-
90
89
  if combined_mm_item is None:
91
90
  return None
92
91
 
@@ -1,5 +1,7 @@
1
1
  # Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
2
2
 
3
+ from functools import lru_cache
4
+
3
5
  import numpy as np
4
6
  import torch
5
7
  import torchvision.transforms as T
@@ -19,6 +21,20 @@ from sglang.srt.multimodal.processors.base_processor import (
19
21
  class InternVLImageProcessor(BaseMultimodalProcessor):
20
22
  models = [InternVLChatModel, InternS1ForConditionalGeneration]
21
23
 
24
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
25
+ IMAGENET_STD = [0.229, 0.224, 0.225]
26
+
27
+ @staticmethod
28
+ @lru_cache(maxsize=1)
29
+ def _get_normalize_tensors(device="cuda", dtype=torch.float32):
30
+ mean = torch.tensor(
31
+ InternVLImageProcessor.IMAGENET_MEAN, device=device, dtype=dtype
32
+ ).view(-1, 1, 1)
33
+ std = torch.tensor(
34
+ InternVLImageProcessor.IMAGENET_STD, device=device, dtype=dtype
35
+ ).view(-1, 1, 1)
36
+ return mean, std
37
+
22
38
  def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
23
39
  super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
24
40
  image_size = (
@@ -88,6 +104,8 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
88
104
  bound, fps, max_frame, first_idx=0, num_segments=num_segments
89
105
  )
90
106
 
107
+ mean, std = InternVLImageProcessor._get_normalize_tensors(device="cuda")
108
+
91
109
  for frame_index in frame_indices:
92
110
  # Load frame
93
111
  frame = vr[frame_index]
@@ -97,10 +115,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
97
115
  img_np = frame.asnumpy()
98
116
  img = torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
99
117
 
100
- # Using the mean and variance of the ImageNet dataset for all input images can lead to accuracy issues, while using the mean and variance of each input image is a more accurate choice.
101
- mean = img.mean(dim=[1, 2], keepdim=True)
102
- # Prevent division by zero; clamp to minimum value of 1e-6
103
- std = img.std(dim=[1, 2], keepdim=True).clamp(min=1e-6)
104
118
  img = (img - mean) / std
105
119
 
106
120
  tiles = InternVLImageProcessor.dynamic_preprocess(
@@ -188,6 +202,8 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
188
202
  num_patches_list = []
189
203
  pixel_values = []
190
204
 
205
+ mean, std = InternVLImageProcessor._get_normalize_tensors(device="cuda")
206
+
191
207
  # Process each input with allocated frames
192
208
  for image_index, image in enumerate(base_output.images):
193
209
  try:
@@ -201,10 +217,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
201
217
  else:
202
218
  tensor = image.cuda() # assume already tensor
203
219
 
204
- # Using the mean and variance of the ImageNet dataset for all input images can lead to accuracy issues, while using the mean and variance of each input image is a more accurate choice.
205
- mean = tensor.mean(dim=[1, 2], keepdim=True)
206
- # Prevent division by zero; clamp to minimum value of 1e-6
207
- std = tensor.std(dim=[1, 2], keepdim=True).clamp(min=1e-6)
208
220
  tensor = (tensor - mean) / std
209
221
  tiles = self.dynamic_preprocess(
210
222
  tensor, image_size=448, max_num=12, use_thumbnail=True
@@ -12,6 +12,8 @@ from torchvision.transforms import InterpolationMode
12
12
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
13
13
  from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
14
14
  from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
15
+ from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
16
+ from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
15
17
  from sglang.srt.multimodal.processors.base_processor import (
16
18
  BaseMultimodalProcessor as SGLangBaseProcessor,
17
19
  )
@@ -209,7 +211,12 @@ async def preprocess_video(
209
211
 
210
212
  # Compatible with Qwen2VL and Qwen2_5VL
211
213
  class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
212
- models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
214
+ models = [
215
+ Qwen2VLForConditionalGeneration,
216
+ Qwen2_5_VLForConditionalGeneration,
217
+ Qwen3VLForConditionalGeneration,
218
+ Qwen3VLMoeForConditionalGeneration,
219
+ ]
213
220
 
214
221
  def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
215
222
  super().__init__(hf_config, server_args, _processor, *args, **kwargs)
@@ -0,0 +1,81 @@
1
+ from typing import List, Union
2
+
3
+ from sglang.srt.models.sarashina2_vision import Sarashina2VisionForCausalLM
4
+ from sglang.srt.multimodal.processors.base_processor import (
5
+ BaseMultimodalProcessor,
6
+ MultimodalSpecialTokens,
7
+ )
8
+
9
+
10
+ class Sarashina2VisionProcessor(BaseMultimodalProcessor):
11
+ models = [Sarashina2VisionForCausalLM]
12
+
13
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
14
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
15
+
16
+ # Sarashina2Vision specific tokens (default is <|file|>)
17
+ self.IMAGE_TOKEN = "<|file|>"
18
+ self.IM_TOKEN_ID = getattr(hf_config, "image_token_index", 14)
19
+ self.IM_START_ID = getattr(hf_config, "start_image_token_index", 102397)
20
+ self.IM_END_ID = getattr(hf_config, "end_image_token_index", 102398)
21
+
22
+ self.mm_tokens = MultimodalSpecialTokens(
23
+ image_token=self.IMAGE_TOKEN,
24
+ image_token_id=self.IM_TOKEN_ID,
25
+ ).build(_processor)
26
+
27
+ # Patch the processor's image processor to handle parameter compatibility
28
+ if hasattr(_processor, "image_processor") and hasattr(
29
+ _processor.image_processor, "_preprocess"
30
+ ):
31
+ original_preprocess = _processor.image_processor._preprocess
32
+
33
+ def patched_preprocess(*args, **kwargs):
34
+ # Filter kwargs to only include parameters that the custom _preprocess method accepts
35
+ # Based on Sarashina2VisionImageProcessor._preprocess signature
36
+ allowed_params = {
37
+ "do_resize",
38
+ "resample",
39
+ "do_rescale",
40
+ "rescale_factor",
41
+ "do_normalize",
42
+ "image_mean",
43
+ "image_std",
44
+ "do_convert_rgb",
45
+ "data_format",
46
+ "input_data_format",
47
+ }
48
+ filtered_kwargs = {
49
+ k: v for k, v in kwargs.items() if k in allowed_params
50
+ }
51
+ return original_preprocess(*args, **filtered_kwargs)
52
+
53
+ _processor.image_processor._preprocess = patched_preprocess
54
+
55
+ async def process_mm_data_async(
56
+ self,
57
+ image_data: List[Union[str, bytes]],
58
+ input_text,
59
+ request_obj,
60
+ *args,
61
+ **kwargs,
62
+ ):
63
+ """Process image data for Sarashina2Vision model using standard SGLang pattern."""
64
+ base_output = self.load_mm_data(
65
+ prompt=input_text,
66
+ image_data=image_data,
67
+ multimodal_tokens=self.mm_tokens,
68
+ )
69
+
70
+ mm_items, input_ids, ret = self.process_and_combine_mm_data(
71
+ base_output=base_output,
72
+ mm_tokens=self.mm_tokens,
73
+ )
74
+
75
+ return {
76
+ "mm_items": mm_items,
77
+ "input_ids": input_ids.tolist(),
78
+ "im_token_id": self.mm_tokens.image_token_id,
79
+ "im_start_id": self.IM_START_ID,
80
+ "im_end_id": self.IM_END_ID,
81
+ }
@@ -89,6 +89,12 @@ def detect_jinja_template_content_format(chat_template: str) -> str:
89
89
  - If template has loops like {%- for content in message['content'] -%} → 'openai'
90
90
  - Otherwise → 'string'
91
91
  """
92
+ # Shortcut for multimodal templates
93
+ if any(
94
+ keyword in chat_template for keyword in ["image", "audio", "video", "vision"]
95
+ ):
96
+ return "openai"
97
+
92
98
  jinja_ast = _try_extract_ast(chat_template)
93
99
  if jinja_ast is None:
94
100
  return "string"