sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ import torch
|
|
12
12
|
import triton
|
13
13
|
import triton.language as tl
|
14
14
|
|
15
|
+
from sglang.math_utils import ceil_div
|
15
16
|
from sglang.srt.layers.moe.topk import select_experts
|
16
17
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
17
18
|
per_token_group_quant_fp8,
|
@@ -24,9 +25,11 @@ from sglang.srt.layers.quantization.int8_kernel import (
|
|
24
25
|
sglang_per_token_group_quant_int8,
|
25
26
|
)
|
26
27
|
from sglang.srt.utils import (
|
28
|
+
cpu_has_amx_support,
|
27
29
|
direct_register_custom_op,
|
28
30
|
get_bool_env_var,
|
29
31
|
get_device_name,
|
32
|
+
is_cpu,
|
30
33
|
is_cuda,
|
31
34
|
is_hip,
|
32
35
|
log_info_on_rank0,
|
@@ -35,9 +38,13 @@ from sglang.srt.utils import (
|
|
35
38
|
|
36
39
|
_is_hip = is_hip()
|
37
40
|
_is_cuda = is_cuda()
|
41
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
42
|
+
_is_cpu = is_cpu()
|
38
43
|
|
39
44
|
if _is_cuda:
|
40
45
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
46
|
+
elif _is_cpu and _is_cpu_amx_available:
|
47
|
+
pass
|
41
48
|
else:
|
42
49
|
from vllm import _custom_ops as vllm_ops
|
43
50
|
from vllm._custom_ops import scaled_fp8_quant
|
@@ -518,10 +525,6 @@ def fused_moe_kernel(
|
|
518
525
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
519
526
|
|
520
527
|
|
521
|
-
def ceil_div(a, b):
|
522
|
-
return (a + b - 1) // b
|
523
|
-
|
524
|
-
|
525
528
|
@triton.jit
|
526
529
|
def moe_align_block_size_stage1(
|
527
530
|
topk_ids_ptr,
|
@@ -32,6 +32,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
32
32
|
|
33
33
|
if _use_aiter:
|
34
34
|
from aiter import ActivationType
|
35
|
+
from aiter.fused_moe import fused_moe
|
35
36
|
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
36
37
|
from aiter.ops.shuffle import shuffle_weight
|
37
38
|
|
@@ -204,7 +205,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
204
205
|
topk_weights, dtype=torch.float32
|
205
206
|
) # topk_weights must be FP32 (float32)
|
206
207
|
|
207
|
-
return
|
208
|
+
return fused_moe(
|
208
209
|
x,
|
209
210
|
layer.w13_weight,
|
210
211
|
layer.w2_weight,
|
@@ -241,7 +242,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
241
242
|
num_fused_shared_experts: int = 0,
|
242
243
|
custom_routing_function: Optional[Callable] = None,
|
243
244
|
correction_bias: Optional[torch.Tensor] = None,
|
245
|
+
activation: str = "silu",
|
246
|
+
apply_router_weight_on_input: bool = False,
|
244
247
|
inplace: bool = True,
|
248
|
+
no_combine: bool = False,
|
249
|
+
routed_scaling_factor: Optional[float] = None,
|
245
250
|
) -> torch.Tensor:
|
246
251
|
return moe_forward_native(
|
247
252
|
layer,
|
@@ -260,7 +265,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
260
265
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
261
266
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
262
267
|
|
263
|
-
forward_native =
|
268
|
+
forward_native = forward_cpu
|
264
269
|
|
265
270
|
|
266
271
|
class FusedMoE(torch.nn.Module):
|
@@ -310,6 +315,8 @@ class FusedMoE(torch.nn.Module):
|
|
310
315
|
inplace: bool = True,
|
311
316
|
no_combine: bool = False,
|
312
317
|
routed_scaling_factor: Optional[float] = None,
|
318
|
+
enable_flashinfer_moe: Optional[bool] = False,
|
319
|
+
enable_ep_moe: Optional[bool] = False,
|
313
320
|
):
|
314
321
|
super().__init__()
|
315
322
|
|
@@ -320,9 +327,40 @@ class FusedMoE(torch.nn.Module):
|
|
320
327
|
self.tp_size = (
|
321
328
|
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
322
329
|
)
|
330
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
331
|
+
self.num_experts = num_experts
|
332
|
+
self.expert_map = None
|
333
|
+
|
334
|
+
if enable_flashinfer_moe and quant_config is None:
|
335
|
+
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
336
|
+
enable_flashinfer_moe = False
|
337
|
+
enable_ep_moe = False
|
338
|
+
|
339
|
+
self.enable_flashinfer_moe = enable_flashinfer_moe
|
340
|
+
if enable_ep_moe:
|
341
|
+
assert (
|
342
|
+
self.enable_flashinfer_moe
|
343
|
+
), "FusedMoE only supports EP with --enable-flashinfer-moe"
|
344
|
+
self.ep_size = self.tp_size
|
345
|
+
self.ep_rank = self.tp_rank
|
346
|
+
self.tp_size = 1
|
347
|
+
self.tp_rank = 0
|
348
|
+
# Create a tensor of size num_experts filled with -1
|
349
|
+
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
350
|
+
# Create a expert map for the local experts
|
351
|
+
assert num_experts % self.ep_size == 0
|
352
|
+
self.local_num_experts = num_experts // self.ep_size
|
353
|
+
self.expert_map[
|
354
|
+
self.ep_rank
|
355
|
+
* self.local_num_experts : (self.ep_rank + 1)
|
356
|
+
* self.local_num_experts
|
357
|
+
] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
|
358
|
+
else:
|
359
|
+
self.ep_size = 1
|
360
|
+
self.ep_rank = 0
|
361
|
+
self.local_num_experts = num_experts
|
323
362
|
self.routed_scaling_factor = routed_scaling_factor
|
324
363
|
self.top_k = top_k
|
325
|
-
self.num_experts = num_experts
|
326
364
|
assert intermediate_size % self.tp_size == 0
|
327
365
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
328
366
|
self.reduce_results = reduce_results
|
@@ -340,7 +378,6 @@ class FusedMoE(torch.nn.Module):
|
|
340
378
|
self.use_presharded_weights = use_presharded_weights
|
341
379
|
self.inplace = inplace
|
342
380
|
self.no_combine = no_combine
|
343
|
-
self.local_num_experts = num_experts
|
344
381
|
|
345
382
|
if quant_config is None:
|
346
383
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
@@ -348,11 +385,13 @@ class FusedMoE(torch.nn.Module):
|
|
348
385
|
)
|
349
386
|
else:
|
350
387
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
388
|
+
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
|
389
|
+
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
|
351
390
|
assert self.quant_method is not None
|
352
391
|
|
353
392
|
self.quant_method.create_weights(
|
354
393
|
layer=self,
|
355
|
-
num_experts=
|
394
|
+
num_experts=self.local_num_experts,
|
356
395
|
hidden_size=hidden_size,
|
357
396
|
# FIXME: figure out which intermediate_size to use
|
358
397
|
intermediate_size=self.intermediate_size_per_partition,
|
@@ -446,12 +485,15 @@ class FusedMoE(torch.nn.Module):
|
|
446
485
|
|
447
486
|
# Narrow parameter and load.
|
448
487
|
# w1, gate_proj: Load into first logical weight of w13.
|
449
|
-
if shard_id == "w1":
|
450
|
-
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
451
488
|
# w3, up_proj: Load into second logical weight of w13.
|
489
|
+
# trtllm cutlass kernel assumes differently
|
490
|
+
assert shard_id in ("w1", "w3")
|
491
|
+
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
|
492
|
+
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
|
493
|
+
start = shard_size
|
452
494
|
else:
|
453
|
-
|
454
|
-
|
495
|
+
start = 0
|
496
|
+
expert_data = expert_data.narrow(shard_dim, start, shard_size)
|
455
497
|
expert_data.copy_(loaded_weight)
|
456
498
|
|
457
499
|
def _load_w2(
|
@@ -505,6 +547,11 @@ class FusedMoE(torch.nn.Module):
|
|
505
547
|
assert shard_id in ("w1", "w3")
|
506
548
|
expert_data.copy_(loaded_weight)
|
507
549
|
|
550
|
+
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
551
|
+
if self.expert_map is None:
|
552
|
+
return expert_id
|
553
|
+
return self.expert_map[expert_id].item()
|
554
|
+
|
508
555
|
def weight_loader(
|
509
556
|
self,
|
510
557
|
param: torch.nn.Parameter,
|
@@ -513,6 +560,13 @@ class FusedMoE(torch.nn.Module):
|
|
513
560
|
shard_id: str,
|
514
561
|
expert_id: int,
|
515
562
|
) -> None:
|
563
|
+
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
564
|
+
if expert_id == -1:
|
565
|
+
return
|
566
|
+
|
567
|
+
# TP rank is set to 0 if EP is enabled
|
568
|
+
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
|
569
|
+
|
516
570
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
517
571
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
518
572
|
# against known CompressionFormat enum values that have this quality
|
@@ -537,7 +591,6 @@ class FusedMoE(torch.nn.Module):
|
|
537
591
|
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
538
592
|
|
539
593
|
expert_data = param.data[expert_id]
|
540
|
-
tp_rank = get_tensor_model_parallel_rank()
|
541
594
|
|
542
595
|
# is_transposed: if the dim to shard the weight
|
543
596
|
# should be flipped. Required by GPTQ, compressed-tensors
|
@@ -545,7 +598,7 @@ class FusedMoE(torch.nn.Module):
|
|
545
598
|
is_transposed = getattr(param, "is_transposed", False)
|
546
599
|
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
547
600
|
if is_transposed:
|
548
|
-
shard_dim =
|
601
|
+
shard_dim = int(not shard_dim)
|
549
602
|
|
550
603
|
# Case input scale: input_scale loading is only supported for fp8
|
551
604
|
if "input_scale" in weight_name:
|
@@ -686,9 +739,19 @@ class FusedMoE(torch.nn.Module):
|
|
686
739
|
activation=self.activation,
|
687
740
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
688
741
|
routed_scaling_factor=self.routed_scaling_factor,
|
742
|
+
**(
|
743
|
+
dict(
|
744
|
+
tp_rank=self.tp_rank,
|
745
|
+
tp_size=self.tp_size,
|
746
|
+
ep_rank=self.ep_rank,
|
747
|
+
ep_size=self.ep_size,
|
748
|
+
)
|
749
|
+
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
|
750
|
+
else {}
|
751
|
+
),
|
689
752
|
)
|
690
753
|
|
691
|
-
if self.reduce_results and self.tp_size > 1:
|
754
|
+
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
692
755
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
693
756
|
|
694
757
|
return final_hidden_states
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -28,10 +28,18 @@ from sglang.srt.managers.expert_location_dispatch import (
|
|
28
28
|
topk_ids_logical_to_physical,
|
29
29
|
)
|
30
30
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
31
|
-
from sglang.srt.utils import
|
31
|
+
from sglang.srt.utils import (
|
32
|
+
cpu_has_amx_support,
|
33
|
+
get_compiler_backend,
|
34
|
+
is_cpu,
|
35
|
+
is_cuda,
|
36
|
+
is_hip,
|
37
|
+
)
|
32
38
|
|
33
39
|
_is_cuda = is_cuda()
|
34
40
|
_is_hip = is_hip()
|
41
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
42
|
+
_is_cpu = is_cpu()
|
35
43
|
|
36
44
|
if _is_cuda:
|
37
45
|
from sgl_kernel import moe_fused_gate
|
@@ -40,7 +48,7 @@ if _is_cuda or _is_hip:
|
|
40
48
|
from sgl_kernel import topk_softmax
|
41
49
|
|
42
50
|
|
43
|
-
def
|
51
|
+
def fused_topk_torch_native(
|
44
52
|
hidden_states: torch.Tensor,
|
45
53
|
gating_output: torch.Tensor,
|
46
54
|
topk: int,
|
@@ -61,6 +69,20 @@ def fused_topk_native(
|
|
61
69
|
return topk_weights, topk_ids
|
62
70
|
|
63
71
|
|
72
|
+
def fused_topk_cpu(
|
73
|
+
hidden_states: torch.Tensor,
|
74
|
+
gating_output: torch.Tensor,
|
75
|
+
topk: int,
|
76
|
+
renormalize: bool,
|
77
|
+
):
|
78
|
+
return torch.ops.sgl_kernel.topk_softmax_cpu(
|
79
|
+
hidden_states=hidden_states,
|
80
|
+
gating_output=gating_output,
|
81
|
+
topk=topk,
|
82
|
+
renormalize=renormalize,
|
83
|
+
)
|
84
|
+
|
85
|
+
|
64
86
|
def fused_topk(
|
65
87
|
hidden_states: torch.Tensor,
|
66
88
|
gating_output: torch.Tensor,
|
@@ -115,7 +137,7 @@ def _fused_topk_postprocess(
|
|
115
137
|
|
116
138
|
# This is used by the Deepseek V2/V3/R1 series models
|
117
139
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
118
|
-
def
|
140
|
+
def grouped_topk_gpu(
|
119
141
|
hidden_states: torch.Tensor,
|
120
142
|
gating_output: torch.Tensor,
|
121
143
|
topk: int,
|
@@ -171,6 +193,32 @@ def grouped_topk(
|
|
171
193
|
return topk_weights, topk_ids
|
172
194
|
|
173
195
|
|
196
|
+
def grouped_topk_cpu(
|
197
|
+
hidden_states: torch.Tensor,
|
198
|
+
gating_output: torch.Tensor,
|
199
|
+
topk: int,
|
200
|
+
renormalize: bool,
|
201
|
+
num_expert_group: int = 0,
|
202
|
+
topk_group: int = 0,
|
203
|
+
num_fused_shared_experts: int = 0,
|
204
|
+
routed_scaling_factor: Optional[float] = None,
|
205
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
206
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
207
|
+
):
|
208
|
+
assert expert_location_dispatch_info is None
|
209
|
+
return torch.ops.sgl_kernel.grouped_topk_cpu(
|
210
|
+
hidden_states,
|
211
|
+
gating_output,
|
212
|
+
topk,
|
213
|
+
renormalize,
|
214
|
+
num_expert_group,
|
215
|
+
topk_group,
|
216
|
+
num_fused_shared_experts,
|
217
|
+
routed_scaling_factor,
|
218
|
+
num_token_non_padded,
|
219
|
+
)
|
220
|
+
|
221
|
+
|
174
222
|
def biased_grouped_topk_impl(
|
175
223
|
hidden_states: torch.Tensor,
|
176
224
|
gating_output: torch.Tensor,
|
@@ -249,7 +297,16 @@ def _mask_topk_ids_padded_region(
|
|
249
297
|
topk_ids[indices >= num_token_non_padded, :] = -1
|
250
298
|
|
251
299
|
|
252
|
-
|
300
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
301
|
+
def _biased_grouped_topk_postprocess(
|
302
|
+
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
303
|
+
):
|
304
|
+
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
305
|
+
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
306
|
+
return topk_ids
|
307
|
+
|
308
|
+
|
309
|
+
def biased_grouped_topk_gpu(
|
253
310
|
hidden_states: torch.Tensor,
|
254
311
|
gating_output: torch.Tensor,
|
255
312
|
correction_bias: torch.Tensor,
|
@@ -282,14 +339,13 @@ def biased_grouped_topk(
|
|
282
339
|
num_fused_shared_experts,
|
283
340
|
routed_scaling_factor,
|
284
341
|
)
|
285
|
-
# TODO merge into kernel
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
)(topk_ids, num_token_non_padded)
|
342
|
+
# TODO merge into kernel
|
343
|
+
if (expert_location_dispatch_info is not None) or (
|
344
|
+
num_token_non_padded is not None
|
345
|
+
):
|
346
|
+
topk_ids = _biased_grouped_topk_postprocess(
|
347
|
+
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
348
|
+
)
|
293
349
|
return topk_weights, topk_ids
|
294
350
|
else:
|
295
351
|
biased_grouped_topk_fn = (
|
@@ -314,6 +370,45 @@ def biased_grouped_topk(
|
|
314
370
|
)
|
315
371
|
|
316
372
|
|
373
|
+
def biased_grouped_topk_cpu(
|
374
|
+
hidden_states: torch.Tensor,
|
375
|
+
gating_output: torch.Tensor,
|
376
|
+
correction_bias: torch.Tensor,
|
377
|
+
topk: int,
|
378
|
+
renormalize: bool,
|
379
|
+
num_expert_group: int = 0,
|
380
|
+
topk_group: int = 0,
|
381
|
+
compiled: bool = True,
|
382
|
+
num_fused_shared_experts: int = 0,
|
383
|
+
routed_scaling_factor: Optional[float] = None,
|
384
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
385
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
386
|
+
):
|
387
|
+
assert expert_location_dispatch_info is None
|
388
|
+
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
389
|
+
hidden_states,
|
390
|
+
gating_output,
|
391
|
+
correction_bias,
|
392
|
+
topk,
|
393
|
+
renormalize,
|
394
|
+
num_expert_group,
|
395
|
+
topk_group,
|
396
|
+
num_fused_shared_experts,
|
397
|
+
routed_scaling_factor,
|
398
|
+
num_token_non_padded,
|
399
|
+
)
|
400
|
+
|
401
|
+
|
402
|
+
if _is_cpu and _is_cpu_amx_available:
|
403
|
+
biased_grouped_topk = biased_grouped_topk_cpu
|
404
|
+
grouped_topk = grouped_topk_cpu
|
405
|
+
fused_topk_native = fused_topk_cpu
|
406
|
+
else:
|
407
|
+
biased_grouped_topk = biased_grouped_topk_gpu
|
408
|
+
grouped_topk = grouped_topk_gpu
|
409
|
+
fused_topk_native = fused_topk_torch_native
|
410
|
+
|
411
|
+
|
317
412
|
def select_experts(
|
318
413
|
hidden_states: torch.Tensor,
|
319
414
|
router_logits: torch.Tensor,
|
sglang/srt/layers/pooler.py
CHANGED
@@ -3,10 +3,13 @@
|
|
3
3
|
|
4
4
|
from dataclasses import dataclass
|
5
5
|
from enum import IntEnum
|
6
|
+
from typing import Optional
|
6
7
|
|
7
8
|
import torch
|
8
9
|
import torch.nn as nn
|
10
|
+
from transformers import PretrainedConfig
|
9
11
|
|
12
|
+
from sglang.srt.layers.activation import get_cross_encoder_activation_function
|
10
13
|
from sglang.srt.model_executor.model_runner import ForwardBatch
|
11
14
|
|
12
15
|
|
@@ -54,3 +57,56 @@ class Pooler(nn.Module):
|
|
54
57
|
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
|
55
58
|
|
56
59
|
return EmbeddingPoolerOutput(embeddings=pooled_data)
|
60
|
+
|
61
|
+
|
62
|
+
class CrossEncodingPooler(nn.Module):
|
63
|
+
"""A layer that pools specific information from hidden states.
|
64
|
+
|
65
|
+
This layer does the following:
|
66
|
+
1. Extracts specific tokens or aggregates data based on pooling method.
|
67
|
+
2. Normalizes output if specified.
|
68
|
+
3. Returns structured results as `EmbeddingPoolerOutput`.
|
69
|
+
"""
|
70
|
+
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
config: PretrainedConfig,
|
74
|
+
classifier: nn.Module,
|
75
|
+
pooler: Optional[nn.Module] = None,
|
76
|
+
):
|
77
|
+
super().__init__()
|
78
|
+
self.classifier = classifier
|
79
|
+
self.pooler = pooler
|
80
|
+
self.default_activation_function = get_cross_encoder_activation_function(config)
|
81
|
+
|
82
|
+
def forward(
|
83
|
+
self,
|
84
|
+
hidden_states: torch.Tensor,
|
85
|
+
forward_batch: ForwardBatch,
|
86
|
+
) -> EmbeddingPoolerOutput:
|
87
|
+
"""Pools sentence pair scores from the hidden_states."""
|
88
|
+
|
89
|
+
prompt_lens = forward_batch.extend_seq_lens
|
90
|
+
|
91
|
+
offset = 0
|
92
|
+
pooled_data_lst = []
|
93
|
+
for prompt_len in prompt_lens:
|
94
|
+
pooled_data_i = hidden_states[offset : offset + prompt_len]
|
95
|
+
|
96
|
+
if self.pooler is not None:
|
97
|
+
final_shape_tensor = self.pooler(pooled_data_i, forward_batch)
|
98
|
+
else:
|
99
|
+
final_shape_tensor = self.classifier(pooled_data_i)
|
100
|
+
|
101
|
+
pooled_data_lst.append(final_shape_tensor)
|
102
|
+
offset += prompt_len
|
103
|
+
|
104
|
+
pooled_output = torch.stack(pooled_data_lst)
|
105
|
+
|
106
|
+
if self.pooler is not None:
|
107
|
+
# apply classifier once on the full batch if possible
|
108
|
+
pooled_output = self.classifier(pooled_output)
|
109
|
+
|
110
|
+
scores = self.default_activation_function(pooled_output).squeeze(-1)
|
111
|
+
|
112
|
+
return EmbeddingPoolerOutput(embeddings=scores)
|
@@ -14,14 +14,18 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
|
|
14
14
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
15
15
|
from sglang.srt.layers.quantization.utils import (
|
16
16
|
all_close_1d,
|
17
|
+
cpu_has_amx_support,
|
17
18
|
per_tensor_dequantize,
|
18
19
|
replace_parameter,
|
19
20
|
)
|
20
|
-
from sglang.srt.utils import is_cuda, set_weight_attrs
|
21
|
+
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
|
21
22
|
|
22
23
|
_is_cuda = is_cuda()
|
24
|
+
_is_npu = is_npu()
|
25
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
26
|
+
_is_cpu = is_cpu()
|
23
27
|
|
24
|
-
if not _is_cuda:
|
28
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
25
29
|
from vllm import _custom_ops as vllm_ops
|
26
30
|
from vllm._custom_ops import scaled_fp8_quant
|
27
31
|
|
@@ -0,0 +1 @@
|
|
1
|
+
from .entrypoint import *
|
@@ -5,34 +5,23 @@ from dataclasses import dataclass
|
|
5
5
|
from enum import IntEnum, auto
|
6
6
|
from typing import Callable, Dict, List, Optional, Tuple
|
7
7
|
|
8
|
-
import torch
|
9
8
|
from tqdm.contrib.concurrent import thread_map
|
10
9
|
|
10
|
+
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
11
|
+
DEEPGEMM_BLACKWELL,
|
12
|
+
ENABLE_JIT_DEEPGEMM,
|
13
|
+
)
|
11
14
|
from sglang.srt.server_args import ServerArgs
|
12
|
-
from sglang.srt.utils import get_bool_env_var,
|
15
|
+
from sglang.srt.utils import get_bool_env_var, get_int_env_var
|
13
16
|
|
14
17
|
logger = logging.getLogger(__name__)
|
15
|
-
_ENABLE_JIT_DEEPGEMM = False
|
16
18
|
|
17
|
-
|
18
|
-
import deep_gemm
|
19
|
+
if ENABLE_JIT_DEEPGEMM and not DEEPGEMM_BLACKWELL:
|
19
20
|
from deep_gemm import get_num_sms
|
20
21
|
from deep_gemm.jit import build
|
21
|
-
from deep_gemm.jit.compiler import get_nvcc_compiler
|
22
22
|
from deep_gemm.jit_kernels.gemm import get_best_configs
|
23
23
|
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
24
24
|
|
25
|
-
sm_version = get_device_sm()
|
26
|
-
if sm_version == 90:
|
27
|
-
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
28
|
-
_ENABLE_JIT_DEEPGEMM = True
|
29
|
-
except ImportError:
|
30
|
-
logger.warning("Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.")
|
31
|
-
|
32
|
-
|
33
|
-
def get_enable_jit_deepgemm():
|
34
|
-
return _ENABLE_JIT_DEEPGEMM
|
35
|
-
|
36
25
|
|
37
26
|
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
38
27
|
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
|
@@ -52,8 +41,10 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
|
52
41
|
# NVRTC may have performance loss with some cases.
|
53
42
|
# And NVCC JIT speed is also 9x faster in the ref commit
|
54
43
|
_USE_NVRTC_DEFAULT = "0"
|
55
|
-
if
|
44
|
+
if ENABLE_JIT_DEEPGEMM:
|
56
45
|
try:
|
46
|
+
from deep_gemm.jit.compiler import get_nvcc_compiler
|
47
|
+
|
57
48
|
get_nvcc_compiler()
|
58
49
|
except:
|
59
50
|
logger.warning(
|
@@ -114,11 +105,12 @@ class DeepGemmKernelHelper:
|
|
114
105
|
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
|
115
106
|
|
116
107
|
|
108
|
+
# TODO improve naming
|
117
109
|
def _compile_warning_1():
|
118
110
|
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
119
111
|
logger.warning(
|
120
112
|
"Entering DeepGEMM JIT Pre-Compile session. "
|
121
|
-
"
|
113
|
+
"It may takes a long time (typically 10-20 mins) "
|
122
114
|
"if you have not run `sglang.compile_deep_gemm`. "
|
123
115
|
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
124
116
|
" for pre-compilation to reduce the overhead if you have not run it before. "
|
@@ -127,6 +119,7 @@ def _compile_warning_1():
|
|
127
119
|
)
|
128
120
|
|
129
121
|
|
122
|
+
# TODO improve naming
|
130
123
|
def _compile_warning_2():
|
131
124
|
logger.warning(
|
132
125
|
"Entering DeepGEMM JIT Single Kernel Compile session. "
|
@@ -238,6 +231,7 @@ def _compile_gemm_nt_f8f8bf16_one(
|
|
238
231
|
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
239
232
|
|
240
233
|
|
234
|
+
# TODO further refactor warmup-related
|
241
235
|
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
|
242
236
|
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
|
243
237
|
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
|
@@ -270,7 +264,6 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
270
264
|
num_groups: int,
|
271
265
|
m_list: Optional[List[int]] = None,
|
272
266
|
) -> None:
|
273
|
-
|
274
267
|
global _INITIALIZATION_DICT
|
275
268
|
global _BUILTIN_M_LIST
|
276
269
|
|
@@ -304,56 +297,6 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
304
297
|
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
|
305
298
|
|
306
299
|
|
307
|
-
def grouped_gemm_nt_f8f8bf16_masked(
|
308
|
-
lhs: Tuple[torch.Tensor, torch.Tensor],
|
309
|
-
rhs: Tuple[torch.Tensor, torch.Tensor],
|
310
|
-
out: torch.Tensor,
|
311
|
-
masked_m: torch.Tensor,
|
312
|
-
expected_m: int,
|
313
|
-
):
|
314
|
-
num_groups, _, k = lhs[0].shape
|
315
|
-
_, n, _ = rhs[0].shape
|
316
|
-
|
317
|
-
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
318
|
-
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
319
|
-
|
320
|
-
with _log_jit_build(expected_m, n, k, kernel_type):
|
321
|
-
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
322
|
-
lhs, rhs, out, masked_m, expected_m
|
323
|
-
)
|
324
|
-
|
325
|
-
|
326
|
-
def grouped_gemm_nt_f8f8bf16_contig(
|
327
|
-
lhs: Tuple[torch.Tensor, torch.Tensor],
|
328
|
-
rhs: Tuple[torch.Tensor, torch.Tensor],
|
329
|
-
out: torch.Tensor,
|
330
|
-
m_indices: torch.Tensor,
|
331
|
-
):
|
332
|
-
m, k = lhs[0].shape
|
333
|
-
num_groups, n, _ = rhs[0].shape
|
334
|
-
|
335
|
-
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
336
|
-
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
337
|
-
|
338
|
-
with _log_jit_build(m, n, k, kernel_type):
|
339
|
-
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
|
340
|
-
|
341
|
-
|
342
|
-
def gemm_nt_f8f8bf16(
|
343
|
-
lhs: Tuple[torch.Tensor, torch.Tensor],
|
344
|
-
rhs: Tuple[torch.Tensor, torch.Tensor],
|
345
|
-
out: torch.Tensor,
|
346
|
-
):
|
347
|
-
m, k = lhs[0].shape
|
348
|
-
n, _ = rhs[0].shape
|
349
|
-
|
350
|
-
kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16
|
351
|
-
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1)
|
352
|
-
|
353
|
-
with _log_jit_build(m, n, k, kernel_type):
|
354
|
-
deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
|
355
|
-
|
356
|
-
|
357
300
|
@contextmanager
|
358
301
|
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
359
302
|
if _IN_PRECOMPILE_STAGE:
|
@@ -368,7 +311,8 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
|
368
311
|
ret = origin_func(self, *args, **kwargs)
|
369
312
|
if ret is None:
|
370
313
|
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
|
371
|
-
|
314
|
+
if not DEEPGEMM_BLACKWELL:
|
315
|
+
_compile_warning_2()
|
372
316
|
logger.warning(
|
373
317
|
f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
|
374
318
|
)
|
@@ -380,13 +324,12 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
|
380
324
|
|
381
325
|
|
382
326
|
@contextmanager
|
383
|
-
def
|
384
|
-
|
327
|
+
def deep_gemm_execution_hook(
|
328
|
+
m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
|
329
|
+
):
|
330
|
+
# not supported yet
|
331
|
+
if not DEEPGEMM_BLACKWELL:
|
332
|
+
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
333
|
+
|
334
|
+
with _log_jit_build(m, n, k, kernel_type):
|
385
335
|
yield
|
386
|
-
else:
|
387
|
-
original_num_sms = deep_gemm.get_num_sms()
|
388
|
-
deep_gemm.set_num_sms(num_sms)
|
389
|
-
try:
|
390
|
-
yield
|
391
|
-
finally:
|
392
|
-
deep_gemm.set_num_sms(original_num_sms)
|