sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +48 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +34 -0
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/nixl/conn.py +6 -6
  10. sglang/srt/disaggregation/prefill.py +2 -2
  11. sglang/srt/disaggregation/utils.py +1 -1
  12. sglang/srt/distributed/parallel_state.py +44 -17
  13. sglang/srt/entrypoints/EngineBase.py +8 -0
  14. sglang/srt/entrypoints/engine.py +40 -6
  15. sglang/srt/entrypoints/http_server.py +111 -24
  16. sglang/srt/entrypoints/openai/protocol.py +4 -2
  17. sglang/srt/eplb/__init__.py +0 -0
  18. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  19. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  20. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  21. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  22. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  24. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  25. sglang/srt/hf_transformers_utils.py +2 -1
  26. sglang/srt/layers/activation.py +2 -2
  27. sglang/srt/layers/amx_utils.py +86 -0
  28. sglang/srt/layers/attention/ascend_backend.py +219 -0
  29. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  30. sglang/srt/layers/attention/tbo_backend.py +37 -9
  31. sglang/srt/layers/communicator.py +18 -2
  32. sglang/srt/layers/dp_attention.py +9 -3
  33. sglang/srt/layers/elementwise.py +76 -12
  34. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  35. sglang/srt/layers/layernorm.py +26 -0
  36. sglang/srt/layers/linear.py +84 -14
  37. sglang/srt/layers/logits_processor.py +4 -4
  38. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +36 -13
  40. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  41. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
  42. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
  43. sglang/srt/layers/moe/router.py +60 -22
  44. sglang/srt/layers/moe/topk.py +10 -28
  45. sglang/srt/layers/parameter.py +67 -7
  46. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  47. sglang/srt/layers/quantization/fp8.py +44 -0
  48. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  49. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  50. sglang/srt/layers/quantization/gptq.py +5 -1
  51. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  52. sglang/srt/layers/quantization/quant_utils.py +166 -0
  53. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  54. sglang/srt/layers/rotary_embedding.py +2 -2
  55. sglang/srt/layers/vocab_parallel_embedding.py +11 -7
  56. sglang/srt/lora/lora.py +4 -5
  57. sglang/srt/lora/lora_manager.py +73 -20
  58. sglang/srt/managers/configure_logging.py +1 -1
  59. sglang/srt/managers/io_struct.py +50 -13
  60. sglang/srt/managers/mm_utils.py +73 -59
  61. sglang/srt/managers/multimodal_processor.py +2 -6
  62. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  63. sglang/srt/managers/schedule_batch.py +77 -84
  64. sglang/srt/managers/scheduler.py +113 -59
  65. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  66. sglang/srt/managers/session_controller.py +12 -3
  67. sglang/srt/managers/tokenizer_manager.py +314 -103
  68. sglang/srt/managers/tp_worker.py +13 -1
  69. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  70. sglang/srt/mem_cache/allocator.py +290 -0
  71. sglang/srt/mem_cache/chunk_cache.py +34 -2
  72. sglang/srt/mem_cache/memory_pool.py +289 -3
  73. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  74. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  75. sglang/srt/model_executor/forward_batch_info.py +17 -4
  76. sglang/srt/model_executor/model_runner.py +297 -56
  77. sglang/srt/model_loader/loader.py +41 -0
  78. sglang/srt/model_loader/weight_utils.py +72 -4
  79. sglang/srt/models/deepseek_nextn.py +1 -3
  80. sglang/srt/models/deepseek_v2.py +181 -45
  81. sglang/srt/models/deepseek_vl2.py +3 -5
  82. sglang/srt/models/gemma3_causal.py +1 -2
  83. sglang/srt/models/gemma3n_causal.py +4 -3
  84. sglang/srt/models/gemma3n_mm.py +4 -20
  85. sglang/srt/models/hunyuan.py +1 -1
  86. sglang/srt/models/kimi_vl.py +1 -2
  87. sglang/srt/models/llama.py +10 -4
  88. sglang/srt/models/llama4.py +32 -45
  89. sglang/srt/models/llama_eagle3.py +61 -11
  90. sglang/srt/models/llava.py +5 -5
  91. sglang/srt/models/minicpmo.py +2 -2
  92. sglang/srt/models/mistral.py +1 -1
  93. sglang/srt/models/mllama4.py +43 -11
  94. sglang/srt/models/phi4mm.py +1 -3
  95. sglang/srt/models/pixtral.py +3 -7
  96. sglang/srt/models/qwen2.py +31 -3
  97. sglang/srt/models/qwen2_5_vl.py +1 -3
  98. sglang/srt/models/qwen2_audio.py +200 -0
  99. sglang/srt/models/qwen2_moe.py +32 -6
  100. sglang/srt/models/qwen2_vl.py +1 -4
  101. sglang/srt/models/qwen3.py +94 -25
  102. sglang/srt/models/qwen3_moe.py +68 -21
  103. sglang/srt/models/vila.py +3 -8
  104. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  105. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  106. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  107. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  108. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  109. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  110. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  111. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  112. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  117. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  120. sglang/srt/operations_strategy.py +6 -2
  121. sglang/srt/reasoning_parser.py +26 -0
  122. sglang/srt/sampling/sampling_batch_info.py +39 -1
  123. sglang/srt/server_args.py +69 -22
  124. sglang/srt/speculative/build_eagle_tree.py +57 -18
  125. sglang/srt/speculative/eagle_worker.py +6 -4
  126. sglang/srt/two_batch_overlap.py +200 -27
  127. sglang/srt/utils.py +306 -146
  128. sglang/srt/warmup.py +12 -3
  129. sglang/test/runners.py +10 -1
  130. sglang/test/test_utils.py +15 -3
  131. sglang/version.py +1 -1
  132. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  133. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
  134. sglang/math_utils.py +0 -8
  135. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  136. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  137. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  138. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  139. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  140. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  141. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,7 @@ class BenchArgs:
