sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -75,10 +75,6 @@ class ForwardMode(IntEnum):
|
|
75
75
|
# Used in speculative decoding: extend a batch in the draft model.
|
76
76
|
DRAFT_EXTEND = auto()
|
77
77
|
|
78
|
-
# A dummy first batch to start the pipeline for overlap scheduler.
|
79
|
-
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
80
|
-
DUMMY_FIRST = auto()
|
81
|
-
|
82
78
|
# Split Prefill for PD multiplexing
|
83
79
|
SPLIT_PREFILL = auto()
|
84
80
|
|
@@ -128,9 +124,6 @@ class ForwardMode(IntEnum):
|
|
128
124
|
def is_cpu_graph(self):
|
129
125
|
return self == ForwardMode.DECODE
|
130
126
|
|
131
|
-
def is_dummy_first(self):
|
132
|
-
return self == ForwardMode.DUMMY_FIRST
|
133
|
-
|
134
127
|
def is_split_prefill(self):
|
135
128
|
return self == ForwardMode.SPLIT_PREFILL
|
136
129
|
|
@@ -285,6 +278,9 @@ class ForwardBatch:
|
|
285
278
|
can_run_dp_cuda_graph: bool = False
|
286
279
|
global_forward_mode: Optional[ForwardMode] = None
|
287
280
|
|
281
|
+
# Whether this batch is prefill-only (no token generation needed)
|
282
|
+
is_prefill_only: bool = False
|
283
|
+
|
288
284
|
# Speculative decoding
|
289
285
|
spec_info: Optional[SpecInput] = None
|
290
286
|
spec_algorithm: SpeculativeAlgorithm = None
|
@@ -332,6 +328,7 @@ class ForwardBatch:
|
|
332
328
|
is_extend_in_batch=batch.is_extend_in_batch,
|
333
329
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
334
330
|
global_forward_mode=batch.global_forward_mode,
|
331
|
+
is_prefill_only=batch.is_prefill_only,
|
335
332
|
lora_ids=batch.lora_ids,
|
336
333
|
sampling_info=batch.sampling_info,
|
337
334
|
req_to_token_pool=model_runner.req_to_token_pool,
|
@@ -902,17 +899,6 @@ class ForwardBatch:
|
|
902
899
|
return self.tbo_split_seq_index is not None
|
903
900
|
|
904
901
|
|
905
|
-
@dataclass
|
906
|
-
class ForwardBatchOutput:
|
907
|
-
# FIXME(lsyin): unify the forward batch output between different spec and parallelism
|
908
|
-
# need to be more organized
|
909
|
-
logits_output: Optional[torch.Tensor] = None
|
910
|
-
next_token_ids: Optional[torch.Tensor] = None
|
911
|
-
num_accepted_tokens: Optional[int] = None
|
912
|
-
pp_proxy_tensors: Optional[PPProxyTensors] = None
|
913
|
-
can_run_cuda_graph: bool = False
|
914
|
-
|
915
|
-
|
916
902
|
def enable_num_token_non_padded(server_args):
|
917
903
|
return get_moe_expert_parallel_world_size() > 1
|
918
904
|
|
@@ -29,6 +29,7 @@ from typing import List, Optional, Tuple, Union
|
|
29
29
|
import torch
|
30
30
|
import torch.distributed as dist
|
31
31
|
|
32
|
+
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
|
32
33
|
from sglang.srt.configs.device_config import DeviceConfig
|
33
34
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
34
35
|
from sglang.srt.configs.model_config import (
|
@@ -354,8 +355,9 @@ class ModelRunner:
|
|
354
355
|
if architectures and not any("Llama4" in arch for arch in architectures):
|
355
356
|
self.is_hybrid = self.model_config.is_hybrid = True
|
356
357
|
|
357
|
-
if self.
|
358
|
-
|
358
|
+
if config := self.mambaish_config:
|
359
|
+
class_name = config.__class__.__name__
|
360
|
+
logger.warning(f"{class_name} model detected, disable radix cache")
|
359
361
|
self.server_args.disable_radix_cache = True
|
360
362
|
if self.server_args.max_mamba_cache_size is None:
|
361
363
|
if self.server_args.max_running_requests is not None:
|
@@ -364,6 +366,7 @@ class ModelRunner:
|
|
364
366
|
)
|
365
367
|
else:
|
366
368
|
self.server_args.max_mamba_cache_size = 512
|
369
|
+
if self.hybrid_gdn_config is not None:
|
367
370
|
self.server_args.max_mamba_cache_size = (
|
368
371
|
self.server_args.max_mamba_cache_size
|
369
372
|
// (
|
@@ -880,7 +883,7 @@ class ModelRunner:
|
|
880
883
|
load_config = LoadConfig(load_format=load_format)
|
881
884
|
|
882
885
|
# Only support DefaultModelLoader for now
|
883
|
-
loader = get_model_loader(load_config)
|
886
|
+
loader = get_model_loader(load_config, self.model_config)
|
884
887
|
if not isinstance(loader, DefaultModelLoader):
|
885
888
|
message = f"Failed to get model loader: {loader}."
|
886
889
|
return False, message
|
@@ -1267,8 +1270,8 @@ class ModelRunner:
|
|
1267
1270
|
"num_nextn_predict_layers",
|
1268
1271
|
self.num_effective_layers,
|
1269
1272
|
)
|
1270
|
-
elif self.
|
1271
|
-
num_layers = len(
|
1273
|
+
elif config := self.mambaish_config:
|
1274
|
+
num_layers = len(config.full_attention_layer_ids)
|
1272
1275
|
else:
|
1273
1276
|
num_layers = self.num_effective_layers
|
1274
1277
|
if self.use_mla_backend:
|
@@ -1277,6 +1280,17 @@ class ModelRunner:
|
|
1277
1280
|
* num_layers
|
1278
1281
|
* torch._utils._element_size(self.kv_cache_dtype)
|
1279
1282
|
)
|
1283
|
+
# Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
|
1284
|
+
if is_deepseek_nsa(self.model_config.hf_config):
|
1285
|
+
index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
|
1286
|
+
indexer_size_per_token = (
|
1287
|
+
index_head_dim
|
1288
|
+
+ index_head_dim // NSATokenToKVPool.quant_block_size * 4
|
1289
|
+
)
|
1290
|
+
element_size = torch._utils._element_size(
|
1291
|
+
NSATokenToKVPool.index_k_with_scale_buffer_dtype
|
1292
|
+
)
|
1293
|
+
cell_size += indexer_size_per_token * num_layers * element_size
|
1280
1294
|
else:
|
1281
1295
|
cell_size = (
|
1282
1296
|
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
@@ -1288,22 +1302,32 @@ class ModelRunner:
|
|
1288
1302
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
1289
1303
|
1 - self.mem_fraction_static
|
1290
1304
|
)
|
1291
|
-
if self.
|
1305
|
+
if config := self.mambaish_config:
|
1292
1306
|
rest_memory -= (
|
1293
1307
|
self.server_args.max_mamba_cache_size
|
1294
|
-
*
|
1308
|
+
* config.mamba2_cache_params.mamba_cache_per_req
|
1295
1309
|
/ (1 << 30)
|
1296
1310
|
)
|
1297
1311
|
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
1298
1312
|
return max_num_token
|
1299
1313
|
|
1300
1314
|
@property
|
1301
|
-
def
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1315
|
+
def hybrid_gdn_config(self):
|
1316
|
+
config = self.model_config.hf_config
|
1317
|
+
if isinstance(config, Qwen3NextConfig):
|
1318
|
+
return config
|
1319
|
+
return None
|
1320
|
+
|
1321
|
+
@property
|
1322
|
+
def mamba2_config(self):
|
1323
|
+
config = self.model_config.hf_config
|
1324
|
+
if isinstance(config, FalconH1Config | NemotronHConfig):
|
1325
|
+
return config
|
1326
|
+
return None
|
1327
|
+
|
1328
|
+
@property
|
1329
|
+
def mambaish_config(self):
|
1330
|
+
return self.mamba2_config or self.hybrid_gdn_config
|
1307
1331
|
|
1308
1332
|
def set_num_token_hybrid(self):
|
1309
1333
|
if (
|
@@ -1438,7 +1462,7 @@ class ModelRunner:
|
|
1438
1462
|
),
|
1439
1463
|
4096,
|
1440
1464
|
)
|
1441
|
-
if self.
|
1465
|
+
if self.mambaish_config is not None:
|
1442
1466
|
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
|
1443
1467
|
|
1444
1468
|
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
@@ -1519,26 +1543,14 @@ class ModelRunner:
|
|
1519
1543
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1520
1544
|
pre_alloc_size=pre_alloc_size,
|
1521
1545
|
)
|
1522
|
-
elif self.
|
1523
|
-
config = self.model_config.hf_config
|
1524
|
-
(
|
1525
|
-
conv_state_shape,
|
1526
|
-
temporal_state_shape,
|
1527
|
-
conv_dtype,
|
1528
|
-
ssm_dtype,
|
1529
|
-
mamba_layers,
|
1530
|
-
) = config.hybrid_gdn_params
|
1546
|
+
elif config := self.mambaish_config:
|
1531
1547
|
self.req_to_token_pool = HybridReqToTokenPool(
|
1532
1548
|
size=max_num_reqs,
|
1533
1549
|
max_context_len=self.model_config.context_len
|
1534
1550
|
+ extra_max_context_len,
|
1535
1551
|
device=self.device,
|
1536
1552
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1537
|
-
|
1538
|
-
temporal_state_shape=temporal_state_shape,
|
1539
|
-
conv_dtype=conv_dtype,
|
1540
|
-
ssm_dtype=ssm_dtype,
|
1541
|
-
mamba_layers=mamba_layers,
|
1553
|
+
cache_params=config.mamba2_cache_params,
|
1542
1554
|
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
1543
1555
|
)
|
1544
1556
|
else:
|
@@ -1640,7 +1652,7 @@ class ModelRunner:
|
|
1640
1652
|
enable_kvcache_transpose=False,
|
1641
1653
|
device=self.device,
|
1642
1654
|
)
|
1643
|
-
elif self.
|
1655
|
+
elif config := self.mambaish_config:
|
1644
1656
|
self.token_to_kv_pool = HybridLinearKVPool(
|
1645
1657
|
page_size=self.page_size,
|
1646
1658
|
size=self.max_total_num_tokens,
|
@@ -1651,9 +1663,7 @@ class ModelRunner:
|
|
1651
1663
|
head_dim=self.model_config.head_dim,
|
1652
1664
|
# if draft worker, we only need 1 attention layer's kv pool
|
1653
1665
|
full_attention_layer_ids=(
|
1654
|
-
[0]
|
1655
|
-
if self.is_draft_worker
|
1656
|
-
else self.model_config.hf_config.full_attention_layer_ids
|
1666
|
+
[0] if self.is_draft_worker else config.full_attention_layer_ids
|
1657
1667
|
),
|
1658
1668
|
enable_kvcache_transpose=False,
|
1659
1669
|
device=self.device,
|
@@ -1672,13 +1682,17 @@ class ModelRunner:
|
|
1672
1682
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1673
1683
|
start_layer=self.start_layer,
|
1674
1684
|
end_layer=self.end_layer,
|
1685
|
+
enable_kv_cache_copy=(
|
1686
|
+
self.server_args.speculative_algorithm is not None
|
1687
|
+
),
|
1675
1688
|
)
|
1676
1689
|
|
1677
1690
|
# Initialize token_to_kv_pool_allocator
|
1678
1691
|
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
1679
1692
|
if self.token_to_kv_pool_allocator is None:
|
1680
1693
|
if _is_npu and (
|
1681
|
-
self.server_args.attention_backend == "ascend"
|
1694
|
+
self.server_args.attention_backend == "ascend"
|
1695
|
+
or self.hybrid_gdn_config is not None
|
1682
1696
|
):
|
1683
1697
|
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
1684
1698
|
self.max_total_num_tokens,
|
@@ -1743,16 +1757,10 @@ class ModelRunner:
|
|
1743
1757
|
|
1744
1758
|
def _get_attention_backend(self):
|
1745
1759
|
"""Init attention kernel backend."""
|
1746
|
-
self.decode_attention_backend_str = (
|
1747
|
-
self.server_args.
|
1748
|
-
if self.server_args.decode_attention_backend
|
1749
|
-
else self.server_args.attention_backend
|
1750
|
-
)
|
1751
|
-
self.prefill_attention_backend_str = (
|
1752
|
-
self.server_args.prefill_attention_backend
|
1753
|
-
if self.server_args.prefill_attention_backend
|
1754
|
-
else self.server_args.attention_backend
|
1760
|
+
self.prefill_attention_backend_str, self.decode_attention_backend_str = (
|
1761
|
+
self.server_args.get_attention_backends()
|
1755
1762
|
)
|
1763
|
+
|
1756
1764
|
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
1757
1765
|
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
1758
1766
|
HybridAttnBackend,
|
@@ -2057,15 +2065,11 @@ class ModelRunner:
|
|
2057
2065
|
def _preprocess_logits(
|
2058
2066
|
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
2059
2067
|
):
|
2060
|
-
#
|
2061
|
-
|
2062
|
-
|
2063
|
-
|
2064
|
-
|
2065
|
-
sampling_info.sampling_info_done.wait()
|
2066
|
-
else:
|
2067
|
-
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
2068
|
-
sampling_info.update_regex_vocab_mask()
|
2068
|
+
# NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
|
2069
|
+
# was executed after we processed last batch's results.
|
2070
|
+
|
2071
|
+
# Calculate logits bias and apply it to next_token_logits.
|
2072
|
+
sampling_info.update_regex_vocab_mask()
|
2069
2073
|
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
2070
2074
|
|
2071
2075
|
def sample(
|
@@ -24,7 +24,7 @@ def get_model(
|
|
24
24
|
load_config: LoadConfig,
|
25
25
|
device_config: DeviceConfig,
|
26
26
|
) -> nn.Module:
|
27
|
-
loader = get_model_loader(load_config)
|
27
|
+
loader = get_model_loader(load_config, model_config)
|
28
28
|
return loader.load_model(
|
29
29
|
model_config=model_config,
|
30
30
|
device_config=device_config,
|
@@ -37,10 +37,22 @@ import numpy as np
|
|
37
37
|
import requests
|
38
38
|
import safetensors.torch
|
39
39
|
import torch
|
40
|
+
|
41
|
+
# Try to import accelerate (optional dependency)
|
42
|
+
try:
|
43
|
+
from accelerate import infer_auto_device_map, init_empty_weights
|
44
|
+
from accelerate.utils import get_max_memory
|
45
|
+
|
46
|
+
HAS_ACCELERATE = True
|
47
|
+
except ImportError:
|
48
|
+
HAS_ACCELERATE = False
|
49
|
+
infer_auto_device_map = None
|
50
|
+
init_empty_weights = None
|
51
|
+
get_max_memory = None
|
52
|
+
|
40
53
|
from huggingface_hub import HfApi, hf_hub_download
|
41
54
|
from torch import nn
|
42
|
-
from
|
43
|
-
from transformers import AutoModelForCausalLM
|
55
|
+
from transformers import AutoConfig, AutoModelForCausalLM
|
44
56
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
45
57
|
|
46
58
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
@@ -54,6 +66,8 @@ from sglang.srt.distributed import (
|
|
54
66
|
get_tensor_model_parallel_rank,
|
55
67
|
get_tensor_model_parallel_world_size,
|
56
68
|
)
|
69
|
+
from sglang.srt.layers.modelopt_utils import QUANT_CFG_CHOICES
|
70
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
57
71
|
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
58
72
|
trigger_transferring_weights_request,
|
59
73
|
)
|
@@ -62,6 +76,11 @@ from sglang.srt.model_loader.utils import (
|
|
62
76
|
post_load_weights,
|
63
77
|
set_default_torch_dtype,
|
64
78
|
)
|
79
|
+
|
80
|
+
# Constants for memory management
|
81
|
+
DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
|
82
|
+
0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
|
83
|
+
)
|
65
84
|
from sglang.srt.model_loader.weight_utils import (
|
66
85
|
_BAR_FORMAT,
|
67
86
|
default_weight_loader,
|
@@ -94,6 +113,8 @@ if TYPE_CHECKING:
|
|
94
113
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
95
114
|
|
96
115
|
_is_npu = is_npu()
|
116
|
+
# ModelOpt: QUANT_CFG_CHOICES is imported from modelopt_utils.py
|
117
|
+
# which contains the complete mapping of quantization config choices
|
97
118
|
|
98
119
|
|
99
120
|
@contextmanager
|
@@ -477,12 +498,78 @@ class DefaultModelLoader(BaseModelLoader):
|
|
477
498
|
model_config.model_path, model_config.revision, fall_back_to_pt=True
|
478
499
|
)
|
479
500
|
|
501
|
+
def _load_modelopt_base_model(self, model_config: ModelConfig) -> nn.Module:
|
502
|
+
"""Load and prepare the base model for ModelOpt quantization.
|
503
|
+
|
504
|
+
This method handles the common model loading logic shared between
|
505
|
+
DefaultModelLoader (conditional) and ModelOptModelLoader (dedicated).
|
506
|
+
"""
|
507
|
+
if not HAS_ACCELERATE:
|
508
|
+
raise ImportError(
|
509
|
+
"accelerate is required for ModelOpt quantization. "
|
510
|
+
"Please install it with: pip install accelerate"
|
511
|
+
)
|
512
|
+
|
513
|
+
hf_config = AutoConfig.from_pretrained(
|
514
|
+
model_config.model_path, trust_remote_code=True
|
515
|
+
)
|
516
|
+
with init_empty_weights():
|
517
|
+
torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
|
518
|
+
model = AutoModelForCausalLM.from_config(
|
519
|
+
hf_config, torch_dtype=torch_dtype, trust_remote_code=True
|
520
|
+
)
|
521
|
+
max_memory = get_max_memory()
|
522
|
+
inferred_device_map = infer_auto_device_map(model, max_memory=max_memory)
|
523
|
+
|
524
|
+
on_cpu = "cpu" in inferred_device_map.values()
|
525
|
+
model_kwargs = {"torch_dtype": "auto"}
|
526
|
+
device_map = "auto"
|
527
|
+
|
528
|
+
if on_cpu:
|
529
|
+
for device in max_memory.keys():
|
530
|
+
if isinstance(device, int):
|
531
|
+
max_memory[device] *= DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION
|
532
|
+
|
533
|
+
logger.warning(
|
534
|
+
"Model does not fit to the GPU mem. "
|
535
|
+
f"We apply the following memory limit for calibration: \n{max_memory}\n"
|
536
|
+
f"If you hit GPU OOM issue, please adjust the memory fraction "
|
537
|
+
f"(currently {DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION}) or "
|
538
|
+
"reduce the calibration `batch_size` manually."
|
539
|
+
)
|
540
|
+
model_kwargs["max_memory"] = max_memory
|
541
|
+
|
542
|
+
model = AutoModelForCausalLM.from_pretrained(
|
543
|
+
model_config.model_path,
|
544
|
+
device_map=device_map,
|
545
|
+
**model_kwargs,
|
546
|
+
trust_remote_code=True,
|
547
|
+
)
|
548
|
+
logger.info(f"ModelOpt quantization requested: {model_config.modelopt_quant}")
|
549
|
+
|
550
|
+
quant_choice_str = model_config.modelopt_quant
|
551
|
+
if not isinstance(quant_choice_str, str):
|
552
|
+
raise TypeError(
|
553
|
+
f"modelopt_quant must be a string preset key (e.g., 'fp8'), "
|
554
|
+
f"got {type(quant_choice_str)}"
|
555
|
+
)
|
556
|
+
|
557
|
+
return model
|
558
|
+
|
480
559
|
def load_model(
|
481
560
|
self,
|
482
561
|
*,
|
483
562
|
model_config: ModelConfig,
|
484
563
|
device_config: DeviceConfig,
|
485
564
|
) -> nn.Module:
|
565
|
+
|
566
|
+
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
|
567
|
+
# Load base model using shared method
|
568
|
+
model = self._load_modelopt_base_model(model_config)
|
569
|
+
# Note: DefaultModelLoader doesn't do additional quantization processing
|
570
|
+
# For full ModelOpt quantization, use ModelOptModelLoader
|
571
|
+
return model.eval()
|
572
|
+
|
486
573
|
target_device = torch.device(device_config.device)
|
487
574
|
with set_default_torch_dtype(model_config.dtype):
|
488
575
|
with target_device:
|
@@ -491,9 +578,9 @@ class DefaultModelLoader(BaseModelLoader):
|
|
491
578
|
self.load_config,
|
492
579
|
)
|
493
580
|
|
494
|
-
|
495
|
-
|
496
|
-
|
581
|
+
self.load_weights_and_postprocess(
|
582
|
+
model, self._get_all_weights(model_config, model), target_device
|
583
|
+
)
|
497
584
|
|
498
585
|
return model.eval()
|
499
586
|
|
@@ -1668,9 +1755,103 @@ def load_model_with_cpu_quantization(
|
|
1668
1755
|
return model.eval()
|
1669
1756
|
|
1670
1757
|
|
1671
|
-
|
1758
|
+
class ModelOptModelLoader(DefaultModelLoader):
|
1759
|
+
"""
|
1760
|
+
Model loader that applies NVIDIA Model Optimizer quantization
|
1761
|
+
"""
|
1762
|
+
|
1763
|
+
def __init__(self, load_config: LoadConfig):
|
1764
|
+
super().__init__(load_config)
|
1765
|
+
# Any ModelOpt specific initialization if needed
|
1766
|
+
|
1767
|
+
def load_model(
|
1768
|
+
self,
|
1769
|
+
*,
|
1770
|
+
model_config: ModelConfig,
|
1771
|
+
device_config: DeviceConfig,
|
1772
|
+
) -> nn.Module:
|
1773
|
+
|
1774
|
+
logger.info("ModelOptModelLoader: Loading base model...")
|
1775
|
+
|
1776
|
+
# Use shared method from parent class to load base model
|
1777
|
+
model = self._load_modelopt_base_model(model_config)
|
1778
|
+
|
1779
|
+
# Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
|
1780
|
+
try:
|
1781
|
+
import modelopt.torch.quantization as mtq
|
1782
|
+
from modelopt.torch.utils.dataset_utils import create_forward_loop
|
1783
|
+
except ImportError:
|
1784
|
+
logger.error(
|
1785
|
+
"NVIDIA Model Optimizer (modelopt) library not found. "
|
1786
|
+
"Please install it to use 'modelopt_quant' feature."
|
1787
|
+
)
|
1788
|
+
raise
|
1789
|
+
|
1790
|
+
quant_choice_str = model_config.modelopt_quant
|
1791
|
+
|
1792
|
+
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
|
1793
|
+
if not quant_cfg_name:
|
1794
|
+
raise ValueError(
|
1795
|
+
f"Invalid modelopt_quant choice: '{quant_choice_str}'. "
|
1796
|
+
f"Available choices in QUANT_CFG_CHOICES: {list(QUANT_CFG_CHOICES.keys())}. "
|
1797
|
+
"Ensure QUANT_CFG_CHOICES is correctly defined with mappings to "
|
1798
|
+
"attribute names of config objects in modelopt.torch.quantization."
|
1799
|
+
)
|
1800
|
+
|
1801
|
+
try:
|
1802
|
+
# getattr will fetch the config object, e.g., mtq.FP8_DEFAULT_CFG
|
1803
|
+
quant_cfg = getattr(mtq, quant_cfg_name)
|
1804
|
+
except AttributeError:
|
1805
|
+
raise AttributeError(
|
1806
|
+
f"ModelOpt quantization config attribute '{quant_cfg_name}' "
|
1807
|
+
f"(from choice '{quant_choice_str}') not found in modelopt.torch.quantization module. "
|
1808
|
+
"Please verify QUANT_CFG_CHOICES and the ModelOpt library."
|
1809
|
+
)
|
1810
|
+
|
1811
|
+
# For now, assume no calibration. Calibration setup is a separate, more complex step.
|
1812
|
+
use_calibration = False # This would ideally be a configurable parameter
|
1813
|
+
calib_dataloader = None # This would need to be provided/configured
|
1814
|
+
|
1815
|
+
calibrate_loop = (
|
1816
|
+
create_forward_loop(dataloader=calib_dataloader)
|
1817
|
+
if use_calibration
|
1818
|
+
else None
|
1819
|
+
)
|
1820
|
+
|
1821
|
+
if use_calibration and calib_dataloader is None:
|
1822
|
+
logger.warning(
|
1823
|
+
"ModelOpt calibration requested but no calib_dataloader provided. "
|
1824
|
+
"Proceeding without calibration. Quantization accuracy may be affected."
|
1825
|
+
)
|
1826
|
+
|
1827
|
+
logger.info(
|
1828
|
+
f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}"
|
1829
|
+
)
|
1830
|
+
|
1831
|
+
try:
|
1832
|
+
model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
|
1833
|
+
logger.info("Model successfully quantized with ModelOpt.")
|
1834
|
+
except Exception as e:
|
1835
|
+
logger.error(f"Error during ModelOpt mtq.quantize call: {e}")
|
1836
|
+
raise
|
1837
|
+
mtq.print_quant_summary(model)
|
1838
|
+
|
1839
|
+
return model.eval()
|
1840
|
+
|
1841
|
+
|
1842
|
+
def get_model_loader(
|
1843
|
+
load_config: LoadConfig, model_config: Optional[ModelConfig] = None
|
1844
|
+
) -> BaseModelLoader:
|
1672
1845
|
"""Get a model loader based on the load format."""
|
1673
1846
|
|
1847
|
+
if (
|
1848
|
+
model_config
|
1849
|
+
and hasattr(model_config, "modelopt_quant")
|
1850
|
+
and model_config.modelopt_quant
|
1851
|
+
):
|
1852
|
+
logger.info("Using ModelOptModelLoader due to 'modelopt_quant' config.")
|
1853
|
+
return ModelOptModelLoader(load_config)
|
1854
|
+
|
1674
1855
|
if isinstance(load_config.load_format, type):
|
1675
1856
|
return load_config.load_format(load_config)
|
1676
1857
|
|
@@ -226,6 +226,9 @@ def get_quant_config(
|
|
226
226
|
return ModelOptFp4Config.from_config(config)
|
227
227
|
else:
|
228
228
|
return quant_cls.from_config(config)
|
229
|
+
elif model_config.quantization == "modelopt_fp8":
|
230
|
+
if config["producer"]["name"] == "modelopt_fp8":
|
231
|
+
return quant_cls.from_config(config)
|
229
232
|
else:
|
230
233
|
raise ValueError(
|
231
234
|
f"Unsupported quantization config"
|
sglang/srt/models/falcon_h1.py
CHANGED
@@ -8,6 +8,10 @@ from torch import nn
|
|
8
8
|
from sglang.srt.configs.falcon_h1 import FalconH1Config
|
9
9
|
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
10
10
|
from sglang.srt.layers.activation import SiluAndMul
|
11
|
+
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
12
|
+
HybridLinearAttnBackend,
|
13
|
+
Mamba2AttnBackend,
|
14
|
+
)
|
11
15
|
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
|
12
16
|
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
13
17
|
from sglang.srt.layers.dp_attention import (
|
@@ -184,18 +188,12 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module):
|
|
184
188
|
)
|
185
189
|
|
186
190
|
self.mamba = MambaMixer2(
|
191
|
+
cache_params=config.mamba2_cache_params,
|
187
192
|
hidden_size=config.hidden_size,
|
188
|
-
ssm_state_size=config.mamba_d_state,
|
189
|
-
conv_kernel_size=config.mamba_d_conv,
|
190
|
-
intermediate_size=self.d_ssm,
|
191
193
|
use_conv_bias=config.mamba_conv_bias,
|
192
194
|
use_bias=config.mamba_proj_bias,
|
193
195
|
n_groups=config.mamba_n_groups,
|
194
|
-
num_heads=config.mamba_n_heads,
|
195
|
-
layer_id=layer_id,
|
196
|
-
head_dim=config.mamba_d_head,
|
197
196
|
rms_norm_eps=config.rms_norm_eps,
|
198
|
-
chunk_size=config.mamba_chunk_size,
|
199
197
|
activation=config.hidden_act,
|
200
198
|
use_rms_norm=config.mamba_rms_norm,
|
201
199
|
prefix=f"{prefix}.mixer",
|
@@ -339,12 +337,16 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module):
|
|
339
337
|
)
|
340
338
|
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
|
341
339
|
|
340
|
+
attn_backend = forward_batch.attn_backend
|
341
|
+
assert isinstance(attn_backend, HybridLinearAttnBackend)
|
342
|
+
assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
|
342
343
|
# Mamba block
|
343
344
|
mamba_hidden_states = torch.empty_like(hidden_states)
|
344
|
-
|
345
|
+
attn_backend.linear_attn_backend.forward(
|
346
|
+
self.mamba,
|
345
347
|
hidden_states * self.ssm_in_multiplier,
|
346
348
|
mamba_hidden_states,
|
347
|
-
|
349
|
+
layer_id=self.layer_id,
|
348
350
|
mup_vector=self.mup_vector,
|
349
351
|
)
|
350
352
|
mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
|
sglang/srt/models/gemma3_mm.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
|
17
17
|
|
18
18
|
import logging
|
19
|
+
import re
|
19
20
|
from functools import lru_cache
|
20
21
|
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
21
22
|
|
@@ -154,6 +155,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
154
155
|
embedding_modules = {}
|
155
156
|
embedding_padding_modules = []
|
156
157
|
supports_lora = True
|
158
|
+
# Pattern to match language model layers only (skip vision_tower and multi_modal_projector)
|
159
|
+
lora_pattern = re.compile(
|
160
|
+
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
|
161
|
+
)
|
157
162
|
|
158
163
|
def __init__(
|
159
164
|
self,
|
@@ -165,6 +170,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
165
170
|
self.config = config
|
166
171
|
self.quant_config = quant_config
|
167
172
|
|
173
|
+
# For LoRA compatibility: expose text_config attributes at top level
|
174
|
+
# This allows LoRA code to work without special multimodal handling
|
175
|
+
if not hasattr(config, "num_hidden_layers"):
|
176
|
+
config.num_hidden_layers = config.text_config.num_hidden_layers
|
177
|
+
if not hasattr(config, "hidden_size"):
|
178
|
+
config.hidden_size = config.text_config.hidden_size
|
179
|
+
|
168
180
|
self.vision_tower = SiglipVisionModel(
|
169
181
|
config=config.vision_config,
|
170
182
|
quant_config=quant_config,
|
@@ -380,6 +392,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
380
392
|
|
381
393
|
return hs
|
382
394
|
|
395
|
+
def should_apply_lora(self, module_name: str) -> bool:
|
396
|
+
"""Skip vision tower and multi_modal_projector for LoRA."""
|
397
|
+
return bool(self.lora_pattern.match(module_name))
|
398
|
+
|
383
399
|
def tie_weights(self):
|
384
400
|
return self.language_model.tie_weights()
|
385
401
|
|
sglang/srt/models/grok.py
CHANGED
@@ -49,7 +49,6 @@ from sglang.srt.layers.linear import (
|
|
49
49
|
RowParallelLinear,
|
50
50
|
)
|
51
51
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
52
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
53
52
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
54
53
|
from sglang.srt.layers.moe.router import fused_moe_router_shim
|
55
54
|
from sglang.srt.layers.moe.topk import TopK
|
@@ -176,17 +175,7 @@ class Grok1MoE(nn.Module):
|
|
176
175
|
custom_routing_function=custom_routing_function,
|
177
176
|
)
|
178
177
|
|
179
|
-
|
180
|
-
if get_moe_expert_parallel_world_size() > 1:
|
181
|
-
MoEImpl = EPMoE
|
182
|
-
else:
|
183
|
-
MoEImpl = FusedMoE
|
184
|
-
kwargs["reduce_results"] = reduce_results
|
185
|
-
kwargs["use_presharded_weights"] = use_presharded_weights
|
186
|
-
kwargs["inplace"] = inplace
|
187
|
-
kwargs["no_combine"] = no_combine
|
188
|
-
|
189
|
-
self.experts = MoEImpl(
|
178
|
+
self.experts = FusedMoE(
|
190
179
|
num_experts=num_experts,
|
191
180
|
top_k=top_k,
|
192
181
|
layer_id=layer_id,
|
@@ -195,7 +184,10 @@ class Grok1MoE(nn.Module):
|
|
195
184
|
params_dtype=params_dtype,
|
196
185
|
quant_config=quant_config,
|
197
186
|
activation="gelu",
|
198
|
-
|
187
|
+
reduce_results=reduce_results,
|
188
|
+
use_presharded_weights=use_presharded_weights,
|
189
|
+
inplace=inplace,
|
190
|
+
no_combine=no_combine,
|
199
191
|
)
|
200
192
|
|
201
193
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|