sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ import torch
12
12
  import triton
13
13
  import triton.language as tl
14
14
 
15
+ from sglang.math_utils import ceil_div
15
16
  from sglang.srt.layers.moe.topk import select_experts
16
17
  from sglang.srt.layers.quantization.fp8_kernel import (
17
18
  per_token_group_quant_fp8,
@@ -24,9 +25,11 @@ from sglang.srt.layers.quantization.int8_kernel import (
24
25
  sglang_per_token_group_quant_int8,
25
26
  )
26
27
  from sglang.srt.utils import (
28
+ cpu_has_amx_support,
27
29
  direct_register_custom_op,
28
30
  get_bool_env_var,
29
31
  get_device_name,
32
+ is_cpu,
30
33
  is_cuda,
31
34
  is_hip,
32
35
  log_info_on_rank0,
@@ -35,9 +38,13 @@ from sglang.srt.utils import (
35
38
 
36
39
  _is_hip = is_hip()
37
40
  _is_cuda = is_cuda()
41
+ _is_cpu_amx_available = cpu_has_amx_support()
42
+ _is_cpu = is_cpu()
38
43
 
39
44
  if _is_cuda:
40
45
  from sgl_kernel import gelu_and_mul, silu_and_mul
46
+ elif _is_cpu and _is_cpu_amx_available:
47
+ pass
41
48
  else:
42
49
  from vllm import _custom_ops as vllm_ops
43
50
  from vllm._custom_ops import scaled_fp8_quant
@@ -518,10 +525,6 @@ def fused_moe_kernel(
518
525
  tl.store(c_ptrs, accumulator, mask=c_mask)
519
526
 
520
527
 
521
- def ceil_div(a, b):
522
- return (a + b - 1) // b
523
-
524
-
525
528
  @triton.jit
526
529
  def moe_align_block_size_stage1(
527
530
  topk_ids_ptr,
@@ -32,6 +32,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
32
32
 
33
33
  if _use_aiter:
34
34
  from aiter import ActivationType
35
+ from aiter.fused_moe import fused_moe
35
36
  from aiter.fused_moe_bf16_asm import ck_moe_2stages
36
37
  from aiter.ops.shuffle import shuffle_weight
37
38
 
@@ -204,7 +205,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
204
205
  topk_weights, dtype=torch.float32
205
206
  ) # topk_weights must be FP32 (float32)
206
207
 
207
- return ck_moe_2stages(
208
+ return fused_moe(
208
209
  x,
209
210
  layer.w13_weight,
210
211
  layer.w2_weight,
@@ -241,7 +242,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
241
242
  num_fused_shared_experts: int = 0,
242
243
  custom_routing_function: Optional[Callable] = None,
243
244
  correction_bias: Optional[torch.Tensor] = None,
245
+ activation: str = "silu",
246
+ apply_router_weight_on_input: bool = False,
244
247
  inplace: bool = True,
248
+ no_combine: bool = False,
249
+ routed_scaling_factor: Optional[float] = None,
245
250
  ) -> torch.Tensor:
246
251
  return moe_forward_native(
247
252
  layer,
@@ -260,7 +265,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
260
265
  def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
261
266
  raise NotImplementedError("The TPU backend currently does not support MoE.")
262
267
 
263
- forward_native = forward_cuda
268
+ forward_native = forward_cpu
264
269
 
265
270
 
266
271
  class FusedMoE(torch.nn.Module):
@@ -310,6 +315,8 @@ class FusedMoE(torch.nn.Module):
310
315
  inplace: bool = True,
311
316
  no_combine: bool = False,
312
317
  routed_scaling_factor: Optional[float] = None,
318
+ enable_flashinfer_moe: Optional[bool] = False,
319
+ enable_ep_moe: Optional[bool] = False,
313
320
  ):
314
321
  super().__init__()
315
322
 
@@ -320,9 +327,40 @@ class FusedMoE(torch.nn.Module):
320
327
  self.tp_size = (
321
328
  tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
322
329
  )
330
+ self.tp_rank = get_tensor_model_parallel_rank()
331
+ self.num_experts = num_experts
332
+ self.expert_map = None
333
+
334
+ if enable_flashinfer_moe and quant_config is None:
335
+ logger.warning("Disable flashinfer MoE when quantization config is None.")
336
+ enable_flashinfer_moe = False
337
+ enable_ep_moe = False
338
+
339
+ self.enable_flashinfer_moe = enable_flashinfer_moe
340
+ if enable_ep_moe:
341
+ assert (
342
+ self.enable_flashinfer_moe
343
+ ), "FusedMoE only supports EP with --enable-flashinfer-moe"
344
+ self.ep_size = self.tp_size
345
+ self.ep_rank = self.tp_rank
346
+ self.tp_size = 1
347
+ self.tp_rank = 0
348
+ # Create a tensor of size num_experts filled with -1
349
+ self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
350
+ # Create a expert map for the local experts
351
+ assert num_experts % self.ep_size == 0
352
+ self.local_num_experts = num_experts // self.ep_size
353
+ self.expert_map[
354
+ self.ep_rank
355
+ * self.local_num_experts : (self.ep_rank + 1)
356
+ * self.local_num_experts
357
+ ] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
358
+ else:
359
+ self.ep_size = 1
360
+ self.ep_rank = 0
361
+ self.local_num_experts = num_experts
323
362
  self.routed_scaling_factor = routed_scaling_factor
324
363
  self.top_k = top_k
325
- self.num_experts = num_experts
326
364
  assert intermediate_size % self.tp_size == 0
327
365
  self.intermediate_size_per_partition = intermediate_size // self.tp_size
328
366
  self.reduce_results = reduce_results
@@ -340,7 +378,6 @@ class FusedMoE(torch.nn.Module):
340
378
  self.use_presharded_weights = use_presharded_weights
341
379
  self.inplace = inplace
342
380
  self.no_combine = no_combine
343
- self.local_num_experts = num_experts
344
381
 
345
382
  if quant_config is None:
346
383
  self.quant_method: Optional[QuantizeMethodBase] = (
@@ -348,11 +385,13 @@ class FusedMoE(torch.nn.Module):
348
385
  )
349
386
  else:
350
387
  self.quant_method = quant_config.get_quant_method(self, prefix)
388
+ if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
389
+ self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
351
390
  assert self.quant_method is not None
352
391
 
353
392
  self.quant_method.create_weights(
354
393
  layer=self,
355
- num_experts=num_experts,
394
+ num_experts=self.local_num_experts,
356
395
  hidden_size=hidden_size,
357
396
  # FIXME: figure out which intermediate_size to use
358
397
  intermediate_size=self.intermediate_size_per_partition,
@@ -446,12 +485,15 @@ class FusedMoE(torch.nn.Module):
446
485
 
447
486
  # Narrow parameter and load.
448
487
  # w1, gate_proj: Load into first logical weight of w13.
449
- if shard_id == "w1":
450
- expert_data = expert_data.narrow(shard_dim, 0, shard_size)
451
488
  # w3, up_proj: Load into second logical weight of w13.
489
+ # trtllm cutlass kernel assumes differently
490
+ assert shard_id in ("w1", "w3")
491
+ switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
492
+ if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
493
+ start = shard_size
452
494
  else:
453
- assert shard_id == "w3"
454
- expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
495
+ start = 0
496
+ expert_data = expert_data.narrow(shard_dim, start, shard_size)
455
497
  expert_data.copy_(loaded_weight)
456
498
 
457
499
  def _load_w2(
@@ -505,6 +547,11 @@ class FusedMoE(torch.nn.Module):
505
547
  assert shard_id in ("w1", "w3")
506
548
  expert_data.copy_(loaded_weight)
507
549
 
550
+ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
551
+ if self.expert_map is None:
552
+ return expert_id
553
+ return self.expert_map[expert_id].item()
554
+
508
555
  def weight_loader(
509
556
  self,
510
557
  param: torch.nn.Parameter,
@@ -513,6 +560,13 @@ class FusedMoE(torch.nn.Module):
513
560
  shard_id: str,
514
561
  expert_id: int,
515
562
  ) -> None:
563
+ expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
564
+ if expert_id == -1:
565
+ return
566
+
567
+ # TP rank is set to 0 if EP is enabled
568
+ tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
569
+
516
570
  # compressed-tensors checkpoints with packed weights are stored flipped
517
571
  # TODO (mgoin): check self.quant_method.quant_config.quant_format
518
572
  # against known CompressionFormat enum values that have this quality
@@ -537,7 +591,6 @@ class FusedMoE(torch.nn.Module):
537
591
  SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
538
592
 
539
593
  expert_data = param.data[expert_id]
540
- tp_rank = get_tensor_model_parallel_rank()
541
594
 
542
595
  # is_transposed: if the dim to shard the weight
543
596
  # should be flipped. Required by GPTQ, compressed-tensors
@@ -545,7 +598,7 @@ class FusedMoE(torch.nn.Module):
545
598
  is_transposed = getattr(param, "is_transposed", False)
546
599
  shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
547
600
  if is_transposed:
548
- shard_dim = ~shard_dim
601
+ shard_dim = int(not shard_dim)
549
602
 
550
603
  # Case input scale: input_scale loading is only supported for fp8
551
604
  if "input_scale" in weight_name:
@@ -686,9 +739,19 @@ class FusedMoE(torch.nn.Module):
686
739
  activation=self.activation,
687
740
  apply_router_weight_on_input=self.apply_router_weight_on_input,
688
741
  routed_scaling_factor=self.routed_scaling_factor,
742
+ **(
743
+ dict(
744
+ tp_rank=self.tp_rank,
745
+ tp_size=self.tp_size,
746
+ ep_rank=self.ep_rank,
747
+ ep_size=self.ep_size,
748
+ )
749
+ if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
750
+ else {}
751
+ ),
689
752
  )
690
753
 
691
- if self.reduce_results and self.tp_size > 1:
754
+ if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
692
755
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
693
756
 
694
757
  return final_hidden_states
@@ -28,10 +28,18 @@ from sglang.srt.managers.expert_location_dispatch import (
28
28
  topk_ids_logical_to_physical,
29
29
  )
30
30
  from sglang.srt.managers.schedule_batch import global_server_args_dict
31
- from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
31
+ from sglang.srt.utils import (
32
+ cpu_has_amx_support,
33
+ get_compiler_backend,
34
+ is_cpu,
35
+ is_cuda,
36
+ is_hip,
37
+ )
32
38
 
33
39
  _is_cuda = is_cuda()
34
40
  _is_hip = is_hip()
41
+ _is_cpu_amx_available = cpu_has_amx_support()
42
+ _is_cpu = is_cpu()
35
43
 
36
44
  if _is_cuda:
37
45
  from sgl_kernel import moe_fused_gate
@@ -40,7 +48,7 @@ if _is_cuda or _is_hip:
40
48
  from sgl_kernel import topk_softmax
41
49
 
42
50
 
43
- def fused_topk_native(
51
+ def fused_topk_torch_native(
44
52
  hidden_states: torch.Tensor,
45
53
  gating_output: torch.Tensor,
46
54
  topk: int,
@@ -61,6 +69,20 @@ def fused_topk_native(
61
69
  return topk_weights, topk_ids
62
70
 
63
71
 
72
+ def fused_topk_cpu(
73
+ hidden_states: torch.Tensor,
74
+ gating_output: torch.Tensor,
75
+ topk: int,
76
+ renormalize: bool,
77
+ ):
78
+ return torch.ops.sgl_kernel.topk_softmax_cpu(
79
+ hidden_states=hidden_states,
80
+ gating_output=gating_output,
81
+ topk=topk,
82
+ renormalize=renormalize,
83
+ )
84
+
85
+
64
86
  def fused_topk(
65
87
  hidden_states: torch.Tensor,
66
88
  gating_output: torch.Tensor,
@@ -115,7 +137,7 @@ def _fused_topk_postprocess(
115
137
 
116
138
  # This is used by the Deepseek V2/V3/R1 series models
117
139
  @torch.compile(dynamic=True, backend=get_compiler_backend())
118
- def grouped_topk(
140
+ def grouped_topk_gpu(
119
141
  hidden_states: torch.Tensor,
120
142
  gating_output: torch.Tensor,
121
143
  topk: int,
@@ -171,6 +193,32 @@ def grouped_topk(
171
193
  return topk_weights, topk_ids
172
194
 
173
195
 
196
+ def grouped_topk_cpu(
197
+ hidden_states: torch.Tensor,
198
+ gating_output: torch.Tensor,
199
+ topk: int,
200
+ renormalize: bool,
201
+ num_expert_group: int = 0,
202
+ topk_group: int = 0,
203
+ num_fused_shared_experts: int = 0,
204
+ routed_scaling_factor: Optional[float] = None,
205
+ num_token_non_padded: Optional[torch.Tensor] = None,
206
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
207
+ ):
208
+ assert expert_location_dispatch_info is None
209
+ return torch.ops.sgl_kernel.grouped_topk_cpu(
210
+ hidden_states,
211
+ gating_output,
212
+ topk,
213
+ renormalize,
214
+ num_expert_group,
215
+ topk_group,
216
+ num_fused_shared_experts,
217
+ routed_scaling_factor,
218
+ num_token_non_padded,
219
+ )
220
+
221
+
174
222
  def biased_grouped_topk_impl(
175
223
  hidden_states: torch.Tensor,
176
224
  gating_output: torch.Tensor,
@@ -249,7 +297,16 @@ def _mask_topk_ids_padded_region(
249
297
  topk_ids[indices >= num_token_non_padded, :] = -1
250
298
 
251
299
 
252
- def biased_grouped_topk(
300
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
301
+ def _biased_grouped_topk_postprocess(
302
+ topk_ids, expert_location_dispatch_info, num_token_non_padded
303
+ ):
304
+ topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
305
+ _mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
306
+ return topk_ids
307
+
308
+
309
+ def biased_grouped_topk_gpu(
253
310
  hidden_states: torch.Tensor,
254
311
  gating_output: torch.Tensor,
255
312
  correction_bias: torch.Tensor,
@@ -282,14 +339,13 @@ def biased_grouped_topk(
282
339
  num_fused_shared_experts,
283
340
  routed_scaling_factor,
284
341
  )
285
- # TODO merge into kernel for this branch
286
- topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
287
- # TODO will fuse this into kernel, thus use slow manual operation now
288
- if num_token_non_padded is None:
289
- return topk_weights, topk_ids
290
- torch.compile(
291
- _mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
292
- )(topk_ids, num_token_non_padded)
342
+ # TODO merge into kernel
343
+ if (expert_location_dispatch_info is not None) or (
344
+ num_token_non_padded is not None
345
+ ):
346
+ topk_ids = _biased_grouped_topk_postprocess(
347
+ topk_ids, expert_location_dispatch_info, num_token_non_padded
348
+ )
293
349
  return topk_weights, topk_ids
294
350
  else:
295
351
  biased_grouped_topk_fn = (
@@ -314,6 +370,45 @@ def biased_grouped_topk(
314
370
  )
315
371
 
316
372
 
373
+ def biased_grouped_topk_cpu(
374
+ hidden_states: torch.Tensor,
375
+ gating_output: torch.Tensor,
376
+ correction_bias: torch.Tensor,
377
+ topk: int,
378
+ renormalize: bool,
379
+ num_expert_group: int = 0,
380
+ topk_group: int = 0,
381
+ compiled: bool = True,
382
+ num_fused_shared_experts: int = 0,
383
+ routed_scaling_factor: Optional[float] = None,
384
+ num_token_non_padded: Optional[torch.Tensor] = None,
385
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
386
+ ):
387
+ assert expert_location_dispatch_info is None
388
+ return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
389
+ hidden_states,
390
+ gating_output,
391
+ correction_bias,
392
+ topk,
393
+ renormalize,
394
+ num_expert_group,
395
+ topk_group,
396
+ num_fused_shared_experts,
397
+ routed_scaling_factor,
398
+ num_token_non_padded,
399
+ )
400
+
401
+
402
+ if _is_cpu and _is_cpu_amx_available:
403
+ biased_grouped_topk = biased_grouped_topk_cpu
404
+ grouped_topk = grouped_topk_cpu
405
+ fused_topk_native = fused_topk_cpu
406
+ else:
407
+ biased_grouped_topk = biased_grouped_topk_gpu
408
+ grouped_topk = grouped_topk_gpu
409
+ fused_topk_native = fused_topk_torch_native
410
+
411
+
317
412
  def select_experts(
318
413
  hidden_states: torch.Tensor,
319
414
  router_logits: torch.Tensor,
@@ -3,10 +3,13 @@
3
3
 
4
4
  from dataclasses import dataclass
5
5
  from enum import IntEnum
6
+ from typing import Optional
6
7
 
7
8
  import torch
8
9
  import torch.nn as nn
10
+ from transformers import PretrainedConfig
9
11
 
12
+ from sglang.srt.layers.activation import get_cross_encoder_activation_function
10
13
  from sglang.srt.model_executor.model_runner import ForwardBatch
11
14
 
12
15
 
@@ -54,3 +57,56 @@ class Pooler(nn.Module):
54
57
  pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
55
58
 
56
59
  return EmbeddingPoolerOutput(embeddings=pooled_data)
60
+
61
+
62
+ class CrossEncodingPooler(nn.Module):
63
+ """A layer that pools specific information from hidden states.
64
+
65
+ This layer does the following:
66
+ 1. Extracts specific tokens or aggregates data based on pooling method.
67
+ 2. Normalizes output if specified.
68
+ 3. Returns structured results as `EmbeddingPoolerOutput`.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ config: PretrainedConfig,
74
+ classifier: nn.Module,
75
+ pooler: Optional[nn.Module] = None,
76
+ ):
77
+ super().__init__()
78
+ self.classifier = classifier
79
+ self.pooler = pooler
80
+ self.default_activation_function = get_cross_encoder_activation_function(config)
81
+
82
+ def forward(
83
+ self,
84
+ hidden_states: torch.Tensor,
85
+ forward_batch: ForwardBatch,
86
+ ) -> EmbeddingPoolerOutput:
87
+ """Pools sentence pair scores from the hidden_states."""
88
+
89
+ prompt_lens = forward_batch.extend_seq_lens
90
+
91
+ offset = 0
92
+ pooled_data_lst = []
93
+ for prompt_len in prompt_lens:
94
+ pooled_data_i = hidden_states[offset : offset + prompt_len]
95
+
96
+ if self.pooler is not None:
97
+ final_shape_tensor = self.pooler(pooled_data_i, forward_batch)
98
+ else:
99
+ final_shape_tensor = self.classifier(pooled_data_i)
100
+
101
+ pooled_data_lst.append(final_shape_tensor)
102
+ offset += prompt_len
103
+
104
+ pooled_output = torch.stack(pooled_data_lst)
105
+
106
+ if self.pooler is not None:
107
+ # apply classifier once on the full batch if possible
108
+ pooled_output = self.classifier(pooled_output)
109
+
110
+ scores = self.default_activation_function(pooled_output).squeeze(-1)
111
+
112
+ return EmbeddingPoolerOutput(embeddings=scores)
@@ -14,14 +14,18 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
14
14
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
15
15
  from sglang.srt.layers.quantization.utils import (
16
16
  all_close_1d,
17
+ cpu_has_amx_support,
17
18
  per_tensor_dequantize,
18
19
  replace_parameter,
19
20
  )
20
- from sglang.srt.utils import is_cuda, set_weight_attrs
21
+ from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
21
22
 
22
23
  _is_cuda = is_cuda()
24
+ _is_npu = is_npu()
25
+ _is_cpu_amx_available = cpu_has_amx_support()
26
+ _is_cpu = is_cpu()
23
27
 
24
- if not _is_cuda:
28
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
25
29
  from vllm import _custom_ops as vllm_ops
26
30
  from vllm._custom_ops import scaled_fp8_quant
27
31
 
@@ -0,0 +1 @@
1
+ from .entrypoint import *
@@ -5,34 +5,23 @@ from dataclasses import dataclass
5
5
  from enum import IntEnum, auto
6
6
  from typing import Callable, Dict, List, Optional, Tuple
7
7
 
8
- import torch
9
8
  from tqdm.contrib.concurrent import thread_map
10
9
 
10
+ from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
11
+ DEEPGEMM_BLACKWELL,
12
+ ENABLE_JIT_DEEPGEMM,
13
+ )
11
14
  from sglang.srt.server_args import ServerArgs
12
- from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
15
+ from sglang.srt.utils import get_bool_env_var, get_int_env_var
13
16
 
14
17
  logger = logging.getLogger(__name__)
15
- _ENABLE_JIT_DEEPGEMM = False
16
18
 
17
- try:
18
- import deep_gemm
19
+ if ENABLE_JIT_DEEPGEMM and not DEEPGEMM_BLACKWELL:
19
20
  from deep_gemm import get_num_sms
20
21
  from deep_gemm.jit import build
21
- from deep_gemm.jit.compiler import get_nvcc_compiler
22
22
  from deep_gemm.jit_kernels.gemm import get_best_configs
23
23
  from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
24
24
 
25
- sm_version = get_device_sm()
26
- if sm_version == 90:
27
- if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
28
- _ENABLE_JIT_DEEPGEMM = True
29
- except ImportError:
30
- logger.warning("Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.")
31
-
32
-
33
- def get_enable_jit_deepgemm():
34
- return _ENABLE_JIT_DEEPGEMM
35
-
36
25
 
37
26
  _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
38
27
  _ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
@@ -52,8 +41,10 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
52
41
  # NVRTC may have performance loss with some cases.
53
42
  # And NVCC JIT speed is also 9x faster in the ref commit
54
43
  _USE_NVRTC_DEFAULT = "0"
55
- if _ENABLE_JIT_DEEPGEMM:
44
+ if ENABLE_JIT_DEEPGEMM:
56
45
  try:
46
+ from deep_gemm.jit.compiler import get_nvcc_compiler
47
+
57
48
  get_nvcc_compiler()
58
49
  except:
59
50
  logger.warning(
@@ -114,11 +105,12 @@ class DeepGemmKernelHelper:
114
105
  _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
115
106
 
116
107
 
108
+ # TODO improve naming
117
109
  def _compile_warning_1():
118
110
  if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
119
111
  logger.warning(
120
112
  "Entering DeepGEMM JIT Pre-Compile session. "
121
- "And it may takes a long time(Typically 10-20 mins) "
113
+ "It may takes a long time (typically 10-20 mins) "
122
114
  "if you have not run `sglang.compile_deep_gemm`. "
123
115
  "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
124
116
  " for pre-compilation to reduce the overhead if you have not run it before. "
@@ -127,6 +119,7 @@ def _compile_warning_1():
127
119
  )
128
120
 
129
121
 
122
+ # TODO improve naming
130
123
  def _compile_warning_2():
131
124
  logger.warning(
132
125
  "Entering DeepGEMM JIT Single Kernel Compile session. "
@@ -238,6 +231,7 @@ def _compile_gemm_nt_f8f8bf16_one(
238
231
  _ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
239
232
 
240
233
 
234
+ # TODO further refactor warmup-related
241
235
  _KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
242
236
  DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
243
237
  name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
@@ -270,7 +264,6 @@ def _maybe_compile_deep_gemm_one_type_all(
270
264
  num_groups: int,
271
265
  m_list: Optional[List[int]] = None,
272
266
  ) -> None:
273
-
274
267
  global _INITIALIZATION_DICT
275
268
  global _BUILTIN_M_LIST
276
269
 
@@ -304,56 +297,6 @@ def _maybe_compile_deep_gemm_one_type_all(
304
297
  thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
305
298
 
306
299
 
307
- def grouped_gemm_nt_f8f8bf16_masked(
308
- lhs: Tuple[torch.Tensor, torch.Tensor],
309
- rhs: Tuple[torch.Tensor, torch.Tensor],
310
- out: torch.Tensor,
311
- masked_m: torch.Tensor,
312
- expected_m: int,
313
- ):
314
- num_groups, _, k = lhs[0].shape
315
- _, n, _ = rhs[0].shape
316
-
317
- kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
318
- _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
319
-
320
- with _log_jit_build(expected_m, n, k, kernel_type):
321
- deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
322
- lhs, rhs, out, masked_m, expected_m
323
- )
324
-
325
-
326
- def grouped_gemm_nt_f8f8bf16_contig(
327
- lhs: Tuple[torch.Tensor, torch.Tensor],
328
- rhs: Tuple[torch.Tensor, torch.Tensor],
329
- out: torch.Tensor,
330
- m_indices: torch.Tensor,
331
- ):
332
- m, k = lhs[0].shape
333
- num_groups, n, _ = rhs[0].shape
334
-
335
- kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
336
- _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
337
-
338
- with _log_jit_build(m, n, k, kernel_type):
339
- deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
340
-
341
-
342
- def gemm_nt_f8f8bf16(
343
- lhs: Tuple[torch.Tensor, torch.Tensor],
344
- rhs: Tuple[torch.Tensor, torch.Tensor],
345
- out: torch.Tensor,
346
- ):
347
- m, k = lhs[0].shape
348
- n, _ = rhs[0].shape
349
-
350
- kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16
351
- _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1)
352
-
353
- with _log_jit_build(m, n, k, kernel_type):
354
- deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
355
-
356
-
357
300
  @contextmanager
358
301
  def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
359
302
  if _IN_PRECOMPILE_STAGE:
@@ -368,7 +311,8 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
368
311
  ret = origin_func(self, *args, **kwargs)
369
312
  if ret is None:
370
313
  kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
371
- _compile_warning_2()
314
+ if not DEEPGEMM_BLACKWELL:
315
+ _compile_warning_2()
372
316
  logger.warning(
373
317
  f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
374
318
  )
@@ -380,13 +324,12 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
380
324
 
381
325
 
382
326
  @contextmanager
383
- def configure_deep_gemm_num_sms(num_sms):
384
- if num_sms is None:
327
+ def deep_gemm_execution_hook(
328
+ m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
329
+ ):
330
+ # not supported yet
331
+ if not DEEPGEMM_BLACKWELL:
332
+ _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
333
+
334
+ with _log_jit_build(m, n, k, kernel_type):
385
335
  yield
386
- else:
387
- original_num_sms = deep_gemm.get_num_sms()
388
- deep_gemm.set_num_sms(num_sms)
389
- try:
390
- yield
391
- finally:
392
- deep_gemm.set_num_sms(original_num_sms)