38
38
  output_len: Tuple[int] = (16,)
39
39
  temperature: float = 0.0
40
40
  return_logprob: bool = False
41
+ client_stream_interval: int = 1
41
42
  input_len_step_percentage: float = 0.0
42
43
  result_filename: str = "result.jsonl"
43
44
  base_url: str = ""
@@ -60,6 +61,11 @@ class BenchArgs:
60
61
  )
61
62
  parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
62
63
  parser.add_argument("--return-logprob", action="store_true")
64
+ parser.add_argument(
65
+ "--client-stream-interval",
66
+ type=int,
67
+ default=BenchArgs.client_stream_interval,
68
+ )
63
69
  parser.add_argument(
64
70
  "--input-len-step-percentage",
65
71
  type=float,
@@ -120,6 +126,7 @@ def run_one_case(
120
126
  output_len: int,
121
127
  temperature: float,
122
128
  return_logprob: bool,
129
+ stream_interval: int,
123
130
  input_len_step_percentage: float,
124
131
  run_name: str,
125
132
  result_filename: str,
@@ -168,6 +175,7 @@ def run_one_case(
168
175
  "max_new_tokens": output_len,
169
176
  "ignore_eos": True,
170
177
  "json_schema": json_schema,
178
+ "stream_interval": stream_interval,
171
179
  },
172
180
  "return_logprob": return_logprob,
173
181
  "stream": True,
@@ -245,8 +253,12 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
245
253
  else:
246
254
  proc, base_url = launch_server_process(server_args)
247
255
 
248
- tokenizer_id = server_args.tokenizer_path or server_args.model_path
249
- tokenizer = get_tokenizer(tokenizer_id)
256
+ server_info = requests.get(base_url + "/get_server_info").json()
257
+ if "tokenizer_path" in server_info:
258
+ tokenizer_path = server_info["tokenizer_path"]
259
+ elif "prefill" in server_info:
260
+ tokenizer_path = server_info["prefill"][0]["tokenizer_path"]
261
+ tokenizer = get_tokenizer(tokenizer_path)
250
262
 
251
263
  # warmup
252
264
  if not bench_args.skip_warmup:
@@ -258,6 +270,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
258
270
  output_len=16,
259
271
  temperature=bench_args.temperature,
260
272
  return_logprob=bench_args.return_logprob,
273
+ stream_interval=bench_args.client_stream_interval,
261
274
  input_len_step_percentage=bench_args.input_len_step_percentage,
262
275
  run_name="",
263
276
  result_filename="",
@@ -280,6 +293,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
280
293
  ol,
281
294
  temperature=bench_args.temperature,
282
295
  return_logprob=bench_args.return_logprob,
296
+ stream_interval=bench_args.client_stream_interval,
283
297
  input_len_step_percentage=bench_args.input_len_step_percentage,
284
298
  run_name=bench_args.run_name,
285
299
  result_filename=bench_args.result_filename,
@@ -301,6 +315,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
301
315
  ol,
302
316
  temperature=bench_args.temperature,
303
317
  return_logprob=bench_args.return_logprob,
318
+ stream_interval=bench_args.client_stream_interval,
304
319
  input_len_step_percentage=bench_args.input_len_step_percentage,
305
320
  run_name=bench_args.run_name,
306
321
  result_filename=bench_args.result_filename,
sglang/bench_serving.py CHANGED
@@ -265,6 +265,138 @@ async def async_request_openai_completions(
265
265
  return output
266
266
 
267
267
 
268
+ async def async_request_openai_chat_completions(
269
+ request_func_input: RequestFuncInput,
270
+ pbar: Optional[tqdm] = None,
271
+ ) -> RequestFuncOutput:
272
+ """Makes a request to the OpenAI Chat Completions API.
273
+
274
+ Handles both streaming and non-streaming responses, including support
275
+ for image data in messages. Calculates and returns various performance
276
+ metrics.
277
+
278
+ Args:
279
+ request_func_input: Input parameters for the request.
280
+ pbar: Optional tqdm progress bar to update.
281
+
282
+ Returns:
283
+ RequestFuncOutput: Output of the request, including generated text,
284
+ latency, TTFT, ITL, and success status.
285
+ """
286
+ api_url = request_func_input.api_url
287
+ assert api_url.endswith(
288
+ "chat/completions"
289
+ ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
290
+
291
+ if request_func_input.image_data:
292
+ messages = [
293
+ {
294
+ "role": "user",
295
+ "content": [
296
+ {
297
+ "type": "image_url",
298
+ "image_url": {"url": request_func_input.image_data},
299
+ },
300
+ {"type": "text", "text": request_func_input.prompt},
301
+ ],
302
+ },
303
+ ]
304
+ else:
305
+ messages = [{"role": "user", "content": request_func_input.prompt}]
306
+
307
+ async with _create_bench_client_session() as session:
308
+ payload = {
309
+ "model": request_func_input.model,
310
+ "messages": messages,
311
+ "temperature": 0.0,
312
+ "max_tokens": request_func_input.output_len,
313
+ "stream": not args.disable_stream,
314
+ **request_func_input.extra_request_body,
315
+ }
316
+ headers = get_auth_headers()
317
+
318
+ output = RequestFuncOutput.init_new(request_func_input)
319
+
320
+ generated_text = ""
321
+ output_len = request_func_input.output_len
322
+ ttft = 0.0
323
+ st = time.perf_counter()
324
+ most_recent_timestamp = st
325
+ try:
326
+ async with session.post(
327
+ url=api_url, json=payload, headers=headers
328
+ ) as response:
329
+ if response.status == 200:
330
+ if args.disable_stream:
331
+ # Non-streaming response
332
+ response_json = await response.json()
333
+ output.generated_text = response_json["choices"][0]["message"][
334
+ "content"
335
+ ]
336
+ output.success = True
337
+ output.latency = time.perf_counter() - st
338
+ output.ttft = (
339
+ output.latency
340
+ ) # For non-streaming, TTFT = total latency
341
+ output.output_len = response_json.get("usage", {}).get(
342
+ "completion_tokens", output_len
343
+ )
344
+ else:
345
+ # Streaming response
346
+ async for chunk_bytes in response.content:
347
+ chunk_bytes = chunk_bytes.strip()
348
+ if not chunk_bytes:
349
+ continue
350
+
351
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
352
+ latency = time.perf_counter() - st
353
+ if chunk == "[DONE]":
354
+ pass
355
+ else:
356
+ data = json.loads(chunk)
357
+
358
+ # Check if this chunk contains content
359
+ delta = data.get("choices", [{}])[0].get("delta", {})
360
+ content = delta.get("content", "")
361
+
362
+ if content:
363
+ timestamp = time.perf_counter()
364
+ # First token
365
+ if ttft == 0.0:
366
+ ttft = timestamp - st
367
+ output.ttft = ttft
368
+
369
+ # Decoding phase
370
+ else:
371
+ output.itl.append(
372
+ timestamp - most_recent_timestamp
373
+ )
374
+
375
+ most_recent_timestamp = timestamp
376
+ generated_text += content
377
+
378
+ # Check for usage info in final chunk
379
+ output_len = (data.get("usage") or {}).get(
380
+ "completion_tokens", output_len
381
+ )
382
+
383
+ output.generated_text = generated_text
384
+ output.success = True
385
+ output.latency = latency
386
+ output.output_len = output_len
387
+ else:
388
+ output.error = response.reason or ""
389
+ output.success = False
390
+ except Exception:
391
+ output.success = False
392
+ exc_info = sys.exc_info()
393
+ output.error = "".join(traceback.format_exception(*exc_info))
394
+
395
+ if pbar:
396
+ pbar.update(1)
397
+ return output
398
+
399
+
268
400
  async def async_request_truss(
269
401
  request_func_input: RequestFuncInput,
270
402
  pbar: Optional[tqdm] = None,
@@ -544,6 +676,7 @@ def get_dataset(args, tokenizer):
544
676
  num_requests=args.num_prompts,
545
677
  tokenizer=tokenizer,
546
678
  fixed_output_len=args.random_output_len,
679
+ apply_chat_template=args.apply_chat_template,
547
680
  random_sample=True,
548
681
  )
549
682
  else:
@@ -555,8 +688,11 @@ ASYNC_REQUEST_FUNCS = {
555
688
  "sglang": async_request_sglang_generate,
556
689
  "sglang-native": async_request_sglang_generate,
557
690
  "sglang-oai": async_request_openai_completions,
691
+ "sglang-oai-chat": async_request_openai_chat_completions,
558
692
  "vllm": async_request_openai_completions,
693
+ "vllm-chat": async_request_openai_chat_completions,
559
694
  "lmdeploy": async_request_openai_completions,
695
+ "lmdeploy-chat": async_request_openai_chat_completions,
560
696
  "trt": async_request_trt_llm,
561
697
  "gserver": async_request_gserver,
562
698
  "truss": async_request_truss,
@@ -661,6 +797,7 @@ def sample_mmmu_requests(
661
797
  num_requests: int,
662
798
  tokenizer: PreTrainedTokenizerBase,
663
799
  fixed_output_len: Optional[int] = None,
800
+ apply_chat_template: bool = True,
664
801
  random_sample: bool = True,
665
802
  ) -> List[DatasetRow]:
666
803
  """
@@ -670,6 +807,7 @@ def sample_mmmu_requests(
670
807
  num_requests: Number of requests to sample.
671
808
  tokenizer: Tokenizer to use for token counting.
672
809
  fixed_output_len: If provided, use this fixed output length for all requests.
810
+ apply_chat_template: Whether to apply the chat template to the prompt.
673
811
  random_sample: Whether to randomly sample or take the first N.
674
812
 
675
813
  Returns:
@@ -739,28 +877,30 @@ def sample_mmmu_requests(
739
877
 
740
878
  # Construct the prompt
741
879
  prompt = f"Question: {question}\n\nAnswer: "
742
-
743
- try:
744
- prompt = tokenizer.apply_chat_template(
745
- [
746
- {
747
- "role": "user",
748
- "content": [
749
- {
750
- "type": "image_url",
751
- "image_url": {"url": image_data},
752
- },
753
- {"type": "text", "text": prompt},
754
- ],
755
- }
756
- ],
757
- add_generation_prompt=True,
758
- tokenize=False,
759
- )
760
- except Exception as e:
761
- # Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
762
- print(f"Error applying chat template: {e}, fallback to <image> tag")
763
- prompt = f"<image>{prompt}"
880
+ if apply_chat_template:
881
+ try:
882
+ prompt = tokenizer.apply_chat_template(
883
+ [
884
+ {
885
+ "role": "user",
886
+ "content": [
887
+ {
888
+ "type": "image_url",
889
+ "image_url": {"url": image_data},
890
+ },
891
+ {"type": "text", "text": prompt},
892
+ ],
893
+ }
894
+ ],
895
+ add_generation_prompt=True,
896
+ tokenize=False,
897
+ )
898
+ except Exception as e:
899
+ # Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
900
+ print(
901
+ f"Error applying chat template: {e}, fallback to <image> tag"
902
+ )
903
+ prompt = f"<image>{prompt}"
764
904
 
765
905
  # Calculate token lengths for text only (without image data)
766
906
  prompt_token_ids = tokenizer.encode(prompt)
@@ -1544,6 +1684,12 @@ def run_benchmark(args_: argparse.Namespace):
1544
1684
  if args.base_url
1545
1685
  else f"http://{args.host}:{args.port}/v1/completions"
1546
1686
  )
1687
+ elif args.backend in ["sglang-oai-chat", "vllm-chat", "lmdeploy-chat"]:
1688
+ api_url = (
1689
+ f"{args.base_url}/v1/chat/completions"
1690
+ if args.base_url
1691
+ else f"http://{args.host}:{args.port}/v1/chat/completions"
1692
+ )
1547
1693
  elif args.backend == "trt":
1548
1694
  api_url = (
1549
1695
  f"{args.base_url}/v2/models/ensemble/generate_stream"
@@ -147,12 +147,14 @@ class InternLM2Config(PretrainedConfig):
147
147
  )
148
148
  if (
149
149
  rope_scaling_factor is None
150
- or not isinstance(rope_scaling_factor, float)
150
+ or not isinstance(rope_scaling_factor, (float, int))
151
151
  or rope_scaling_factor < 1.0
152
152
  ):
153
153
  raise ValueError(
154
- f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}"
154
+ f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}"
155
155
  )
156
+ if isinstance(rope_scaling_factor, int):
157
+ rope_scaling_factor = float(rope_scaling_factor)
156
158
 
157
159
 
158
160
  class InternVisionConfig(PretrainedConfig):
@@ -19,7 +19,7 @@ from transformers import (
19
19
  from transformers.image_utils import to_numpy_array
20
20
 
21
21
  from sglang.srt.configs.utils import register_image_processor, register_processor
22
- from sglang.srt.mm_utils import expand2square
22
+ from sglang.srt.multimodal.mm_utils import expand2square
23
23
 
24
24
 
25
25
  class DictToObject(dict):
@@ -59,6 +59,7 @@ class ModelConfig:
59
59
  quantization: Optional[str] = None,
60
60
  override_config_file: Optional[str] = None,
61
61
  is_draft_model: bool = False,
62
+ hybrid_kvcache_ratio: Optional[float] = None,
62
63
  impl: Union[str, ModelImpl] = ModelImpl.AUTO,
63
64
  ) -> None:
64
65
 
@@ -86,6 +87,18 @@ class ModelConfig:
86
87
  self.attention_chunk_size = getattr(
87
88
  self.hf_text_config, "attention_chunk_size", None
88
89
  )
90
+ self.is_hybrid = is_hybrid_model(
91
+ self.hf_config.architectures,
92
+ hybrid_kvcache_ratio=hybrid_kvcache_ratio,
93
+ context_length=context_length,
94
+ attention_chunk_size=self.attention_chunk_size,
95
+ )
96
+ if self.is_hybrid is not None:
97
+ self.swa_attention_layer_ids, self.full_attention_layer_ids = (
98
+ get_hybrid_layer_ids(
99
+ self.hf_config.architectures, self.hf_text_config.num_hidden_layers
100
+ )
101
+ )
89
102
 
90
103
  if enable_multimodal is None:
91
104
  mm_disabled_models = [
@@ -264,6 +277,7 @@ class ModelConfig:
264
277
  enable_multimodal=server_args.enable_multimodal,
265
278
  dtype=server_args.dtype,
266
279
  quantization=server_args.quantization,
280
+ hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
267
281
  impl=server_args.impl,
268
282
  **kwargs,
269
283
  )
@@ -579,6 +593,7 @@ multimodal_model_archs = [
579
593
  "Mistral3ForConditionalGeneration",
580
594
  "MultiModalityCausalLM",
581
595
  "MllamaForConditionalGeneration",
596
+ "Qwen2AudioForConditionalGeneration",
582
597
  "Qwen2VLForConditionalGeneration",
583
598
  "Qwen2_5_VLForConditionalGeneration",
584
599
  "KimiVLForConditionalGeneration",
@@ -633,3 +648,36 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
633
648
  if scale <= 1:
634
649
  return 1.0
635
650
  return 0.1 * mscale * math.log(scale) + 1.0
651
+
652
+
653
+ def is_hybrid_model(
654
+ model_architectures: List[str],
655
+ hybrid_kvcache_ratio: Optional[float],
656
+ context_length: Optional[int],
657
+ attention_chunk_size: Optional[int],
658
+ ):
659
+ if hybrid_kvcache_ratio is None:
660
+ return None
661
+ elif (
662
+ hybrid_kvcache_ratio > 0
663
+ and model_architectures[0] == "Llama4ForConditionalGeneration"
664
+ and context_length > attention_chunk_size
665
+ ):
666
+ return hybrid_kvcache_ratio
667
+ else:
668
+ return None
669
+
670
+
671
+ def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
672
+ if "Llama4ForConditionalGeneration" in model_architectures:
673
+ swa_attention_layer_ids = [
674
+ i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
675
+ ]
676
+ full_attention_layer_ids = [
677
+ i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
678
+ ]
679
+ else:
680
+ raise ValueError(
681
+ "get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration"
682
+ )
683
+ return swa_attention_layer_ids, full_attention_layer_ids
@@ -0,0 +1,119 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ DEFAULT_MOE_PADDING_SIZE = 32
6
+
7
+
8
+ if TYPE_CHECKING:
9
+ from sglang.srt.configs.load_config import LoadConfig
10
+ from sglang.srt.configs.model_config import ModelConfig
11
+
12
+
13
+ def may_get_weight_block_size(model_config, load_config):
14
+ from sglang.srt.model_loader.loader import _get_quantization_config
15
+ from sglang.srt.model_loader.utils import get_model_architecture
16
+
17
+ model_class, _ = get_model_architecture(model_config)
18
+ packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
19
+
20
+ quant_config = _get_quantization_config(
21
+ model_config, load_config, packed_modules_mapping
22
+ )
23
+
24
+ if quant_config is not None and hasattr(quant_config, "weight_block_size"):
25
+ return getattr(quant_config, "weight_block_size")
26
+ return None
27
+
28
+
29
+ def get_moe_padding_size(weight_block_size):
30
+ if weight_block_size is not None:
31
+ # See NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
32
+ assert (
33
+ len(weight_block_size) == 2
34
+ ), "Only len(weight_block_size) == 2 is supported"
35
+ assert (
36
+ weight_block_size[0] == weight_block_size[1]
37
+ ), "Only weight_block_size[0] == weight_block_size[1] is supported"
38
+
39
+ return weight_block_size[0]
40
+
41
+ return DEFAULT_MOE_PADDING_SIZE
42
+
43
+
44
+ def get_num_heads_padding_size(tp_size, weight_block_size):
45
+ pad_size = (
46
+ tp_size * 2 if tp_size % 2 == 1 and weight_block_size is not None else tp_size
47
+ )
48
+ return pad_size
49
+
50
+
51
+ def update_intermediate_size(model_config, attr_name, intermediate_padding_size):
52
+ if hasattr(model_config.hf_config, attr_name):
53
+ attr_value = getattr(model_config.hf_config, attr_name)
54
+ if attr_value % intermediate_padding_size != 0:
55
+ from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
56
+
57
+ attr_value = pad_vocab_size(attr_value, intermediate_padding_size)
58
+ setattr(model_config.hf_config, attr_name, attr_value)
59
+ setattr(model_config.hf_text_config, attr_name, attr_value)
60
+ return model_config
61
+
62
+
63
+ def adjust_config_with_unaligned_cpu_tp(
64
+ model_config: ModelConfig, load_config: LoadConfig, tp_size: int
65
+ ) -> ModelConfig:
66
+ # Support the case where the num_attention_heads is not divisible by the TP size.
67
+ weight_block_size = may_get_weight_block_size(model_config, load_config)
68
+
69
+ model_config.hf_config.original_num_attention_heads = (
70
+ model_config.num_attention_heads
71
+ )
72
+ model_config.hf_text_config.original_num_attention_heads = (
73
+ model_config.num_attention_heads
74
+ )
75
+
76
+ model_config.hf_config.original_total_num_kv_heads = (
77
+ model_config.get_total_num_kv_heads()
78
+ )
79
+ model_config.hf_text_config.original_total_num_kv_heads = (
80
+ model_config.get_total_num_kv_heads()
81
+ )
82
+
83
+ if (
84
+ model_config.num_attention_heads % tp_size != 0
85
+ or model_config.get_total_num_kv_heads() % tp_size != 0
86
+ ):
87
+ # Compute the head_dim using the model_config.num_attention_heads before padding
88
+ if not hasattr(model_config.hf_config, "head_dim"):
89
+ model_config.hf_config.head_dim = (
90
+ model_config.hidden_size // model_config.num_attention_heads
91
+ )
92
+
93
+ query_heads_per_kv = (
94
+ model_config.num_attention_heads // model_config.get_total_num_kv_heads()
95
+ )
96
+ total_kv_heads = model_config.get_total_num_kv_heads()
97
+ from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
98
+
99
+ pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
100
+ num_key_value_heads = pad_vocab_size(total_kv_heads, pad_size)
101
+
102
+ model_config.num_key_value_heads = num_key_value_heads
103
+ model_config.hf_config.num_key_value_heads = num_key_value_heads
104
+ model_config.hf_text_config.num_key_value_heads = num_key_value_heads
105
+
106
+ num_attention_heads = num_key_value_heads * query_heads_per_kv
107
+ model_config.num_attention_heads = num_attention_heads
108
+ model_config.hf_config.num_attention_heads = num_attention_heads
109
+ model_config.hf_text_config.num_attention_heads = num_attention_heads
110
+
111
+ intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size)
112
+ model_config = update_intermediate_size(
113
+ model_config, "moe_intermediate_size", intermediate_padding_size
114
+ )
115
+ model_config = update_intermediate_size(
116
+ model_config, "intermediate_size", intermediate_padding_size
117
+ )
118
+
119
+ return model_config
@@ -59,6 +59,7 @@ class SeparatorStyle(IntEnum):
59
59
  METAMATH = auto()
60
60
  DeepSeekVL2 = auto()
61
61
  QWEN2_VL_EMBED = auto()
62
+ QWEN2_AUDIO = auto()
62
63
  GEMMA3 = auto()
63
64
  MPT = auto()
64
65
 
@@ -350,6 +351,23 @@ class Conversation:
350
351
  else:
351
352
  ret += role
352
353
  return ret
354
+ elif self.sep_style == SeparatorStyle.QWEN2_AUDIO:
355
+ ret = "" if system_prompt == "" else system_prompt + self.sep
356
+
357
+ counter = 1
358
+ for role, message in self.messages:
359
+ if message:
360
+ while self.audio_token in message:
361
+ message = message.replace(
362
+ self.audio_token, self.audio_token.format(idx=counter), 1
363
+ )
364
+ counter += 1
365
+
366
+ ret += role + "\n" + message + self.sep
367
+ else:
368
+ ret += role + "\n"
369
+
370
+ return ret
353
371
  else:
354
372
  raise ValueError(f"Invalid style: {self.sep_style}")
355
373
 
@@ -904,6 +922,20 @@ register_conv_template(
904
922
  )
905
923
 
906
924
 
925
+ register_conv_template(
926
+ Conversation(
927
+ name="qwen2-audio",
928
+ system_template="<|im_start|>system\n{system_message}",
929
+ system_message="You are a helpful assistant.",
930
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
931
+ sep="<|im_end|>\n",
932
+ sep_style=SeparatorStyle.QWEN2_AUDIO,
933
+ stop_str=["<|im_end|>"],
934
+ audio_token="Audio {idx}: <|audio_bos|><|AUDIO|><|audio_eos|>\n",
935
+ )
936
+ )
937
+
938
+
907
939
  @register_conv_template_matching_function
908
940
  def match_internvl(model_path: str):
909
941
  if re.search(r"internvl2_5", model_path, re.IGNORECASE):
@@ -956,6 +988,8 @@ def match_qwen_chat_ml(model_path: str):
956
988
  return "gme-qwen2-vl"
957
989
  if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
958
990
  return "qwen2-vl"
991
+ if re.search(r"qwen.*audio", model_path, re.IGNORECASE):
992
+ return "qwen2-audio"
959
993
  if re.search(
960
994
  r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
961
995
  model_path,
@@ -416,6 +416,12 @@ class DecodePreallocQueue:
416
416
 
417
417
  return preallocated_reqs
418
418
 
419
+ @property
420
+ def num_tokens_pre_allocated(self):
421
+ return sum(
422
+ len(decode_req.req.fill_ids) for decode_req in self.transfer_queue.queue
423
+ )
424
+
419
425
  def _allocatable_tokens(
420
426
  self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
421
427
  ) -> int:
@@ -433,9 +439,7 @@ class DecodePreallocQueue:
433
439
  else 0
434
440
  )
435
441
 
436
- available_size = self.token_to_kv_pool_allocator.available_size()
437
-
438
- allocatable_tokens = available_size - max(
442
+ allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max(
439
443
  # preserve some space for future decode
440
444
  self.num_reserved_decode_tokens
441
445
  * (
@@ -606,9 +610,21 @@ class DecodeTransferQueue:
606
610
  : decode_req.req.top_logprobs_num
607
611
  ].tolist()
608
612
  )
613
+
609
614
  if hasattr(decode_req.kv_receiver, "clear"):
610
615
  decode_req.kv_receiver.clear()
611
- transferred_reqs.append(decode_req.req)
616
+
617
+ # special handling for sampling_params.max_new_tokens == 1
618
+ if decode_req.req.sampling_params.max_new_tokens == 1:
619
+ # finish immediately
620
+ decode_req.req.check_finished()
621
+ self.scheduler.stream_output(
622
+ [decode_req.req], decode_req.req.return_logprob
623
+ )
624
+ self.tree_cache.cache_finished_req(decode_req.req)
625
+ else:
626
+ transferred_reqs.append(decode_req.req)
627
+
612
628
  indices_to_remove.add(i)
613
629
  elif poll in [
614
630
  KVPoll.Bootstrapping,
@@ -756,7 +772,7 @@ class SchedulerDisaggregationDecodeMixin:
756
772
  self.last_batch_in_queue = last_batch_in_queue
757
773
 
758
774
  def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
759
- batch, _ = self.prepare_mlp_sync_batch(batch)
775
+ batch = self.prepare_mlp_sync_batch(batch)
760
776
  result = None
761
777
  if batch:
762
778
  result = self.run_batch(batch)