sglang 0.3.1.post2__py3-none-any.whl → 0.3.1.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +8 -1
- sglang/srt/layers/activation.py +3 -2
- sglang/srt/layers/attention_backend.py +3 -1
- sglang/srt/layers/linear.py +1133 -0
- sglang/srt/layers/quantization/__init__.py +76 -0
- sglang/srt/layers/quantization/base_config.py +122 -0
- sglang/srt/models/baichuan.py +1 -1
- sglang/srt/models/chatglm.py +6 -6
- sglang/srt/models/commandr.py +7 -7
- sglang/srt/models/dbrx.py +7 -7
- sglang/srt/models/deepseek.py +7 -7
- sglang/srt/models/deepseek_v2.py +7 -7
- sglang/srt/models/exaone.py +6 -6
- sglang/srt/models/gemma.py +6 -6
- sglang/srt/models/gemma2.py +6 -6
- sglang/srt/models/gpt_bigcode.py +6 -6
- sglang/srt/models/grok.py +6 -6
- sglang/srt/models/internlm2.py +6 -6
- sglang/srt/models/llama.py +6 -6
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +6 -6
- sglang/srt/models/minicpm3.py +1 -1
- sglang/srt/models/mixtral.py +6 -6
- sglang/srt/models/mixtral_quant.py +6 -6
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen.py +6 -6
- sglang/srt/models/qwen2.py +6 -6
- sglang/srt/models/qwen2_moe.py +7 -7
- sglang/srt/models/stablelm.py +6 -6
- sglang/srt/models/xverse.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/models/yivl.py +1 -1
- sglang/srt/utils.py +21 -1
- sglang/test/test_utils.py +4 -2
- sglang/version.py +1 -1
- {sglang-0.3.1.post2.dist-info → sglang-0.3.1.post3.dist-info}/METADATA +3 -2
- {sglang-0.3.1.post2.dist-info → sglang-0.3.1.post3.dist-info}/RECORD +42 -39
- {sglang-0.3.1.post2.dist-info → sglang-0.3.1.post3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post2.dist-info → sglang-0.3.1.post3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post2.dist-info → sglang-0.3.1.post3.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
CHANGED
@@ -64,8 +64,13 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
64
64
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
65
65
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
66
66
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
67
|
+
from sglang.srt.server import _set_envs_and_config
|
67
68
|
from sglang.srt.server_args import ServerArgs
|
68
|
-
from sglang.srt.utils import
|
69
|
+
from sglang.srt.utils import (
|
70
|
+
configure_logger,
|
71
|
+
kill_child_process,
|
72
|
+
suppress_other_loggers,
|
73
|
+
)
|
69
74
|
|
70
75
|
|
71
76
|
@dataclasses.dataclass
|
@@ -341,6 +346,8 @@ def latency_test(
|
|
341
346
|
bench_args,
|
342
347
|
tp_rank,
|
343
348
|
):
|
349
|
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
350
|
+
_set_envs_and_config(server_args)
|
344
351
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
345
352
|
|
346
353
|
# Load the model
|
sglang/srt/layers/activation.py
CHANGED
@@ -31,8 +31,9 @@ from vllm.distributed import (
|
|
31
31
|
get_tensor_model_parallel_world_size,
|
32
32
|
)
|
33
33
|
from vllm.model_executor.custom_op import CustomOp
|
34
|
-
|
35
|
-
from
|
34
|
+
|
35
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
|
+
from sglang.srt.utils import set_weight_attrs
|
36
37
|
|
37
38
|
logger = logging.getLogger(__name__)
|
38
39
|
|
@@ -346,7 +346,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
346
346
|
|
347
347
|
self.decode_attention_fwd = decode_attention_fwd
|
348
348
|
self.extend_attention_fwd = extend_attention_fwd
|
349
|
-
self.num_head =
|
349
|
+
self.num_head = (
|
350
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
351
|
+
)
|
350
352
|
|
351
353
|
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
352
354
|
self.reduce_dtype = torch.float32
|