sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,7 @@ class BenchArgs:
38
38
  output_len: Tuple[int] = (16,)
39
39
  temperature: float = 0.0
40
40
  return_logprob: bool = False
41
+ client_stream_interval: int = 1
41
42
  input_len_step_percentage: float = 0.0
42
43
  result_filename: str = "result.jsonl"
43
44
  base_url: str = ""
@@ -60,6 +61,11 @@ class BenchArgs:
60
61
  )
61
62
  parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
62
63
  parser.add_argument("--return-logprob", action="store_true")
64
+ parser.add_argument(
65
+ "--client-stream-interval",
66
+ type=int,
67
+ default=BenchArgs.client_stream_interval,
68
+ )
63
69
  parser.add_argument(
64
70
  "--input-len-step-percentage",
65
71
  type=float,
@@ -120,6 +126,7 @@ def run_one_case(
120
126
  output_len: int,
121
127
  temperature: float,
122
128
  return_logprob: bool,
129
+ stream_interval: int,
123
130
  input_len_step_percentage: float,
124
131
  run_name: str,
125
132
  result_filename: str,
@@ -168,6 +175,7 @@ def run_one_case(
168
175
  "max_new_tokens": output_len,
169
176
  "ignore_eos": True,
170
177
  "json_schema": json_schema,
178
+ "stream_interval": stream_interval,
171
179
  },
172
180
  "return_logprob": return_logprob,
173
181
  "stream": True,
@@ -245,8 +253,12 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
245
253
  else:
246
254
  proc, base_url = launch_server_process(server_args)
247
255
 
248
- tokenizer_id = server_args.tokenizer_path or server_args.model_path
249
- tokenizer = get_tokenizer(tokenizer_id)
256
+ server_info = requests.get(base_url + "/get_server_info").json()
257
+ if "tokenizer_path" in server_info:
258
+ tokenizer_path = server_info["tokenizer_path"]
259
+ elif "prefill" in server_info:
260
+ tokenizer_path = server_info["prefill"][0]["tokenizer_path"]
261
+ tokenizer = get_tokenizer(tokenizer_path)
250
262
 
251
263
  # warmup
252
264
  if not bench_args.skip_warmup:
@@ -258,6 +270,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
258
270
  output_len=16,
259
271
  temperature=bench_args.temperature,
260
272
  return_logprob=bench_args.return_logprob,
273
+ stream_interval=bench_args.client_stream_interval,
261
274
  input_len_step_percentage=bench_args.input_len_step_percentage,
262
275
  run_name="",
263
276
  result_filename="",
@@ -280,6 +293,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
280
293
  ol,
281
294
  temperature=bench_args.temperature,
282
295
  return_logprob=bench_args.return_logprob,
296
+ stream_interval=bench_args.client_stream_interval,
283
297
  input_len_step_percentage=bench_args.input_len_step_percentage,
284
298
  run_name=bench_args.run_name,
285
299
  result_filename=bench_args.result_filename,
@@ -301,6 +315,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
301
315
  ol,
302
316
  temperature=bench_args.temperature,
303
317
  return_logprob=bench_args.return_logprob,
318
+ stream_interval=bench_args.client_stream_interval,
304
319
  input_len_step_percentage=bench_args.input_len_step_percentage,
305
320
  run_name=bench_args.run_name,
306
321
  result_filename=bench_args.result_filename,
sglang/bench_serving.py CHANGED
@@ -265,6 +265,138 @@ async def async_request_openai_completions(
265
265
  return output
266
266
 
267
267
 
268
+ async def async_request_openai_chat_completions(
269
+ request_func_input: RequestFuncInput,
270
+ pbar: Optional[tqdm] = None,
271
+ ) -> RequestFuncOutput:
272
+ """Makes a request to the OpenAI Chat Completions API.
273
+
274
+ Handles both streaming and non-streaming responses, including support
275
+ for image data in messages. Calculates and returns various performance
276
+ metrics.
277
+
278
+ Args:
279
+ request_func_input: Input parameters for the request.
280
+ pbar: Optional tqdm progress bar to update.
281
+
282
+ Returns:
283
+ RequestFuncOutput: Output of the request, including generated text,
284
+ latency, TTFT, ITL, and success status.
285
+ """
286
+ api_url = request_func_input.api_url
287
+ assert api_url.endswith(
288
+ "chat/completions"
289
+ ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
290
+
291
+ if request_func_input.image_data:
292
+ messages = [
293
+ {
294
+ "role": "user",
295
+ "content": [
296
+ {
297
+ "type": "image_url",
298
+ "image_url": {"url": request_func_input.image_data},
299
+ },
300
+ {"type": "text", "text": request_func_input.prompt},
301
+ ],
302
+ },
303
+ ]
304
+ else:
305
+ messages = [{"role": "user", "content": request_func_input.prompt}]
306
+
307
+ async with _create_bench_client_session() as session:
308
+ payload = {
309
+ "model": request_func_input.model,
310
+ "messages": messages,
311
+ "temperature": 0.0,
312
+ "max_tokens": request_func_input.output_len,
313
+ "stream": not args.disable_stream,
314
+ **request_func_input.extra_request_body,
315
+ }
316
+ headers = get_auth_headers()
317
+
318
+ output = RequestFuncOutput.init_new(request_func_input)
319
+
320
+ generated_text = ""
321
+ output_len = request_func_input.output_len
322
+ ttft = 0.0
323
+ st = time.perf_counter()
324
+ most_recent_timestamp = st
325
+ try:
326
+ async with session.post(
327
+ url=api_url, json=payload, headers=headers
328
+ ) as response:
329
+ if response.status == 200:
330
+ if args.disable_stream:
331
+ # Non-streaming response
332
+ response_json = await response.json()
333
+ output.generated_text = response_json["choices"][0]["message"][
334
+ "content"
335
+ ]
336
+ output.success = True
337
+ output.latency = time.perf_counter() - st
338
+ output.ttft = (
339
+ output.latency
340
+ ) # For non-streaming, TTFT = total latency
341
+ output.output_len = response_json.get("usage", {}).get(
342
+ "completion_tokens", output_len
343
+ )
344
+ else:
345
+ # Streaming response
346
+ async for chunk_bytes in response.content:
347
+ chunk_bytes = chunk_bytes.strip()
348
+ if not chunk_bytes:
349
+ continue
350
+
351
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
352
+ latency = time.perf_counter() - st
353
+ if chunk == "[DONE]":
354
+ pass
355
+ else:
356
+ data = json.loads(chunk)
357
+
358
+ # Check if this chunk contains content
359
+ delta = data.get("choices", [{}])[0].get("delta", {})
360
+ content = delta.get("content", "")
361
+
362
+ if content:
363
+ timestamp = time.perf_counter()
364
+ # First token
365
+ if ttft == 0.0:
366
+ ttft = timestamp - st
367
+ output.ttft = ttft
368
+
369
+ # Decoding phase
370
+ else:
371
+ output.itl.append(
372
+ timestamp - most_recent_timestamp
373
+ )
374
+
375
+ most_recent_timestamp = timestamp
376
+ generated_text += content
377
+
378
+ # Check for usage info in final chunk
379
+ output_len = (data.get("usage") or {}).get(
380
+ "completion_tokens", output_len
381
+ )
382
+
383
+ output.generated_text = generated_text
384
+ output.success = True
385
+ output.latency = latency
386
+ output.output_len = output_len
387
+ else:
388
+ output.error = response.reason or ""
389
+ output.success = False
390
+ except Exception:
391
+ output.success = False
392
+ exc_info = sys.exc_info()
393
+ output.error = "".join(traceback.format_exception(*exc_info))
394
+
395
+ if pbar:
396
+ pbar.update(1)
397
+ return output
398
+
399
+
268
400
  async def async_request_truss(
269
401
  request_func_input: RequestFuncInput,
270
402
  pbar: Optional[tqdm] = None,
@@ -544,6 +676,7 @@ def get_dataset(args, tokenizer):
544
676
  num_requests=args.num_prompts,
545
677
  tokenizer=tokenizer,
546
678
  fixed_output_len=args.random_output_len,
679
+ apply_chat_template=args.apply_chat_template,
547
680
  random_sample=True,
548
681
  )
549
682
  else:
@@ -555,8 +688,11 @@ ASYNC_REQUEST_FUNCS = {
555
688
  "sglang": async_request_sglang_generate,
556
689
  "sglang-native": async_request_sglang_generate,
557
690
  "sglang-oai": async_request_openai_completions,
691
+ "sglang-oai-chat": async_request_openai_chat_completions,
558
692
  "vllm": async_request_openai_completions,
693
+ "vllm-chat": async_request_openai_chat_completions,
559
694
  "lmdeploy": async_request_openai_completions,
695
+ "lmdeploy-chat": async_request_openai_chat_completions,
560
696
  "trt": async_request_trt_llm,
561
697
  "gserver": async_request_gserver,
562
698
  "truss": async_request_truss,
@@ -661,6 +797,7 @@ def sample_mmmu_requests(
661
797
  num_requests: int,
662
798
  tokenizer: PreTrainedTokenizerBase,
663
799
  fixed_output_len: Optional[int] = None,
800
+ apply_chat_template: bool = True,
664
801
  random_sample: bool = True,
665
802
  ) -> List[DatasetRow]:
666
803
  """
@@ -670,6 +807,7 @@ def sample_mmmu_requests(
670
807
  num_requests: Number of requests to sample.
671
808
  tokenizer: Tokenizer to use for token counting.
672
809
  fixed_output_len: If provided, use this fixed output length for all requests.
810
+ apply_chat_template: Whether to apply the chat template to the prompt.
673
811
  random_sample: Whether to randomly sample or take the first N.
674
812
 
675
813
  Returns:
@@ -739,28 +877,30 @@ def sample_mmmu_requests(
739
877
 
740
878
  # Construct the prompt
741
879
  prompt = f"Question: {question}\n\nAnswer: "
742
-
743
- try:
744
- prompt = tokenizer.apply_chat_template(
745
- [
746
- {
747
- "role": "user",
748
- "content": [
749
- {
750
- "type": "image_url",
751
- "image_url": {"url": image_data},
752
- },
753
- {"type": "text", "text": prompt},
754
- ],
755
- }
756
- ],
757
- add_generation_prompt=True,
758
- tokenize=False,
759
- )
760
- except Exception as e:
761
- # Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
762
- print(f"Error applying chat template: {e}, fallback to <image> tag")
763
- prompt = f"<image>{prompt}"
880
+ if apply_chat_template:
881
+ try:
882
+ prompt = tokenizer.apply_chat_template(
883
+ [
884
+ {
885
+ "role": "user",
886
+ "content": [
887
+ {
888
+ "type": "image_url",
889
+ "image_url": {"url": image_data},
890
+ },
891
+ {"type": "text", "text": prompt},
892
+ ],
893
+ }
894
+ ],
895
+ add_generation_prompt=True,
896
+ tokenize=False,
897
+ )
898
+ except Exception as e:
899
+ # Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
900
+ print(
901
+ f"Error applying chat template: {e}, fallback to <image> tag"
902
+ )
903
+ prompt = f"<image>{prompt}"
764
904
 
765
905
  # Calculate token lengths for text only (without image data)
766
906
  prompt_token_ids = tokenizer.encode(prompt)
@@ -1544,6 +1684,12 @@ def run_benchmark(args_: argparse.Namespace):
1544
1684
  if args.base_url
1545
1685
  else f"http://{args.host}:{args.port}/v1/completions"
1546
1686
  )
1687
+ elif args.backend in ["sglang-oai-chat", "vllm-chat", "lmdeploy-chat"]:
1688
+ api_url = (
1689
+ f"{args.base_url}/v1/chat/completions"
1690
+ if args.base_url
1691
+ else f"http://{args.host}:{args.port}/v1/chat/completions"
1692
+ )
1547
1693
  elif args.backend == "trt":
1548
1694
  api_url = (
1549
1695
  f"{args.base_url}/v2/models/ensemble/generate_stream"
@@ -147,12 +147,14 @@ class InternLM2Config(PretrainedConfig):
147
147
  )
148
148
  if (
149
149
  rope_scaling_factor is None
150
- or not isinstance(rope_scaling_factor, float)
150
+ or not isinstance(rope_scaling_factor, (float, int))
151
151
  or rope_scaling_factor < 1.0
152
152
  ):
153
153
  raise ValueError(
154
- f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}"
154
+ f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}"
155
155
  )
156
+ if isinstance(rope_scaling_factor, int):
157
+ rope_scaling_factor = float(rope_scaling_factor)
156
158
 
157
159
 
158
160
  class InternVisionConfig(PretrainedConfig):
@@ -19,7 +19,7 @@ from transformers import (
19
19
  from transformers.image_utils import to_numpy_array
20
20
 
21
21
  from sglang.srt.configs.utils import register_image_processor, register_processor
22
- from sglang.srt.mm_utils import expand2square
22
+ from sglang.srt.multimodal.mm_utils import expand2square
23
23
 
24
24
 
25
25
  class DictToObject(dict):
@@ -59,6 +59,7 @@ class ModelConfig:
59
59
  quantization: Optional[str] = None,
60
60
  override_config_file: Optional[str] = None,
61
61
  is_draft_model: bool = False,
62
+ hybrid_kvcache_ratio: Optional[float] = None,
62
63
  impl: Union[str, ModelImpl] = ModelImpl.AUTO,
63
64
  ) -> None:
64
65
 
@@ -86,6 +87,18 @@ class ModelConfig:
86
87
  self.attention_chunk_size = getattr(
87
88
  self.hf_text_config, "attention_chunk_size", None
88
89
  )
90
+ self.is_hybrid = is_hybrid_model(
91
+ self.hf_config.architectures,
92
+ hybrid_kvcache_ratio=hybrid_kvcache_ratio,
93
+ context_length=context_length,
94
+ attention_chunk_size=self.attention_chunk_size,
95
+ )
96
+ if self.is_hybrid is not None:
97
+ self.swa_attention_layer_ids, self.full_attention_layer_ids = (
98
+ get_hybrid_layer_ids(
99
+ self.hf_config.architectures, self.hf_text_config.num_hidden_layers
100
+ )
101
+ )
89
102
 
90
103
  if enable_multimodal is None:
91
104
  mm_disabled_models = [
@@ -264,6 +277,7 @@ class ModelConfig:
264
277
  enable_multimodal=server_args.enable_multimodal,
265
278
  dtype=server_args.dtype,
266
279
  quantization=server_args.quantization,
280
+ hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
267
281
  impl=server_args.impl,
268
282
  **kwargs,
269
283
  )
@@ -565,6 +579,7 @@ multimodal_model_archs = [
565
579
  "CLIPModel",
566
580
  "DeepseekVL2ForCausalLM",
567
581
  "Gemma3ForConditionalGeneration",
582
+ "Gemma3nForConditionalGeneration",
568
583
  "Grok1VForCausalLM",
569
584
  "Grok1AForCausalLM",
570
585
  "LlavaLlamaForCausalLM",
@@ -578,6 +593,7 @@ multimodal_model_archs = [
578
593
  "Mistral3ForConditionalGeneration",
579
594
  "MultiModalityCausalLM",
580
595
  "MllamaForConditionalGeneration",
596
+ "Qwen2AudioForConditionalGeneration",
581
597
  "Qwen2VLForConditionalGeneration",
582
598
  "Qwen2_5_VLForConditionalGeneration",
583
599
  "KimiVLForConditionalGeneration",
@@ -632,3 +648,36 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
632
648
  if scale <= 1:
633
649
  return 1.0
634
650
  return 0.1 * mscale * math.log(scale) + 1.0
651
+
652
+
653
+ def is_hybrid_model(
654
+ model_architectures: List[str],
655
+ hybrid_kvcache_ratio: Optional[float],
656
+ context_length: Optional[int],
657
+ attention_chunk_size: Optional[int],
658
+ ):
659
+ if hybrid_kvcache_ratio is None:
660
+ return None
661
+ elif (
662
+ hybrid_kvcache_ratio > 0
663
+ and model_architectures[0] == "Llama4ForConditionalGeneration"
664
+ and context_length > attention_chunk_size
665
+ ):
666
+ return hybrid_kvcache_ratio
667
+ else:
668
+ return None
669
+
670
+
671
+ def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
672
+ if "Llama4ForConditionalGeneration" in model_architectures:
673
+ swa_attention_layer_ids = [
674
+ i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
675
+ ]
676
+ full_attention_layer_ids = [
677
+ i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
678
+ ]
679
+ else:
680
+ raise ValueError(
681
+ "get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration"
682
+ )
683
+ return swa_attention_layer_ids, full_attention_layer_ids
@@ -0,0 +1,119 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ DEFAULT_MOE_PADDING_SIZE = 32
6
+
7
+
8
+ if TYPE_CHECKING:
9
+ from sglang.srt.configs.load_config import LoadConfig
10
+ from sglang.srt.configs.model_config import ModelConfig
11
+
12
+
13
+ def may_get_weight_block_size(model_config, load_config):
14
+ from sglang.srt.model_loader.loader import _get_quantization_config
15
+ from sglang.srt.model_loader.utils import get_model_architecture
16
+
17
+ model_class, _ = get_model_architecture(model_config)
18
+ packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
19
+
20
+ quant_config = _get_quantization_config(
21
+ model_config, load_config, packed_modules_mapping
22
+ )
23
+
24
+ if quant_config is not None and hasattr(quant_config, "weight_block_size"):
25
+ return getattr(quant_config, "weight_block_size")
26
+ return None
27
+
28
+
29
+ def get_moe_padding_size(weight_block_size):
30
+ if weight_block_size is not None:
31
+ # See NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
32
+ assert (
33
+ len(weight_block_size) == 2
34
+ ), "Only len(weight_block_size) == 2 is supported"
35
+ assert (
36
+ weight_block_size[0] == weight_block_size[1]
37
+ ), "Only weight_block_size[0] == weight_block_size[1] is supported"
38
+
39
+ return weight_block_size[0]
40
+
41
+ return DEFAULT_MOE_PADDING_SIZE
42
+
43
+
44
+ def get_num_heads_padding_size(tp_size, weight_block_size):
45
+ pad_size = (
46
+ tp_size * 2 if tp_size % 2 == 1 and weight_block_size is not None else tp_size
47
+ )
48
+ return pad_size
49
+
50
+
51
+ def update_intermediate_size(model_config, attr_name, intermediate_padding_size):
52
+ if hasattr(model_config.hf_config, attr_name):
53
+ attr_value = getattr(model_config.hf_config, attr_name)
54
+ if attr_value % intermediate_padding_size != 0:
55
+ from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
56
+
57
+ attr_value = pad_vocab_size(attr_value, intermediate_padding_size)
58
+ setattr(model_config.hf_config, attr_name, attr_value)
59
+ setattr(model_config.hf_text_config, attr_name, attr_value)
60
+ return model_config
61
+
62
+
63
+ def adjust_config_with_unaligned_cpu_tp(
64
+ model_config: ModelConfig, load_config: LoadConfig, tp_size: int
65
+ ) -> ModelConfig:
66
+ # Support the case where the num_attention_heads is not divisible by the TP size.
67
+ weight_block_size = may_get_weight_block_size(model_config, load_config)
68
+
69
+ model_config.hf_config.original_num_attention_heads = (
70
+ model_config.num_attention_heads
71
+ )
72
+ model_config.hf_text_config.original_num_attention_heads = (
73
+ model_config.num_attention_heads
74
+ )
75
+
76
+ model_config.hf_config.original_total_num_kv_heads = (
77
+ model_config.get_total_num_kv_heads()
78
+ )
79
+ model_config.hf_text_config.original_total_num_kv_heads = (
80
+ model_config.get_total_num_kv_heads()
81
+ )
82
+
83
+ if (
84
+ model_config.num_attention_heads % tp_size != 0
85
+ or model_config.get_total_num_kv_heads() % tp_size != 0
86
+ ):
87
+ # Compute the head_dim using the model_config.num_attention_heads before padding
88
+ if not hasattr(model_config.hf_config, "head_dim"):
89
+ model_config.hf_config.head_dim = (
90
+ model_config.hidden_size // model_config.num_attention_heads
91
+ )
92
+
93
+ query_heads_per_kv = (
94
+ model_config.num_attention_heads // model_config.get_total_num_kv_heads()
95
+ )
96
+ total_kv_heads = model_config.get_total_num_kv_heads()
97
+ from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
98
+
99
+ pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
100
+ num_key_value_heads = pad_vocab_size(total_kv_heads, pad_size)
101
+
102
+ model_config.num_key_value_heads = num_key_value_heads
103
+ model_config.hf_config.num_key_value_heads = num_key_value_heads
104
+ model_config.hf_text_config.num_key_value_heads = num_key_value_heads
105
+
106
+ num_attention_heads = num_key_value_heads * query_heads_per_kv
107
+ model_config.num_attention_heads = num_attention_heads
108
+ model_config.hf_config.num_attention_heads = num_attention_heads
109
+ model_config.hf_text_config.num_attention_heads = num_attention_heads
110
+
111
+ intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size)
112
+ model_config = update_intermediate_size(
113
+ model_config, "moe_intermediate_size", intermediate_padding_size
114
+ )
115
+ model_config = update_intermediate_size(
116
+ model_config, "intermediate_size", intermediate_padding_size
117
+ )
118
+
119
+ return model_config
@@ -59,6 +59,7 @@ class SeparatorStyle(IntEnum):
59
59
  METAMATH = auto()
60
60
  DeepSeekVL2 = auto()
61
61
  QWEN2_VL_EMBED = auto()
62
+ QWEN2_AUDIO = auto()
62
63
  GEMMA3 = auto()
63
64
  MPT = auto()
64
65
 
@@ -350,6 +351,23 @@ class Conversation:
350
351
  else:
351
352
  ret += role
352
353
  return ret
354
+ elif self.sep_style == SeparatorStyle.QWEN2_AUDIO:
355
+ ret = "" if system_prompt == "" else system_prompt + self.sep
356
+
357
+ counter = 1
358
+ for role, message in self.messages:
359
+ if message:
360
+ while self.audio_token in message:
361
+ message = message.replace(
362
+ self.audio_token, self.audio_token.format(idx=counter), 1
363
+ )
364
+ counter += 1
365
+
366
+ ret += role + "\n" + message + self.sep
367
+ else:
368
+ ret += role + "\n"
369
+
370
+ return ret
353
371
  else:
354
372
  raise ValueError(f"Invalid style: {self.sep_style}")
355
373
 
@@ -823,6 +841,7 @@ register_conv_template(
823
841
  sep_style=SeparatorStyle.GEMMA3,
824
842
  stop_str=["<end_of_turn>"],
825
843
  image_token="<start_of_image>",
844
+ audio_token="<start_of_audio>",
826
845
  )
827
846
  )
828
847
 
@@ -903,6 +922,20 @@ register_conv_template(
903
922
  )
904
923
 
905
924
 
925
+ register_conv_template(
926
+ Conversation(
927
+ name="qwen2-audio",
928
+ system_template="<|im_start|>system\n{system_message}",
929
+ system_message="You are a helpful assistant.",
930
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
931
+ sep="<|im_end|>\n",
932
+ sep_style=SeparatorStyle.QWEN2_AUDIO,
933
+ stop_str=["<|im_end|>"],
934
+ audio_token="Audio {idx}: <|audio_bos|><|AUDIO|><|audio_eos|>\n",
935
+ )
936
+ )
937
+
938
+
906
939
  @register_conv_template_matching_function
907
940
  def match_internvl(model_path: str):
908
941
  if re.search(r"internvl2_5", model_path, re.IGNORECASE):
@@ -955,6 +988,8 @@ def match_qwen_chat_ml(model_path: str):
955
988
  return "gme-qwen2-vl"
956
989
  if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
957
990
  return "qwen2-vl"
991
+ if re.search(r"qwen.*audio", model_path, re.IGNORECASE):
992
+ return "qwen2-audio"
958
993
  if re.search(
959
994
  r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
960
995
  model_path,
sglang/srt/custom_op.py CHANGED
@@ -1,11 +1,12 @@
1
1
  from torch import nn
2
2
 
3
- from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
3
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
4
4
 
5
5
  _is_cuda = is_cuda()
6
6
  _is_hip = is_hip()
7
7
  _is_cpu = is_cpu()
8
8
  _is_cpu_amx_available = cpu_has_amx_support()
9
+ _is_npu = is_npu()
9
10
 
10
11
 
11
12
  class CustomOp(nn.Module):
@@ -60,6 +61,9 @@ class CustomOp(nn.Module):
60
61
  def forward_cuda(self, *args, **kwargs):
61
62
  raise NotImplementedError
62
63
 
64
+ def forward_npu(self, *args, **kwargs):
65
+ raise NotImplementedError
66
+
63
67
  def forward_hip(self, *args, **kwargs):
64
68
  return self.forward_cuda(*args, **kwargs)
65
69
 
@@ -79,5 +83,7 @@ class CustomOp(nn.Module):
79
83
  return self.forward_hip
80
84
  elif _is_cpu and _is_cpu_amx_available:
81
85
  return self.forward_cpu
86
+ elif _is_npu:
87
+ return self.forward_npu
82
88
  else:
83
89
  return self.forward_native
@@ -27,6 +27,8 @@ class KVArgs:
27
27
  decode_tp_size: int
28
28
  # for pp prefill
29
29
  prefill_pp_size: int
30
+ kv_head_num: int
31
+ page_size: int
30
32
 
31
33
 
32
34
  class KVPoll: