sglang 0.4.1.post6__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +41 -5
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +29 -5
- sglang/srt/layers/parameter.py +2 -1
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/fp8.py +6 -3
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +25 -2
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +277 -178
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +206 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +37 -15
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/sampling_batch_info.py +139 -4
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +57 -14
- sglang/srt/utils.py +103 -65
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +16 -5
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +119 -115
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -27,7 +27,7 @@ import logging
|
|
27
27
|
import threading
|
28
28
|
from enum import IntEnum
|
29
29
|
from functools import wraps
|
30
|
-
from typing import List, Tuple, Union
|
30
|
+
from typing import List, Optional, Tuple, Union
|
31
31
|
|
32
32
|
import numpy as np
|
33
33
|
import psutil
|
@@ -49,7 +49,6 @@ class ReqToTokenPool:
|
|
49
49
|
size: int,
|
50
50
|
max_context_len: int,
|
51
51
|
device: str,
|
52
|
-
use_records: bool,
|
53
52
|
enable_memory_saver: bool,
|
54
53
|
):
|
55
54
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
@@ -64,17 +63,9 @@ class ReqToTokenPool:
|
|
64
63
|
(size, max_context_len), dtype=torch.int32, device=device
|
65
64
|
)
|
66
65
|
self.free_slots = list(range(size))
|
67
|
-
self.write_records = []
|
68
|
-
self.use_records = use_records
|
69
|
-
|
70
|
-
if self.use_records:
|
71
|
-
self.write = self.write_with_records
|
72
|
-
else:
|
73
|
-
self.write = self.write_without_records
|
74
66
|
|
75
67
|
def write(self, indices, values):
|
76
|
-
|
77
|
-
raise NotImplementedError()
|
68
|
+
self.req_to_token[indices] = values
|
78
69
|
|
79
70
|
def available_size(self):
|
80
71
|
return len(self.free_slots)
|
@@ -96,23 +87,6 @@ class ReqToTokenPool:
|
|
96
87
|
|
97
88
|
def clear(self):
|
98
89
|
self.free_slots = list(range(self.size))
|
99
|
-
self.write_records = []
|
100
|
-
|
101
|
-
def write_without_records(self, indices, values):
|
102
|
-
self.req_to_token[indices] = values
|
103
|
-
|
104
|
-
def write_with_records(self, indices, values):
|
105
|
-
self.req_to_token[indices] = values
|
106
|
-
self.write_records.append((indices, values))
|
107
|
-
|
108
|
-
def get_write_records(self):
|
109
|
-
ret = self.write_records
|
110
|
-
self.write_records = []
|
111
|
-
return ret
|
112
|
-
|
113
|
-
def apply_write_records(self, write_records: List[Tuple]):
|
114
|
-
for indices, values in write_records:
|
115
|
-
self.req_to_token[indices] = values
|
116
90
|
|
117
91
|
|
118
92
|
class BaseTokenToKVPool:
|
@@ -296,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
296
270
|
loc: torch.Tensor,
|
297
271
|
cache_k: torch.Tensor,
|
298
272
|
cache_v: torch.Tensor,
|
299
|
-
k_scale: float =
|
300
|
-
v_scale: float =
|
273
|
+
k_scale: Optional[float] = None,
|
274
|
+
v_scale: Optional[float] = None,
|
301
275
|
):
|
302
276
|
layer_id = layer.layer_id
|
303
277
|
if cache_k.dtype != self.dtype:
|
304
|
-
|
305
|
-
|
278
|
+
if k_scale is not None:
|
279
|
+
cache_k.div_(k_scale)
|
280
|
+
if v_scale is not None:
|
281
|
+
cache_v.div_(v_scale)
|
282
|
+
cache_k = cache_k.to(self.dtype)
|
283
|
+
cache_v = cache_v.to(self.dtype)
|
306
284
|
if self.store_dtype != self.dtype:
|
307
285
|
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
308
286
|
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
sglang/srt/metrics/collector.py
CHANGED
@@ -25,6 +25,7 @@ class SchedulerStats:
|
|
25
25
|
gen_throughput: float = 0.0
|
26
26
|
num_queue_reqs: int = 0
|
27
27
|
cache_hit_rate: float = 0.0
|
28
|
+
spec_accept_length: float = 0.0
|
28
29
|
|
29
30
|
|
30
31
|
class SchedulerMetricsCollector:
|
@@ -37,42 +38,49 @@ class SchedulerMetricsCollector:
|
|
37
38
|
|
38
39
|
self.num_running_reqs = Gauge(
|
39
40
|
name="sglang:num_running_reqs",
|
40
|
-
documentation="The number of running requests",
|
41
|
+
documentation="The number of running requests.",
|
41
42
|
labelnames=labels.keys(),
|
42
43
|
multiprocess_mode="sum",
|
43
44
|
)
|
44
45
|
|
45
46
|
self.num_used_tokens = Gauge(
|
46
47
|
name="sglang:num_used_tokens",
|
47
|
-
documentation="The number of used tokens",
|
48
|
+
documentation="The number of used tokens.",
|
48
49
|
labelnames=labels.keys(),
|
49
50
|
multiprocess_mode="sum",
|
50
51
|
)
|
51
52
|
|
52
53
|
self.token_usage = Gauge(
|
53
54
|
name="sglang:token_usage",
|
54
|
-
documentation="The token usage",
|
55
|
+
documentation="The token usage.",
|
55
56
|
labelnames=labels.keys(),
|
56
57
|
multiprocess_mode="mostrecent",
|
57
58
|
)
|
58
59
|
|
59
60
|
self.gen_throughput = Gauge(
|
60
61
|
name="sglang:gen_throughput",
|
61
|
-
documentation="The
|
62
|
+
documentation="The generation throughput (token/s).",
|
62
63
|
labelnames=labels.keys(),
|
63
64
|
multiprocess_mode="sum",
|
64
65
|
)
|
65
66
|
|
66
67
|
self.num_queue_reqs = Gauge(
|
67
68
|
name="sglang:num_queue_reqs",
|
68
|
-
documentation="The number of requests in the waiting queue",
|
69
|
+
documentation="The number of requests in the waiting queue.",
|
69
70
|
labelnames=labels.keys(),
|
70
71
|
multiprocess_mode="sum",
|
71
72
|
)
|
72
73
|
|
73
74
|
self.cache_hit_rate = Gauge(
|
74
75
|
name="sglang:cache_hit_rate",
|
75
|
-
documentation="The cache hit rate",
|
76
|
+
documentation="The prefix cache hit rate.",
|
77
|
+
labelnames=labels.keys(),
|
78
|
+
multiprocess_mode="mostrecent",
|
79
|
+
)
|
80
|
+
|
81
|
+
self.spec_accept_length = Gauge(
|
82
|
+
name="sglang:spec_accept_length",
|
83
|
+
documentation="The average acceptance length of speculative decoding.",
|
76
84
|
labelnames=labels.keys(),
|
77
85
|
multiprocess_mode="mostrecent",
|
78
86
|
)
|
@@ -88,6 +96,7 @@ class SchedulerMetricsCollector:
|
|
88
96
|
self._log_gauge(self.gen_throughput, stats.gen_throughput)
|
89
97
|
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
|
90
98
|
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
99
|
+
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
91
100
|
|
92
101
|
|
93
102
|
class TokenizerMetricsCollector:
|
@@ -21,10 +21,10 @@ from typing import TYPE_CHECKING, Callable
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
import tqdm
|
24
|
-
from vllm.distributed import get_tensor_model_parallel_rank
|
25
|
-
from vllm.distributed.parallel_state import graph_capture
|
26
24
|
from vllm.model_executor.custom_op import CustomOp
|
27
25
|
|
26
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
27
|
+
from sglang.srt.distributed.parallel_state import graph_capture
|
28
28
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
29
29
|
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
30
30
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
@@ -33,7 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
33
33
|
ForwardBatch,
|
34
34
|
ForwardMode,
|
35
35
|
)
|
36
|
-
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
37
36
|
|
38
37
|
if TYPE_CHECKING:
|
39
38
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -72,7 +71,6 @@ def patch_model(
|
|
72
71
|
try:
|
73
72
|
if enable_compile:
|
74
73
|
_to_torch(model, reverse=False, batch_size=batch_size)
|
75
|
-
monkey_patch_vllm_all_gather()
|
76
74
|
backup_ca_comm = tp_group.ca_comm
|
77
75
|
# Use custom-allreduce here.
|
78
76
|
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
@@ -88,7 +86,6 @@ def patch_model(
|
|
88
86
|
finally:
|
89
87
|
if enable_compile:
|
90
88
|
_to_torch(model, reverse=True, batch_size=batch_size)
|
91
|
-
monkey_patch_vllm_all_gather(reverse=True)
|
92
89
|
tp_group.ca_comm = backup_ca_comm
|
93
90
|
|
94
91
|
|
@@ -122,6 +119,7 @@ class CudaGraphRunner:
|
|
122
119
|
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
123
120
|
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
|
124
121
|
self.tp_size = self.model_runner.tp_size
|
122
|
+
self.dp_size = self.model_runner.server_args.dp_size
|
125
123
|
|
126
124
|
# Batch sizes to capture
|
127
125
|
self.capture_bs = self.model_runner.server_args.cuda_graph_bs
|
@@ -218,7 +216,7 @@ class CudaGraphRunner:
|
|
218
216
|
if self.enable_dp_attention:
|
219
217
|
self.gathered_buffer = torch.zeros(
|
220
218
|
(
|
221
|
-
self.max_bs * self.
|
219
|
+
self.max_bs * self.dp_size,
|
222
220
|
self.model_runner.model_config.hidden_size,
|
223
221
|
),
|
224
222
|
dtype=self.model_runner.dtype,
|
@@ -21,20 +21,26 @@ from typing import List, Optional, Tuple
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
import torch.distributed as dist
|
24
|
-
|
24
|
+
|
25
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
26
|
+
from sglang.srt.configs.load_config import LoadConfig
|
27
|
+
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
28
|
+
from sglang.srt.distributed import (
|
25
29
|
get_tp_group,
|
26
30
|
init_distributed_environment,
|
27
31
|
initialize_model_parallel,
|
28
32
|
set_custom_all_reduce,
|
29
33
|
)
|
30
|
-
|
31
|
-
from sglang.srt.configs.device_config import DeviceConfig
|
32
|
-
from sglang.srt.configs.load_config import LoadConfig
|
33
|
-
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
34
|
+
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
34
35
|
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
35
36
|
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
36
37
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
37
38
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
39
|
+
from sglang.srt.layers.dp_attention import (
|
40
|
+
get_attention_tp_group,
|
41
|
+
get_attention_tp_size,
|
42
|
+
initialize_dp_attention,
|
43
|
+
)
|
38
44
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
39
45
|
from sglang.srt.layers.sampler import Sampler
|
40
46
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
@@ -57,8 +63,8 @@ from sglang.srt.utils import (
|
|
57
63
|
init_custom_process_group,
|
58
64
|
is_cuda,
|
59
65
|
is_hip,
|
66
|
+
monkey_patch_p2p_access_check,
|
60
67
|
monkey_patch_vllm_gguf_config,
|
61
|
-
monkey_patch_vllm_p2p_access_check,
|
62
68
|
set_cpu_offload_max_bytes,
|
63
69
|
)
|
64
70
|
|
@@ -101,8 +107,10 @@ class ModelRunner:
|
|
101
107
|
self.model_config.attention_arch == AttentionArch.MLA
|
102
108
|
and not self.server_args.disable_mla
|
103
109
|
):
|
104
|
-
|
105
|
-
self.server_args.
|
110
|
+
# TODO: add MLA optimization on CPU
|
111
|
+
if self.server_args.device != "cpu":
|
112
|
+
logger.info("MLA optimization is turned on. Use triton backend.")
|
113
|
+
self.server_args.attention_backend = "triton"
|
106
114
|
|
107
115
|
if self.server_args.enable_double_sparsity:
|
108
116
|
logger.info(
|
@@ -159,6 +167,7 @@ class ModelRunner:
|
|
159
167
|
"enable_nan_detection": server_args.enable_nan_detection,
|
160
168
|
"enable_dp_attention": server_args.enable_dp_attention,
|
161
169
|
"enable_ep_moe": server_args.enable_ep_moe,
|
170
|
+
"device": server_args.device,
|
162
171
|
}
|
163
172
|
)
|
164
173
|
|
@@ -216,9 +225,12 @@ class ModelRunner:
|
|
216
225
|
backend = "gloo"
|
217
226
|
elif self.device == "hpu":
|
218
227
|
backend = "hccl"
|
228
|
+
elif self.device == "cpu":
|
229
|
+
backend = "gloo"
|
219
230
|
|
220
231
|
if not self.server_args.enable_p2p_check:
|
221
|
-
|
232
|
+
monkey_patch_p2p_access_check()
|
233
|
+
|
222
234
|
if self.server_args.dist_init_addr:
|
223
235
|
dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
|
224
236
|
else:
|
@@ -226,7 +238,7 @@ class ModelRunner:
|
|
226
238
|
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
227
239
|
|
228
240
|
if not self.is_draft_worker:
|
229
|
-
# Only
|
241
|
+
# Only initialize the distributed environment on the target model worker.
|
230
242
|
init_distributed_environment(
|
231
243
|
backend=backend,
|
232
244
|
world_size=self.tp_size,
|
@@ -235,11 +247,18 @@ class ModelRunner:
|
|
235
247
|
distributed_init_method=dist_init_method,
|
236
248
|
)
|
237
249
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
250
|
+
initialize_dp_attention(
|
251
|
+
enable_dp_attention=self.server_args.enable_dp_attention,
|
252
|
+
tp_rank=self.tp_rank,
|
253
|
+
tp_size=self.tp_size,
|
254
|
+
dp_size=self.server_args.dp_size,
|
255
|
+
)
|
238
256
|
|
239
257
|
min_per_gpu_memory = get_available_gpu_memory(
|
240
258
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
241
259
|
)
|
242
260
|
self.tp_group = get_tp_group()
|
261
|
+
self.attention_tp_group = get_attention_tp_group()
|
243
262
|
|
244
263
|
# Check memory for tensor parallelism
|
245
264
|
if self.tp_size > 1:
|
@@ -257,7 +276,8 @@ class ModelRunner:
|
|
257
276
|
)
|
258
277
|
|
259
278
|
# This can reduce thread conflicts and speed up weight loading.
|
260
|
-
|
279
|
+
if self.device != "cpu":
|
280
|
+
torch.set_num_threads(1)
|
261
281
|
if self.device == "cuda":
|
262
282
|
if torch.cuda.get_device_capability()[0] < 8:
|
263
283
|
logger.info(
|
@@ -277,12 +297,15 @@ class ModelRunner:
|
|
277
297
|
monkey_patch_vllm_gguf_config()
|
278
298
|
|
279
299
|
# Load the model
|
300
|
+
# Remove monkey_patch when linear.py quant remove dependencies with vllm
|
301
|
+
monkey_patch_vllm_parallel_state()
|
280
302
|
with self.memory_saver_adapter.region():
|
281
303
|
self.model = get_model(
|
282
304
|
model_config=self.model_config,
|
283
305
|
load_config=self.load_config,
|
284
306
|
device_config=DeviceConfig(self.device),
|
285
307
|
)
|
308
|
+
monkey_patch_vllm_parallel_state(reverse=True)
|
286
309
|
|
287
310
|
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
288
311
|
if self.server_args.quantization_param_path is not None:
|
@@ -521,7 +544,7 @@ class ModelRunner:
|
|
521
544
|
)
|
522
545
|
else:
|
523
546
|
cell_size = (
|
524
|
-
self.model_config.get_num_kv_heads(
|
547
|
+
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
525
548
|
* self.model_config.head_dim
|
526
549
|
* self.model_config.num_hidden_layers
|
527
550
|
* 2
|
@@ -595,7 +618,6 @@ class ModelRunner:
|
|
595
618
|
size=max_num_reqs + 1,
|
596
619
|
max_context_len=self.model_config.context_len + 4,
|
597
620
|
device=self.device,
|
598
|
-
use_records=False,
|
599
621
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
600
622
|
)
|
601
623
|
if (
|
@@ -615,7 +637,7 @@ class ModelRunner:
|
|
615
637
|
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
616
638
|
self.max_total_num_tokens,
|
617
639
|
dtype=self.kv_cache_dtype,
|
618
|
-
head_num=self.model_config.get_num_kv_heads(
|
640
|
+
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
619
641
|
head_dim=self.model_config.head_dim,
|
620
642
|
layer_num=self.model_config.num_hidden_layers,
|
621
643
|
device=self.device,
|
@@ -626,7 +648,7 @@ class ModelRunner:
|
|
626
648
|
self.token_to_kv_pool = MHATokenToKVPool(
|
627
649
|
self.max_total_num_tokens,
|
628
650
|
dtype=self.kv_cache_dtype,
|
629
|
-
head_num=self.model_config.get_num_kv_heads(
|
651
|
+
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
630
652
|
head_dim=self.model_config.head_dim,
|
631
653
|
layer_num=self.model_config.num_hidden_layers,
|
632
654
|
device=self.device,
|
@@ -21,14 +21,14 @@ from huggingface_hub import HfApi, hf_hub_download
|
|
21
21
|
from torch import nn
|
22
22
|
from transformers import AutoModelForCausalLM, PretrainedConfig
|
23
23
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
24
|
-
from vllm.distributed import (
|
25
|
-
get_tensor_model_parallel_rank,
|
26
|
-
get_tensor_model_parallel_world_size,
|
27
|
-
)
|
28
24
|
|
29
25
|
from sglang.srt.configs.device_config import DeviceConfig
|
30
26
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
31
27
|
from sglang.srt.configs.model_config import ModelConfig
|
28
|
+
from sglang.srt.distributed import (
|
29
|
+
get_tensor_model_parallel_rank,
|
30
|
+
get_tensor_model_parallel_world_size,
|
31
|
+
)
|
32
32
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
33
33
|
from sglang.srt.model_loader.utils import (
|
34
34
|
get_model_architecture,
|
@@ -496,7 +496,8 @@ class ShardedStateLoader(BaseModelLoader):
|
|
496
496
|
device_config: DeviceConfig,
|
497
497
|
) -> nn.Module:
|
498
498
|
from safetensors.torch import safe_open
|
499
|
-
|
499
|
+
|
500
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
500
501
|
|
501
502
|
local_model_path = self._prepare_weights(
|
502
503
|
model_config.model_path, model_config.revision
|
@@ -556,7 +557,8 @@ class ShardedStateLoader(BaseModelLoader):
|
|
556
557
|
max_size: Optional[int] = None,
|
557
558
|
) -> None:
|
558
559
|
from safetensors.torch import save_file
|
559
|
-
|
560
|
+
|
561
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
560
562
|
|
561
563
|
if pattern is None:
|
562
564
|
pattern = ShardedStateLoader.DEFAULT_PATTERN
|
@@ -9,7 +9,17 @@ import logging
|
|
9
9
|
import os
|
10
10
|
import tempfile
|
11
11
|
from collections import defaultdict
|
12
|
-
from typing import
|
12
|
+
from typing import (
|
13
|
+
Any,
|
14
|
+
Callable,
|
15
|
+
Dict,
|
16
|
+
Generator,
|
17
|
+
Iterable,
|
18
|
+
List,
|
19
|
+
Optional,
|
20
|
+
Tuple,
|
21
|
+
Union,
|
22
|
+
)
|
13
23
|
|
14
24
|
import filelock
|
15
25
|
import gguf
|
@@ -19,10 +29,10 @@ import torch
|
|
19
29
|
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
20
30
|
from safetensors.torch import load_file, safe_open, save_file
|
21
31
|
from tqdm.auto import tqdm
|
22
|
-
from vllm.distributed import get_tensor_model_parallel_rank
|
23
32
|
|
24
33
|
from sglang.srt.configs.load_config import LoadConfig
|
25
34
|
from sglang.srt.configs.model_config import ModelConfig
|
35
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
26
36
|
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
27
37
|
from sglang.srt.utils import print_warning_once
|
28
38
|
|
@@ -638,3 +648,46 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|
638
648
|
|
639
649
|
# If there were no matches, return the untouched param name
|
640
650
|
return name
|
651
|
+
|
652
|
+
|
653
|
+
def kv_cache_scales_loader(
|
654
|
+
filename: str,
|
655
|
+
tp_rank: int,
|
656
|
+
tp_size: int,
|
657
|
+
num_hidden_layers: int,
|
658
|
+
model_type: Optional[str],
|
659
|
+
) -> Iterable[Tuple[int, float]]:
|
660
|
+
"""
|
661
|
+
A simple utility to read in KV cache scaling factors that have been
|
662
|
+
previously serialized to disk. Used by the model to populate the appropriate
|
663
|
+
KV cache scaling factors. The serialization should represent a dictionary
|
664
|
+
whose keys are the TP ranks and values are another dictionary mapping layers
|
665
|
+
to their KV cache scaling factors.
|
666
|
+
"""
|
667
|
+
try:
|
668
|
+
with open(filename) as f:
|
669
|
+
context = {
|
670
|
+
"model_type": model_type,
|
671
|
+
"num_hidden_layers": num_hidden_layers,
|
672
|
+
"tp_rank": tp_rank,
|
673
|
+
"tp_size": tp_size,
|
674
|
+
}
|
675
|
+
schema_dct = json.load(f)
|
676
|
+
schema = QuantParamSchema.model_validate(schema_dct, context=context)
|
677
|
+
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
|
678
|
+
return layer_scales_map.items()
|
679
|
+
except FileNotFoundError:
|
680
|
+
logger.error("File or directory '%s' not found.", filename)
|
681
|
+
except json.JSONDecodeError:
|
682
|
+
logger.error("Error decoding JSON in file '%s'.", filename)
|
683
|
+
except Exception:
|
684
|
+
logger.exception("An error occurred while reading '%s'.", filename)
|
685
|
+
# This section is reached if and only if any of the excepts are hit
|
686
|
+
# Return an empty iterable (list) => no KV cache scales are loaded
|
687
|
+
# which ultimately defaults to 1.0 scales
|
688
|
+
logger.warning(
|
689
|
+
"Defaulting to KV cache scaling factors = 1.0 for all "
|
690
|
+
"layers in TP rank %d as an error occurred during loading.",
|
691
|
+
tp_rank,
|
692
|
+
)
|
693
|
+
return []
|
sglang/srt/models/baichuan.py
CHANGED
@@ -24,22 +24,22 @@ from typing import Iterable, Optional, Tuple
|
|
24
24
|
import torch
|
25
25
|
from torch import nn
|
26
26
|
from transformers import PretrainedConfig
|
27
|
-
|
27
|
+
|
28
|
+
from sglang.srt.distributed import (
|
28
29
|
get_tensor_model_parallel_rank,
|
29
30
|
get_tensor_model_parallel_world_size,
|
30
31
|
)
|
31
|
-
from
|
32
|
+
from sglang.srt.layers.activation import SiluAndMul
|
33
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
34
|
+
from sglang.srt.layers.linear import (
|
32
35
|
MergedColumnParallelLinear,
|
33
36
|
QKVParallelLinear,
|
34
37
|
RowParallelLinear,
|
35
38
|
)
|
36
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
37
|
-
|
38
|
-
from sglang.srt.layers.activation import SiluAndMul
|
39
|
-
from sglang.srt.layers.layernorm import RMSNorm
|
40
39
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
43
43
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
44
44
|
ParallelLMHead,
|
45
45
|
VocabParallelEmbedding,
|
sglang/srt/models/chatglm.py
CHANGED
@@ -21,10 +21,9 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from torch.nn import LayerNorm
|
24
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
24
|
|
27
25
|
from sglang.srt.configs import ChatGLMConfig
|
26
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
28
27
|
from sglang.srt.layers.activation import SiluAndMul
|
29
28
|
from sglang.srt.layers.layernorm import RMSNorm
|
30
29
|
from sglang.srt.layers.linear import (
|
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
|
|
35
34
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
36
35
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
37
36
|
from sglang.srt.layers.radix_attention import RadixAttention
|
37
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
38
38
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
39
39
|
ParallelLMHead,
|
40
40
|
VocabParallelEmbedding,
|
sglang/srt/models/commandr.py
CHANGED
@@ -44,12 +44,11 @@ import torch.utils.checkpoint
|
|
44
44
|
from torch import nn
|
45
45
|
from torch.nn.parameter import Parameter
|
46
46
|
from transformers import PretrainedConfig
|
47
|
-
|
47
|
+
|
48
|
+
from sglang.srt.distributed import (
|
48
49
|
get_tensor_model_parallel_rank,
|
49
50
|
get_tensor_model_parallel_world_size,
|
50
51
|
)
|
51
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
52
|
-
|
53
52
|
from sglang.srt.layers.activation import SiluAndMul
|
54
53
|
from sglang.srt.layers.linear import (
|
55
54
|
MergedColumnParallelLinear,
|
@@ -59,6 +58,7 @@ from sglang.srt.layers.linear import (
|
|
59
58
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
60
59
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
61
60
|
from sglang.srt.layers.radix_attention import RadixAttention
|
61
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
62
62
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
63
63
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
64
64
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
sglang/srt/models/dbrx.py
CHANGED
@@ -19,14 +19,13 @@ from typing import Iterable, Optional, Tuple
|
|
19
19
|
|
20
20
|
import torch
|
21
21
|
import torch.nn as nn
|
22
|
-
|
22
|
+
|
23
|
+
from sglang.srt.configs import DbrxConfig
|
24
|
+
from sglang.srt.distributed import (
|
23
25
|
get_tensor_model_parallel_rank,
|
24
26
|
get_tensor_model_parallel_world_size,
|
25
27
|
tensor_model_parallel_all_reduce,
|
26
28
|
)
|
27
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
28
|
-
|
29
|
-
from sglang.srt.configs import DbrxConfig
|
30
29
|
from sglang.srt.layers.linear import (
|
31
30
|
QKVParallelLinear,
|
32
31
|
ReplicatedLinear,
|
@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
36
35
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
37
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
39
39
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
40
40
|
DEFAULT_VOCAB_PADDING_SIZE,
|
41
41
|
ParallelLMHead,
|
sglang/srt/models/deepseek.py
CHANGED
@@ -21,13 +21,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import PretrainedConfig
|
24
|
-
|
24
|
+
|
25
|
+
from sglang.srt.distributed import (
|
25
26
|
get_tensor_model_parallel_rank,
|
26
27
|
get_tensor_model_parallel_world_size,
|
27
28
|
tensor_model_parallel_all_reduce,
|
28
29
|
)
|
29
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
-
|
31
30
|
from sglang.srt.layers.activation import SiluAndMul
|
32
31
|
from sglang.srt.layers.layernorm import RMSNorm
|
33
32
|
from sglang.srt.layers.linear import (
|
@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
40
39
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
41
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
43
43
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
44
44
|
ParallelLMHead,
|
45
45
|
VocabParallelEmbedding,
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -23,14 +23,13 @@ import torch.nn.functional as F
|
|
23
23
|
from torch import nn
|
24
24
|
from transformers import PretrainedConfig
|
25
25
|
from vllm import _custom_ops as ops
|
26
|
-
|
26
|
+
|
27
|
+
from sglang.srt.distributed import (
|
27
28
|
get_tensor_model_parallel_rank,
|
28
29
|
get_tensor_model_parallel_world_size,
|
29
30
|
get_tp_group,
|
30
31
|
tensor_model_parallel_all_reduce,
|
31
32
|
)
|
32
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
33
|
-
|
34
33
|
from sglang.srt.layers.activation import SiluAndMul
|
35
34
|
from sglang.srt.layers.layernorm import RMSNorm
|
36
35
|
from sglang.srt.layers.linear import (
|
@@ -49,6 +48,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
49
48
|
normalize_e4m3fn_to_e4m3fnuz,
|
50
49
|
)
|
51
50
|
from sglang.srt.layers.radix_attention import RadixAttention
|
51
|
+
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
52
52
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
53
53
|
ParallelLMHead,
|
54
54
|
VocabParallelEmbedding,
|
@@ -271,13 +271,14 @@ class DeepseekV2Attention(nn.Module):
|
|
271
271
|
quant_config=quant_config,
|
272
272
|
)
|
273
273
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
274
|
-
self.rotary_emb =
|
274
|
+
self.rotary_emb = get_rope_wrapper(
|
275
275
|
qk_rope_head_dim,
|
276
276
|
rotary_dim=qk_rope_head_dim,
|
277
277
|
max_position=max_position_embeddings,
|
278
278
|
base=rope_theta,
|
279
279
|
rope_scaling=rope_scaling,
|
280
280
|
is_neox_style=False,
|
281
|
+
device=global_server_args_dict["device"],
|
281
282
|
)
|
282
283
|
|
283
284
|
if rope_scaling:
|
@@ -855,10 +856,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
855
856
|
forward_batch: ForwardBatch,
|
856
857
|
) -> torch.Tensor:
|
857
858
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
)
|
859
|
+
return self.logits_processor(
|
860
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
861
|
+
)
|
862
862
|
|
863
863
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
864
864
|
stacked_params_mapping = [
|