sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,9 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import importlib
|
2
4
|
import sys
|
3
5
|
from types import MappingProxyType
|
4
|
-
from typing import
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
|
5
7
|
|
6
8
|
import torch
|
7
9
|
from torch.nn.parameter import Parameter
|
@@ -11,21 +13,20 @@ from sglang.srt.distributed import (
|
|
11
13
|
get_tensor_model_parallel_world_size,
|
12
14
|
)
|
13
15
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
14
|
-
from sglang.srt.layers.linear import (
|
15
|
-
LinearMethodBase,
|
16
|
-
RowParallelLinear,
|
17
|
-
UnquantizedLinearMethod,
|
18
|
-
)
|
19
16
|
from sglang.srt.layers.parameter import (
|
20
17
|
ChannelQuantScaleParameter,
|
21
18
|
ModelWeightParameter,
|
22
19
|
PerTensorScaleParameter,
|
23
20
|
)
|
24
21
|
from sglang.srt.layers.quantization.base_config import (
|
22
|
+
FusedMoEMethodBase,
|
23
|
+
LinearMethodBase,
|
25
24
|
QuantizationConfig,
|
26
25
|
QuantizeMethodBase,
|
27
26
|
)
|
27
|
+
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
|
28
28
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
29
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
29
30
|
from sglang.srt.utils import (
|
30
31
|
apply_module_patch,
|
31
32
|
cpu_has_amx_support,
|
@@ -36,6 +37,9 @@ from sglang.srt.utils import (
|
|
36
37
|
use_intel_amx_backend,
|
37
38
|
)
|
38
39
|
|
40
|
+
if TYPE_CHECKING:
|
41
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
42
|
+
|
39
43
|
_is_cuda = is_cuda()
|
40
44
|
_is_cpu_amx_available = cpu_has_amx_support()
|
41
45
|
_is_cpu = is_cpu()
|
@@ -178,17 +182,18 @@ class W8A8Int8Config(QuantizationConfig):
|
|
178
182
|
- Activation: dynamic, per-token, symmetric
|
179
183
|
"""
|
180
184
|
|
181
|
-
def __init__(self, quant_config: Dict[str, Any]):
|
185
|
+
def __init__(self, quant_config: Dict[str, Any] = {}):
|
182
186
|
super().__init__()
|
183
187
|
self.quant_description = quant_config
|
184
188
|
self.is_dynamic = quant_config.get("is_dynamic", False)
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
189
|
+
ignore = cast(List[str], quant_config.get("ignore", []))
|
190
|
+
self.ignore = ignore if ignore is not None else []
|
191
|
+
packed_modules_mapping = quant_config.get("packed_modules_mapping", {})
|
192
|
+
self.packed_modules_mapping = (
|
193
|
+
packed_modules_mapping if packed_modules_mapping is not None else {}
|
194
|
+
)
|
191
195
|
|
196
|
+
if _is_npu:
|
192
197
|
# Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
|
193
198
|
for name in self.quant_description.keys():
|
194
199
|
if "norm.bias" in name:
|
@@ -229,14 +234,14 @@ class W8A8Int8Config(QuantizationConfig):
|
|
229
234
|
return []
|
230
235
|
|
231
236
|
@classmethod
|
232
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
237
|
+
def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config:
|
233
238
|
return cls(config)
|
234
239
|
|
235
240
|
def get_quant_method(
|
236
241
|
self,
|
237
242
|
layer: torch.nn.Module,
|
238
243
|
prefix: str,
|
239
|
-
) -> Optional[
|
244
|
+
) -> Optional[QuantizeMethodBase]:
|
240
245
|
from sglang.srt.layers.linear import LinearBase
|
241
246
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
242
247
|
|
@@ -262,12 +267,16 @@ class W8A8Int8Config(QuantizationConfig):
|
|
262
267
|
elif isinstance(layer, FusedMoE):
|
263
268
|
return NPU_W8A8MoEMethod(self)
|
264
269
|
return None
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
270
|
+
|
271
|
+
if should_ignore_layer(
|
272
|
+
prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
|
273
|
+
):
|
274
|
+
return UnquantizedLinearMethod()
|
275
|
+
if isinstance(layer, LinearBase):
|
276
|
+
return W8A8Int8LinearMethod(self)
|
277
|
+
elif isinstance(layer, FusedMoE):
|
278
|
+
return W8A8Int8MoEMethod(self)
|
279
|
+
return None
|
271
280
|
|
272
281
|
def is_layer_skipped(
|
273
282
|
self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
@@ -374,7 +383,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
374
383
|
)
|
375
384
|
|
376
385
|
|
377
|
-
class W8A8Int8MoEMethod:
|
386
|
+
class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
378
387
|
"""MoE method for INT8.
|
379
388
|
Supports loading INT8 checkpoints with static weight scale and
|
380
389
|
dynamic/static activation scale.
|
@@ -385,25 +394,7 @@ class W8A8Int8MoEMethod:
|
|
385
394
|
quant_config: The quantization config.
|
386
395
|
"""
|
387
396
|
|
388
|
-
def
|
389
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
390
|
-
|
391
|
-
if not hasattr(cls, "_initialized"):
|
392
|
-
original_init = cls.__init__
|
393
|
-
new_cls = type(
|
394
|
-
cls.__name__,
|
395
|
-
(FusedMoEMethodBase,),
|
396
|
-
{
|
397
|
-
"__init__": original_init,
|
398
|
-
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
399
|
-
},
|
400
|
-
)
|
401
|
-
obj = super(new_cls, new_cls).__new__(new_cls)
|
402
|
-
obj.__init__(*args, **kwargs)
|
403
|
-
return obj
|
404
|
-
return super().__new__(cls)
|
405
|
-
|
406
|
-
def __init__(self, quant_config):
|
397
|
+
def __init__(self, quant_config: W8A8Int8Config):
|
407
398
|
self.quant_config = quant_config
|
408
399
|
|
409
400
|
def create_weights(
|
@@ -481,15 +472,8 @@ class W8A8Int8MoEMethod:
|
|
481
472
|
self,
|
482
473
|
layer: torch.nn.Module,
|
483
474
|
x: torch.Tensor,
|
484
|
-
|
485
|
-
|
486
|
-
renormalize: bool,
|
487
|
-
use_grouped_topk: bool,
|
488
|
-
topk_group: Optional[int] = None,
|
489
|
-
num_expert_group: Optional[int] = None,
|
490
|
-
num_fused_shared_experts: int = 0,
|
491
|
-
custom_routing_function: Optional[Callable] = None,
|
492
|
-
correction_bias: Optional[torch.Tensor] = None,
|
475
|
+
topk_output: TopKOutput,
|
476
|
+
*,
|
493
477
|
activation: str = "silu",
|
494
478
|
apply_router_weight_on_input: bool = False,
|
495
479
|
inplace: bool = True,
|
@@ -497,24 +481,14 @@ class W8A8Int8MoEMethod:
|
|
497
481
|
routed_scaling_factor: Optional[float] = None,
|
498
482
|
) -> torch.Tensor:
|
499
483
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
500
|
-
from sglang.srt.layers.moe.topk import select_experts
|
501
|
-
|
502
|
-
# Expert selection
|
503
|
-
topk_weights, topk_ids = select_experts(
|
504
|
-
hidden_states=x,
|
505
|
-
router_logits=router_logits,
|
506
|
-
use_grouped_topk=use_grouped_topk,
|
507
|
-
top_k=top_k,
|
508
|
-
renormalize=renormalize,
|
509
|
-
topk_group=topk_group,
|
510
|
-
num_expert_group=num_expert_group,
|
511
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
512
|
-
custom_routing_function=custom_routing_function,
|
513
|
-
correction_bias=correction_bias,
|
514
|
-
routed_scaling_factor=routed_scaling_factor,
|
515
|
-
)
|
516
484
|
|
517
485
|
if use_intel_amx_backend(layer):
|
486
|
+
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
487
|
+
|
488
|
+
topk_weights, topk_ids, _ = topk_output
|
489
|
+
x, topk_weights = apply_topk_weights_cpu(
|
490
|
+
apply_router_weight_on_input, topk_weights, x
|
491
|
+
)
|
518
492
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
519
493
|
x,
|
520
494
|
layer.w13_weight,
|
@@ -536,8 +510,7 @@ class W8A8Int8MoEMethod:
|
|
536
510
|
x,
|
537
511
|
layer.w13_weight,
|
538
512
|
layer.w2_weight,
|
539
|
-
|
540
|
-
topk_ids=topk_ids,
|
513
|
+
topk_output=topk_output,
|
541
514
|
inplace=inplace,
|
542
515
|
activation=activation,
|
543
516
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
@@ -761,6 +734,8 @@ class NPU_W8A8LinearMethod(LinearMethodBase):
|
|
761
734
|
x: torch.Tensor,
|
762
735
|
bias: Optional[torch.Tensor] = None,
|
763
736
|
) -> torch.Tensor:
|
737
|
+
from sglang.srt.layers.linear import RowParallelLinear
|
738
|
+
|
764
739
|
if isinstance(layer, RowParallelLinear):
|
765
740
|
tp_rank = get_tensor_model_parallel_rank()
|
766
741
|
return self.quant_method.apply(layer, x, bias, tp_rank)
|
@@ -885,13 +860,15 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
|
885
860
|
x: torch.Tensor,
|
886
861
|
bias: Optional[torch.Tensor] = None,
|
887
862
|
) -> torch.Tensor:
|
863
|
+
from sglang.srt.layers.linear import RowParallelLinear
|
864
|
+
|
888
865
|
if isinstance(layer, RowParallelLinear):
|
889
866
|
tp_rank = get_tensor_model_parallel_rank()
|
890
867
|
return self.quant_method.apply(layer, x, bias, tp_rank)
|
891
868
|
return self.quant_method.apply(layer, x, bias)
|
892
869
|
|
893
870
|
|
894
|
-
class NPU_W8A8MoEMethod:
|
871
|
+
class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
895
872
|
"""MoE method for NPU quantization.
|
896
873
|
|
897
874
|
This class search for specific quantization
|
@@ -910,7 +887,7 @@ class NPU_W8A8MoEMethod:
|
|
910
887
|
layer: torch.nn.Module,
|
911
888
|
num_experts: int,
|
912
889
|
hidden_size: int,
|
913
|
-
intermediate_size:
|
890
|
+
intermediate_size: int,
|
914
891
|
params_dtype: torch.dtype,
|
915
892
|
**extra_weight_attrs,
|
916
893
|
) -> None:
|
@@ -987,52 +964,11 @@ class NPU_W8A8MoEMethod:
|
|
987
964
|
self,
|
988
965
|
layer,
|
989
966
|
x,
|
990
|
-
|
991
|
-
top_k,
|
992
|
-
renormalize,
|
993
|
-
use_grouped_topk,
|
994
|
-
topk_group,
|
995
|
-
num_expert_group,
|
996
|
-
num_fused_shared_experts,
|
997
|
-
custom_routing_function,
|
998
|
-
correction_bias,
|
999
|
-
activation,
|
1000
|
-
apply_router_weight_on_input,
|
1001
|
-
routed_scaling_factor,
|
967
|
+
topk_output: TopKOutput,
|
1002
968
|
**kwargs,
|
1003
969
|
) -> torch.Tensor:
|
1004
|
-
|
1005
|
-
|
1006
|
-
global_num_experts = router_logits.shape[-1]
|
1007
|
-
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
1008
|
-
if global_num_experts == 256:
|
1009
|
-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
1010
|
-
router_logits,
|
1011
|
-
k=top_k,
|
1012
|
-
bias=correction_bias,
|
1013
|
-
k_group=topk_group,
|
1014
|
-
group_count=num_expert_group,
|
1015
|
-
group_select_mode=1,
|
1016
|
-
renorm=0,
|
1017
|
-
norm_type=1,
|
1018
|
-
routed_scaling_factor=1,
|
1019
|
-
eps=float(1e-20),
|
1020
|
-
)
|
1021
|
-
else:
|
1022
|
-
topk_weights, topk_ids = select_experts(
|
1023
|
-
hidden_states=x,
|
1024
|
-
router_logits=router_logits,
|
1025
|
-
use_grouped_topk=use_grouped_topk,
|
1026
|
-
top_k=top_k,
|
1027
|
-
renormalize=renormalize,
|
1028
|
-
topk_group=topk_group,
|
1029
|
-
num_expert_group=num_expert_group,
|
1030
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
1031
|
-
custom_routing_function=custom_routing_function,
|
1032
|
-
correction_bias=correction_bias,
|
1033
|
-
torch_native=True,
|
1034
|
-
routed_scaling_factor=routed_scaling_factor,
|
1035
|
-
)
|
970
|
+
|
971
|
+
topk_weights, topk_ids, _ = topk_output
|
1036
972
|
topk_ids = topk_ids.to(torch.int32)
|
1037
973
|
topk_weights = topk_weights.to(x.dtype)
|
1038
974
|
return npu_fused_experts(
|
@@ -1043,5 +979,5 @@ class NPU_W8A8MoEMethod:
|
|
1043
979
|
w2_scale=layer.w2_weight_scale,
|
1044
980
|
topk_weights=topk_weights,
|
1045
981
|
topk_ids=topk_ids,
|
1046
|
-
top_k=
|
982
|
+
top_k=topk_ids.shape[1],
|
1047
983
|
)
|
@@ -12,14 +12,16 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
"""Radix attention."""
|
15
|
+
from __future__ import annotations
|
15
16
|
|
16
17
|
from enum import Enum
|
17
|
-
from typing import Optional
|
18
|
+
from typing import TYPE_CHECKING, Optional
|
18
19
|
|
19
20
|
from torch import nn
|
20
21
|
|
21
|
-
|
22
|
-
from sglang.srt.
|
22
|
+
if TYPE_CHECKING:
|
23
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
24
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
23
25
|
|
24
26
|
|
25
27
|
class AttentionType(Enum):
|
@@ -5,7 +5,6 @@ from dataclasses import dataclass
|
|
5
5
|
from typing import List, Optional, Sequence, Tuple
|
6
6
|
|
7
7
|
import torch
|
8
|
-
import torch.nn.functional as F
|
9
8
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
10
9
|
|
11
10
|
from sglang.srt.distributed import (
|
@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
22
21
|
QuantizeMethodBase,
|
23
22
|
method_has_implemented_embedding,
|
24
23
|
)
|
24
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
|
25
25
|
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
|
26
26
|
|
27
27
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
@@ -32,44 +32,6 @@ _is_cpu = is_cpu()
|
|
32
32
|
logger = logging.getLogger(__name__)
|
33
33
|
|
34
34
|
|
35
|
-
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
36
|
-
"""Unquantized method for embeddings."""
|
37
|
-
|
38
|
-
def create_weights(
|
39
|
-
self,
|
40
|
-
layer: torch.nn.Module,
|
41
|
-
input_size_per_partition: int,
|
42
|
-
output_partition_sizes: List[int],
|
43
|
-
input_size: int,
|
44
|
-
output_size: int,
|
45
|
-
params_dtype: torch.dtype,
|
46
|
-
**extra_weight_attrs,
|
47
|
-
):
|
48
|
-
"""Create weights for embedding layer."""
|
49
|
-
weight = Parameter(
|
50
|
-
torch.empty(
|
51
|
-
sum(output_partition_sizes),
|
52
|
-
input_size_per_partition,
|
53
|
-
dtype=params_dtype,
|
54
|
-
),
|
55
|
-
requires_grad=False,
|
56
|
-
)
|
57
|
-
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
58
|
-
layer.register_parameter("weight", weight)
|
59
|
-
set_weight_attrs(weight, extra_weight_attrs)
|
60
|
-
|
61
|
-
def apply(
|
62
|
-
self,
|
63
|
-
layer: torch.nn.Module,
|
64
|
-
x: torch.Tensor,
|
65
|
-
bias: Optional[torch.Tensor] = None,
|
66
|
-
) -> torch.Tensor:
|
67
|
-
return F.linear(x, layer.weight, bias)
|
68
|
-
|
69
|
-
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
|
70
|
-
return F.embedding(input_, layer.weight)
|
71
|
-
|
72
|
-
|
73
35
|
def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
74
36
|
"""Pad the vocab size to the given value."""
|
75
37
|
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
@@ -569,8 +531,6 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
569
531
|
if _is_cpu and _is_cpu_amx_available:
|
570
532
|
if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16:
|
571
533
|
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
572
|
-
else:
|
573
|
-
logger.warning("The weight of LmHead is not packed")
|
574
534
|
|
575
535
|
if bias:
|
576
536
|
self.bias = Parameter(
|
sglang/srt/lora/lora.py
CHANGED
@@ -186,10 +186,6 @@ class LoRAAdapter(nn.Module):
|
|
186
186
|
up_name = weight_name.replace("gate_proj", "up_proj")
|
187
187
|
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
188
188
|
if up_name not in weights:
|
189
|
-
logger.warning(
|
190
|
-
f"Gate projection {weight_name} does not have a corresponding up projection {up_name}. "
|
191
|
-
f"Initializing up projection to zero."
|
192
|
-
)
|
193
189
|
weights[up_name] = torch.zeros_like(weights[weight_name])
|
194
190
|
# FIXME: Add gate-only support for flashinfer in future implementations
|
195
191
|
assert self.lora_backend.name == "triton", (
|