sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -75,10 +75,6 @@ class ForwardMode(IntEnum):
75
75
  # Used in speculative decoding: extend a batch in the draft model.
76
76
  DRAFT_EXTEND = auto()
77
77
 
78
- # A dummy first batch to start the pipeline for overlap scheduler.
79
- # It is now used for triggering the sampling_info_done event for the first prefill batch.
80
- DUMMY_FIRST = auto()
81
-
82
78
  # Split Prefill for PD multiplexing
83
79
  SPLIT_PREFILL = auto()
84
80
 
@@ -128,9 +124,6 @@ class ForwardMode(IntEnum):
128
124
  def is_cpu_graph(self):
129
125
  return self == ForwardMode.DECODE
130
126
 
131
- def is_dummy_first(self):
132
- return self == ForwardMode.DUMMY_FIRST
133
-
134
127
  def is_split_prefill(self):
135
128
  return self == ForwardMode.SPLIT_PREFILL
136
129
 
@@ -285,6 +278,9 @@ class ForwardBatch:
285
278
  can_run_dp_cuda_graph: bool = False
286
279
  global_forward_mode: Optional[ForwardMode] = None
287
280
 
281
+ # Whether this batch is prefill-only (no token generation needed)
282
+ is_prefill_only: bool = False
283
+
288
284
  # Speculative decoding
289
285
  spec_info: Optional[SpecInput] = None
290
286
  spec_algorithm: SpeculativeAlgorithm = None
@@ -332,6 +328,7 @@ class ForwardBatch:
332
328
  is_extend_in_batch=batch.is_extend_in_batch,
333
329
  can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
334
330
  global_forward_mode=batch.global_forward_mode,
331
+ is_prefill_only=batch.is_prefill_only,
335
332
  lora_ids=batch.lora_ids,
336
333
  sampling_info=batch.sampling_info,
337
334
  req_to_token_pool=model_runner.req_to_token_pool,
@@ -902,17 +899,6 @@ class ForwardBatch:
902
899
  return self.tbo_split_seq_index is not None
903
900
 
904
901
 
905
- @dataclass
906
- class ForwardBatchOutput:
907
- # FIXME(lsyin): unify the forward batch output between different spec and parallelism
908
- # need to be more organized
909
- logits_output: Optional[torch.Tensor] = None
910
- next_token_ids: Optional[torch.Tensor] = None
911
- num_accepted_tokens: Optional[int] = None
912
- pp_proxy_tensors: Optional[PPProxyTensors] = None
913
- can_run_cuda_graph: bool = False
914
-
915
-
916
902
  def enable_num_token_non_padded(server_args):
917
903
  return get_moe_expert_parallel_world_size() > 1
918
904
 
@@ -29,6 +29,7 @@ from typing import List, Optional, Tuple, Union
29
29
  import torch
30
30
  import torch.distributed as dist
31
31
 
32
+ from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
32
33
  from sglang.srt.configs.device_config import DeviceConfig
33
34
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
34
35
  from sglang.srt.configs.model_config import (
@@ -354,8 +355,9 @@ class ModelRunner:
354
355
  if architectures and not any("Llama4" in arch for arch in architectures):
355
356
  self.is_hybrid = self.model_config.is_hybrid = True
356
357
 
357
- if self.is_hybrid_gdn:
358
- logger.warning("Hybrid GDN model detected, disable radix cache")
358
+ if config := self.mambaish_config:
359
+ class_name = config.__class__.__name__
360
+ logger.warning(f"{class_name} model detected, disable radix cache")
359
361
  self.server_args.disable_radix_cache = True
360
362
  if self.server_args.max_mamba_cache_size is None:
361
363
  if self.server_args.max_running_requests is not None:
@@ -364,6 +366,7 @@ class ModelRunner:
364
366
  )
365
367
  else:
366
368
  self.server_args.max_mamba_cache_size = 512
369
+ if self.hybrid_gdn_config is not None:
367
370
  self.server_args.max_mamba_cache_size = (
368
371
  self.server_args.max_mamba_cache_size
369
372
  // (
@@ -880,7 +883,7 @@ class ModelRunner:
880
883
  load_config = LoadConfig(load_format=load_format)
881
884
 
882
885
  # Only support DefaultModelLoader for now
883
- loader = get_model_loader(load_config)
886
+ loader = get_model_loader(load_config, self.model_config)
884
887
  if not isinstance(loader, DefaultModelLoader):
885
888
  message = f"Failed to get model loader: {loader}."
886
889
  return False, message
@@ -1267,8 +1270,8 @@ class ModelRunner:
1267
1270
  "num_nextn_predict_layers",
1268
1271
  self.num_effective_layers,
1269
1272
  )
1270
- elif self.is_hybrid_gdn:
1271
- num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
1273
+ elif config := self.mambaish_config:
1274
+ num_layers = len(config.full_attention_layer_ids)
1272
1275
  else:
1273
1276
  num_layers = self.num_effective_layers
1274
1277
  if self.use_mla_backend:
@@ -1277,6 +1280,17 @@ class ModelRunner:
1277
1280
  * num_layers
1278
1281
  * torch._utils._element_size(self.kv_cache_dtype)
1279
1282
  )
1283
+ # Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
1284
+ if is_deepseek_nsa(self.model_config.hf_config):
1285
+ index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
1286
+ indexer_size_per_token = (
1287
+ index_head_dim
1288
+ + index_head_dim // NSATokenToKVPool.quant_block_size * 4
1289
+ )
1290
+ element_size = torch._utils._element_size(
1291
+ NSATokenToKVPool.index_k_with_scale_buffer_dtype
1292
+ )
1293
+ cell_size += indexer_size_per_token * num_layers * element_size
1280
1294
  else:
1281
1295
  cell_size = (
1282
1296
  self.model_config.get_num_kv_heads(get_attention_tp_size())
@@ -1288,22 +1302,32 @@ class ModelRunner:
1288
1302
  rest_memory = available_gpu_memory - total_gpu_memory * (
1289
1303
  1 - self.mem_fraction_static
1290
1304
  )
1291
- if self.is_hybrid_gdn:
1305
+ if config := self.mambaish_config:
1292
1306
  rest_memory -= (
1293
1307
  self.server_args.max_mamba_cache_size
1294
- * self.model_config.hf_config.mamba_cache_per_req
1308
+ * config.mamba2_cache_params.mamba_cache_per_req
1295
1309
  / (1 << 30)
1296
1310
  )
1297
1311
  max_num_token = int(rest_memory * (1 << 30) // cell_size)
1298
1312
  return max_num_token
1299
1313
 
1300
1314
  @property
1301
- def is_hybrid_gdn(self):
1302
- return self.model_config.hf_config.architectures[0] in [
1303
- "Qwen3NextForCausalLM",
1304
- "Qwen3NextForCausalLMMTP",
1305
- "FalconH1ForCausalLM",
1306
- ]
1315
+ def hybrid_gdn_config(self):
1316
+ config = self.model_config.hf_config
1317
+ if isinstance(config, Qwen3NextConfig):
1318
+ return config
1319
+ return None
1320
+
1321
+ @property
1322
+ def mamba2_config(self):
1323
+ config = self.model_config.hf_config
1324
+ if isinstance(config, FalconH1Config | NemotronHConfig):
1325
+ return config
1326
+ return None
1327
+
1328
+ @property
1329
+ def mambaish_config(self):
1330
+ return self.mamba2_config or self.hybrid_gdn_config
1307
1331
 
1308
1332
  def set_num_token_hybrid(self):
1309
1333
  if (
@@ -1438,7 +1462,7 @@ class ModelRunner:
1438
1462
  ),
1439
1463
  4096,
1440
1464
  )
1441
- if self.is_hybrid_gdn:
1465
+ if self.mambaish_config is not None:
1442
1466
  max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
1443
1467
 
1444
1468
  if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
@@ -1519,26 +1543,14 @@ class ModelRunner:
1519
1543
  enable_memory_saver=self.server_args.enable_memory_saver,
1520
1544
  pre_alloc_size=pre_alloc_size,
1521
1545
  )
1522
- elif self.is_hybrid_gdn:
1523
- config = self.model_config.hf_config
1524
- (
1525
- conv_state_shape,
1526
- temporal_state_shape,
1527
- conv_dtype,
1528
- ssm_dtype,
1529
- mamba_layers,
1530
- ) = config.hybrid_gdn_params
1546
+ elif config := self.mambaish_config:
1531
1547
  self.req_to_token_pool = HybridReqToTokenPool(
1532
1548
  size=max_num_reqs,
1533
1549
  max_context_len=self.model_config.context_len
1534
1550
  + extra_max_context_len,
1535
1551
  device=self.device,
1536
1552
  enable_memory_saver=self.server_args.enable_memory_saver,
1537
- conv_state_shape=conv_state_shape,
1538
- temporal_state_shape=temporal_state_shape,
1539
- conv_dtype=conv_dtype,
1540
- ssm_dtype=ssm_dtype,
1541
- mamba_layers=mamba_layers,
1553
+ cache_params=config.mamba2_cache_params,
1542
1554
  speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1543
1555
  )
1544
1556
  else:
@@ -1640,7 +1652,7 @@ class ModelRunner:
1640
1652
  enable_kvcache_transpose=False,
1641
1653
  device=self.device,
1642
1654
  )
1643
- elif self.is_hybrid_gdn:
1655
+ elif config := self.mambaish_config:
1644
1656
  self.token_to_kv_pool = HybridLinearKVPool(
1645
1657
  page_size=self.page_size,
1646
1658
  size=self.max_total_num_tokens,
@@ -1651,9 +1663,7 @@ class ModelRunner:
1651
1663
  head_dim=self.model_config.head_dim,
1652
1664
  # if draft worker, we only need 1 attention layer's kv pool
1653
1665
  full_attention_layer_ids=(
1654
- [0]
1655
- if self.is_draft_worker
1656
- else self.model_config.hf_config.full_attention_layer_ids
1666
+ [0] if self.is_draft_worker else config.full_attention_layer_ids
1657
1667
  ),
1658
1668
  enable_kvcache_transpose=False,
1659
1669
  device=self.device,
@@ -1672,13 +1682,17 @@ class ModelRunner:
1672
1682
  enable_memory_saver=self.server_args.enable_memory_saver,
1673
1683
  start_layer=self.start_layer,
1674
1684
  end_layer=self.end_layer,
1685
+ enable_kv_cache_copy=(
1686
+ self.server_args.speculative_algorithm is not None
1687
+ ),
1675
1688
  )
1676
1689
 
1677
1690
  # Initialize token_to_kv_pool_allocator
1678
1691
  need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
1679
1692
  if self.token_to_kv_pool_allocator is None:
1680
1693
  if _is_npu and (
1681
- self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
1694
+ self.server_args.attention_backend == "ascend"
1695
+ or self.hybrid_gdn_config is not None
1682
1696
  ):
1683
1697
  self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1684
1698
  self.max_total_num_tokens,
@@ -1743,16 +1757,10 @@ class ModelRunner:
1743
1757
 
1744
1758
  def _get_attention_backend(self):
1745
1759
  """Init attention kernel backend."""
