sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,7 @@ class BenchArgs:
38
38
  output_len: Tuple[int] = (16,)
39
39
  temperature: float = 0.0
40
40
  return_logprob: bool = False
41
+ client_stream_interval: int = 1
41
42
  input_len_step_percentage: float = 0.0
42
43
  result_filename: str = "result.jsonl"
43
44
  base_url: str = ""
@@ -60,6 +61,11 @@ class BenchArgs:
60
61
  )
61
62
  parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
62
63
  parser.add_argument("--return-logprob", action="store_true")
64
+ parser.add_argument(
65
+ "--client-stream-interval",
66
+ type=int,
67
+ default=BenchArgs.client_stream_interval,
68
+ )
63
69
  parser.add_argument(
64
70
  "--input-len-step-percentage",
65
71
  type=float,
@@ -120,6 +126,7 @@ def run_one_case(
120
126
  output_len: int,
121
127
  temperature: float,
122
128
  return_logprob: bool,
129
+ stream_interval: int,
123
130
  input_len_step_percentage: float,
124
131
  run_name: str,
125
132
  result_filename: str,
@@ -168,6 +175,7 @@ def run_one_case(
168
175
  "max_new_tokens": output_len,
169
176
  "ignore_eos": True,
170
177
  "json_schema": json_schema,
178
+ "stream_interval": stream_interval,
171
179
  },
172
180
  "return_logprob": return_logprob,
173
181
  "stream": True,
@@ -245,8 +253,12 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
245
253
  else:
246
254
  proc, base_url = launch_server_process(server_args)
247
255
 
248
- tokenizer_id = server_args.tokenizer_path or server_args.model_path
249
- tokenizer = get_tokenizer(tokenizer_id)
256
+ server_info = requests.get(base_url + "/get_server_info").json()
257
+ if "tokenizer_path" in server_info:
258
+ tokenizer_path = server_info["tokenizer_path"]
259
+ elif "prefill" in server_info:
260
+ tokenizer_path = server_info["prefill"][0]["tokenizer_path"]
261
+ tokenizer = get_tokenizer(tokenizer_path)
250
262
 
251
263
  # warmup
252
264
  if not bench_args.skip_warmup:
@@ -258,6 +270,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
258
270
  output_len=16,
259
271
  temperature=bench_args.temperature,
260
272
  return_logprob=bench_args.return_logprob,
273
+ stream_interval=bench_args.client_stream_interval,
261
274
  input_len_step_percentage=bench_args.input_len_step_percentage,
262
275
  run_name="",
263
276
  result_filename="",
@@ -280,6 +293,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
280
293
  ol,
281
294
  temperature=bench_args.temperature,
282
295
  return_logprob=bench_args.return_logprob,
296
+ stream_interval=bench_args.client_stream_interval,
283
297
  input_len_step_percentage=bench_args.input_len_step_percentage,
284
298
  run_name=bench_args.run_name,
285
299
  result_filename=bench_args.result_filename,
@@ -301,6 +315,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
301
315
  ol,
302
316
  temperature=bench_args.temperature,
303
317
  return_logprob=bench_args.return_logprob,
318
+ stream_interval=bench_args.client_stream_interval,
304
319
  input_len_step_percentage=bench_args.input_len_step_percentage,
305
320
  run_name=bench_args.run_name,
306
321
  result_filename=bench_args.result_filename,
sglang/bench_serving.py CHANGED
@@ -265,6 +265,138 @@ async def async_request_openai_completions(
265
265
  return output
266
266
 
267
267
 
268
+ async def async_request_openai_chat_completions(
269
+ request_func_input: RequestFuncInput,
270
+ pbar: Optional[tqdm] = None,
271
+ ) -> RequestFuncOutput:
272
+ """Makes a request to the OpenAI Chat Completions API.
273
+
274
+ Handles both streaming and non-streaming responses, including support
275
+ for image data in messages. Calculates and returns various performance
276
+ metrics.
277
+
278
+ Args:
279
+ request_func_input: Input parameters for the request.
280
+ pbar: Optional tqdm progress bar to update.
281
+
282
+ Returns:
283
+ RequestFuncOutput: Output of the request, including generated text,
284
+ latency, TTFT, ITL, and success status.
285
+ """
286
+ api_url = request_func_input.api_url
287
+ assert api_url.endswith(
288
+ "chat/completions"
289
+ ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
290
+
291
+ if request_func_input.image_data:
292
+ messages = [
293
+ {
294
+ "role": "user",
295
+ "content": [
296
+ {
297
+ "type": "image_url",
298
+ "image_url": {"url": request_func_input.image_data},
299
+ },
300
+ {"type": "text", "text": request_func_input.prompt},
301
+ ],
302
+ },
303
+ ]
304
+ else:
305
+ messages = [{"role": "user", "content": request_func_input.prompt}]
306
+
307
+ async with _create_bench_client_session() as session:
308
+ payload = {
309
+ "model": request_func_input.model,
310
+ "messages": messages,
311
+ "temperature": 0.0,
312
+ "max_tokens": request_func_input.output_len,
313
+ "stream": not args.disable_stream,
314
+ **request_func_input.extra_request_body,
315
+ }
316
+ headers = get_auth_headers()
317
+
318
+ output = RequestFuncOutput.init_new(request_func_input)
319
+
320
+ generated_text = ""
321
+ output_len = request_func_input.output_len
322
+ ttft = 0.0
323
+ st = time.perf_counter()
324
+ most_recent_timestamp = st
325
+ try:
326
+ async with session.post(
327
+ url=api_url, json=payload, headers=headers
328
+ ) as response:
329
+ if response.status == 200:
330
+ if args.disable_stream:
331
+ # Non-streaming response
332
+ response_json = await response.json()
333
+ output.generated_text = response_json["choices"][0]["message"][
334
+ "content"
335
+ ]
336
+ output.success = True
337
+ output.latency = time.perf_counter() - st
338
+ output.ttft = (
339
+ output.latency
340
+ ) # For non-streaming, TTFT = total latency
341
+ output.output_len = response_json.get("usage", {}).get(
342
+ "completion_tokens", output_len
343
+ )
344
+ else:
345
+ # Streaming response
346
+ async for chunk_bytes in response.content:
347
+ chunk_bytes = chunk_bytes.strip()
348
+ if not chunk_bytes:
349
+ continue
350
+
351
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
352
+ latency = time.perf_counter() - st
353
+ if chunk == "[DONE]":
354
+ pass
355
+ else:
356
+ data = json.loads(chunk)
357
+
358
+ # Check if this chunk contains content
359
+ delta = data.get("choices", [{}])[0].get("delta", {})
360
+ content = delta.get("content", "")
361
+
362
+ if content:
363
+ timestamp = time.perf_counter()
364
+ # First token
365
+ if ttft == 0.0:
366
+ ttft = timestamp - st
367
+ output.ttft = ttft
368
+
369
+ # Decoding phase
370
+ else:
371
+ output.itl.append(
372
+ timestamp - most_recent_timestamp
373
+ )
374
+
375
+ most_recent_timestamp = timestamp
376
+ generated_text += content
377
+
378
+ # Check for usage info in final chunk
379
+ output_len = (data.get("usage") or {}).get(
380
+ "completion_tokens", output_len
381
+ )
382
+
383
+ output.generated_text = generated_text
384
+ output.success = True
385
+ output.latency = latency
386
+ output.output_len = output_len
387
+ else:
388
+ output.error = response.reason or ""
389
+ output.success = False
390
+ except Exception:
391
+ output.success = False
392
+ exc_info = sys.exc_info()
393
+ output.error = "".join(traceback.format_exception(*exc_info))
394
+
395
+ if pbar:
396
+ pbar.update(1)
397
+ return output
398
+
399
+
268
400
  async def async_request_truss(
269
401
  request_func_input: RequestFuncInput,
270
402
  pbar: Optional[tqdm] = None,
@@ -544,6 +676,7 @@ def get_dataset(args, tokenizer):
544
676
  num_requests=args.num_prompts,
545
677
  tokenizer=tokenizer,
546
678
  fixed_output_len=args.random_output_len,
679
+ apply_chat_template=args.apply_chat_template,
547
680
  random_sample=True,
548
681
  )
549
682
  else:
@@ -555,8 +688,11 @@ ASYNC_REQUEST_FUNCS = {
555
688
  "sglang": async_request_sglang_generate,
556
689
  "sglang-native": async_request_sglang_generate,
557
690
  "sglang-oai": async_request_openai_completions,
691
+ "sglang-oai-chat": async_request_openai_chat_completions,
558
692
  "vllm": async_request_openai_completions,
693
+ "vllm-chat": async_request_openai_chat_completions,
559
694
  "lmdeploy": async_request_openai_completions,
695
+ "lmdeploy-chat": async_request_openai_chat_completions,
560
696
  "trt": async_request_trt_llm,
561
697
  "gserver": async_request_gserver,
562
698
  "truss": async_request_truss,
@@ -661,6 +797,7 @@ def sample_mmmu_requests(
661
797
  num_requests: int,
662
798
  tokenizer: PreTrainedTokenizerBase,
663
799
  fixed_output_len: Optional[int] = None,
800
+ apply_chat_template: bool = True,
664
801
  random_sample: bool = True,
665
802
  ) -> List[DatasetRow]:
666
803
  """
@@ -670,15 +807,16 @@ def sample_mmmu_requests(
670
807
  num_requests: Number of requests to sample.
671
808
  tokenizer: Tokenizer to use for token counting.
672
809
  fixed_output_len: If provided, use this fixed output length for all requests.
810
+ apply_chat_template: Whether to apply the chat template to the prompt.
673
811
  random_sample: Whether to randomly sample or take the first N.
674
812
 
675
813
  Returns:
676
814
  List of tuples (prompt, prompt_token_len, output_token_len).
677
815
  """
678
816
  try:
679
- import base64
680
817
  import io
681
818
 
819
+ import pybase64
682
820
  from datasets import load_dataset
683
821
  except ImportError:
684
822
  raise ImportError("Please install datasets: pip install datasets")
@@ -729,7 +867,7 @@ def sample_mmmu_requests(
729
867
  # Encode image to base64
730
868
  buffered = io.BytesIO()
731
869
  image.save(buffered, format="JPEG")
732
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
870
+ img_str = pybase64.b64encode(buffered.getvalue()).decode("utf-8")
733
871
  image_data = f"data:image/jpeg;base64,{img_str}"
734
872
  else:
735
873
  continue
@@ -739,28 +877,30 @@ def sample_mmmu_requests(
739
877
 
740
878
  # Construct the prompt
741
879
  prompt = f"Question: {question}\n\nAnswer: "
742
-
743
- try:
744
- prompt = tokenizer.apply_chat_template(
745
- [
746
- {
747
- "role": "user",
748
- "content": [
749
- {
750
- "type": "image_url",
751
- "image_url": {"url": image_data},
752
- },
753
- {"type": "text", "text": prompt},
754
- ],
755
- }
756
- ],
757
- add_generation_prompt=True,
758
- tokenize=False,
759
- )
760
- except Exception as e:
761
- # Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
762
- print(f"Error applying chat template: {e}, fallback to <image> tag")
763
- prompt = f"<image>{prompt}"
880
+ if apply_chat_template:
881
+ try:
882
+ prompt = tokenizer.apply_chat_template(
883
+ [
884
+ {
885
+ "role": "user",
886
+ "content": [
887
+ {
888
+ "type": "image_url",
889
+ "image_url": {"url": image_data},
890
+ },
891
+ {"type": "text", "text": prompt},
892
+ ],
893
+ }
894
+ ],
895
+ add_generation_prompt=True,
896
+ tokenize=False,
897
+ )
898
+ except Exception as e:
899
+ # Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
900
+ print(
901
+ f"Error applying chat template: {e}, fallback to <image> tag"
902
+ )
903
+ prompt = f"<image>{prompt}"
764
904
 
765
905
  # Calculate token lengths for text only (without image data)
766
906
  prompt_token_ids = tokenizer.encode(prompt)
@@ -1544,6 +1684,12 @@ def run_benchmark(args_: argparse.Namespace):
1544
1684
  if args.base_url
1545
1685
  else f"http://{args.host}:{args.port}/v1/completions"
1546
1686
  )
1687
+ elif args.backend in ["sglang-oai-chat", "vllm-chat", "lmdeploy-chat"]:
1688
+ api_url = (
1689
+ f"{args.base_url}/v1/chat/completions"
1690
+ if args.base_url
1691
+ else f"http://{args.host}:{args.port}/v1/chat/completions"
1692
+ )
1547
1693
  elif args.backend == "trt":
1548
1694
  api_url = (
1549
1695
  f"{args.base_url}/v2/models/ensemble/generate_stream"
@@ -147,12 +147,14 @@ class InternLM2Config(PretrainedConfig):
147
147
  )
148
148
  if (
149
149
  rope_scaling_factor is None
150
- or not isinstance(rope_scaling_factor, float)
150
+ or not isinstance(rope_scaling_factor, (float, int))
151
151
  or rope_scaling_factor < 1.0
152
152
  ):
153
153
  raise ValueError(
154
- f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}"
154
+ f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}"
155
155
  )
156
+ if isinstance(rope_scaling_factor, int):
157
+ rope_scaling_factor = float(rope_scaling_factor)
156
158
 
157
159
 
158
160
  class InternVisionConfig(PretrainedConfig):
@@ -19,7 +19,7 @@ from transformers import (
19
19
  from transformers.image_utils import to_numpy_array
20
20
 
21
21
  from sglang.srt.configs.utils import register_image_processor, register_processor
22
- from sglang.srt.mm_utils import expand2square
22
+ from sglang.srt.multimodal.mm_utils import expand2square
23
23
 
24
24
 
25
25
  class DictToObject(dict):
@@ -59,6 +59,7 @@ class ModelConfig:
59
59
  quantization: Optional[str] = None,
60
60
  override_config_file: Optional[str] = None,
61
61
  is_draft_model: bool = False,
62
+ hybrid_kvcache_ratio: Optional[float] = None,
62
63
  impl: Union[str, ModelImpl] = ModelImpl.AUTO,
63
64
  ) -> None:
64
65
 
@@ -86,6 +87,18 @@ class ModelConfig:
86
87
  self.attention_chunk_size = getattr(
87
88
  self.hf_text_config, "attention_chunk_size", None
88
89
  )
90
+ self.is_hybrid = is_hybrid_model(
91
+ self.hf_config.architectures,
92
+ hybrid_kvcache_ratio=hybrid_kvcache_ratio,
93
+ context_length=context_length,
94
+ attention_chunk_size=self.attention_chunk_size,
95
+ )
96
+ if self.is_hybrid is not None:
97
+ self.swa_attention_layer_ids, self.full_attention_layer_ids = (
98
+ get_hybrid_layer_ids(
99
+ self.hf_config.architectures, self.hf_text_config.num_hidden_layers
100
+ )
101
+ )
89
102
 
90
103
  if enable_multimodal is None:
91
104
  mm_disabled_models = [
@@ -264,6 +277,7 @@ class ModelConfig:
264
277
  enable_multimodal=server_args.enable_multimodal,
265
278
  dtype=server_args.dtype,
266
279
  quantization=server_args.quantization,
280
+ hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
267
281
  impl=server_args.impl,
268
282
  **kwargs,
269
283
  )
@@ -345,7 +359,17 @@ class ModelConfig:
345
359
  if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
346
360
  quant_cfg = modelopt_quant_config
347
361
  elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
348
- quant_cfg = modelopt_quant_config
362
+ quant_config_file = os.path.join(
363
+ self.model_path, "hf_quant_config.json"
364
+ )
365
+ with open(quant_config_file) as f:
366
+ quant_config_dict = json.load(f)
367
+ json_quant_configs = quant_config_dict["quantization"]
368
+ quant_algo = json_quant_configs.get("quant_algo", None)
369
+ if quant_algo == "MIXED_PRECISION":
370
+ quant_cfg = {"quant_method": "w4afp8"}
371
+ else:
372
+ quant_cfg = modelopt_quant_config
349
373
  return quant_cfg
350
374
 
351
375
  # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
@@ -375,6 +399,7 @@ class ModelConfig:
375
399
  "w8a8_fp8",
376
400
  "moe_wna16",
377
401
  "qoq",
402
+ "w4afp8",
378
403
  ]
379
404
  compatible_quantization_methods = {
380
405
  "modelopt_fp4": ["modelopt"],
@@ -579,6 +604,7 @@ multimodal_model_archs = [
579
604
  "Mistral3ForConditionalGeneration",
580
605
  "MultiModalityCausalLM",
581
606
  "MllamaForConditionalGeneration",
607
+ "Qwen2AudioForConditionalGeneration",
582
608
  "Qwen2VLForConditionalGeneration",
583
609
  "Qwen2_5_VLForConditionalGeneration",
584
610
  "KimiVLForConditionalGeneration",
@@ -633,3 +659,36 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
633
659
  if scale <= 1:
634
660
  return 1.0
635
661
  return 0.1 * mscale * math.log(scale) + 1.0
662
+
663
+
664
+ def is_hybrid_model(
665
+ model_architectures: List[str],
666
+ hybrid_kvcache_ratio: Optional[float],
667
+ context_length: Optional[int],
668
+ attention_chunk_size: Optional[int],
669
+ ):
670
+ if hybrid_kvcache_ratio is None:
671
+ return None
672
+ elif (
673
+ hybrid_kvcache_ratio > 0
674
+ and model_architectures[0] == "Llama4ForConditionalGeneration"
675
+ and context_length > attention_chunk_size
676
+ ):
677
+ return hybrid_kvcache_ratio
678
+ else:
679
+ return None
680
+
681
+
682
+ def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
683
+ if "Llama4ForConditionalGeneration" in model_architectures:
684
+ swa_attention_layer_ids = [
685
+ i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
686
+ ]
687
+ full_attention_layer_ids = [
688
+ i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
689
+ ]
690
+ else:
691
+ raise ValueError(
692
+ "get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration"
693
+ )
694
+ return swa_attention_layer_ids, full_attention_layer_ids
@@ -0,0 +1,119 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ DEFAULT_MOE_PADDING_SIZE = 32
6
+
7
+
8
+ if TYPE_CHECKING:
9
+ from sglang.srt.configs.load_config import LoadConfig
10
+ from sglang.srt.configs.model_config import ModelConfig
11
+
12
+
13
+ def may_get_weight_block_size(model_config, load_config):
14
+ from sglang.srt.model_loader.loader import _get_quantization_config
15
+ from sglang.srt.model_loader.utils import get_model_architecture
16
+
17
+ model_class, _ = get_model_architecture(model_config)
18
+ packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
19
+
20
+ quant_config = _get_quantization_config(
21
+ model_config, load_config, packed_modules_mapping
22
+ )
23
+
24
+ if quant_config is not None and hasattr(quant_config, "weight_block_size"):
25
+ return getattr(quant_config, "weight_block_size")
26
+ return None
27
+
28
+
29
+ def get_moe_padding_size(weight_block_size):
30
+ if weight_block_size is not None:
31
+ # See NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
32
+ assert (
33
+ len(weight_block_size) == 2
34
+ ), "Only len(weight_block_size) == 2 is supported"
35
+ assert (
36
+ weight_block_size[0] == weight_block_size[1]
37
+ ), "Only weight_block_size[0] == weight_block_size[1] is supported"
38
+
39
+ return weight_block_size[0]
40
+
41
+ return DEFAULT_MOE_PADDING_SIZE
42
+
43
+
44
+ def get_num_heads_padding_size(tp_size, weight_block_size):
45
+ pad_size = (
46
+ tp_size * 2 if tp_size % 2 == 1 and weight_block_size is not None else tp_size
47
+ )
48
+ return pad_size
49
+
50
+
51
+ def update_intermediate_size(model_config, attr_name, intermediate_padding_size):
52
+ if hasattr(model_config.hf_config, attr_name):
53
+ attr_value = getattr(model_config.hf_config, attr_name)
54
+ if attr_value % intermediate_padding_size != 0:
55
+ from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
56
+
57
+ attr_value = pad_vocab_size(attr_value, intermediate_padding_size)
58
+ setattr(model_config.hf_config, attr_name, attr_value)
59
+ setattr(model_config.hf_text_config, attr_name, attr_value)
60
+ return model_config
61
+
62
+
63
+ def adjust_config_with_unaligned_cpu_tp(
64
+ model_config: ModelConfig, load_config: LoadConfig, tp_size: int
65
+ ) -> ModelConfig:
66
+ # Support the case where the num_attention_heads is not divisible by the TP size.
67
+ weight_block_size = may_get_weight_block_size(model_config, load_config)
68
+
69
+ model_config.hf_config.original_num_attention_heads = (
70
+ model_config.num_attention_heads
71
+ )
72
+ model_config.hf_text_config.original_num_attention_heads = (
73
+ model_config.num_attention_heads
74
+ )
75
+
76
+ model_config.hf_config.original_total_num_kv_heads = (
77
+ model_config.get_total_num_kv_heads()
78
+ )
79
+ model_config.hf_text_config.original_total_num_kv_heads = (
80
+ model_config.get_total_num_kv_heads()
81
+ )
82
+
83
+ if (
84
+ model_config.num_attention_heads % tp_size != 0
85
+ or model_config.get_total_num_kv_heads() % tp_size != 0
86
+ ):
87
+ # Compute the head_dim using the model_config.num_attention_heads before padding
88
+ if not hasattr(model_config.hf_config, "head_dim"):
89
+ model_config.hf_config.head_dim = (
90
+ model_config.hidden_size // model_config.num_attention_heads
91
+ )
92
+
93
+ query_heads_per_kv = (
94
+ model_config.num_attention_heads // model_config.get_total_num_kv_heads()
95
+ )
96
+ total_kv_heads = model_config.get_total_num_kv_heads()
97
+ from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
98
+
99
+ pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
100
+ num_key_value_heads = pad_vocab_size(total_kv_heads, pad_size)
101
+
102
+ model_config.num_key_value_heads = num_key_value_heads
103
+ model_config.hf_config.num_key_value_heads = num_key_value_heads
104
+ model_config.hf_text_config.num_key_value_heads = num_key_value_heads
105
+
106
+ num_attention_heads = num_key_value_heads * query_heads_per_kv
107
+ model_config.num_attention_heads = num_attention_heads
108
+ model_config.hf_config.num_attention_heads = num_attention_heads
109
+ model_config.hf_text_config.num_attention_heads = num_attention_heads
110
+
111
+ intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size)
112
+ model_config = update_intermediate_size(
113
+ model_config, "moe_intermediate_size", intermediate_padding_size
114
+ )
115
+ model_config = update_intermediate_size(
116
+ model_config, "intermediate_size", intermediate_padding_size
117
+ )
118
+
119
+ return model_config