sglang 0.4.1.post7__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +17 -11
- sglang/bench_one_batch.py +14 -6
- sglang/bench_serving.py +47 -44
- sglang/lang/chat_template.py +31 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
- sglang/srt/entrypoints/engine.py +5 -2
- sglang/srt/entrypoints/http_server.py +24 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +5 -5
- sglang/srt/layers/dp_attention.py +3 -1
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +24 -9
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +20 -12
- sglang/srt/layers/moe/fused_moe_native.py +17 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
- sglang/srt/layers/parameter.py +16 -7
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +4 -1
- sglang/srt/layers/rotary_embedding.py +6 -1
- sglang/srt/layers/sampler.py +28 -8
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +36 -5
- sglang/srt/managers/schedule_batch.py +31 -25
- sglang/srt/managers/scheduler.py +61 -35
- sglang/srt/managers/tokenizer_manager.py +4 -0
- sglang/srt/model_executor/cuda_graph_runner.py +23 -25
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +7 -4
- sglang/srt/model_loader/loader.py +75 -0
- sglang/srt/model_loader/weight_utils.py +91 -5
- sglang/srt/models/commandr.py +14 -2
- sglang/srt/models/dbrx.py +9 -1
- sglang/srt/models/deepseek_v2.py +3 -3
- sglang/srt/models/gemma2.py +9 -1
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/minicpm3.py +3 -3
- sglang/srt/models/torch_native_llama.py +17 -4
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +5 -4
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +4 -14
- sglang/srt/server.py +2 -2
- sglang/srt/server_args.py +20 -1
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +62 -65
- sglang/test/test_programs.py +1 -0
- sglang/test/test_utils.py +81 -22
- sglang/version.py +1 -1
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/METADATA +7 -7
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/RECORD +67 -56
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,7 @@ import tqdm
|
|
24
24
|
from vllm.model_executor.custom_op import CustomOp
|
25
25
|
|
26
26
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
27
|
-
from sglang.srt.distributed.parallel_state import graph_capture
|
27
|
+
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
28
28
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
29
29
|
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
30
30
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
|
|
38
38
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
39
39
|
|
40
40
|
|
41
|
-
def _to_torch(model: torch.nn.Module, reverse: bool,
|
41
|
+
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
42
42
|
for sub in model._modules.values():
|
43
43
|
if isinstance(sub, CustomOp):
|
44
44
|
if reverse:
|
@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
|
47
47
|
else:
|
48
48
|
# NOTE: Temporarily workaround MoE
|
49
49
|
if "FusedMoE" in sub.__class__.__name__:
|
50
|
-
if
|
50
|
+
if num_tokens == 1:
|
51
51
|
# The performance of torch.compile on this layer is not always good when bs > 1,
|
52
52
|
# so we decide to only use torch.compile when bs =1
|
53
53
|
sub._forward_method = fused_moe_forward_native
|
@@ -55,22 +55,22 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
|
55
55
|
sub._forward_method = sub.forward_native
|
56
56
|
setattr(sub, "is_torch_compile", True)
|
57
57
|
if isinstance(sub, torch.nn.Module):
|
58
|
-
_to_torch(sub, reverse,
|
58
|
+
_to_torch(sub, reverse, num_tokens)
|
59
59
|
|
60
60
|
|
61
61
|
@contextmanager
|
62
62
|
def patch_model(
|
63
63
|
model: torch.nn.Module,
|
64
64
|
enable_compile: bool,
|
65
|
-
|
66
|
-
tp_group:
|
65
|
+
num_tokens: int,
|
66
|
+
tp_group: GroupCoordinator,
|
67
67
|
):
|
68
68
|
"""Patch the model to make it compatible with with torch.compile"""
|
69
69
|
backup_ca_comm = None
|
70
70
|
|
71
71
|
try:
|
72
72
|
if enable_compile:
|
73
|
-
_to_torch(model, reverse=False,
|
73
|
+
_to_torch(model, reverse=False, num_tokens=num_tokens)
|
74
74
|
backup_ca_comm = tp_group.ca_comm
|
75
75
|
# Use custom-allreduce here.
|
76
76
|
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
@@ -85,7 +85,7 @@ def patch_model(
|
|
85
85
|
yield model.forward
|
86
86
|
finally:
|
87
87
|
if enable_compile:
|
88
|
-
_to_torch(model, reverse=True,
|
88
|
+
_to_torch(model, reverse=True, num_tokens=num_tokens)
|
89
89
|
tp_group.ca_comm = backup_ca_comm
|
90
90
|
|
91
91
|
|
@@ -149,9 +149,18 @@ class CudaGraphRunner:
|
|
149
149
|
and bs <= model_runner.server_args.cuda_graph_max_bs
|
150
150
|
]
|
151
151
|
|
152
|
+
self.compile_bs = (
|
153
|
+
[
|
154
|
+
bs
|
155
|
+
for bs in self.capture_bs
|
156
|
+
if bs <= self.model_runner.server_args.torch_compile_max_bs
|
157
|
+
]
|
158
|
+
if self.use_torch_compile
|
159
|
+
else []
|
160
|
+
)
|
161
|
+
|
152
162
|
self.capture_forward_mode = ForwardMode.DECODE
|
153
163
|
self.num_tokens_per_bs = 1
|
154
|
-
|
155
164
|
if model_runner.spec_algorithm.is_eagle():
|
156
165
|
if self.model_runner.is_draft_worker:
|
157
166
|
self.num_tokens_per_bs = (
|
@@ -163,16 +172,6 @@ class CudaGraphRunner:
|
|
163
172
|
self.model_runner.server_args.speculative_num_draft_tokens
|
164
173
|
)
|
165
174
|
|
166
|
-
self.compile_bs = (
|
167
|
-
[
|
168
|
-
bs
|
169
|
-
for bs in self.capture_bs
|
170
|
-
if bs <= self.model_runner.server_args.torch_compile_max_bs
|
171
|
-
]
|
172
|
-
if self.use_torch_compile
|
173
|
-
else []
|
174
|
-
)
|
175
|
-
|
176
175
|
# Attention backend
|
177
176
|
self.max_bs = max(self.capture_bs)
|
178
177
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
@@ -180,7 +179,6 @@ class CudaGraphRunner:
|
|
180
179
|
self.seq_len_fill_value = (
|
181
180
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
182
181
|
)
|
183
|
-
|
184
182
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
185
183
|
self.encoder_len_fill_value = 0
|
186
184
|
|
@@ -189,14 +187,14 @@ class CudaGraphRunner:
|
|
189
187
|
|
190
188
|
# Common inputs
|
191
189
|
with torch.device("cuda"):
|
192
|
-
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.
|
190
|
+
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
193
191
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
194
192
|
self.seq_lens = torch.full(
|
195
193
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
196
194
|
)
|
197
|
-
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.
|
195
|
+
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
198
196
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
199
|
-
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.
|
197
|
+
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
200
198
|
|
201
199
|
# Speculative_inference
|
202
200
|
if model_runner.spec_algorithm.is_eagle():
|
@@ -285,8 +283,8 @@ class CudaGraphRunner:
|
|
285
283
|
with patch_model(
|
286
284
|
self.model_runner.model,
|
287
285
|
bs in self.compile_bs,
|
288
|
-
bs,
|
289
|
-
self.model_runner.tp_group,
|
286
|
+
num_tokens=bs * self.num_tokens_per_bs,
|
287
|
+
tp_group=self.model_runner.tp_group,
|
290
288
|
) as forward:
|
291
289
|
(
|
292
290
|
graph,
|
@@ -38,7 +38,7 @@ import triton
|
|
38
38
|
import triton.language as tl
|
39
39
|
|
40
40
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
41
|
-
from sglang.srt.utils import
|
41
|
+
from sglang.srt.utils import get_compiler_backend
|
42
42
|
|
43
43
|
if TYPE_CHECKING:
|
44
44
|
from sglang.srt.layers.attention import AttentionBackend
|
@@ -282,6 +282,9 @@ class ForwardBatch:
|
|
282
282
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
283
283
|
lora_paths=batch.lora_paths,
|
284
284
|
sampling_info=batch.sampling_info,
|
285
|
+
req_to_token_pool=model_runner.req_to_token_pool,
|
286
|
+
token_to_kv_pool=model_runner.token_to_kv_pool,
|
287
|
+
attn_backend=model_runner.attn_backend,
|
285
288
|
spec_algorithm=batch.spec_algorithm,
|
286
289
|
spec_info=batch.spec_info,
|
287
290
|
capture_hidden_mode=batch.capture_hidden_mode,
|
@@ -336,11 +339,6 @@ class ForwardBatch:
|
|
336
339
|
if model_runner.model_is_mrope:
|
337
340
|
ret.compute_mrope_positions(model_runner, batch)
|
338
341
|
|
339
|
-
# Init attention information
|
340
|
-
ret.req_to_token_pool = model_runner.req_to_token_pool
|
341
|
-
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
342
|
-
ret.attn_backend = model_runner.attn_backend
|
343
|
-
|
344
342
|
# Init lora information
|
345
343
|
if model_runner.server_args.lora_paths is not None:
|
346
344
|
model_runner.lora_manager.prepare_lora_batch(ret)
|
@@ -417,6 +415,6 @@ def compute_position_torch(
|
|
417
415
|
return positions.to(torch.int64), extend_start_loc
|
418
416
|
|
419
417
|
|
420
|
-
@
|
418
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
421
419
|
def clamp_position(seq_lens):
|
422
420
|
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
@@ -185,9 +185,12 @@ class ModelRunner:
|
|
185
185
|
self.load_model()
|
186
186
|
|
187
187
|
# Apply torchao quantization
|
188
|
-
|
189
|
-
|
190
|
-
|
188
|
+
torchao_applied = getattr(self.model, "torchao_applied", False)
|
189
|
+
# In layered loading, torchao may have been applied
|
190
|
+
if not torchao_applied:
|
191
|
+
apply_torchao_config_to_model(
|
192
|
+
self.model, global_server_args_dict["torchao_config"]
|
193
|
+
)
|
191
194
|
|
192
195
|
# Apply torch TP if the model supports it
|
193
196
|
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
@@ -215,7 +218,7 @@ class ModelRunner:
|
|
215
218
|
|
216
219
|
def init_torch_distributed(self):
|
217
220
|
logger.info("Init torch distributed begin.")
|
218
|
-
|
221
|
+
|
219
222
|
torch.get_device_module(self.device).set_device(self.gpu_id)
|
220
223
|
if self.device == "cuda":
|
221
224
|
backend = "nccl"
|
@@ -374,6 +374,78 @@ class DefaultModelLoader(BaseModelLoader):
|
|
374
374
|
return model.eval()
|
375
375
|
|
376
376
|
|
377
|
+
class LayeredModelLoader(DefaultModelLoader):
|
378
|
+
"""Model loader that loads weights layer by layer so that one can quantize a
|
379
|
+
layer before loading another to make the peak memory envelope smaller."""
|
380
|
+
|
381
|
+
def __init__(self, load_config: LoadConfig):
|
382
|
+
# Back to the default load format
|
383
|
+
load_config.load_format = LoadFormat.AUTO
|
384
|
+
super().__init__(load_config)
|
385
|
+
|
386
|
+
def load_model(
|
387
|
+
self,
|
388
|
+
*,
|
389
|
+
model_config: ModelConfig,
|
390
|
+
device_config: DeviceConfig,
|
391
|
+
) -> nn.Module:
|
392
|
+
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
393
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
394
|
+
|
395
|
+
torchao_config = global_server_args_dict.get("torchao_config")
|
396
|
+
target_device = torch.device(device_config.device)
|
397
|
+
|
398
|
+
with set_default_torch_dtype(model_config.dtype):
|
399
|
+
# Create model on meta device
|
400
|
+
with torch.device("meta"):
|
401
|
+
model = _initialize_model(
|
402
|
+
model_config,
|
403
|
+
self.load_config,
|
404
|
+
)
|
405
|
+
|
406
|
+
# Check model's layered load support
|
407
|
+
if not hasattr(model, "load_weights_to_module"):
|
408
|
+
raise ValueError(
|
409
|
+
"LayeredModelLoader requires the model to have a "
|
410
|
+
"`load_weights_to_module` method. "
|
411
|
+
f"{model_config.model_path} does not support it."
|
412
|
+
)
|
413
|
+
|
414
|
+
# Get all weights from disk
|
415
|
+
weights = self._get_all_weights(model_config, model)
|
416
|
+
|
417
|
+
# Helper function to recursively fill the weights of a module
|
418
|
+
def fill_module(module, fqn: List[str], weights):
|
419
|
+
"""
|
420
|
+
fqn: list of strings representing the fully qualified name of `module`.
|
421
|
+
"""
|
422
|
+
# Layer by layer
|
423
|
+
for name, submod in module.named_children():
|
424
|
+
fill_module(submod, fqn + [name], weights)
|
425
|
+
|
426
|
+
# First materialize on target device
|
427
|
+
module.to_empty(device=target_device, recurse=False)
|
428
|
+
fqn_path = ".".join(fqn)
|
429
|
+
# Fill weights
|
430
|
+
model.load_weights_to_module(
|
431
|
+
fqn_path,
|
432
|
+
weights,
|
433
|
+
)
|
434
|
+
# Quantize weights if applicable
|
435
|
+
if torchao_config and "proj" in fqn_path:
|
436
|
+
# Note: `None` here is needed to indicate no filter, see
|
437
|
+
# `apply_torchao_config_to_model` for details.
|
438
|
+
apply_torchao_config_to_model(module, torchao_config, None)
|
439
|
+
|
440
|
+
# Start calling on root module
|
441
|
+
fill_module(model, [], weights)
|
442
|
+
|
443
|
+
if torchao_config:
|
444
|
+
model.torchao_applied = True
|
445
|
+
|
446
|
+
return model.eval()
|
447
|
+
|
448
|
+
|
377
449
|
class DummyModelLoader(BaseModelLoader):
|
378
450
|
"""Model loader that will set model weights to random values."""
|
379
451
|
|
@@ -1149,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|
1149
1221
|
if load_config.load_format == LoadFormat.GGUF:
|
1150
1222
|
return GGUFModelLoader(load_config)
|
1151
1223
|
|
1224
|
+
if load_config.load_format == LoadFormat.LAYERED:
|
1225
|
+
return LayeredModelLoader(load_config)
|
1226
|
+
|
1152
1227
|
return DefaultModelLoader(load_config)
|
@@ -27,6 +27,7 @@ import huggingface_hub.constants
|
|
27
27
|
import numpy as np
|
28
28
|
import torch
|
29
29
|
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
30
|
+
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
30
31
|
from safetensors.torch import load_file, safe_open, save_file
|
31
32
|
from tqdm.auto import tqdm
|
32
33
|
|
@@ -403,8 +404,13 @@ def np_cache_weights_iterator(
|
|
403
404
|
|
404
405
|
def safetensors_weights_iterator(
|
405
406
|
hf_weights_files: List[str],
|
407
|
+
is_all_weights_sharded: bool = False,
|
406
408
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
407
|
-
"""Iterate over the weights in the model safetensor files.
|
409
|
+
"""Iterate over the weights in the model safetensor files.
|
410
|
+
|
411
|
+
If is_all_weights_sharded is True, it uses more optimize read by reading an
|
412
|
+
entire file instead of reading each tensor one by one.
|
413
|
+
"""
|
408
414
|
enable_tqdm = (
|
409
415
|
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
410
416
|
)
|
@@ -414,9 +420,14 @@ def safetensors_weights_iterator(
|
|
414
420
|
disable=not enable_tqdm,
|
415
421
|
bar_format=_BAR_FORMAT,
|
416
422
|
):
|
417
|
-
|
418
|
-
|
419
|
-
|
423
|
+
if not is_all_weights_sharded:
|
424
|
+
with safe_open(st_file, framework="pt") as f:
|
425
|
+
for name in f.keys(): # noqa: SIM118
|
426
|
+
param = f.get_tensor(name)
|
427
|
+
yield name, param
|
428
|
+
else:
|
429
|
+
result = load_file(st_file, device="cpu")
|
430
|
+
for name, param in result.items():
|
420
431
|
yield name, param
|
421
432
|
|
422
433
|
|
@@ -650,6 +661,81 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|
650
661
|
return name
|
651
662
|
|
652
663
|
|
664
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py
|
665
|
+
class KVCacheQuantSchema(BaseModel):
|
666
|
+
dtype: str
|
667
|
+
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
|
668
|
+
# layer indices to their per-tensor KV cache scaling factor.
|
669
|
+
# TODO: Consider pulling this and its validation methods out into its
|
670
|
+
# own schema class (tricky as its members are variable)
|
671
|
+
scaling_factor: Dict[int, Dict[int, float]]
|
672
|
+
|
673
|
+
@model_validator(mode="after")
|
674
|
+
def check_is_fp8(self) -> "KVCacheQuantSchema":
|
675
|
+
assert self.dtype == "float8_e4m3fn", (
|
676
|
+
"Loaded scaling factors intended for KV cache dtype = "
|
677
|
+
f"{self.dtype} rather than float8_e4m3fn!"
|
678
|
+
)
|
679
|
+
return self
|
680
|
+
|
681
|
+
@model_validator(mode="after")
|
682
|
+
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
683
|
+
context = info.context
|
684
|
+
if context:
|
685
|
+
tp_size = context["tp_size"]
|
686
|
+
num_hidden_layers = context["num_hidden_layers"]
|
687
|
+
assert len(self.scaling_factor) == tp_size, (
|
688
|
+
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
|
689
|
+
f"but LLM engine is currently running with TP size {tp_size}."
|
690
|
+
)
|
691
|
+
for tp_rank, layer_maps in self.scaling_factor.items():
|
692
|
+
assert len(layer_maps) == num_hidden_layers, (
|
693
|
+
f"KV cache scales map for TP rank {tp_rank} is malformed. "
|
694
|
+
f"Expected {num_hidden_layers} layers, got "
|
695
|
+
f"{len(layer_maps)}."
|
696
|
+
)
|
697
|
+
for i in range(tp_size):
|
698
|
+
assert (
|
699
|
+
i in self.scaling_factor
|
700
|
+
), f"KV cache scales map for TP rank {i} not found."
|
701
|
+
return self
|
702
|
+
|
703
|
+
@model_validator(mode="after")
|
704
|
+
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
705
|
+
context = info.context
|
706
|
+
if context:
|
707
|
+
tp_rank = context["tp_rank"]
|
708
|
+
num_hidden_layers = context["num_hidden_layers"]
|
709
|
+
layer_scales_map = self.scaling_factor[tp_rank]
|
710
|
+
for i in range(num_hidden_layers):
|
711
|
+
assert i in layer_scales_map, (
|
712
|
+
f"Could not find KV cache scales for layer {i} in "
|
713
|
+
f"TP rank {tp_rank}."
|
714
|
+
)
|
715
|
+
return self
|
716
|
+
|
717
|
+
|
718
|
+
class QuantParamSchema(BaseModel):
|
719
|
+
# TODO: Generalize and extend with more fields
|
720
|
+
# (e.g. weights/activations params) once functionality is enabled
|
721
|
+
model_config = ConfigDict(protected_namespaces=())
|
722
|
+
model_type: Optional[str]
|
723
|
+
kv_cache: KVCacheQuantSchema
|
724
|
+
|
725
|
+
@model_validator(mode="after")
|
726
|
+
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
|
727
|
+
context = info.context
|
728
|
+
if context:
|
729
|
+
model_type = context.get("model_type", None)
|
730
|
+
if model_type is not None:
|
731
|
+
assert model_type == self.model_type, (
|
732
|
+
f"Model type is {model_type} but loaded "
|
733
|
+
f"scaling factors belonging to different "
|
734
|
+
f"model type {self.model_type}!"
|
735
|
+
)
|
736
|
+
return self
|
737
|
+
|
738
|
+
|
653
739
|
def kv_cache_scales_loader(
|
654
740
|
filename: str,
|
655
741
|
tp_rank: int,
|
@@ -681,7 +767,7 @@ def kv_cache_scales_loader(
|
|
681
767
|
except json.JSONDecodeError:
|
682
768
|
logger.error("Error decoding JSON in file '%s'.", filename)
|
683
769
|
except Exception:
|
684
|
-
logger.
|
770
|
+
logger.error("An error occurred while reading '%s'.", filename)
|
685
771
|
# This section is reached if and only if any of the excepts are hit
|
686
772
|
# Return an empty iterable (list) => no KV cache scales are loaded
|
687
773
|
# which ultimately defaults to 1.0 scales
|
sglang/srt/models/commandr.py
CHANGED
@@ -61,7 +61,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|
61
61
|
from sglang.srt.layers.rotary_embedding import get_rope
|
62
62
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
63
63
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
64
|
-
from sglang.srt.model_loader.weight_utils import
|
64
|
+
from sglang.srt.model_loader.weight_utils import (
|
65
|
+
default_weight_loader,
|
66
|
+
maybe_remap_kv_scale_name,
|
67
|
+
)
|
65
68
|
from sglang.srt.utils import get_compiler_backend, set_weight_attrs
|
66
69
|
|
67
70
|
|
@@ -372,10 +375,19 @@ class CohereForCausalLM(nn.Module):
|
|
372
375
|
# Skip loading extra bias for GPTQ models.
|
373
376
|
if name.endswith(".bias") and name not in params_dict:
|
374
377
|
continue
|
378
|
+
# Remapping the name of FP8 kv-scale.
|
379
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
380
|
+
if name is None:
|
381
|
+
continue
|
382
|
+
|
375
383
|
param = params_dict[name]
|
376
384
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
377
385
|
weight_loader(param, loaded_weight)
|
378
386
|
loaded_params.add(name)
|
379
387
|
|
380
388
|
|
381
|
-
|
389
|
+
class Cohere2ForCausalLM(CohereForCausalLM):
|
390
|
+
pass
|
391
|
+
|
392
|
+
|
393
|
+
EntryClass = [CohereForCausalLM, Cohere2ForCausalLM]
|
sglang/srt/models/dbrx.py
CHANGED
@@ -42,7 +42,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
42
42
|
VocabParallelEmbedding,
|
43
43
|
)
|
44
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
45
|
-
from sglang.srt.model_loader.weight_utils import
|
45
|
+
from sglang.srt.model_loader.weight_utils import (
|
46
|
+
default_weight_loader,
|
47
|
+
maybe_remap_kv_scale_name,
|
48
|
+
)
|
46
49
|
from sglang.srt.utils import set_weight_attrs
|
47
50
|
|
48
51
|
|
@@ -411,6 +414,11 @@ class DbrxForCausalLM(nn.Module):
|
|
411
414
|
weight_loader(param, loaded_weight, weight_name)
|
412
415
|
break
|
413
416
|
else:
|
417
|
+
# Remapping the name of FP8 kv-scale.
|
418
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
419
|
+
if name is None:
|
420
|
+
continue
|
421
|
+
|
414
422
|
param = params_dict[name]
|
415
423
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
416
424
|
weight_loader(param, loaded_weight)
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -56,12 +56,12 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
56
56
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
57
57
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
58
58
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
59
|
-
from sglang.srt.utils import
|
59
|
+
from sglang.srt.utils import is_cuda_available, is_hip
|
60
60
|
|
61
61
|
is_hip_ = is_hip()
|
62
62
|
|
63
|
-
if
|
64
|
-
from
|
63
|
+
if is_cuda_available():
|
64
|
+
from sgl_kernel import bmm_fp8
|
65
65
|
|
66
66
|
|
67
67
|
class DeepseekV2MLP(nn.Module):
|
sglang/srt/models/gemma2.py
CHANGED
@@ -35,7 +35,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|
35
35
|
from sglang.srt.layers.rotary_embedding import get_rope
|
36
36
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
37
37
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
38
|
-
from sglang.srt.model_loader.weight_utils import
|
38
|
+
from sglang.srt.model_loader.weight_utils import (
|
39
|
+
default_weight_loader,
|
40
|
+
maybe_remap_kv_scale_name,
|
41
|
+
)
|
39
42
|
from sglang.srt.utils import make_layers
|
40
43
|
|
41
44
|
|
@@ -424,6 +427,11 @@ class Gemma2ForCausalLM(nn.Module):
|
|
424
427
|
# Skip loading extra bias for GPTQ models.
|
425
428
|
if name.endswith(".bias") and name not in params_dict:
|
426
429
|
continue
|
430
|
+
# Remapping the name of FP8 kv-scale.
|
431
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
432
|
+
if name is None:
|
433
|
+
continue
|
434
|
+
|
427
435
|
param = params_dict[name]
|
428
436
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
429
437
|
weight_loader(param, loaded_weight)
|
sglang/srt/models/grok.py
CHANGED
sglang/srt/models/minicpm3.py
CHANGED
@@ -40,10 +40,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
40
40
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
41
41
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
42
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
|
-
from sglang.srt.utils import
|
43
|
+
from sglang.srt.utils import is_cuda_available
|
44
44
|
|
45
|
-
if
|
46
|
-
from
|
45
|
+
if is_cuda_available():
|
46
|
+
from sgl_kernel import bmm_fp8
|
47
47
|
|
48
48
|
|
49
49
|
class MiniCPM3MLP(nn.Module):
|
@@ -460,7 +460,12 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
460
460
|
params_dict = dict(self.named_parameters())
|
461
461
|
return len(params_dict)
|
462
462
|
|
463
|
-
def
|
463
|
+
def load_weights_to_module(
|
464
|
+
self,
|
465
|
+
fqn: str,
|
466
|
+
weights: Iterable[Tuple[str, torch.Tensor]],
|
467
|
+
):
|
468
|
+
"""Load weights onto submodule pointed by path `fqn`."""
|
464
469
|
stacked_params_mapping = [
|
465
470
|
# (param_name, shard_name, shard_id)
|
466
471
|
(".qkv_proj", ".q_proj", "q"),
|
@@ -469,7 +474,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
469
474
|
(".gate_up_proj", ".gate_proj", 0),
|
470
475
|
(".gate_up_proj", ".up_proj", 1),
|
471
476
|
]
|
472
|
-
|
477
|
+
module = self.get_submodule(fqn)
|
478
|
+
params_dict = dict(module.named_parameters(prefix=fqn, recurse=False))
|
473
479
|
|
474
480
|
for name, loaded_weight in weights:
|
475
481
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
@@ -486,7 +492,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
486
492
|
continue
|
487
493
|
name = name.replace(weight_name, param_name)
|
488
494
|
# Skip loading extra bias for GPTQ models.
|
489
|
-
if name.endswith(".bias")
|
495
|
+
if name.endswith(".bias") or name not in params_dict:
|
490
496
|
continue
|
491
497
|
param = params_dict[name]
|
492
498
|
weight_loader = param.weight_loader
|
@@ -494,12 +500,19 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
494
500
|
break
|
495
501
|
else:
|
496
502
|
# Skip loading extra bias for GPTQ models.
|
497
|
-
if name.endswith(".bias")
|
503
|
+
if name.endswith(".bias") or name not in params_dict:
|
498
504
|
continue
|
499
505
|
param = params_dict[name]
|
500
506
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
501
507
|
weight_loader(param, loaded_weight)
|
502
508
|
|
509
|
+
def load_weights(
|
510
|
+
self,
|
511
|
+
weights: Iterable[Tuple[str, torch.Tensor]],
|
512
|
+
):
|
513
|
+
"""Load weights onto the full model."""
|
514
|
+
self.load_weights_to_module("", weights)
|
515
|
+
|
503
516
|
|
504
517
|
class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
|
505
518
|
pass
|