1746
- self.decode_attention_backend_str = (
1747
- self.server_args.decode_attention_backend
1748
- if self.server_args.decode_attention_backend
1749
- else self.server_args.attention_backend
1750
- )
1751
- self.prefill_attention_backend_str = (
1752
- self.server_args.prefill_attention_backend
1753
- if self.server_args.prefill_attention_backend
1754
- else self.server_args.attention_backend
1760
+ self.prefill_attention_backend_str, self.decode_attention_backend_str = (
1761
+ self.server_args.get_attention_backends()
1755
1762
  )
1763
+
1756
1764
  if self.decode_attention_backend_str != self.prefill_attention_backend_str:
1757
1765
  from sglang.srt.layers.attention.hybrid_attn_backend import (
1758
1766
  HybridAttnBackend,
@@ -2057,15 +2065,11 @@ class ModelRunner:
2057
2065
  def _preprocess_logits(
2058
2066
  self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
2059
2067
  ):
2060
- # Apply logit bias
2061
- if sampling_info.sampling_info_done:
2062
- # Overlap mode: the function update_regex_vocab_mask was executed
2063
- # in process_batch_result of the last batch.
2064
- if sampling_info.grammars:
2065
- sampling_info.sampling_info_done.wait()
2066
- else:
2067
- # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
2068
- sampling_info.update_regex_vocab_mask()
2068
+ # NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
2069
+ # was executed after we processed last batch's results.
2070
+
2071
+ # Calculate logits bias and apply it to next_token_logits.
2072
+ sampling_info.update_regex_vocab_mask()
2069
2073
  sampling_info.apply_logits_bias(logits_output.next_token_logits)
2070
2074
 
2071
2075
  def sample(
@@ -24,7 +24,7 @@ def get_model(
24
24
  load_config: LoadConfig,
25
25
  device_config: DeviceConfig,
26
26
  ) -> nn.Module:
27
- loader = get_model_loader(load_config)
27
+ loader = get_model_loader(load_config, model_config)
28
28
  return loader.load_model(
29
29
  model_config=model_config,
30
30
  device_config=device_config,
@@ -37,10 +37,22 @@ import numpy as np
37
37
  import requests
38
38
  import safetensors.torch
39
39
  import torch
40
+
41
+ # Try to import accelerate (optional dependency)
42
+ try:
43
+ from accelerate import infer_auto_device_map, init_empty_weights
44
+ from accelerate.utils import get_max_memory
45
+
46
+ HAS_ACCELERATE = True
47
+ except ImportError:
48
+ HAS_ACCELERATE = False
49
+ infer_auto_device_map = None
50
+ init_empty_weights = None
51
+ get_max_memory = None
52
+
40
53
  from huggingface_hub import HfApi, hf_hub_download
41
54
  from torch import nn
42
- from tqdm.auto import tqdm
43
- from transformers import AutoModelForCausalLM
55
+ from transformers import AutoConfig, AutoModelForCausalLM
44
56
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
45
57
 
46
58
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
@@ -54,6 +66,8 @@ from sglang.srt.distributed import (
54
66
  get_tensor_model_parallel_rank,
55
67
  get_tensor_model_parallel_world_size,
56
68
  )
69
+ from sglang.srt.layers.modelopt_utils import QUANT_CFG_CHOICES
70
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
57
71
  from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
58
72
  trigger_transferring_weights_request,
59
73
  )
@@ -62,6 +76,11 @@ from sglang.srt.model_loader.utils import (
62
76
  post_load_weights,
63
77
  set_default_torch_dtype,
64
78
  )
79
+
80
+ # Constants for memory management
81
+ DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
82
+ 0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
83
+ )
65
84
  from sglang.srt.model_loader.weight_utils import (
66
85
  _BAR_FORMAT,
67
86
  default_weight_loader,
@@ -94,6 +113,8 @@ if TYPE_CHECKING:
94
113
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
95
114
 
96
115
  _is_npu = is_npu()
116
+ # ModelOpt: QUANT_CFG_CHOICES is imported from modelopt_utils.py
117
+ # which contains the complete mapping of quantization config choices
97
118
 
98
119
 
99
120
  @contextmanager
@@ -477,12 +498,78 @@ class DefaultModelLoader(BaseModelLoader):
477
498
  model_config.model_path, model_config.revision, fall_back_to_pt=True
478
499
  )
479
500
 
501
+ def _load_modelopt_base_model(self, model_config: ModelConfig) -> nn.Module:
502
+ """Load and prepare the base model for ModelOpt quantization.
503
+
504
+ This method handles the common model loading logic shared between
505
+ DefaultModelLoader (conditional) and ModelOptModelLoader (dedicated).
506
+ """
507
+ if not HAS_ACCELERATE:
508
+ raise ImportError(
509
+ "accelerate is required for ModelOpt quantization. "
510
+ "Please install it with: pip install accelerate"
511
+ )
512
+
513
+ hf_config = AutoConfig.from_pretrained(
514
+ model_config.model_path, trust_remote_code=True
515
+ )
516
+ with init_empty_weights():
517
+ torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
518
+ model = AutoModelForCausalLM.from_config(
519
+ hf_config, torch_dtype=torch_dtype, trust_remote_code=True
520
+ )
521
+ max_memory = get_max_memory()
522
+ inferred_device_map = infer_auto_device_map(model, max_memory=max_memory)
523
+
524
+ on_cpu = "cpu" in inferred_device_map.values()
525
+ model_kwargs = {"torch_dtype": "auto"}
526
+ device_map = "auto"
527
+
528
+ if on_cpu:
529
+ for device in max_memory.keys():
530
+ if isinstance(device, int):
531
+ max_memory[device] *= DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION
532
+
533
+ logger.warning(
534
+ "Model does not fit to the GPU mem. "
535
+ f"We apply the following memory limit for calibration: \n{max_memory}\n"
536
+ f"If you hit GPU OOM issue, please adjust the memory fraction "
537
+ f"(currently {DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION}) or "
538
+ "reduce the calibration `batch_size` manually."
539
+ )
540
+ model_kwargs["max_memory"] = max_memory
541
+
542
+ model = AutoModelForCausalLM.from_pretrained(
543
+ model_config.model_path,
544
+ device_map=device_map,
545
+ **model_kwargs,
546
+ trust_remote_code=True,
547
+ )
548
+ logger.info(f"ModelOpt quantization requested: {model_config.modelopt_quant}")
549
+
550
+ quant_choice_str = model_config.modelopt_quant
551
+ if not isinstance(quant_choice_str, str):
552
+ raise TypeError(
553
+ f"modelopt_quant must be a string preset key (e.g., 'fp8'), "
554
+ f"got {type(quant_choice_str)}"
555
+ )
556
+
557
+ return model
558
+
480
559
  def load_model(
481
560
  self,
482
561
  *,
483
562
  model_config: ModelConfig,
484
563
  device_config: DeviceConfig,
485
564
  ) -> nn.Module:
565
+
566
+ if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
567
+ # Load base model using shared method
568
+ model = self._load_modelopt_base_model(model_config)
569
+ # Note: DefaultModelLoader doesn't do additional quantization processing
570
+ # For full ModelOpt quantization, use ModelOptModelLoader
571
+ return model.eval()
572
+
486
573
  target_device = torch.device(device_config.device)
487
574
  with set_default_torch_dtype(model_config.dtype):
488
575
  with target_device:
@@ -491,9 +578,9 @@ class DefaultModelLoader(BaseModelLoader):
491
578
  self.load_config,
492
579
  )
493
580
 
494
- self.load_weights_and_postprocess(
495
- model, self._get_all_weights(model_config, model), target_device
496
- )
581
+ self.load_weights_and_postprocess(
582
+ model, self._get_all_weights(model_config, model), target_device
583
+ )
497
584
 
498
585
  return model.eval()
499
586
 
@@ -1668,9 +1755,103 @@ def load_model_with_cpu_quantization(
1668
1755
  return model.eval()
1669
1756
 
1670
1757
 
1671
- def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1758
+ class ModelOptModelLoader(DefaultModelLoader):
1759
+ """
1760
+ Model loader that applies NVIDIA Model Optimizer quantization
1761
+ """
1762
+
1763
+ def __init__(self, load_config: LoadConfig):
1764
+ super().__init__(load_config)
1765
+ # Any ModelOpt specific initialization if needed
1766
+
1767
+ def load_model(
1768
+ self,
1769
+ *,
1770
+ model_config: ModelConfig,
1771
+ device_config: DeviceConfig,
1772
+ ) -> nn.Module:
1773
+
1774
+ logger.info("ModelOptModelLoader: Loading base model...")
1775
+
1776
+ # Use shared method from parent class to load base model
1777
+ model = self._load_modelopt_base_model(model_config)
1778
+
1779
+ # Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
1780
+ try:
1781
+ import modelopt.torch.quantization as mtq
1782
+ from modelopt.torch.utils.dataset_utils import create_forward_loop
1783
+ except ImportError:
1784
+ logger.error(
1785
+ "NVIDIA Model Optimizer (modelopt) library not found. "
1786
+ "Please install it to use 'modelopt_quant' feature."
1787
+ )
1788
+ raise
1789
+
1790
+ quant_choice_str = model_config.modelopt_quant
1791
+
1792
+ quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
1793
+ if not quant_cfg_name:
1794
+ raise ValueError(
1795
+ f"Invalid modelopt_quant choice: '{quant_choice_str}'. "
1796
+ f"Available choices in QUANT_CFG_CHOICES: {list(QUANT_CFG_CHOICES.keys())}. "
1797
+ "Ensure QUANT_CFG_CHOICES is correctly defined with mappings to "
1798
+ "attribute names of config objects in modelopt.torch.quantization."
1799
+ )
1800
+
1801
+ try:
1802
+ # getattr will fetch the config object, e.g., mtq.FP8_DEFAULT_CFG
1803
+ quant_cfg = getattr(mtq, quant_cfg_name)
1804
+ except AttributeError:
1805
+ raise AttributeError(
1806
+ f"ModelOpt quantization config attribute '{quant_cfg_name}' "
1807
+ f"(from choice '{quant_choice_str}') not found in modelopt.torch.quantization module. "
1808
+ "Please verify QUANT_CFG_CHOICES and the ModelOpt library."
1809
+ )
1810
+
1811
+ # For now, assume no calibration. Calibration setup is a separate, more complex step.
1812
+ use_calibration = False # This would ideally be a configurable parameter
1813
+ calib_dataloader = None # This would need to be provided/configured
1814
+
1815
+ calibrate_loop = (
1816
+ create_forward_loop(dataloader=calib_dataloader)
1817
+ if use_calibration
1818
+ else None
1819
+ )
1820
+
1821
+ if use_calibration and calib_dataloader is None:
1822
+ logger.warning(
1823
+ "ModelOpt calibration requested but no calib_dataloader provided. "
1824
+ "Proceeding without calibration. Quantization accuracy may be affected."
1825
+ )
1826
+
1827
+ logger.info(
1828
+ f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}"
1829
+ )
1830
+
1831
+ try:
1832
+ model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
1833
+ logger.info("Model successfully quantized with ModelOpt.")
1834
+ except Exception as e:
1835
+ logger.error(f"Error during ModelOpt mtq.quantize call: {e}")
1836
+ raise
1837
+ mtq.print_quant_summary(model)
1838
+
1839
+ return model.eval()
1840
+
1841
+
1842
+ def get_model_loader(
1843
+ load_config: LoadConfig, model_config: Optional[ModelConfig] = None
1844
+ ) -> BaseModelLoader:
1672
1845
  """Get a model loader based on the load format."""
1673
1846
 
1847
+ if (
1848
+ model_config
1849
+ and hasattr(model_config, "modelopt_quant")
1850
+ and model_config.modelopt_quant
1851
+ ):
1852
+ logger.info("Using ModelOptModelLoader due to 'modelopt_quant' config.")
1853
+ return ModelOptModelLoader(load_config)
1854
+
1674
1855
  if isinstance(load_config.load_format, type):
1675
1856
  return load_config.load_format(load_config)
1676
1857
 
@@ -226,6 +226,9 @@ def get_quant_config(
226
226
  return ModelOptFp4Config.from_config(config)
227
227
  else:
228
228
  return quant_cls.from_config(config)
229
+ elif model_config.quantization == "modelopt_fp8":
230
+ if config["producer"]["name"] == "modelopt_fp8":
231
+ return quant_cls.from_config(config)
229
232
  else:
230
233
  raise ValueError(
231
234
  f"Unsupported quantization config"
@@ -8,6 +8,10 @@ from torch import nn
8
8
  from sglang.srt.configs.falcon_h1 import FalconH1Config
9
9
  from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
10
10
  from sglang.srt.layers.activation import SiluAndMul
11
+ from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
12
+ HybridLinearAttnBackend,
13
+ Mamba2AttnBackend,
14
+ )
11
15
  from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
12
16
  from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
13
17
  from sglang.srt.layers.dp_attention import (
@@ -184,18 +188,12 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module):
184
188
  )
185
189
 
186
190
  self.mamba = MambaMixer2(
191
+ cache_params=config.mamba2_cache_params,
187
192
  hidden_size=config.hidden_size,
188
- ssm_state_size=config.mamba_d_state,
189
- conv_kernel_size=config.mamba_d_conv,
190
- intermediate_size=self.d_ssm,
191
193
  use_conv_bias=config.mamba_conv_bias,
192
194
  use_bias=config.mamba_proj_bias,
193
195
  n_groups=config.mamba_n_groups,
194
- num_heads=config.mamba_n_heads,
195
- layer_id=layer_id,
196
- head_dim=config.mamba_d_head,
197
196
  rms_norm_eps=config.rms_norm_eps,
198
- chunk_size=config.mamba_chunk_size,
199
197
  activation=config.hidden_act,
200
198
  use_rms_norm=config.mamba_rms_norm,
201
199
  prefix=f"{prefix}.mixer",
@@ -339,12 +337,16 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module):
339
337
  )
