sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -31,6 +31,7 @@ from __future__ import annotations
|
|
31
31
|
|
32
32
|
from dataclasses import dataclass
|
33
33
|
from enum import IntEnum, auto
|
34
|
+
from functools import total_ordering
|
34
35
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
35
36
|
|
36
37
|
import torch
|
@@ -117,13 +118,14 @@ class ForwardMode(IntEnum):
|
|
117
118
|
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
|
118
119
|
|
119
120
|
|
121
|
+
@total_ordering
|
120
122
|
class CaptureHiddenMode(IntEnum):
|
121
123
|
# Do not capture anything.
|
122
|
-
NULL =
|
123
|
-
# Capture hidden states of all tokens.
|
124
|
-
FULL = auto()
|
124
|
+
NULL = 0
|
125
125
|
# Capture a hidden state of the last token.
|
126
|
-
LAST =
|
126
|
+
LAST = 1
|
127
|
+
# Capture hidden states of all tokens.
|
128
|
+
FULL = 2
|
127
129
|
|
128
130
|
def need_capture(self):
|
129
131
|
return self != CaptureHiddenMode.NULL
|
@@ -134,6 +136,9 @@ class CaptureHiddenMode(IntEnum):
|
|
134
136
|
def is_last(self):
|
135
137
|
return self == CaptureHiddenMode.LAST
|
136
138
|
|
139
|
+
def __lt__(self, other):
|
140
|
+
return self.value < other.value
|
141
|
+
|
137
142
|
|
138
143
|
@dataclass
|
139
144
|
class ForwardBatch:
|
@@ -219,6 +224,9 @@ class ForwardBatch:
|
|
219
224
|
# For input embeddings
|
220
225
|
input_embeds: Optional[torch.tensor] = None
|
221
226
|
|
227
|
+
# For cross-encoder model
|
228
|
+
token_type_ids: Optional[torch.Tensor] = None
|
229
|
+
|
222
230
|
# Sampling info
|
223
231
|
sampling_info: SamplingBatchInfo = None
|
224
232
|
|
@@ -295,6 +303,7 @@ class ForwardBatch:
|
|
295
303
|
spec_info=batch.spec_info,
|
296
304
|
capture_hidden_mode=batch.capture_hidden_mode,
|
297
305
|
input_embeds=batch.input_embeds,
|
306
|
+
token_type_ids=batch.token_type_ids,
|
298
307
|
tbo_split_seq_index=batch.tbo_split_seq_index,
|
299
308
|
)
|
300
309
|
device = model_runner.device
|
@@ -351,8 +360,8 @@ class ForwardBatch:
|
|
351
360
|
ret.extend_prefix_lens = torch.tensor(
|
352
361
|
batch.extend_prefix_lens, dtype=torch.int32
|
353
362
|
).to(device, non_blocking=True)
|
363
|
+
ret.extend_num_tokens = batch.extend_num_tokens
|
354
364
|
if support_triton(model_runner.server_args.attention_backend):
|
355
|
-
ret.extend_num_tokens = batch.extend_num_tokens
|
356
365
|
positions, ret.extend_start_loc = compute_position_triton(
|
357
366
|
ret.extend_prefix_lens,
|
358
367
|
ret.extend_seq_lens,
|
@@ -26,6 +26,7 @@ from typing import List, Optional, Tuple, Union
|
|
26
26
|
import torch
|
27
27
|
import torch.distributed as dist
|
28
28
|
|
29
|
+
from sglang.srt import debug_utils
|
29
30
|
from sglang.srt.configs.device_config import DeviceConfig
|
30
31
|
from sglang.srt.configs.load_config import LoadConfig
|
31
32
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
@@ -45,10 +46,9 @@ from sglang.srt.layers.dp_attention import (
|
|
45
46
|
initialize_dp_attention,
|
46
47
|
)
|
47
48
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
48
|
-
from sglang.srt.layers.quantization import
|
49
|
-
|
50
|
-
|
51
|
-
update_deep_gemm_config,
|
49
|
+
from sglang.srt.layers.quantization import (
|
50
|
+
deep_gemm_wrapper,
|
51
|
+
monkey_patch_isinstance_for_vllm_base_layer,
|
52
52
|
)
|
53
53
|
from sglang.srt.layers.sampler import Sampler
|
54
54
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
@@ -205,8 +205,8 @@ class ModelRunner:
|
|
205
205
|
min_per_gpu_memory = self.init_torch_distributed()
|
206
206
|
|
207
207
|
# Update deep gemm configure
|
208
|
-
if
|
209
|
-
update_deep_gemm_config(gpu_id, server_args)
|
208
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
209
|
+
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
210
210
|
|
211
211
|
# If it is a draft model, tp_group can be different
|
212
212
|
self.initialize(min_per_gpu_memory)
|
@@ -1259,12 +1259,19 @@ class GGUFModelLoader(BaseModelLoader):
|
|
1259
1259
|
):
|
1260
1260
|
model_config.hf_config.update({"tie_word_embeddings": True})
|
1261
1261
|
|
1262
|
+
target_device = torch.device(device_config.device)
|
1262
1263
|
with set_default_torch_dtype(model_config.dtype):
|
1263
|
-
with
|
1264
|
+
with target_device:
|
1264
1265
|
model = _initialize_model(model_config, self.load_config)
|
1265
1266
|
model.load_weights(
|
1266
1267
|
self._get_weights_iterator(local_model_path, gguf_weights_map)
|
1267
1268
|
)
|
1269
|
+
|
1270
|
+
for _, module in model.named_modules():
|
1271
|
+
quant_method = getattr(module, "quant_method", None)
|
1272
|
+
if quant_method is not None:
|
1273
|
+
with device_loading_context(module, target_device):
|
1274
|
+
quant_method.process_weights_after_loading(module)
|
1268
1275
|
return model
|
1269
1276
|
|
1270
1277
|
|
sglang/srt/models/bert.py
CHANGED
@@ -11,12 +11,13 @@ from sglang.srt.layers.linear import (
|
|
11
11
|
QKVParallelLinear,
|
12
12
|
RowParallelLinear,
|
13
13
|
)
|
14
|
-
from sglang.srt.layers.pooler import
|
14
|
+
from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
|
15
15
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
16
16
|
from sglang.srt.layers.radix_attention import AttentionType, RadixAttention
|
17
17
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
18
18
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
19
19
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
20
|
+
from sglang.srt.utils import add_prefix
|
20
21
|
|
21
22
|
BertConfig = None
|
22
23
|
|
@@ -50,7 +51,8 @@ class BertEmbedding(nn.Module):
|
|
50
51
|
def forward(
|
51
52
|
self,
|
52
53
|
input_ids: torch.Tensor,
|
53
|
-
|
54
|
+
positions: torch.Tensor,
|
55
|
+
forward_batch: ForwardBatch,
|
54
56
|
) -> torch.Tensor:
|
55
57
|
input_shape = input_ids.size()
|
56
58
|
|
@@ -58,11 +60,14 @@ class BertEmbedding(nn.Module):
|
|
58
60
|
inputs_embeds = self.word_embeddings(input_ids)
|
59
61
|
|
60
62
|
# Position embeddings.
|
61
|
-
position_embeddings = self.position_embeddings(
|
63
|
+
position_embeddings = self.position_embeddings(positions)
|
62
64
|
|
63
|
-
token_type_ids =
|
64
|
-
|
65
|
-
|
65
|
+
token_type_ids = forward_batch.token_type_ids
|
66
|
+
|
67
|
+
if token_type_ids is None:
|
68
|
+
token_type_ids = torch.zeros(
|
69
|
+
input_shape, dtype=torch.long, device=inputs_embeds.device
|
70
|
+
)
|
66
71
|
|
67
72
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
68
73
|
|
@@ -71,6 +76,25 @@ class BertEmbedding(nn.Module):
|
|
71
76
|
return embeddings
|
72
77
|
|
73
78
|
|
79
|
+
class BertPooler(nn.Module):
|
80
|
+
|
81
|
+
def __init__(self, config: BertConfig):
|
82
|
+
super().__init__()
|
83
|
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
84
|
+
self.activation = nn.Tanh()
|
85
|
+
|
86
|
+
def forward(
|
87
|
+
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
88
|
+
) -> torch.Tensor:
|
89
|
+
# simply taking the hidden state corresponding
|
90
|
+
first_token_tensor = hidden_states[0, :]
|
91
|
+
|
92
|
+
pooled_output = self.dense(first_token_tensor)
|
93
|
+
pooled_output = self.activation(pooled_output)
|
94
|
+
|
95
|
+
return pooled_output
|
96
|
+
|
97
|
+
|
74
98
|
class BertEncoder(nn.Module):
|
75
99
|
|
76
100
|
def __init__(
|
@@ -113,6 +137,8 @@ class BertLayer(nn.Module):
|
|
113
137
|
):
|
114
138
|
super().__init__()
|
115
139
|
|
140
|
+
self.layer_id = layer_id
|
141
|
+
|
116
142
|
self.attention = BertAttention(
|
117
143
|
hidden_size=config.hidden_size,
|
118
144
|
num_attention_heads=config.num_attention_heads,
|
@@ -142,6 +168,7 @@ class BertLayer(nn.Module):
|
|
142
168
|
attn_output = self.attention(hidden_states, forward_batch)
|
143
169
|
intermediate_output = self.intermediate(attn_output)
|
144
170
|
output = self.output(intermediate_output, attn_output)
|
171
|
+
|
145
172
|
return output
|
146
173
|
|
147
174
|
|
@@ -326,16 +353,23 @@ class BertModel(nn.Module):
|
|
326
353
|
*,
|
327
354
|
config: BertConfig,
|
328
355
|
quant_config: Optional[QuantizationConfig] = None,
|
356
|
+
use_bert_pooler: bool = False,
|
329
357
|
prefix: str = "",
|
330
358
|
):
|
331
359
|
super().__init__()
|
360
|
+
self.use_bert_pooler = use_bert_pooler
|
332
361
|
self.config = config
|
333
362
|
self.embeddings = BertEmbedding(config)
|
334
363
|
self.encoder = BertEncoder(
|
335
|
-
config=config,
|
364
|
+
config=config,
|
365
|
+
quant_config=quant_config,
|
366
|
+
prefix=add_prefix("encoder", prefix),
|
367
|
+
)
|
368
|
+
self.pooler = (
|
369
|
+
BertPooler(config)
|
370
|
+
if self.use_bert_pooler
|
371
|
+
else Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
336
372
|
)
|
337
|
-
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
338
|
-
# self.pooler = BertPooler(config)
|
339
373
|
|
340
374
|
@torch.no_grad()
|
341
375
|
def forward(
|
@@ -351,11 +385,16 @@ class BertModel(nn.Module):
|
|
351
385
|
|
352
386
|
hidden_states = self.embeddings(
|
353
387
|
input_ids=input_ids,
|
354
|
-
|
388
|
+
positions=positions,
|
389
|
+
forward_batch=forward_batch,
|
355
390
|
)
|
356
391
|
|
357
392
|
hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
|
358
|
-
|
393
|
+
|
394
|
+
if not self.use_bert_pooler:
|
395
|
+
hidden_states = self.pooler(hidden_states, forward_batch)
|
396
|
+
|
397
|
+
return hidden_states
|
359
398
|
|
360
399
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
361
400
|
stacked_params_mapping = [
|
@@ -368,7 +407,7 @@ class BertModel(nn.Module):
|
|
368
407
|
params_dict = dict(self.named_parameters())
|
369
408
|
for name, loaded_weight in weights:
|
370
409
|
name = name.replace("self", "self_attn")
|
371
|
-
if "pooler" in name:
|
410
|
+
if not self.use_bert_pooler and "pooler" in name:
|
372
411
|
continue
|
373
412
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
374
413
|
|
@@ -395,4 +434,65 @@ class Contriever(BertModel):
|
|
395
434
|
pass
|
396
435
|
|
397
436
|
|
398
|
-
|
437
|
+
class BertForSequenceClassification(nn.Module):
|
438
|
+
|
439
|
+
def __init__(
|
440
|
+
self,
|
441
|
+
*,
|
442
|
+
config: BertConfig,
|
443
|
+
quant_config: Optional[QuantizationConfig] = None,
|
444
|
+
prefix: str = "",
|
445
|
+
):
|
446
|
+
super().__init__()
|
447
|
+
|
448
|
+
self.num_labels = config.num_labels
|
449
|
+
self.bert = BertModel(
|
450
|
+
config=config,
|
451
|
+
quant_config=quant_config,
|
452
|
+
use_bert_pooler=True,
|
453
|
+
prefix=add_prefix("bert", prefix),
|
454
|
+
)
|
455
|
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
456
|
+
self.pooler = CrossEncodingPooler(config, self.classifier, self.bert.pooler)
|
457
|
+
|
458
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
459
|
+
self_weights = []
|
460
|
+
|
461
|
+
def weight_filter():
|
462
|
+
for name, weight in weights:
|
463
|
+
if name.startswith("bert."):
|
464
|
+
yield (name[len("bert.") :], weight)
|
465
|
+
else:
|
466
|
+
self_weights.append((name, weight))
|
467
|
+
|
468
|
+
self.bert.load_weights(weight_filter())
|
469
|
+
|
470
|
+
params_dict = dict(self.named_parameters())
|
471
|
+
|
472
|
+
for name, loaded_weight in self_weights:
|
473
|
+
if name.startswith("classifier"):
|
474
|
+
param = params_dict[name]
|
475
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
476
|
+
weight_loader(param, loaded_weight)
|
477
|
+
|
478
|
+
def forward(
|
479
|
+
self,
|
480
|
+
input_ids: torch.Tensor,
|
481
|
+
positions: torch.Tensor,
|
482
|
+
forward_batch: ForwardBatch,
|
483
|
+
input_embeds: torch.Tensor = None,
|
484
|
+
get_embedding: bool = False,
|
485
|
+
) -> torch.Tensor:
|
486
|
+
assert get_embedding == True
|
487
|
+
|
488
|
+
hidden_states = self.bert(
|
489
|
+
input_ids=input_ids,
|
490
|
+
positions=positions,
|
491
|
+
forward_batch=forward_batch,
|
492
|
+
input_embeds=input_embeds,
|
493
|
+
get_embedding=get_embedding,
|
494
|
+
)
|
495
|
+
return self.pooler(hidden_states, forward_batch)
|
496
|
+
|
497
|
+
|
498
|
+
EntryClass = [BertModel, Contriever, BertForSequenceClassification]
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -51,11 +51,11 @@ from sglang.srt.layers.linear import (
|
|
51
51
|
RowParallelLinear,
|
52
52
|
)
|
53
53
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
54
|
-
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
54
|
+
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
55
55
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
56
56
|
from sglang.srt.layers.moe.topk import select_experts
|
57
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
57
58
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
58
|
-
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
59
59
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
60
60
|
is_fp8_fnuz,
|
61
61
|
per_tensor_quant_mla_fp8,
|
@@ -66,6 +66,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
66
66
|
block_quant_to_tensor_quant,
|
67
67
|
channel_quant_to_tensor_quant,
|
68
68
|
normalize_e4m3fn_to_e4m3fnuz,
|
69
|
+
requant_weight_ue8m0_inplace,
|
69
70
|
)
|
70
71
|
from sglang.srt.layers.quantization.int8_utils import (
|
71
72
|
block_dequant as int8_block_dequant,
|
@@ -109,10 +110,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
109
110
|
|
110
111
|
if _is_cuda:
|
111
112
|
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
112
|
-
|
113
|
-
from sglang.srt.layers.quantization.deep_gemm import (
|
114
|
-
grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
|
115
|
-
)
|
116
113
|
else:
|
117
114
|
from vllm._custom_ops import awq_dequantize
|
118
115
|
|
@@ -980,7 +977,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
980
977
|
q_nope_out = q_nope.new_empty(
|
981
978
|
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
982
979
|
)
|
983
|
-
|
980
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
984
981
|
(q_nope_val, q_nope_scale),
|
985
982
|
(self.w_kc, self.w_scale_k),
|
986
983
|
q_nope_out,
|
@@ -1013,7 +1010,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1013
1010
|
def forward_absorb_core(
|
1014
1011
|
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1015
1012
|
):
|
1016
|
-
if
|
1013
|
+
if (
|
1014
|
+
self.attention_backend == "fa3"
|
1015
|
+
or self.attention_backend == "flashinfer"
|
1016
|
+
or self.attention_backend == "cutlass_mla"
|
1017
|
+
):
|
1017
1018
|
attn_output = self.attn_mqa(
|
1018
1019
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
1019
1020
|
)
|
@@ -1032,7 +1033,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1032
1033
|
attn_bmm_output = attn_output.new_empty(
|
1033
1034
|
(self.num_local_heads, aligned_m, self.v_head_dim)
|
1034
1035
|
)
|
1035
|
-
|
1036
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
1036
1037
|
(attn_output_val, attn_output_scale),
|
1037
1038
|
(self.w_vc, self.w_scale_v),
|
1038
1039
|
attn_bmm_output,
|
@@ -1708,53 +1709,35 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1708
1709
|
def determine_num_fused_shared_experts(
|
1709
1710
|
self, architecture: str = "DeepseekV3ForCausalLM"
|
1710
1711
|
):
|
1711
|
-
self.num_fused_shared_experts =
|
1712
|
-
|
1713
|
-
|
1714
|
-
|
1715
|
-
|
1716
|
-
|
1717
|
-
|
1718
|
-
|
1719
|
-
|
1720
|
-
|
1721
|
-
|
1722
|
-
|
1723
|
-
|
1724
|
-
|
1725
|
-
|
1726
|
-
|
1727
|
-
|
1728
|
-
|
1729
|
-
|
1730
|
-
|
1731
|
-
|
1732
|
-
|
1733
|
-
|
1734
|
-
|
1735
|
-
|
1736
|
-
|
1737
|
-
|
1738
|
-
|
1739
|
-
|
1740
|
-
if (
|
1741
|
-
_is_cuda
|
1742
|
-
and torch.cuda.get_device_capability("cuda") >= (9, 0)
|
1743
|
-
and self.config.architectures[0] == architecture
|
1744
|
-
and self.config.n_routed_experts == 256
|
1745
|
-
and (
|
1746
|
-
not (
|
1747
|
-
global_server_args_dict["enable_deepep_moe"]
|
1748
|
-
or global_server_args_dict["enable_ep_moe"]
|
1749
|
-
)
|
1750
|
-
)
|
1751
|
-
):
|
1752
|
-
self.num_fused_shared_experts = self.config.n_shared_experts
|
1753
|
-
global_server_args_dict["disable_shared_experts_fusion"] = False
|
1754
|
-
log_info_on_rank0(
|
1755
|
-
logger,
|
1756
|
-
"Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
1757
|
-
)
|
1712
|
+
self.num_fused_shared_experts = 0
|
1713
|
+
if global_server_args_dict["disable_shared_experts_fusion"]:
|
1714
|
+
return
|
1715
|
+
|
1716
|
+
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
1717
|
+
disable_reason = None
|
1718
|
+
if (
|
1719
|
+
not _is_cuda
|
1720
|
+
or torch.cuda.get_device_capability("cuda") < (9, 0)
|
1721
|
+
or self.config.architectures[0] != architecture
|
1722
|
+
or self.config.n_routed_experts != 256
|
1723
|
+
or self.config.n_shared_experts != 1
|
1724
|
+
):
|
1725
|
+
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 90 can use shared experts fusion optimization."
|
1726
|
+
elif (
|
1727
|
+
global_server_args_dict["enable_deepep_moe"]
|
1728
|
+
or global_server_args_dict["enable_ep_moe"]
|
1729
|
+
):
|
1730
|
+
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
1731
|
+
|
1732
|
+
if disable_reason is not None:
|
1733
|
+
global_server_args_dict["disable_shared_experts_fusion"] = True
|
1734
|
+
log_info_on_rank0(
|
1735
|
+
logger,
|
1736
|
+
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
1737
|
+
)
|
1738
|
+
return
|
1739
|
+
|
1740
|
+
self.num_fused_shared_experts = self.config.n_shared_experts
|
1758
1741
|
|
1759
1742
|
def get_input_embeddings(self) -> nn.Embedding:
|
1760
1743
|
return self.model.embed_tokens
|
@@ -1786,8 +1769,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1786
1769
|
for name in weight_names:
|
1787
1770
|
if "kv_b_proj" in name:
|
1788
1771
|
layer_id = int(name.split(".")[2])
|
1789
|
-
|
1790
|
-
if layer_id != self.config.num_hidden_layers:
|
1772
|
+
if layer_id < self.config.num_hidden_layers:
|
1791
1773
|
layer_ids.add(layer_id)
|
1792
1774
|
|
1793
1775
|
for layer_id in layer_ids:
|
@@ -1847,8 +1829,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1847
1829
|
and weight_block_size[1] == 128
|
1848
1830
|
and model_dtype == torch.bfloat16
|
1849
1831
|
):
|
1850
|
-
if
|
1851
|
-
|
1832
|
+
if (
|
1833
|
+
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
1834
|
+
and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL
|
1835
|
+
and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false")
|
1852
1836
|
):
|
1853
1837
|
block_scale = weight_scale
|
1854
1838
|
use_deep_gemm_bmm = True
|
@@ -1932,6 +1916,65 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1932
1916
|
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
|
1933
1917
|
self_attn.use_deep_gemm_bmm = True
|
1934
1918
|
|
1919
|
+
if (
|
1920
|
+
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
1921
|
+
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
1922
|
+
):
|
1923
|
+
self._weight_requant_ue8m0()
|
1924
|
+
|
1925
|
+
def _weight_requant_ue8m0(self):
|
1926
|
+
weight_block_size = self.quant_config.weight_block_size
|
1927
|
+
|
1928
|
+
moe_layers = list(
|
1929
|
+
range(
|
1930
|
+
self.config.first_k_dense_replace,
|
1931
|
+
self.config.num_hidden_layers,
|
1932
|
+
self.config.moe_layer_freq,
|
1933
|
+
)
|
1934
|
+
)
|
1935
|
+
|
1936
|
+
for layer_id in range(self.config.num_hidden_layers):
|
1937
|
+
layer = self.model.layers[layer_id]
|
1938
|
+
|
1939
|
+
for module in [
|
1940
|
+
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
1941
|
+
layer.self_attn.q_b_proj,
|
1942
|
+
layer.self_attn.kv_b_proj,
|
1943
|
+
layer.self_attn.o_proj,
|
1944
|
+
]:
|
1945
|
+
requant_weight_ue8m0_inplace(
|
1946
|
+
module.weight, module.weight_scale_inv, weight_block_size
|
1947
|
+
)
|
1948
|
+
|
1949
|
+
if layer_id in moe_layers:
|
1950
|
+
shared_experts = getattr(layer.mlp, "shared_experts", None)
|
1951
|
+
if shared_experts is not None:
|
1952
|
+
for module in [
|
1953
|
+
shared_experts.gate_up_proj,
|
1954
|
+
shared_experts.down_proj,
|
1955
|
+
]:
|
1956
|
+
requant_weight_ue8m0_inplace(
|
1957
|
+
module.weight, module.weight_scale_inv, weight_block_size
|
1958
|
+
)
|
1959
|
+
|
1960
|
+
experts = layer.mlp.experts
|
1961
|
+
if isinstance(experts, DeepEPMoE):
|
1962
|
+
for w in [
|
1963
|
+
experts.w13_weight_fp8,
|
1964
|
+
experts.w2_weight_fp8,
|
1965
|
+
]:
|
1966
|
+
requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
|
1967
|
+
else:
|
1968
|
+
mlp = layer.mlp
|
1969
|
+
assert isinstance(mlp, DeepseekV2MLP)
|
1970
|
+
for module in [
|
1971
|
+
mlp.gate_up_proj,
|
1972
|
+
mlp.down_proj,
|
1973
|
+
]:
|
1974
|
+
requant_weight_ue8m0_inplace(
|
1975
|
+
module.weight, module.weight_scale_inv, weight_block_size
|
1976
|
+
)
|
1977
|
+
|
1935
1978
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
1936
1979
|
|
1937
1980
|
if is_nextn:
|
@@ -1952,101 +1995,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1952
1995
|
("gate_up_proj", "gate_proj", 0),
|
1953
1996
|
("gate_up_proj", "up_proj", 1),
|
1954
1997
|
]
|
1955
|
-
if self.num_fused_shared_experts > 0:
|
1956
|
-
assert self.num_fused_shared_experts == 1
|
1957
|
-
weights_list = list(weights)
|
1958
|
-
weights_dict = dict(weights_list)
|
1959
|
-
if self.quant_config is not None:
|
1960
|
-
if self.quant_config.get_name() == "w8a8_int8":
|
1961
|
-
suffix_list = [
|
1962
|
-
"down_proj.weight",
|
1963
|
-
"down_proj.weight_scale",
|
1964
|
-
"gate_proj.weight",
|
1965
|
-
"gate_proj.weight_scale",
|
1966
|
-
"up_proj.weight",
|
1967
|
-
"up_proj.weight_scale",
|
1968
|
-
]
|
1969
|
-
elif (
|
1970
|
-
self.quant_config.get_name() == "fp8"
|
1971
|
-
or self.quant_config.get_name() == "blockwise_int8"
|
1972
|
-
):
|
1973
|
-
suffix_list = [
|
1974
|
-
"down_proj.weight",
|
1975
|
-
"down_proj.weight_scale_inv",
|
1976
|
-
"gate_proj.weight",
|
1977
|
-
"gate_proj.weight_scale_inv",
|
1978
|
-
"up_proj.weight",
|
1979
|
-
"up_proj.weight_scale_inv",
|
1980
|
-
]
|
1981
|
-
elif self.quant_config.get_name() == "awq":
|
1982
|
-
suffix_list = [
|
1983
|
-
"down_proj.qweight",
|
1984
|
-
"down_proj.qzeros",
|
1985
|
-
"down_proj.scales",
|
1986
|
-
"gate_proj.qweight",
|
1987
|
-
"gate_proj.qzeros",
|
1988
|
-
"gate_proj.scales",
|
1989
|
-
"up_proj.qweight",
|
1990
|
-
"up_proj.qzeros",
|
1991
|
-
"up_proj.scales",
|
1992
|
-
]
|
1993
|
-
elif self.quant_config.get_name() == "modelopt_fp4":
|
1994
|
-
suffix_list = [
|
1995
|
-
"down_proj.weight",
|
1996
|
-
"down_proj.weight_scale",
|
1997
|
-
"down_proj.weight_scale_2",
|
1998
|
-
"down_proj.input_scale",
|
1999
|
-
"gate_proj.weight",
|
2000
|
-
"gate_proj.weight_scale",
|
2001
|
-
"gate_proj.weight_scale_2",
|
2002
|
-
"gate_proj.input_scale",
|
2003
|
-
"up_proj.weight",
|
2004
|
-
"up_proj.weight_scale",
|
2005
|
-
"up_proj.weight_scale_2",
|
2006
|
-
"up_proj.input_scale",
|
2007
|
-
]
|
2008
|
-
else:
|
2009
|
-
raise ValueError(
|
2010
|
-
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
2011
|
-
)
|
2012
|
-
else:
|
2013
|
-
suffix_list = [
|
2014
|
-
"down_proj.weight",
|
2015
|
-
"gate_proj.weight",
|
2016
|
-
"up_proj.weight",
|
2017
|
-
]
|
2018
|
-
names_to_remove = []
|
2019
|
-
|
2020
|
-
moe_layers = (
|
2021
|
-
range(
|
2022
|
-
self.config.first_k_dense_replace,
|
2023
|
-
self.config.num_hidden_layers,
|
2024
|
-
self.config.moe_layer_freq,
|
2025
|
-
)
|
2026
|
-
if not is_nextn
|
2027
|
-
else [nextn_layer_id]
|
2028
|
-
)
|
2029
|
-
|
2030
|
-
for moe_layer in tqdm(
|
2031
|
-
moe_layers,
|
2032
|
-
desc=f"Cloning {self.num_fused_shared_experts} "
|
2033
|
-
"shared expert into MoE",
|
2034
|
-
):
|
2035
|
-
for suffix in suffix_list:
|
2036
|
-
shared_expert_weight_name = (
|
2037
|
-
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
2038
|
-
)
|
2039
|
-
weights_list.append(
|
2040
|
-
(
|
2041
|
-
f"model.layers.{moe_layer}."
|
2042
|
-
f"mlp.experts."
|
2043
|
-
f"{self.config.n_routed_experts + 0}"
|
2044
|
-
f".{suffix}",
|
2045
|
-
weights_dict[shared_expert_weight_name],
|
2046
|
-
)
|
2047
|
-
)
|
2048
|
-
names_to_remove += [shared_expert_weight_name]
|
2049
|
-
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
2050
1998
|
|
2051
1999
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
2052
2000
|
# (param_name, weight_name, expert_id, shard_id)
|
@@ -2072,9 +2020,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2072
2020
|
"hnorm",
|
2073
2021
|
]
|
2074
2022
|
|
2023
|
+
if self.num_fused_shared_experts > 0:
|
2024
|
+
assert self.num_fused_shared_experts == 1
|
2025
|
+
logger.info("Shared experts fusion optimization enabled.")
|
2026
|
+
|
2075
2027
|
params_dict = dict(self.named_parameters())
|
2076
2028
|
weight_names = []
|
2077
2029
|
for name, loaded_weight in weights:
|
2030
|
+
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
|
2031
|
+
name = name.replace(
|
2032
|
+
"mlp.shared_experts",
|
2033
|
+
f"mlp.experts.{self.config.n_routed_experts}",
|
2034
|
+
)
|
2035
|
+
|
2078
2036
|
weight_names.append(name)
|
2079
2037
|
|
2080
2038
|
if not is_nextn:
|