sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +15 -6
- sglang/srt/layers/attention/flashinfer_backend.py +17 -3
- sglang/srt/layers/linear.py +36 -98
- sglang/srt/layers/moe/fused_moe_triton/layer.py +37 -9
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +24 -16
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +106 -52
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -2
- sglang/srt/managers/configure_logging.py +43 -0
- sglang/srt/managers/detokenizer_manager.py +0 -2
- sglang/srt/managers/io_struct.py +29 -13
- sglang/srt/managers/scheduler.py +48 -9
- sglang/srt/managers/tokenizer_manager.py +109 -49
- sglang/srt/mem_cache/memory_pool.py +107 -52
- sglang/srt/metrics/collector.py +10 -5
- sglang/srt/model_executor/model_runner.py +43 -6
- sglang/srt/models/llama.py +37 -2
- sglang/srt/models/qwen2.py +11 -0
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +14 -5
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +114 -61
- sglang/srt/server_args.py +27 -18
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +29 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +12 -10
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +39 -34
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
2
2
|
|
3
3
|
import logging
|
4
|
-
import os
|
5
4
|
from typing import Any, Callable, Dict, List, Optional
|
6
5
|
|
7
6
|
import torch
|
@@ -40,12 +39,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
40
39
|
from sglang.srt.utils import (
|
41
40
|
get_bool_env_var,
|
42
41
|
is_hip,
|
42
|
+
permute_weight,
|
43
43
|
print_warning_once,
|
44
44
|
set_weight_attrs,
|
45
45
|
)
|
46
46
|
|
47
47
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
48
48
|
|
49
|
+
is_hip_ = is_hip()
|
50
|
+
|
49
51
|
logger = logging.getLogger(__name__)
|
50
52
|
|
51
53
|
|
@@ -161,7 +163,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
161
163
|
# kernel for fast weight-only FP8 quantization
|
162
164
|
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
163
165
|
# Disable marlin for ROCm
|
164
|
-
if
|
166
|
+
if is_hip_:
|
165
167
|
self.use_marlin = False
|
166
168
|
|
167
169
|
self.block_quant = self.quant_config.weight_block_size is not None
|
@@ -273,7 +275,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
273
275
|
# Block quant doesn't need to process weights after loading
|
274
276
|
if self.block_quant:
|
275
277
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
276
|
-
if
|
278
|
+
if is_hip_:
|
277
279
|
# activation_scheme: dynamic
|
278
280
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
279
281
|
weight=layer.weight,
|
@@ -330,7 +332,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
330
332
|
weight_scale = layer.weight_scale
|
331
333
|
|
332
334
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
333
|
-
if
|
335
|
+
if is_hip_:
|
334
336
|
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
335
337
|
weight=weight,
|
336
338
|
weight_scale=weight_scale,
|
@@ -567,7 +569,7 @@ class Fp8MoEMethod:
|
|
567
569
|
# Block quant doesn't need to process weights after loading
|
568
570
|
if self.block_quant:
|
569
571
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
570
|
-
if
|
572
|
+
if is_hip_:
|
571
573
|
# activation_scheme: dynamic
|
572
574
|
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
573
575
|
weight=layer.w13_weight,
|
@@ -594,7 +596,7 @@ class Fp8MoEMethod:
|
|
594
596
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
595
597
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
596
598
|
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
597
|
-
fp8_dtype = torch.float8_e4m3fnuz if
|
599
|
+
fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
598
600
|
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
599
601
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
600
602
|
|
@@ -616,18 +618,30 @@ class Fp8MoEMethod:
|
|
616
618
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
617
619
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
618
620
|
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
621
|
+
if is_hip_:
|
622
|
+
if get_bool_env_var("CK_MOE"):
|
623
|
+
layer.w13_weight = torch.nn.Parameter(
|
624
|
+
permute_weight(layer.w13_weight.data),
|
625
|
+
requires_grad=False,
|
626
|
+
)
|
627
|
+
torch.cuda.empty_cache()
|
628
|
+
layer.w2_weight = torch.nn.Parameter(
|
629
|
+
permute_weight(layer.w2_weight.data),
|
630
|
+
requires_grad=False,
|
631
|
+
)
|
632
|
+
torch.cuda.empty_cache()
|
633
|
+
elif get_bool_env_var("MOE_PADDING"):
|
634
|
+
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
635
|
+
layer.w13_weight = torch.nn.Parameter(
|
636
|
+
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
637
|
+
requires_grad=False,
|
638
|
+
)
|
639
|
+
torch.cuda.empty_cache()
|
640
|
+
layer.w2_weight = torch.nn.Parameter(
|
641
|
+
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
642
|
+
requires_grad=False,
|
643
|
+
)
|
644
|
+
torch.cuda.empty_cache()
|
631
645
|
return
|
632
646
|
|
633
647
|
# If checkpoint is fp8, we need to handle that the
|
@@ -658,7 +672,7 @@ class Fp8MoEMethod:
|
|
658
672
|
)
|
659
673
|
|
660
674
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
661
|
-
if
|
675
|
+
if is_hip_:
|
662
676
|
# Normalize the weights and scales
|
663
677
|
w13_weight, w13_weight_scale, w13_input_scale = (
|
664
678
|
normalize_e4m3fn_to_e4m3fnuz(
|
@@ -708,18 +722,30 @@ class Fp8MoEMethod:
|
|
708
722
|
max_w13_scales, requires_grad=False
|
709
723
|
)
|
710
724
|
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
725
|
+
if is_hip_:
|
726
|
+
if get_bool_env_var("CK_MOE"):
|
727
|
+
layer.w13_weight = torch.nn.Parameter(
|
728
|
+
permute_weight(layer.w13_weight.data),
|
729
|
+
requires_grad=False,
|
730
|
+
)
|
731
|
+
torch.cuda.empty_cache()
|
732
|
+
layer.w2_weight = torch.nn.Parameter(
|
733
|
+
permute_weight(layer.w2_weight.data),
|
734
|
+
requires_grad=False,
|
735
|
+
)
|
736
|
+
torch.cuda.empty_cache()
|
737
|
+
elif get_bool_env_var("MOE_PADDING"):
|
738
|
+
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
739
|
+
layer.w13_weight = torch.nn.Parameter(
|
740
|
+
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
741
|
+
requires_grad=False,
|
742
|
+
)
|
743
|
+
torch.cuda.empty_cache()
|
744
|
+
layer.w2_weight = torch.nn.Parameter(
|
745
|
+
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
746
|
+
requires_grad=False,
|
747
|
+
)
|
748
|
+
torch.cuda.empty_cache()
|
723
749
|
return
|
724
750
|
|
725
751
|
def apply(
|
@@ -752,27 +778,55 @@ class Fp8MoEMethod:
|
|
752
778
|
correction_bias=correction_bias,
|
753
779
|
)
|
754
780
|
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
781
|
+
if is_hip_ and get_bool_env_var("CK_MOE"):
|
782
|
+
import ater
|
783
|
+
from ater.fused_moe import fused_experts_ck
|
784
|
+
|
785
|
+
return fused_experts_ck(
|
786
|
+
x,
|
787
|
+
layer.w13_weight,
|
788
|
+
layer.w2_weight,
|
789
|
+
topk_weights=topk_weights,
|
790
|
+
topk_ids=topk_ids,
|
791
|
+
use_fp8_w8a8=True,
|
792
|
+
w1_scale=(
|
793
|
+
layer.w13_weight_scale_inv
|
794
|
+
if self.block_quant
|
795
|
+
else layer.w13_weight_scale
|
796
|
+
),
|
797
|
+
w2_scale=(
|
798
|
+
layer.w2_weight_scale_inv
|
799
|
+
if self.block_quant
|
800
|
+
else layer.w2_weight_scale
|
801
|
+
),
|
802
|
+
a1_scale=layer.w13_input_scale,
|
803
|
+
a2_scale=layer.w2_input_scale,
|
804
|
+
)
|
805
|
+
|
806
|
+
else:
|
807
|
+
# Expert fusion with FP8 quantization
|
808
|
+
return fused_experts(
|
809
|
+
x,
|
810
|
+
layer.w13_weight,
|
811
|
+
layer.w2_weight,
|
812
|
+
topk_weights=topk_weights,
|
813
|
+
topk_ids=topk_ids,
|
814
|
+
inplace=True,
|
815
|
+
use_fp8_w8a8=True,
|
816
|
+
w1_scale=(
|
817
|
+
layer.w13_weight_scale_inv
|
818
|
+
if self.block_quant
|
819
|
+
else layer.w13_weight_scale
|
820
|
+
),
|
821
|
+
w2_scale=(
|
822
|
+
layer.w2_weight_scale_inv
|
823
|
+
if self.block_quant
|
824
|
+
else layer.w2_weight_scale
|
825
|
+
),
|
826
|
+
a1_scale=layer.w13_input_scale,
|
827
|
+
a2_scale=layer.w2_input_scale,
|
828
|
+
block_shape=self.quant_config.weight_block_size,
|
829
|
+
)
|
776
830
|
|
777
831
|
|
778
832
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
@@ -1,8 +1,8 @@
|
|
1
1
|
from typing import List, Optional, Tuple
|
2
2
|
|
3
3
|
import torch
|
4
|
-
from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
|
5
4
|
|
5
|
+
from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter
|
6
6
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
7
7
|
per_token_group_quant_fp8,
|
8
8
|
w8a8_block_fp8_matmul,
|
@@ -0,0 +1,54 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
|
6
|
+
@triton.jit
|
7
|
+
def _per_token_quant_int8(
|
8
|
+
x_ptr,
|
9
|
+
xq_ptr,
|
10
|
+
scale_ptr,
|
11
|
+
stride_x,
|
12
|
+
stride_xq,
|
13
|
+
N,
|
14
|
+
BLOCK: tl.constexpr,
|
15
|
+
):
|
16
|
+
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
|
17
|
+
row_id = tl.program_id(0)
|
18
|
+
|
19
|
+
cols = tl.arange(0, BLOCK)
|
20
|
+
mask = cols < N
|
21
|
+
|
22
|
+
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
|
23
|
+
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
|
24
|
+
scale_x = absmax / 127
|
25
|
+
x_q = x * (127 / absmax)
|
26
|
+
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
|
27
|
+
|
28
|
+
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
|
29
|
+
tl.store(scale_ptr + row_id, scale_x)
|
30
|
+
|
31
|
+
|
32
|
+
def per_token_quant_int8(x):
|
33
|
+
M = x.numel() // x.shape[-1]
|
34
|
+
N = x.shape[-1]
|
35
|
+
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
|
36
|
+
scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
|
37
|
+
BLOCK = triton.next_power_of_2(N)
|
38
|
+
# heuristics for number of warps
|
39
|
+
num_warps = min(max(BLOCK // 256, 1), 8)
|
40
|
+
|
41
|
+
assert x.is_contiguous()
|
42
|
+
_per_token_quant_int8[(M,)](
|
43
|
+
x,
|
44
|
+
x_q,
|
45
|
+
scales,
|
46
|
+
stride_x=x.stride(-2),
|
47
|
+
stride_xq=x_q.stride(-2),
|
48
|
+
N=N,
|
49
|
+
BLOCK=BLOCK,
|
50
|
+
num_warps=num_warps,
|
51
|
+
num_stages=1,
|
52
|
+
)
|
53
|
+
|
54
|
+
return x_q, scales
|
@@ -11,9 +11,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
11
11
|
cutlass_fp8_supported,
|
12
12
|
requantize_with_max_scale,
|
13
13
|
)
|
14
|
-
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
15
14
|
|
16
15
|
from sglang.srt.layers.linear import LinearMethodBase
|
16
|
+
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
17
17
|
from sglang.srt.layers.quantization.base_config import (
|
18
18
|
QuantizationConfig,
|
19
19
|
QuantizeMethodBase,
|
@@ -0,0 +1,117 @@
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.utils import is_cuda_available
|
6
|
+
|
7
|
+
is_cuda = is_cuda_available()
|
8
|
+
if is_cuda:
|
9
|
+
from sgl_kernel import int8_scaled_mm
|
10
|
+
|
11
|
+
from torch.nn.parameter import Parameter
|
12
|
+
|
13
|
+
from sglang.srt.layers.linear import LinearMethodBase
|
14
|
+
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
15
|
+
from sglang.srt.layers.quantization.base_config import (
|
16
|
+
QuantizationConfig,
|
17
|
+
QuantizeMethodBase,
|
18
|
+
)
|
19
|
+
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
20
|
+
|
21
|
+
|
22
|
+
class W8A8Int8Config(QuantizationConfig):
|
23
|
+
"""Config class for W8A8 Int8 Quantization.
|
24
|
+
|
25
|
+
- Weight: static, per-channel, symmetric
|
26
|
+
- Activation: dynamic, per-token, symmetric
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self):
|
30
|
+
pass
|
31
|
+
|
32
|
+
@classmethod
|
33
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
34
|
+
return [torch.float16, torch.bfloat16]
|
35
|
+
|
36
|
+
@classmethod
|
37
|
+
def get_min_capability(cls) -> int:
|
38
|
+
return 75
|
39
|
+
|
40
|
+
@classmethod
|
41
|
+
def get_name(self) -> str:
|
42
|
+
return "w8a8_int8"
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def get_config_filenames(cls) -> List[str]:
|
46
|
+
return []
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
|
50
|
+
return cls()
|
51
|
+
|
52
|
+
def get_quant_method(
|
53
|
+
self,
|
54
|
+
layer: torch.nn.Module,
|
55
|
+
prefix: str,
|
56
|
+
) -> Optional["QuantizeMethodBase"]:
|
57
|
+
from vllm.model_executor.layers.linear import LinearBase
|
58
|
+
|
59
|
+
if isinstance(layer, LinearBase):
|
60
|
+
return W8A8Int8LinearMethod(self)
|
61
|
+
return None
|
62
|
+
|
63
|
+
def get_scaled_act_names(self) -> List[str]:
|
64
|
+
return []
|
65
|
+
|
66
|
+
|
67
|
+
class W8A8Int8LinearMethod(LinearMethodBase):
|
68
|
+
|
69
|
+
def __init__(self, quantization_config: W8A8Int8Config):
|
70
|
+
self.quantization_config = quantization_config
|
71
|
+
|
72
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
73
|
+
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
74
|
+
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
75
|
+
|
76
|
+
def create_weights(
|
77
|
+
self,
|
78
|
+
layer: torch.nn.Module,
|
79
|
+
input_size_per_partition: int,
|
80
|
+
output_partition_sizes: List[int],
|
81
|
+
input_size: int,
|
82
|
+
output_size: int,
|
83
|
+
params_dtype: torch.dtype,
|
84
|
+
**extra_weight_attrs
|
85
|
+
):
|
86
|
+
|
87
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
88
|
+
self.logical_widths = output_partition_sizes
|
89
|
+
|
90
|
+
weight = ModelWeightParameter(
|
91
|
+
data=torch.empty(
|
92
|
+
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
|
93
|
+
),
|
94
|
+
input_dim=1,
|
95
|
+
output_dim=0,
|
96
|
+
weight_loader=weight_loader,
|
97
|
+
)
|
98
|
+
layer.register_parameter("weight", weight)
|
99
|
+
|
100
|
+
weight_scale = ChannelQuantScaleParameter(
|
101
|
+
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
102
|
+
output_dim=0,
|
103
|
+
weight_loader=weight_loader,
|
104
|
+
)
|
105
|
+
layer.register_parameter("weight_scale", weight_scale)
|
106
|
+
|
107
|
+
def apply(
|
108
|
+
self,
|
109
|
+
layer: torch.nn.Module,
|
110
|
+
x: torch.Tensor,
|
111
|
+
bias: Optional[torch.Tensor] = None,
|
112
|
+
):
|
113
|
+
x_q, x_scale = per_token_quant_int8(x)
|
114
|
+
|
115
|
+
return int8_scaled_mm(
|
116
|
+
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
|
117
|
+
)
|
@@ -220,6 +220,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
220
220
|
quant_config: Optional[QuantizationConfig] = None,
|
221
221
|
prefix: str = "",
|
222
222
|
enable_tp: bool = True,
|
223
|
+
use_presharded_weights: bool = False,
|
223
224
|
):
|
224
225
|
super().__init__()
|
225
226
|
self.quant_config = quant_config
|
@@ -236,6 +237,12 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
236
237
|
self.padding_size = padding_size
|
237
238
|
self.org_vocab_size = org_num_embeddings or num_embeddings
|
238
239
|
num_added_embeddings = num_embeddings - self.org_vocab_size
|
240
|
+
self.use_presharded_weights = use_presharded_weights
|
241
|
+
if use_presharded_weights:
|
242
|
+
assert (
|
243
|
+
num_added_embeddings == 0
|
244
|
+
), "Lora is not supported with presharded weights."
|
245
|
+
|
239
246
|
self.org_vocab_size_padded = pad_vocab_size(
|
240
247
|
self.org_vocab_size, self.padding_size
|
241
248
|
)
|
@@ -447,10 +454,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
447
454
|
start_idx = start_idx // packed_factor
|
448
455
|
shard_size = shard_size // packed_factor
|
449
456
|
else:
|
450
|
-
assert loaded_weight.shape[output_dim] ==
|
457
|
+
assert loaded_weight.shape[output_dim] == (
|
458
|
+
self.org_vocab_size
|
459
|
+
// (self.tp_size if self.use_presharded_weights else 1)
|
460
|
+
)
|
451
461
|
|
452
462
|
# Copy the data.
|
453
|
-
|
463
|
+
if not self.use_presharded_weights:
|
464
|
+
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
454
465
|
param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
|
455
466
|
param[loaded_weight.shape[0] :].data.fill_(0)
|
456
467
|
|
@@ -514,6 +525,7 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
514
525
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
515
526
|
quant_config: Optional[QuantizationConfig] = None,
|
516
527
|
prefix: str = "",
|
528
|
+
use_presharded_weights: bool = False,
|
517
529
|
):
|
518
530
|
super().__init__(
|
519
531
|
num_embeddings,
|
@@ -523,6 +535,7 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
523
535
|
padding_size,
|
524
536
|
quant_config,
|
525
537
|
prefix,
|
538
|
+
use_presharded_weights=use_presharded_weights,
|
526
539
|
)
|
527
540
|
self.quant_config = quant_config
|
528
541
|
if bias:
|
@@ -0,0 +1,43 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2025 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""
|
17
|
+
Configure the logging settings of a server.
|
18
|
+
|
19
|
+
Usage:
|
20
|
+
python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000
|
21
|
+
"""
|
22
|
+
|
23
|
+
import argparse
|
24
|
+
|
25
|
+
import requests
|
26
|
+
|
27
|
+
if __name__ == "__main__":
|
28
|
+
parser = argparse.ArgumentParser()
|
29
|
+
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
30
|
+
parser.add_argument(
|
31
|
+
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
32
|
+
)
|
33
|
+
parser.add_argument("--dump-requests-threshold", type=int, default=1000)
|
34
|
+
args = parser.parse_args()
|
35
|
+
|
36
|
+
response = requests.post(
|
37
|
+
args.url + "/configure_logging",
|
38
|
+
json={
|
39
|
+
"dump_requests_folder": args.dump_requests_folder,
|
40
|
+
"dump_requests_threshold": args.dump_requests_threshold,
|
41
|
+
},
|
42
|
+
)
|
43
|
+
assert response.status_code == 200
|
@@ -181,8 +181,6 @@ class DetokenizerManager:
|
|
181
181
|
finished_reasons=recv_obj.finished_reasons,
|
182
182
|
output_strs=output_strs,
|
183
183
|
prompt_tokens=recv_obj.prompt_tokens,
|
184
|
-
origin_input_ids=recv_obj.origin_input_ids,
|
185
|
-
output_ids=recv_obj.output_ids,
|
186
184
|
completion_tokens=recv_obj.completion_tokens,
|
187
185
|
cached_tokens=recv_obj.cached_tokens,
|
188
186
|
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -19,9 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
|
19
19
|
import uuid
|
20
20
|
from dataclasses import dataclass
|
21
21
|
from enum import Enum
|
22
|
-
from typing import Dict, List, Optional,
|
23
|
-
|
24
|
-
import torch
|
22
|
+
from typing import Dict, List, Optional, Union
|
25
23
|
|
26
24
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
27
25
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -323,9 +321,7 @@ class BatchTokenIDOut:
|
|
323
321
|
decoded_texts: List[str]
|
324
322
|
decode_ids: List[int]
|
325
323
|
read_offsets: List[int]
|
326
|
-
# Only used when
|
327
|
-
origin_input_ids: Optional[List[int]]
|
328
|
-
# Only used when `--skip-tokenizer-init` or `--return-token-ids` is set
|
324
|
+
# Only used when `--skip-tokenizer-init` is on
|
329
325
|
output_ids: Optional[List[int]]
|
330
326
|
# Detokenization configs
|
331
327
|
skip_special_tokens: List[bool]
|
@@ -356,14 +352,7 @@ class BatchStrOut:
|
|
356
352
|
# The output decoded strings
|
357
353
|
output_strs: List[str]
|
358
354
|
|
359
|
-
# The token ids
|
360
|
-
origin_input_ids: Optional[List[int]]
|
361
|
-
output_ids: Optional[List[int]]
|
362
|
-
|
363
355
|
# Token counts
|
364
|
-
# real input and output tokens can be get from
|
365
|
-
# origin_input_ids and output_ids by enabling --return_token_ids
|
366
|
-
# TODO (Shuai): Rename this to clarify the meaning.
|
367
356
|
prompt_tokens: List[int]
|
368
357
|
completion_tokens: List[int]
|
369
358
|
cached_tokens: List[int]
|
@@ -468,6 +457,26 @@ class GetWeightsByNameReqOutput:
|
|
468
457
|
parameter: list
|
469
458
|
|
470
459
|
|
460
|
+
@dataclass
|
461
|
+
class ReleaseMemoryOccupationReqInput:
|
462
|
+
pass
|
463
|
+
|
464
|
+
|
465
|
+
@dataclass
|
466
|
+
class ReleaseMemoryOccupationReqOutput:
|
467
|
+
pass
|
468
|
+
|
469
|
+
|
470
|
+
@dataclass
|
471
|
+
class ResumeMemoryOccupationReqInput:
|
472
|
+
pass
|
473
|
+
|
474
|
+
|
475
|
+
@dataclass
|
476
|
+
class ResumeMemoryOccupationReqOutput:
|
477
|
+
pass
|
478
|
+
|
479
|
+
|
471
480
|
@dataclass
|
472
481
|
class AbortReq:
|
473
482
|
# The request id
|
@@ -479,6 +488,13 @@ class ProfileReq(Enum):
|
|
479
488
|
STOP_PROFILE = 2
|
480
489
|
|
481
490
|
|
491
|
+
@dataclass
|
492
|
+
class ConfigureLoggingReq:
|
493
|
+
log_requests: Optional[bool] = None
|
494
|
+
dump_requests_folder: Optional[str] = None
|
495
|
+
dump_requests_threshold: Optional[int] = None
|
496
|
+
|
497
|
+
|
482
498
|
@dataclass
|
483
499
|
class OpenSessionReqInput:
|
484
500
|
capacity_of_str_len: int
|