340
338
  attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
341
339
 
340
+ attn_backend = forward_batch.attn_backend
341
+ assert isinstance(attn_backend, HybridLinearAttnBackend)
342
+ assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
342
343
  # Mamba block
343
344
  mamba_hidden_states = torch.empty_like(hidden_states)
344
- self.mamba(
345
+ attn_backend.linear_attn_backend.forward(
346
+ self.mamba,
345
347
  hidden_states * self.ssm_in_multiplier,
346
348
  mamba_hidden_states,
347
- forward_batch=forward_batch,
349
+ layer_id=self.layer_id,
348
350
  mup_vector=self.mup_vector,
349
351
  )
350
352
  mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
@@ -16,6 +16,7 @@
16
16
  # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
17
17
 
18
18
  import logging
19
+ import re
19
20
  from functools import lru_cache
20
21
  from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
21
22
 
@@ -154,6 +155,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
154
155
  embedding_modules = {}
155
156
  embedding_padding_modules = []
156
157
  supports_lora = True
158
+ # Pattern to match language model layers only (skip vision_tower and multi_modal_projector)
159
+ lora_pattern = re.compile(
160
+ r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
161
+ )
157
162
 
158
163
  def __init__(
159
164
  self,
@@ -165,6 +170,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
165
170
  self.config = config
166
171
  self.quant_config = quant_config
167
172
 
173
+ # For LoRA compatibility: expose text_config attributes at top level
174
+ # This allows LoRA code to work without special multimodal handling
175
+ if not hasattr(config, "num_hidden_layers"):
176
+ config.num_hidden_layers = config.text_config.num_hidden_layers
177
+ if not hasattr(config, "hidden_size"):
178
+ config.hidden_size = config.text_config.hidden_size
179
+
168
180
  self.vision_tower = SiglipVisionModel(
169
181
  config=config.vision_config,
170
182
  quant_config=quant_config,
@@ -380,6 +392,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
380
392
 
381
393
  return hs
382
394
 
395
+ def should_apply_lora(self, module_name: str) -> bool:
396
+ """Skip vision tower and multi_modal_projector for LoRA."""
397
+ return bool(self.lora_pattern.match(module_name))
398
+
383
399
  def tie_weights(self):
384
400
  return self.language_model.tie_weights()
385
401
 
sglang/srt/models/grok.py CHANGED
@@ -49,7 +49,6 @@ from sglang.srt.layers.linear import (
49
49
  RowParallelLinear,
50
50
  )
51
51
  from sglang.srt.layers.logits_processor import LogitsProcessor
52
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
53
52
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
54
53
  from sglang.srt.layers.moe.router import fused_moe_router_shim
55
54
  from sglang.srt.layers.moe.topk import TopK
@@ -176,17 +175,7 @@ class Grok1MoE(nn.Module):
176
175
  custom_routing_function=custom_routing_function,
177
176
  )
178
177
 
179
- kwargs = {}
180
- if get_moe_expert_parallel_world_size() > 1:
181
- MoEImpl = EPMoE
182
- else:
183
- MoEImpl = FusedMoE
184
- kwargs["reduce_results"] = reduce_results
185
- kwargs["use_presharded_weights"] = use_presharded_weights
186
- kwargs["inplace"] = inplace
187
- kwargs["no_combine"] = no_combine
188
-
189
- self.experts = MoEImpl(
178
+ self.experts = FusedMoE(
190
179
  num_experts=num_experts,
191
180
  top_k=top_k,
192
181
  layer_id=layer_id,
@@ -195,7 +184,10 @@ class Grok1MoE(nn.Module):
195
184
  params_dtype=params_dtype,
196
185
  quant_config=quant_config,
197
186
  activation="gelu",
198
- **kwargs,
187
+ reduce_results=reduce_results,
188
+ use_presharded_weights=use_presharded_weights,
189
+ inplace=inplace,
190
+ no_combine=no_combine,
199
191
  )
200
192
 
201
193
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: