sglang 0.4.6__py3-none-any.whl → 0.4.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. sglang/srt/disaggregation/decode.py +8 -2
  2. sglang/srt/disaggregation/fake/__init__.py +1 -0
  3. sglang/srt/disaggregation/fake/conn.py +88 -0
  4. sglang/srt/disaggregation/prefill.py +12 -3
  5. sglang/srt/disaggregation/utils.py +16 -2
  6. sglang/srt/entrypoints/engine.py +9 -0
  7. sglang/srt/entrypoints/http_server.py +27 -2
  8. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  9. sglang/srt/layers/attention/utils.py +1 -1
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -2
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  24. sglang/srt/layers/quantization/fp8.py +20 -22
  25. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  26. sglang/srt/managers/schedule_batch.py +9 -0
  27. sglang/srt/managers/scheduler.py +10 -8
  28. sglang/srt/managers/scheduler_output_processor_mixin.py +25 -9
  29. sglang/srt/managers/tp_worker.py +3 -3
  30. sglang/srt/managers/tp_worker_overlap_thread.py +9 -4
  31. sglang/srt/model_executor/model_runner.py +8 -1
  32. sglang/srt/openai_api/adapter.py +32 -3
  33. sglang/srt/openai_api/protocol.py +2 -0
  34. sglang/srt/reasoning_parser.py +25 -1
  35. sglang/srt/server_args.py +16 -2
  36. sglang/srt/utils.py +3 -0
  37. sglang/test/send_one.py +84 -28
  38. sglang/test/test_utils.py +38 -0
  39. sglang/version.py +1 -1
  40. {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/METADATA +2 -2
  41. {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/RECORD +44 -29
  42. {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/WHEEL +0 -0
  43. {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/licenses/LICENSE +0 -0
  44. {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/top_level.txt +0 -0
@@ -72,8 +72,8 @@ _is_hip = is_hip()
72
72
  _is_cuda = is_cuda()
73
73
 
74
74
  if _is_hip:
75
- from aiter import ActivationType
76
- from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
75
+ from aiter import ActivationType, QuantType
76
+ from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
77
77
  from aiter.ops.shuffle import shuffle_weight
78
78
 
79
79
  if not _is_cuda:
@@ -484,7 +484,7 @@ class Fp8MoEMethod:
484
484
  if self.quant_config.is_checkpoint_fp8_serialized:
485
485
  params_dtype = (
486
486
  torch.uint32
487
- if get_bool_env_var("USE_INT4_WEIGHT")
487
+ if get_bool_env_var("SGLANG_INT4_WEIGHT")
488
488
  else torch.float8_e4m3fn
489
489
  )
490
490
  tp_size = get_tensor_model_parallel_world_size()
@@ -511,7 +511,7 @@ class Fp8MoEMethod:
511
511
  )
512
512
 
513
513
  # WEIGHTS
514
- if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
514
+ if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
515
515
  # INT4 MoE weight - INT32 packed
516
516
  w13_weight = torch.nn.Parameter(
517
517
  torch.empty(
@@ -585,7 +585,7 @@ class Fp8MoEMethod:
585
585
 
586
586
  if (
587
587
  _is_hip
588
- ): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
588
+ ): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
589
589
  # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
590
590
  w13_weight_scale1 = torch.nn.Parameter(
591
591
  torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
@@ -612,7 +612,7 @@ class Fp8MoEMethod:
612
612
  set_weight_attrs(w13_weight_scale, extra_weight_attrs)
613
613
  set_weight_attrs(w2_weight_scale, extra_weight_attrs)
614
614
 
615
- if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
615
+ if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
616
616
  extra_weight_attrs.update(
617
617
  {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
618
618
  )
@@ -644,7 +644,7 @@ class Fp8MoEMethod:
644
644
  layer.w2_input_scale = None
645
645
 
646
646
  def process_weights_after_loading(self, layer: Module) -> None:
647
- if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
647
+ if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
648
648
  self.process_weights_hip_int4(layer)
649
649
  return
650
650
 
@@ -675,7 +675,7 @@ class Fp8MoEMethod:
675
675
  )
676
676
  layer.w2_input_scale = None
677
677
 
678
- if get_bool_env_var("CK_MOE"):
678
+ if get_bool_env_var("SGLANG_AITER_MOE"):
679
679
  # Pre-shuffle weights
680
680
  layer.w13_weight.data = shuffle_weight(
681
681
  layer.w13_weight.contiguous(), (16, 16)
@@ -798,17 +798,15 @@ class Fp8MoEMethod:
798
798
  return
799
799
 
800
800
  def process_weights_hip_int4(self, layer: Module):
801
- # TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
801
+ # TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
802
802
  # INT4-FP8 (INT4 MoE Weight, FP8 Compute)
803
803
  # Weight Permutation
804
804
  layer.w13_weight = torch.nn.Parameter(
805
- # permute_weight(layer.w13_weight.data),
806
805
  shuffle_weight(layer.w13_weight.data, (16, 16)),
807
806
  requires_grad=False,
808
807
  )
809
808
  torch.cuda.empty_cache()
810
809
  layer.w2_weight = torch.nn.Parameter(
811
- # permute_weight(layer.w2_weight.data),
812
810
  shuffle_weight(layer.w2_weight.data, (16, 16)),
813
811
  requires_grad=False,
814
812
  )
@@ -847,23 +845,21 @@ class Fp8MoEMethod:
847
845
  padding_size, # Avoid circular import
848
846
  )
849
847
 
850
- if get_bool_env_var("CK_MOE"):
848
+ if get_bool_env_var("SGLANG_AITER_MOE"):
851
849
  layer.w13_weight = torch.nn.Parameter(
852
- # permute_weight(layer.w13_weight.data),
853
850
  shuffle_weight(layer.w13_weight.data, (16, 16)),
854
851
  requires_grad=False,
855
852
  )
856
853
  torch.cuda.empty_cache()
857
854
  layer.w2_weight = torch.nn.Parameter(
858
- # permute_weight(layer.w2_weight.data),
859
855
  shuffle_weight(layer.w2_weight.data, (16, 16)),
860
856
  requires_grad=False,
861
857
  )
862
858
  torch.cuda.empty_cache()
863
- # ROCm (CK_MOE): using column-wise scaling
859
+ # ROCm (SGLANG_AITER_MOE): using column-wise scaling
864
860
  layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
865
861
  layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
866
- elif get_bool_env_var("MOE_PADDING"):
862
+ elif get_bool_env_var("SGLANG_MOE_PADDING"):
867
863
  # If ROCm, apply weight padding (min. Mem channel contention) only if set
868
864
  layer.w13_weight = torch.nn.Parameter(
869
865
  F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
@@ -912,15 +908,16 @@ class Fp8MoEMethod:
912
908
  )
913
909
 
914
910
  if _is_hip:
915
- if get_bool_env_var("USE_INT4_WEIGHT"):
916
- # TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
911
+ if get_bool_env_var("SGLANG_INT4_WEIGHT"):
912
+ # TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
917
913
  assert not no_combine, f"{no_combine=} is not supported."
918
- return ck_moe_2stages_win4(
914
+ return ck_moe_2stages(
919
915
  x,
920
916
  layer.w13_weight,
921
917
  layer.w2_weight,
922
918
  topk_weights,
923
919
  topk_ids,
920
+ QuantType.per_Token,
924
921
  layer.w13_weight_scale1,
925
922
  layer.w2_weight_scale1,
926
923
  activation=(
@@ -930,13 +927,13 @@ class Fp8MoEMethod:
930
927
  ),
931
928
  )
932
929
 
933
- if get_bool_env_var("CK_MOE"):
930
+ if get_bool_env_var("SGLANG_AITER_MOE"):
934
931
  assert not no_combine, f"{no_combine=} is not supported."
935
932
  if self.block_quant:
936
- # TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
933
+ # TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
937
934
  assert (
938
935
  activation == "silu"
939
- ), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
936
+ ), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
940
937
  return asm_moe(
941
938
  x,
942
939
  layer.w13_weight,
@@ -955,6 +952,7 @@ class Fp8MoEMethod:
955
952
  layer.w2_weight,
956
953
  topk_weights,
957
954
  topk_ids,
955
+ QuantType.per_Token,
958
956
  layer.w13_weight_scale1,
959
957
  layer.w2_weight_scale1,
960
958
  activation=(
@@ -31,7 +31,7 @@ from sglang.srt.utils import (
31
31
  _is_hip = is_hip()
32
32
  _is_cuda = is_cuda()
33
33
 
34
- if _is_hip and get_bool_env_var("CK_MOE"):
34
+ if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
35
35
  from aiter import gemm_a8w8_blockscale
36
36
 
37
37
  if _is_cuda:
@@ -132,7 +132,7 @@ def apply_w8a8_block_fp8_linear(
132
132
  output = fp8_blockwise_scaled_mm(
133
133
  q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
134
134
  )
135
- elif _is_hip and get_bool_env_var("CK_MOE"):
135
+ elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
136
136
  q_input, x_scale = per_token_group_quant_fp8(
137
137
  input_2d, block_size[1], column_major_scales=False
138
138
  )
@@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
35
35
  import copy
36
36
  import dataclasses
37
37
  import logging
38
+ import threading
38
39
  from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
39
40
 
40
41
  import numpy as np
@@ -724,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
724
725
  # This is an optimization to reduce the overhead of the prefill check.
725
726
  batch_is_full: bool = False
726
727
 
728
+ # Events
729
+ launch_done: Optional[threading.Event] = None
730
+
727
731
  # Sampling info
728
732
  sampling_info: SamplingBatchInfo = None
729
733
  next_batch_sampling_info: SamplingBatchInfo = None
@@ -1511,6 +1515,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1511
1515
  )
1512
1516
  or global_server_args_dict["attention_backend"] == "flashmla"
1513
1517
  or global_server_args_dict["attention_backend"] == "fa3"
1518
+ or global_server_args_dict["attention_backend"] == "cutlass_mla"
1514
1519
  ):
1515
1520
  seq_lens_cpu = self.seq_lens.cpu()
1516
1521
  else:
@@ -1565,6 +1570,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1565
1570
  )
1566
1571
  ),
1567
1572
  extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1573
+ launch_done=self.launch_done,
1568
1574
  )
1569
1575
 
1570
1576
  def copy(self):
@@ -1647,6 +1653,9 @@ class ModelWorkerBatch:
1647
1653
  # If set, the output of the batch contains the hidden states of the run.
1648
1654
  capture_hidden_mode: CaptureHiddenMode = None
1649
1655
 
1656
+ # Overlap event
1657
+ launch_done: Optional[threading.Event] = None
1658
+
1650
1659
 
1651
1660
  @triton.jit
1652
1661
  def write_req_to_token_pool_triton(
@@ -248,9 +248,6 @@ class Scheduler(
248
248
  if not self.is_generation:
249
249
  self.enable_overlap = False
250
250
  logger.info("Overlap scheduler is disabled for embedding models.")
251
- if self.model_config.is_multimodal:
252
- self.enable_overlap = False
253
- logger.info("Overlap scheduler is disabled for multimodal models.")
254
251
 
255
252
  # Launch a tensor parallel worker
256
253
  if self.enable_overlap:
@@ -645,6 +642,7 @@ class Scheduler(
645
642
  self.cur_batch = batch
646
643
 
647
644
  if batch:
645
+ batch.launch_done = threading.Event()
648
646
  result = self.run_batch(batch)
649
647
  self.result_queue.append((batch.copy(), result))
650
648
 
@@ -656,7 +654,7 @@ class Scheduler(
656
654
  forward_mode=ForwardMode.DUMMY_FIRST,
657
655
  next_batch_sampling_info=self.tp_worker.cur_sampling_info,
658
656
  )
659
- self.process_batch_result(tmp_batch, None)
657
+ self.process_batch_result(tmp_batch, None, batch.launch_done)
660
658
 
661
659
  if self.last_batch:
662
660
  # Process the results of the last batch
@@ -664,7 +662,10 @@ class Scheduler(
664
662
  tmp_batch.next_batch_sampling_info = (
665
663
  self.tp_worker.cur_sampling_info if batch else None
666
664
  )
667
- self.process_batch_result(tmp_batch, tmp_result)
665
+ # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
666
+ self.process_batch_result(
667
+ tmp_batch, tmp_result, batch.launch_done if batch else None
668
+ )
668
669
  elif batch is None:
669
670
  # When the server is idle, do self-check and re-init some states
670
671
  self.check_memory()
@@ -1417,14 +1418,15 @@ class Scheduler(
1417
1418
  self,
1418
1419
  batch: ScheduleBatch,
1419
1420
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
1421
+ launch_done: Optional[threading.Event] = None,
1420
1422
  ):
1421
1423
  if batch.forward_mode.is_decode():
1422
- self.process_batch_result_decode(batch, result)
1424
+ self.process_batch_result_decode(batch, result, launch_done)
1423
1425
  elif batch.forward_mode.is_extend():
1424
- self.process_batch_result_prefill(batch, result)
1426
+ self.process_batch_result_prefill(batch, result, launch_done)
1425
1427
  elif batch.forward_mode.is_idle():
1426
1428
  if self.enable_overlap:
1427
- self.tp_worker.resolve_batch_result(result.bid)
1429
+ self.tp_worker.resolve_last_batch_result(launch_done)
1428
1430
  if batch.next_batch_sampling_info:
1429
1431
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1430
1432
  self.current_stream.synchronize()
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import threading
3
4
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
4
5
 
5
6
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -11,6 +12,7 @@ if TYPE_CHECKING:
11
12
  EmbeddingBatchResult,
12
13
  GenerationBatchResult,
13
14
  ScheduleBatch,
15
+ Scheduler,
14
16
  )
15
17
 
16
18
 
@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
21
23
  """
22
24
 
23
25
  def process_batch_result_prefill(
24
- self,
26
+ self: Scheduler,
25
27
  batch: ScheduleBatch,
26
28
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
29
+ launch_done: Optional[threading.Event] = None,
27
30
  ):
28
31
  skip_stream_req = None
29
32
 
@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
43
46
  )
44
47
 
45
48
  if self.enable_overlap:
46
- logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
49
+ logits_output, next_token_ids = (
50
+ self.tp_worker.resolve_last_batch_result(
51
+ launch_done,
52
+ )
53
+ )
47
54
  else:
48
55
  # Move next_token_ids and logprobs to cpu
49
56
  next_token_ids = next_token_ids.tolist()
@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
175
182
  self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
176
183
 
177
184
  def process_batch_result_decode(
178
- self,
185
+ self: Scheduler,
179
186
  batch: ScheduleBatch,
180
187
  result: GenerationBatchResult,
188
+ launch_done: Optional[threading.Event] = None,
181
189
  ):
182
190
  logits_output, next_token_ids, bid = (
183
191
  result.logits_output,
@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
187
195
  self.num_generated_tokens += len(batch.reqs)
188
196
 
189
197
  if self.enable_overlap:
190
- logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
198
+ logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result(
199
+ launch_done
200
+ )
191
201
  next_token_logprobs = logits_output.next_token_logprobs
192
202
  elif batch.spec_algorithm.is_none():
193
203
  # spec decoding handles output logprobs inside verify process.
@@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin:
271
281
  self.log_decode_stats()
272
282
 
273
283
  def add_input_logprob_return_values(
274
- self,
284
+ self: Scheduler,
275
285
  i: int,
276
286
  req: Req,
277
287
  output: LogitsProcessorOutput,
@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
405
415
  assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
406
416
 
407
417
  def add_logprob_return_values(
408
- self,
418
+ self: Scheduler,
409
419
  i: int,
410
420
  req: Req,
411
421
  pt: int,
@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
436
446
  return num_input_logprobs
437
447
 
438
448
  def stream_output(
439
- self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
449
+ self: Scheduler,
450
+ reqs: List[Req],
451
+ return_logprob: bool,
452
+ skip_req: Optional[Req] = None,
440
453
  ):
441
454
  """Stream the output to detokenizer."""
442
455
  if self.is_generation:
@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
445
458
  self.stream_output_embedding(reqs)
446
459
 
447
460
  def stream_output_generation(
448
- self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
461
+ self: Scheduler,
462
+ reqs: List[Req],
463
+ return_logprob: bool,
464
+ skip_req: Optional[Req] = None,
449
465
  ):
450
466
  rids = []
451
467
  finished_reasons: List[BaseFinishReason] = []
@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
593
609
  )
594
610
  )
595
611
 
596
- def stream_output_embedding(self, reqs: List[Req]):
612
+ def stream_output_embedding(self: Scheduler, reqs: List[Req]):
597
613
  rids = []
598
614
  finished_reasons: List[BaseFinishReason] = []
599
615
 
@@ -170,13 +170,13 @@ class TpModelWorker:
170
170
  def forward_batch_generation(
171
171
  self,
172
172
  model_worker_batch: ModelWorkerBatch,
173
- launch_done: Optional[threading.Event] = None,
174
173
  skip_sample: bool = False,
175
174
  ) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
176
175
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
177
176
  logits_output = self.model_runner.forward(forward_batch)
178
- if launch_done:
179
- launch_done.set()
177
+
178
+ if model_worker_batch.launch_done is not None:
179
+ model_worker_batch.launch_done.set()
180
180
 
181
181
  if skip_sample:
182
182
  next_token_ids = None
@@ -132,7 +132,6 @@ class TpModelWorkerClient:
132
132
  batch_pt += 1
133
133
 
134
134
  # Create event
135
- self.launch_done = threading.Event()
136
135
  copy_done = torch.get_device_module(self.device).Event()
137
136
 
138
137
  # Resolve future tokens in the input
@@ -141,7 +140,7 @@ class TpModelWorkerClient:
141
140
 
142
141
  # Run forward
143
142
  logits_output, next_token_ids = self.worker.forward_batch_generation(
144
- model_worker_batch, self.launch_done
143
+ model_worker_batch
145
144
  )
146
145
 
147
146
  # Update the future token ids map
@@ -168,10 +167,16 @@ class TpModelWorkerClient:
168
167
 
169
168
  self.output_queue.put((copy_done, logits_output, next_token_ids))
170
169
 
171
- def resolve_batch_result(self, bid: int):
170
+ def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
171
+ """
172
+ This function is called to resolve the last batch result and
173
+ wait for the current batch to be launched. Used in overlap mode.
174
+ """
172
175
  copy_done, logits_output, next_token_ids = self.output_queue.get()
176
+
177
+ if launch_done is not None:
178
+ launch_done.wait()
173
179
  copy_done.synchronize()
174
- self.launch_done.wait()
175
180
 
176
181
  if logits_output.next_token_logprobs is not None:
177
182
  logits_output.next_token_logprobs = (
@@ -271,6 +271,7 @@ class ModelRunner:
271
271
  "fa3",
272
272
  "triton",
273
273
  "flashmla",
274
+ "cutlass_mla",
274
275
  ]:
275
276
  logger.info(
276
277
  f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
@@ -926,6 +927,12 @@ class ModelRunner:
926
927
  )
927
928
 
928
929
  self.attn_backend = FlashAttentionBackend(self)
930
+ elif self.server_args.attention_backend == "cutlass_mla":
931
+ from sglang.srt.layers.attention.cutlass_mla_backend import (
932
+ CutlassMLABackend,
933
+ )
934
+
935
+ self.attn_backend = CutlassMLABackend(self)
929
936
  else:
930
937
  raise ValueError(
931
938
  f"Invalid attention backend: {self.server_args.attention_backend}"
@@ -968,7 +975,7 @@ class ModelRunner:
968
975
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
969
976
  logger.info(
970
977
  f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
971
- f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
978
+ f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
972
979
  )
973
980
 
974
981
  def apply_torch_tp(self):
@@ -971,6 +971,8 @@ def v1_chat_generate_request(
971
971
  )
972
972
 
973
973
  for message in request.messages:
974
+ if message.content is None:
975
+ message.content = ""
974
976
  if isinstance(message.content, str):
975
977
  openai_compatible_messages.append(
976
978
  {"role": message.role, "content": message.content}
@@ -1001,6 +1003,11 @@ def v1_chat_generate_request(
1001
1003
  tokenize=True,
1002
1004
  add_generation_prompt=True,
1003
1005
  tools=tools,
1006
+ **(
1007
+ request.chat_template_kwargs
1008
+ if request.chat_template_kwargs
1009
+ else {}
1010
+ ),
1004
1011
  )
1005
1012
  except:
1006
1013
  # This except branch will be triggered when the chosen model
@@ -1012,6 +1019,11 @@ def v1_chat_generate_request(
1012
1019
  tokenize=True,
1013
1020
  add_generation_prompt=True,
1014
1021
  tools=tools,
1022
+ **(
1023
+ request.chat_template_kwargs
1024
+ if request.chat_template_kwargs
1025
+ else {}
1026
+ ),
1015
1027
  )
1016
1028
 
1017
1029
  if assistant_prefix:
@@ -1179,6 +1191,7 @@ def v1_chat_generate_request(
1179
1191
  modalities=modalities_list,
1180
1192
  lora_path=lora_paths,
1181
1193
  bootstrap_host=all_requests[0].bootstrap_host,
1194
+ bootstrap_port=all_requests[0].bootstrap_port,
1182
1195
  bootstrap_room=all_requests[0].bootstrap_room,
1183
1196
  )
1184
1197
 
@@ -1245,16 +1258,34 @@ def v1_chat_generate_response(
1245
1258
  tool_calls = None
1246
1259
  text = ret_item["text"]
1247
1260
 
1261
+ enable_thinking = True
1248
1262
  if isinstance(request, list):
1249
1263
  tool_choice = request[idx].tool_choice
1250
1264
  tools = request[idx].tools
1251
1265
  separate_reasoning = request[idx].separate_reasoning
1266
+
1267
+ if (
1268
+ request[idx].chat_template_kwargs
1269
+ and request[idx].chat_template_kwargs.get("enable_thinking") is not None
1270
+ ):
1271
+ enable_thinking = request[idx].chat_template_kwargs.get(
1272
+ "enable_thinking", True
1273
+ )
1252
1274
  else:
1253
1275
  tool_choice = request.tool_choice
1254
1276
  tools = request.tools
1255
1277
  separate_reasoning = request.separate_reasoning
1256
1278
 
1257
- if reasoning_parser and separate_reasoning:
1279
+ if (
1280
+ request.chat_template_kwargs
1281
+ and request.chat_template_kwargs.get("enable_thinking") is not None
1282
+ ):
1283
+ enable_thinking = request.chat_template_kwargs.get(
1284
+ "enable_thinking", True
1285
+ )
1286
+
1287
+ reasoning_text = None
1288
+ if reasoning_parser and separate_reasoning and enable_thinking:
1258
1289
  try:
1259
1290
  parser = ReasoningParser(
1260
1291
  model_type=reasoning_parser, stream_reasoning=False
@@ -1266,8 +1297,6 @@ def v1_chat_generate_response(
1266
1297
  HTTPStatus.BAD_REQUEST,
1267
1298
  "Failed to parse reasoning related info to json format!",
1268
1299
  )
1269
- else:
1270
- reasoning_text = None
1271
1300
 
1272
1301
  if tool_choice != "none" and tools:
1273
1302
  parser = FunctionCallParser(tools, tool_call_parser)
@@ -361,9 +361,11 @@ class ChatCompletionRequest(BaseModel):
361
361
  session_params: Optional[Dict] = None
362
362
  separate_reasoning: bool = True
363
363
  stream_reasoning: bool = True
364
+ chat_template_kwargs: Optional[Dict] = None
364
365
 
365
366
  # For PD disaggregation
366
367
  bootstrap_host: Optional[str] = None
368
+ bootstrap_port: Optional[int] = None
367
369
  bootstrap_room: Optional[int] = None
368
370
 
369
371
 
@@ -117,6 +117,29 @@ class DeepSeekR1Detector(BaseReasoningFormatDetector):
117
117
  # https://github.com/sgl-project/sglang/pull/3202#discussion_r1950153599
118
118
 
119
119
 
120
+ class Qwen3Detector(BaseReasoningFormatDetector):
121
+ """
122
+ Detector for Qwen3 model.
123
+ Assumes reasoning format:
124
+ (<think>)*(.*)</think>
125
+ Returns all the text before the </think> tag as `reasoning_text`
126
+ and the rest of the text as `normal_text`.
127
+
128
+ Args:
129
+ stream_reasoning (bool): If False, accumulates reasoning content until the end tag.
130
+ If True, streams reasoning content as it arrives.
131
+ """
132
+
133
+ def __init__(self, stream_reasoning: bool = True):
134
+ # Qwen3 is assumed to be reasoning until `</think>` token
135
+ super().__init__(
136
+ "<think>",
137
+ "</think>",
138
+ force_reasoning=True,
139
+ stream_reasoning=stream_reasoning,
140
+ )
141
+
142
+
120
143
  class ReasoningParser:
121
144
  """
122
145
  Parser that handles both streaming and non-streaming scenarios for extracting
@@ -129,7 +152,8 @@ class ReasoningParser:
129
152
  """
130
153
 
131
154
  DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
132
- "deepseek-r1": DeepSeekR1Detector
155
+ "deepseek-r1": DeepSeekR1Detector,
156
+ "qwen3": Qwen3Detector,
133
157
  }
134
158
 
135
159
  def __init__(self, model_type: str = None, stream_reasoning: bool = True):
sglang/srt/server_args.py CHANGED
@@ -256,6 +256,12 @@ class ServerArgs:
256
256
  )
257
257
  self.page_size = 64
258
258
 
259
+ if self.attention_backend == "cutlass_mla":
260
+ logger.warning(
261
+ "Cutlass MLA only supports a page_size of 128, change page_size to 128."
262
+ )
263
+ self.page_size = 128
264
+
259
265
  # Set cuda graph max batch size
260
266
  if self.cuda_graph_max_bs is None:
261
267
  # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
@@ -420,7 +426,7 @@ class ServerArgs:
420
426
  parser.add_argument(
421
427
  "--skip-tokenizer-init",
422
428
  action="store_true",
423
- help="If set, skip init tokenizer and pass input_ids in generate request",
429
+ help="If set, skip init tokenizer and pass input_ids in generate request.",
424
430
  )
425
431
  parser.add_argument(
426
432
  "--enable-tokenizer-batch-encode",
@@ -559,6 +565,7 @@ class ServerArgs:
559
565
  "name, a tag name, or a commit id. If unspecified, will use "
560
566
  "the default version.",
561
567
  )
568
+
562
569
  # Memory and scheduling
563
570
  parser.add_argument(
564
571
  "--mem-fraction-static",
@@ -823,7 +830,14 @@ class ServerArgs:
823
830
  parser.add_argument(
824
831
  "--attention-backend",
825
832
  type=str,
826
- choices=["flashinfer", "triton", "torch_native", "fa3", "flashmla"],
833
+ choices=[
834
+ "flashinfer",
835
+ "triton",
836
+ "torch_native",
837
+ "fa3",
838
+ "flashmla",
839
+ "cutlass_mla",
840
+ ],
827
841
  default=ServerArgs.attention_backend,
828
842
  help="Choose the kernels for attention layers.",
829
843
  )
sglang/srt/utils.py CHANGED
@@ -1970,8 +1970,11 @@ def is_fa3_default_architecture(hf_config):
1970
1970
  "Llama4ForConditionalGeneration",
1971
1971
  "LlamaForCausalLM",
1972
1972
  "MistralForCausalLM",
1973
+ "MixtralForCausalLM",
1973
1974
  "Gemma2ForCausalLM",
1974
1975
  "Gemma3ForConditionalGeneration",
1976
+ "Qwen3ForCausalLM",
1977
+ "Qwen3MoeForCausalLM",
1975
1978
  }
1976
1979
  return architectures[0] in default_archs
1977
1980