sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,9 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import importlib
|
2
4
|
import sys
|
3
5
|
from types import MappingProxyType
|
4
|
-
from typing import
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
|
5
7
|
|
6
8
|
import torch
|
7
9
|
from torch.nn.parameter import Parameter
|
@@ -11,21 +13,20 @@ from sglang.srt.distributed import (
|
|
11
13
|
get_tensor_model_parallel_world_size,
|
12
14
|
)
|
13
15
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
14
|
-
from sglang.srt.layers.linear import (
|
15
|
-
LinearMethodBase,
|
16
|
-
RowParallelLinear,
|
17
|
-
UnquantizedLinearMethod,
|
18
|
-
)
|
19
16
|
from sglang.srt.layers.parameter import (
|
20
17
|
ChannelQuantScaleParameter,
|
21
18
|
ModelWeightParameter,
|
22
19
|
PerTensorScaleParameter,
|
23
20
|
)
|
24
21
|
from sglang.srt.layers.quantization.base_config import (
|
22
|
+
FusedMoEMethodBase,
|
23
|
+
LinearMethodBase,
|
25
24
|
QuantizationConfig,
|
26
25
|
QuantizeMethodBase,
|
27
26
|
)
|
27
|
+
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
|
28
28
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
29
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
29
30
|
from sglang.srt.utils import (
|
30
31
|
apply_module_patch,
|
31
32
|
cpu_has_amx_support,
|
@@ -36,6 +37,9 @@ from sglang.srt.utils import (
|
|
36
37
|
use_intel_amx_backend,
|
37
38
|
)
|
38
39
|
|
40
|
+
if TYPE_CHECKING:
|
41
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
42
|
+
|
39
43
|
_is_cuda = is_cuda()
|
40
44
|
_is_cpu_amx_available = cpu_has_amx_support()
|
41
45
|
_is_cpu = is_cpu()
|
@@ -178,17 +182,18 @@ class W8A8Int8Config(QuantizationConfig):
|
|
178
182
|
- Activation: dynamic, per-token, symmetric
|
179
183
|
"""
|
180
184
|
|
181
|
-
def __init__(self, quant_config: Dict[str, Any]):
|
185
|
+
def __init__(self, quant_config: Dict[str, Any] = {}):
|
182
186
|
super().__init__()
|
183
187
|
self.quant_description = quant_config
|
184
188
|
self.is_dynamic = quant_config.get("is_dynamic", False)
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
189
|
+
ignore = cast(List[str], quant_config.get("ignore", []))
|
190
|
+
self.ignore = ignore if ignore is not None else []
|
191
|
+
packed_modules_mapping = quant_config.get("packed_modules_mapping", {})
|
192
|
+
self.packed_modules_mapping = (
|
193
|
+
packed_modules_mapping if packed_modules_mapping is not None else {}
|
194
|
+
)
|
191
195
|
|
196
|
+
if _is_npu:
|
192
197
|
# Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
|
193
198
|
for name in self.quant_description.keys():
|
194
199
|
if "norm.bias" in name:
|
@@ -229,14 +234,14 @@ class W8A8Int8Config(QuantizationConfig):
|
|
229
234
|
return []
|
230
235
|
|
231
236
|
@classmethod
|
232
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
237
|
+
def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config:
|
233
238
|
return cls(config)
|
234
239
|
|
235
240
|
def get_quant_method(
|
236
241
|
self,
|
237
242
|
layer: torch.nn.Module,
|
238
243
|
prefix: str,
|
239
|
-
) -> Optional[
|
244
|
+
) -> Optional[QuantizeMethodBase]:
|
240
245
|
from sglang.srt.layers.linear import LinearBase
|
241
246
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
242
247
|
|
@@ -262,12 +267,16 @@ class W8A8Int8Config(QuantizationConfig):
|
|
262
267
|
elif isinstance(layer, FusedMoE):
|
263
268
|
return NPU_W8A8MoEMethod(self)
|
264
269
|
return None
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
270
|
+
|
271
|
+
if should_ignore_layer(
|
272
|
+
prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
|
273
|
+
):
|
274
|
+
return UnquantizedLinearMethod()
|
275
|
+
if isinstance(layer, LinearBase):
|
276
|
+
return W8A8Int8LinearMethod(self)
|
277
|
+
elif isinstance(layer, FusedMoE):
|
278
|
+
return W8A8Int8MoEMethod(self)
|
279
|
+
return None
|
271
280
|
|
272
281
|
def is_layer_skipped(
|
273
282
|
self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
@@ -374,7 +383,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
374
383
|
)
|
375
384
|
|
376
385
|
|
377
|
-
class W8A8Int8MoEMethod:
|
386
|
+
class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
378
387
|
"""MoE method for INT8.
|
379
388
|
Supports loading INT8 checkpoints with static weight scale and
|
380
389
|
dynamic/static activation scale.
|
@@ -385,25 +394,7 @@ class W8A8Int8MoEMethod:
|
|
385
394
|
quant_config: The quantization config.
|
386
395
|
"""
|
387
396
|
|
388
|
-
def
|
389
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
390
|
-
|
391
|
-
if not hasattr(cls, "_initialized"):
|
392
|
-
original_init = cls.__init__
|
393
|
-
new_cls = type(
|
394
|
-
cls.__name__,
|
395
|
-
(FusedMoEMethodBase,),
|
396
|
-
{
|
397
|
-
"__init__": original_init,
|
398
|
-
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
399
|
-
},
|
400
|
-
)
|
401
|
-
obj = super(new_cls, new_cls).__new__(new_cls)
|
402
|
-
obj.__init__(*args, **kwargs)
|
403
|
-
return obj
|
404
|
-
return super().__new__(cls)
|
405
|
-
|
406
|
-
def __init__(self, quant_config):
|
397
|
+
def __init__(self, quant_config: W8A8Int8Config):
|
407
398
|
self.quant_config = quant_config
|
408
399
|
|
409
400
|
def create_weights(
|
@@ -481,15 +472,8 @@ class W8A8Int8MoEMethod:
|
|
481
472
|
self,
|
482
473
|
layer: torch.nn.Module,
|
483
474
|
x: torch.Tensor,
|
484
|
-
|
485
|
-
|
486
|
-
renormalize: bool,
|
487
|
-
use_grouped_topk: bool,
|
488
|
-
topk_group: Optional[int] = None,
|
489
|
-
num_expert_group: Optional[int] = None,
|
490
|
-
num_fused_shared_experts: int = 0,
|
491
|
-
custom_routing_function: Optional[Callable] = None,
|
492
|
-
correction_bias: Optional[torch.Tensor] = None,
|
475
|
+
topk_output: TopKOutput,
|
476
|
+
*,
|
493
477
|
activation: str = "silu",
|
494
478
|
apply_router_weight_on_input: bool = False,
|
495
479
|
inplace: bool = True,
|
@@ -497,24 +481,14 @@ class W8A8Int8MoEMethod:
|
|
497
481
|
routed_scaling_factor: Optional[float] = None,
|
498
482
|
) -> torch.Tensor:
|
499
483
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
500
|
-
from sglang.srt.layers.moe.topk import select_experts
|
501
|
-
|
502
|
-
# Expert selection
|
503
|
-
topk_weights, topk_ids = select_experts(
|
504
|
-
hidden_states=x,
|
505
|
-
router_logits=router_logits,
|
506
|
-
use_grouped_topk=use_grouped_topk,
|
507
|
-
top_k=top_k,
|
508
|
-
renormalize=renormalize,
|
509
|
-
topk_group=topk_group,
|
510
|
-
num_expert_group=num_expert_group,
|
511
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
512
|
-
custom_routing_function=custom_routing_function,
|
513
|
-
correction_bias=correction_bias,
|
514
|
-
routed_scaling_factor=routed_scaling_factor,
|
515
|
-
)
|
516
484
|
|
517
485
|
if use_intel_amx_backend(layer):
|
486
|
+
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
487
|
+
|
488
|
+
topk_weights, topk_ids, _ = topk_output
|
489
|
+
x, topk_weights = apply_topk_weights_cpu(
|
490
|
+
apply_router_weight_on_input, topk_weights, x
|
491
|
+
)
|
518
492
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
519
493
|
x,
|
520
494
|
layer.w13_weight,
|
@@ -536,8 +510,7 @@ class W8A8Int8MoEMethod:
|
|
536
510
|
x,
|
537
511
|
layer.w13_weight,
|
538
512
|
layer.w2_weight,
|
539
|
-
|
540
|
-
topk_ids=topk_ids,
|
513
|
+
topk_output=topk_output,
|
541
514
|
inplace=inplace,
|
542
515
|
activation=activation,
|
543
516
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
@@ -761,6 +734,8 @@ class NPU_W8A8LinearMethod(LinearMethodBase):
|
|
761
734
|
x: torch.Tensor,
|
762
735
|
bias: Optional[torch.Tensor] = None,
|
763
736
|
) -> torch.Tensor:
|
737
|
+
from sglang.srt.layers.linear import RowParallelLinear
|
738
|
+
|
764
739
|
if isinstance(layer, RowParallelLinear):
|
765
740
|
tp_rank = get_tensor_model_parallel_rank()
|
766
741
|
return self.quant_method.apply(layer, x, bias, tp_rank)
|
@@ -885,13 +860,15 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
|
885
860
|
x: torch.Tensor,
|
886
861
|
bias: Optional[torch.Tensor] = None,
|
887
862
|
) -> torch.Tensor:
|
863
|
+
from sglang.srt.layers.linear import RowParallelLinear
|
864
|
+
|
888
865
|
if isinstance(layer, RowParallelLinear):
|
889
866
|
tp_rank = get_tensor_model_parallel_rank()
|
890
867
|
return self.quant_method.apply(layer, x, bias, tp_rank)
|
891
868
|
return self.quant_method.apply(layer, x, bias)
|
892
869
|
|
893
870
|
|
894
|
-
class NPU_W8A8MoEMethod:
|
871
|
+
class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
895
872
|
"""MoE method for NPU quantization.
|
896
873
|
|
897
874
|
This class search for specific quantization
|
@@ -910,7 +887,7 @@ class NPU_W8A8MoEMethod:
|
|
910
887
|
layer: torch.nn.Module,
|
911
888
|
num_experts: int,
|
912
889
|
hidden_size: int,
|
913
|
-
intermediate_size:
|
890
|
+
intermediate_size: int,
|
914
891
|
params_dtype: torch.dtype,
|
915
892
|
**extra_weight_attrs,
|
916
893
|
) -> None:
|
@@ -987,52 +964,11 @@ class NPU_W8A8MoEMethod:
|
|
987
964
|
self,
|
988
965
|
layer,
|
989
966
|
x,
|
990
|
-
|
991
|
-
top_k,
|
992
|
-
renormalize,
|
993
|
-
use_grouped_topk,
|
994
|
-
topk_group,
|
995
|
-
num_expert_group,
|
996
|
-
num_fused_shared_experts,
|
997
|
-
custom_routing_function,
|
998
|
-
correction_bias,
|
999
|
-
activation,
|
1000
|
-
apply_router_weight_on_input,
|
1001
|
-
routed_scaling_factor,
|
967
|
+
topk_output: TopKOutput,
|
1002
968
|
**kwargs,
|
1003
969
|
) -> torch.Tensor:
|
1004
|
-
|
1005
|
-
|
1006
|
-
global_num_experts = router_logits.shape[-1]
|
1007
|
-
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
1008
|
-
if global_num_experts == 256:
|
1009
|
-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
1010
|
-
router_logits,
|
1011
|
-
k=top_k,
|
1012
|
-
bias=correction_bias,
|
1013
|
-
k_group=topk_group,
|
1014
|
-
group_count=num_expert_group,
|
1015
|
-
group_select_mode=1,
|
1016
|
-
renorm=0,
|
1017
|
-
norm_type=1,
|
1018
|
-
routed_scaling_factor=1,
|
1019
|
-
eps=float(1e-20),
|
1020
|
-
)
|
1021
|
-
else:
|
1022
|
-
topk_weights, topk_ids = select_experts(
|
1023
|
-
hidden_states=x,
|
1024
|
-
router_logits=router_logits,
|
1025
|
-
use_grouped_topk=use_grouped_topk,
|
1026
|
-
top_k=top_k,
|
1027
|
-
renormalize=renormalize,
|
1028
|
-
topk_group=topk_group,
|
1029
|
-
num_expert_group=num_expert_group,
|
1030
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
1031
|
-
custom_routing_function=custom_routing_function,
|
1032
|
-
correction_bias=correction_bias,
|
1033
|
-
torch_native=True,
|
1034
|
-
routed_scaling_factor=routed_scaling_factor,
|
1035
|
-
)
|
970
|
+
|
971
|
+
topk_weights, topk_ids, _ = topk_output
|
1036
972
|
topk_ids = topk_ids.to(torch.int32)
|
1037
973
|
topk_weights = topk_weights.to(x.dtype)
|
1038
974
|
return npu_fused_experts(
|
@@ -1043,5 +979,5 @@ class NPU_W8A8MoEMethod:
|
|
1043
979
|
w2_scale=layer.w2_weight_scale,
|
1044
980
|
topk_weights=topk_weights,
|
1045
981
|
topk_ids=topk_ids,
|
1046
|
-
top_k=
|
982
|
+
top_k=topk_ids.shape[1],
|
1047
983
|
)
|
@@ -5,7 +5,6 @@ from dataclasses import dataclass
|
|
5
5
|
from typing import List, Optional, Sequence, Tuple
|
6
6
|
|
7
7
|
import torch
|
8
|
-
import torch.nn.functional as F
|
9
8
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
10
9
|
|
11
10
|
from sglang.srt.distributed import (
|
@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
22
21
|
QuantizeMethodBase,
|
23
22
|
method_has_implemented_embedding,
|
24
23
|
)
|
24
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
|
25
25
|
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
|
26
26
|
|
27
27
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
@@ -32,44 +32,6 @@ _is_cpu = is_cpu()
|
|
32
32
|
logger = logging.getLogger(__name__)
|
33
33
|
|
34
34
|
|
35
|
-
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
36
|
-
"""Unquantized method for embeddings."""
|
37
|
-
|
38
|
-
def create_weights(
|
39
|
-
self,
|
40
|
-
layer: torch.nn.Module,
|
41
|
-
input_size_per_partition: int,
|
42
|
-
output_partition_sizes: List[int],
|
43
|
-
input_size: int,
|
44
|
-
output_size: int,
|
45
|
-
params_dtype: torch.dtype,
|
46
|
-
**extra_weight_attrs,
|
47
|
-
):
|
48
|
-
"""Create weights for embedding layer."""
|
49
|
-
weight = Parameter(
|
50
|
-
torch.empty(
|
51
|
-
sum(output_partition_sizes),
|
52
|
-
input_size_per_partition,
|
53
|
-
dtype=params_dtype,
|
54
|
-
),
|
55
|
-
requires_grad=False,
|
56
|
-
)
|
57
|
-
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
58
|
-
layer.register_parameter("weight", weight)
|
59
|
-
set_weight_attrs(weight, extra_weight_attrs)
|
60
|
-
|
61
|
-
def apply(
|
62
|
-
self,
|
63
|
-
layer: torch.nn.Module,
|
64
|
-
x: torch.Tensor,
|
65
|
-
bias: Optional[torch.Tensor] = None,
|
66
|
-
) -> torch.Tensor:
|
67
|
-
return F.linear(x, layer.weight, bias)
|
68
|
-
|
69
|
-
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
|
70
|
-
return F.embedding(input_, layer.weight)
|
71
|
-
|
72
|
-
|
73
35
|
def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
74
36
|
"""Pad the vocab size to the given value."""
|
75
37
|
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
@@ -569,8 +531,6 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
569
531
|
if _is_cpu and _is_cpu_amx_available:
|
570
532
|
if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16:
|
571
533
|
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
572
|
-
else:
|
573
|
-
logger.warning("The weight of LmHead is not packed")
|
574
534
|
|
575
535
|
if bias:
|
576
536
|
self.bias = Parameter(
|
sglang/srt/lora/lora.py
CHANGED
@@ -186,10 +186,6 @@ class LoRAAdapter(nn.Module):
|
|
186
186
|
up_name = weight_name.replace("gate_proj", "up_proj")
|
187
187
|
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
188
188
|
if up_name not in weights:
|
189
|
-
logger.warning(
|
190
|
-
f"Gate projection {weight_name} does not have a corresponding up projection {up_name}. "
|
191
|
-
f"Initializing up projection to zero."
|
192
|
-
)
|
193
189
|
weights[up_name] = torch.zeros_like(weights[weight_name])
|
194
190
|
# FIXME: Add gate-only support for flashinfer in future implementations
|
195
191
|
assert self.lora_backend.name == "triton", (
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16
16
|
# and "Punica: Multi-Tenant LoRA Serving"
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from typing import Dict, Set, Tuple
|
19
|
+
from typing import Dict, Iterable, Optional, Set, Tuple
|
20
20
|
|
21
21
|
import torch
|
22
22
|
|
@@ -53,6 +53,8 @@ class LoRAManager:
|
|
53
53
|
lora_backend: str = "triton",
|
54
54
|
tp_size: int = 1,
|
55
55
|
tp_rank: int = 0,
|
56
|
+
max_lora_rank: Optional[int] = None,
|
57
|
+
target_modules: Optional[Iterable[str]] = None,
|
56
58
|
):
|
57
59
|
self.base_model: torch.nn.Module = base_model
|
58
60
|
self.base_hf_config: AutoConfig = base_hf_config
|
@@ -62,6 +64,10 @@ class LoRAManager:
|
|
62
64
|
self.device: torch.device = next(self.base_model.parameters()).device
|
63
65
|
self.tp_size: int = tp_size
|
64
66
|
self.tp_rank: int = tp_rank
|
67
|
+
self.max_lora_rank: Optional[int] = max_lora_rank
|
68
|
+
self.target_modules: Optional[Set[str]] = (
|
69
|
+
set(target_modules) if target_modules else None
|
70
|
+
)
|
65
71
|
|
66
72
|
# LoRA backend for running sgemm kernels
|
67
73
|
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
@@ -153,7 +159,9 @@ class LoRAManager:
|
|
153
159
|
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
154
160
|
|
155
161
|
try:
|
156
|
-
|
162
|
+
new_adapter = LoRAConfig(lora_path)
|
163
|
+
self.validate_new_adapter(lora_name, new_adapter)
|
164
|
+
self.configs[lora_name] = new_adapter
|
157
165
|
except Exception as e:
|
158
166
|
success = False
|
159
167
|
error_message = (
|
@@ -168,6 +176,21 @@ class LoRAManager:
|
|
168
176
|
error_message=error_message,
|
169
177
|
)
|
170
178
|
|
179
|
+
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
|
180
|
+
"""
|
181
|
+
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
|
182
|
+
"""
|
183
|
+
|
184
|
+
incompatible = self.memory_pool and not self.memory_pool.can_support(
|
185
|
+
lora_config
|
186
|
+
)
|
187
|
+
if incompatible:
|
188
|
+
raise ValueError(
|
189
|
+
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
|
190
|
+
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
|
191
|
+
"included in `--enable_lora_modules`."
|
192
|
+
)
|
193
|
+
|
171
194
|
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
|
172
195
|
"""
|
173
196
|
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
@@ -214,7 +237,7 @@ class LoRAManager:
|
|
214
237
|
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
215
238
|
if lora_path is not None:
|
216
239
|
lora = self.loras[lora_path]
|
217
|
-
lora_ranks[weight_indices[i]] = lora.config.
|
240
|
+
lora_ranks[weight_indices[i]] = lora.config.r
|
218
241
|
scalings[weight_indices[i]] = lora.scaling
|
219
242
|
|
220
243
|
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
@@ -319,7 +342,7 @@ class LoRAManager:
|
|
319
342
|
)
|
320
343
|
else:
|
321
344
|
weight_name = get_weight_name(
|
322
|
-
module_name, self.lora_weight_names, LoRAType.LORA_A
|
345
|
+
module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
|
323
346
|
)
|
324
347
|
module.set_lora_info(
|
325
348
|
self.memory_pool.get_tensor(
|
@@ -351,58 +374,67 @@ class LoRAManager:
|
|
351
374
|
i: {} for i in range(self.base_hf_config.num_hidden_layers)
|
352
375
|
}
|
353
376
|
|
354
|
-
#
|
355
|
-
|
356
|
-
|
357
|
-
self.max_loras_per_batch,
|
358
|
-
self.dtype,
|
359
|
-
self.tp_size,
|
360
|
-
self.tp_rank,
|
361
|
-
)
|
377
|
+
# The LoRA memory pool that manages the GPU buffers for active LoRA weights.
|
378
|
+
# It is initialized lazily when the first LoRA adapter is loaded.
|
379
|
+
self.memory_pool: Optional[LoRAMemoryPool] = None
|
362
380
|
|
363
381
|
def update_state_from_configs(self):
|
364
382
|
"""
|
365
383
|
Update the internal state of the LoRAManager based on the current `self.configs`. This method
|
366
384
|
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
|
367
|
-
|
368
|
-
This includes:
|
369
|
-
- Initializing LoRA adapters if they are not already loaded.
|
370
|
-
- Collect all LoRA weight names based on the current loaded adapters.
|
371
|
-
- Lazily monkey-patching the base model to use LoRA layers where applicable.
|
372
|
-
- Preparing the GPU buffer pool for active LoRA weights.
|
373
385
|
"""
|
374
386
|
|
375
|
-
# Target module names in huggingface lora configs.
|
376
|
-
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
377
|
-
hf_target_module_names: Set[str] = set()
|
378
|
-
for config in self.configs.values():
|
379
|
-
hf_target_module_names.update(config.target_modules)
|
380
|
-
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
381
|
-
|
382
387
|
# Loads / unloads LoRA adapters based on the latest configs.
|
383
388
|
self.update_lora_adapters()
|
389
|
+
# Apply the latest LoRA configurations to the internal state for inferencing.
|
390
|
+
self.apply_lora_configs()
|
391
|
+
|
392
|
+
def apply_lora_configs(self):
|
393
|
+
"""
|
394
|
+
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
|
395
|
+
|
396
|
+
Notes:
|
397
|
+
- Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
|
398
|
+
we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
|
399
|
+
LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
|
400
|
+
early CY25H2.
|
401
|
+
"""
|
402
|
+
|
403
|
+
if self.memory_pool is None:
|
404
|
+
# Infer max_lora_rank and target_modules if not explicitly specified in server args.
|
405
|
+
if self.target_modules is None:
|
406
|
+
self.target_modules = set()
|
407
|
+
for config in self.configs.values():
|
408
|
+
self.target_modules.update(config.target_modules)
|
409
|
+
|
410
|
+
if self.max_lora_rank is None:
|
411
|
+
self.max_lora_rank = max(
|
412
|
+
[x.hf_config["r"] for x in self.configs.values()],
|
413
|
+
default=0,
|
414
|
+
)
|
415
|
+
|
416
|
+
self.update_lora_weight_names()
|
417
|
+
self.update_lora_modules()
|
418
|
+
self.update_memory_buffers()
|
419
|
+
else:
|
420
|
+
# No-op if the memory pool can support the current LoRA configurations.
|
421
|
+
# TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
|
422
|
+
# module is changed once FlashInfer backend is deprecated.
|
423
|
+
assert self.memory_pool.can_support(self.configs.values()), (
|
424
|
+
"LoRA memory pool cannot support the current LoRA configuration. "
|
425
|
+
"This should never happen as we should have validated adapter compatibility. "
|
426
|
+
"Please create a Github issue to report.",
|
427
|
+
)
|
384
428
|
|
385
|
-
|
386
|
-
#
|
387
|
-
# Please note that the following update operations are "monotonic" by design, meaning that we update
|
388
|
-
# multiple places to support the new weight names when the first adapter targeting such weight names
|
389
|
-
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
|
390
|
-
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
|
391
|
-
# list of LoRA weight names is expected to be extremely finite and stable.
|
392
|
-
self.update_lora_weight_names(hf_target_module_names)
|
393
|
-
self.update_lora_modules(hf_target_module_names)
|
394
|
-
self.update_memory_buffers(max_lora_dim)
|
395
|
-
|
396
|
-
def update_lora_weight_names(self, hf_target_names: Set[str]):
|
429
|
+
def update_lora_weight_names(self):
|
397
430
|
"""
|
398
431
|
Add new LoRA weight names if needed based on the current `self.configs`.
|
399
432
|
"""
|
400
433
|
|
401
434
|
# Target lora weight names for lora_a and lora_b modules respectively.
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
self.lora_weight_names[1].update(lora_B)
|
435
|
+
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
|
436
|
+
self.lora_weight_names[0].update(lora_A)
|
437
|
+
self.lora_weight_names[1].update(lora_B)
|
406
438
|
|
407
439
|
def update_lora_adapters(self):
|
408
440
|
"""
|
@@ -434,21 +466,23 @@ class LoRAManager:
|
|
434
466
|
# Additional checks for flashinfer backend
|
435
467
|
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
436
468
|
if self.lora_backend == "flashinfer":
|
437
|
-
lora_dims = set(x.
|
469
|
+
lora_dims = set(x.r for x in self.configs.values())
|
438
470
|
scalings = set(x.scaling for x in self.loras.values())
|
439
471
|
assert (
|
440
472
|
len(lora_dims) == 1 and len(scalings) == 1
|
441
473
|
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
442
474
|
|
443
|
-
def update_memory_buffers(self
|
444
|
-
"""
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
self.
|
475
|
+
def update_memory_buffers(self):
|
476
|
+
"""(Re)initialize the LoRA memory pool based on the current configurations."""
|
477
|
+
self.memory_pool = LoRAMemoryPool(
|
478
|
+
base_hf_config=self.base_hf_config,
|
479
|
+
max_loras_per_batch=self.max_loras_per_batch,
|
480
|
+
dtype=self.dtype,
|
481
|
+
tp_size=self.tp_size,
|
482
|
+
tp_rank=self.tp_rank,
|
483
|
+
max_lora_rank=self.max_lora_rank,
|
484
|
+
lora_weight_names=self.lora_weight_names,
|
485
|
+
base_model=self.base_model,
|
452
486
|
)
|
453
487
|
|
454
488
|
def set_lora_module(self, module_name, module):
|
@@ -456,11 +490,11 @@ class LoRAManager:
|
|
456
490
|
replace_submodule(self.base_model, module_name, lora_module)
|
457
491
|
return lora_module
|
458
492
|
|
459
|
-
def update_lora_modules(self
|
493
|
+
def update_lora_modules(self):
|
460
494
|
# Target module names of customized layers defined in python/sglang/srt/layers
|
461
495
|
# e.g., {"qkv_proj", "o_proj"}
|
462
496
|
customized_target_names = get_customized_names_from_hf_names(
|
463
|
-
|
497
|
+
self.target_modules, self.base_model
|
464
498
|
)
|
465
499
|
|
466
500
|
for module_name, module in self.base_model.named_modules():
|