sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -15,16 +15,12 @@
15
15
  """Falcon-H1 model configuration"""
16
16
 
17
17
  import enum
18
- import os
19
18
 
20
- import numpy as np
21
- import torch
22
19
  from transformers.configuration_utils import PretrainedConfig
23
20
  from transformers.modeling_rope_utils import rope_config_validation
24
21
  from transformers.utils import logging
25
22
 
26
- from sglang.srt.distributed.utils import divide
27
- from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
23
+ from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
28
24
  from sglang.srt.layers.dp_attention import (
29
25
  get_attention_tp_size,
30
26
  get_tensor_model_parallel_world_size,
@@ -214,7 +210,7 @@ class FalconH1Config(PretrainedConfig):
214
210
  self.rope_scaling = None
215
211
  self.rope_scaling = rope_scaling
216
212
  self.projectors_bias = projectors_bias
217
- mamba_intermediate = (
213
+ self.mamba_intermediate = mamba_intermediate = (
218
214
  mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm
219
215
  )
220
216
 
@@ -294,18 +290,6 @@ class FalconH1Config(PretrainedConfig):
294
290
  def layers_block_type(self):
295
291
  return ["falcon_h1" for i in range(self.num_hidden_layers)]
296
292
 
297
- @property
298
- def mamba_cache_per_req(self):
299
- conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
300
- self.hybrid_gdn_params
301
- )
302
- mamba_layers_len = len(mamba_layers)
303
-
304
- return (
305
- int(np.prod(conv_state_shape)) * conv_dtype.itemsize
306
- + int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
307
- ) * mamba_layers_len
308
-
309
293
  @property
310
294
  def full_attention_layer_ids(self):
311
295
  # For Falcon-H1, we do have attention on all layers
@@ -317,44 +301,14 @@ class FalconH1Config(PretrainedConfig):
317
301
  return range(self.num_hidden_layers)
318
302
 
319
303
  @property
320
- def hybrid_gdn_params(self):
321
- world_size = get_tensor_model_parallel_world_size()
322
-
323
- n_groups = self.mamba_n_groups
324
- if self.mamba_n_groups % world_size != 0:
325
- # - for TP we shard conv_dim by sharding on n_groups,
326
- # - but if n_groups cannot divide tp_size, we need to
327
- # extend some extra groups
328
- extra_groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
329
- self.mamba_n_groups, world_size
330
- )
331
- n_groups += extra_groups
332
-
333
- conv_dim = self.mamba_d_ssm + 2 * n_groups * self.mamba_d_state
334
-
335
- conv_state_shape = (
336
- divide(conv_dim, world_size),
337
- self.mamba_d_conv - 1,
338
- )
339
-
340
- # we TP-ize on the heads dimension
341
- temporal_state_shape = (
342
- self.mamba_d_state,
343
- self.mamba_d_head,
344
- divide(self.mamba_n_heads, world_size),
345
- )
346
- conv_dtype = torch.bfloat16
347
- dtype_map = {
348
- "float32": torch.float32,
349
- "bfloat16": torch.bfloat16,
350
- }
351
- ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
352
- mamba_layers = self.linear_layer_ids
353
-
354
- return (
355
- conv_state_shape,
356
- temporal_state_shape,
357
- conv_dtype,
358
- ssm_dtype,
359
- mamba_layers,
304
+ def mamba2_cache_params(self):
305
+ shape = Mamba2StateShape.create(
306
+ tp_world_size=get_tensor_model_parallel_world_size(),
307
+ intermediate_size=self.mamba_intermediate,
308
+ n_groups=self.mamba_n_groups,
309
+ num_heads=self.mamba_n_heads,
310
+ head_dim=self.mamba_d_head,
311
+ state_size=self.mamba_d_state,
312
+ conv_kernel=self.mamba_d_conv,
360
313
  )
314
+ return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)
@@ -0,0 +1,117 @@
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ """Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, etc."""
14
+
15
+ import os
16
+ from dataclasses import dataclass, field
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from sglang.srt.distributed.utils import divide
22
+
23
+
24
+ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
25
+ """Compute the increase in group numbers to account for
26
+ replication in order to accompany the head shards."""
27
+
28
+ # in the case ngoups % tp_size == 0, this will be zero
29
+ if ngroups % tp_size == 0:
30
+ return 0
31
+
32
+ # for n_groups == 1, this is exactly tp_size - n_groups
33
+ return tp_size - ngroups
34
+
35
+
36
+ @dataclass(kw_only=True, frozen=True)
37
+ class Mamba2StateShape:
38
+ conv: tuple[int, int]
39
+ temporal: tuple[int, int, int]
40
+
41
+ intermediate_size: int
42
+ conv_dim: int
43
+ ssm_state_size: int
44
+ num_heads: int
45
+ head_dim: int
46
+ state_size: int
47
+ conv_kernel: int
48
+
49
+ @staticmethod
50
+ def create(
51
+ *,
52
+ tp_world_size: int,
53
+ intermediate_size: int,
54
+ n_groups: int,
55
+ num_heads: int,
56
+ head_dim: int,
57
+ state_size: int,
58
+ conv_kernel: int,
59
+ ) -> "Mamba2StateShape":
60
+ # if n_groups is not divisible by world_size, need to extend the shards
61
+ # to ensure all groups needed by a head is sharded along with it
62
+ if n_groups % tp_world_size != 0:
63
+ extra_groups = extra_groups_for_head_shards(n_groups, tp_world_size)
64
+ n_groups += extra_groups
65
+ # heads and n_groups are TP-ed
66
+ conv_dim = intermediate_size + 2 * n_groups * state_size
67
+
68
+ # contiguous along 'dim' axis
69
+ conv_state_shape = divide(conv_dim, tp_world_size), conv_kernel - 1
70
+
71
+ # These are not TP-ed as they depend on A, dt_bias, D
72
+ # - they are typically small
73
+ # e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
74
+ temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
75
+ return Mamba2StateShape(
76
+ conv=conv_state_shape,
77
+ temporal=temporal_state_shape,
78
+ intermediate_size=intermediate_size,
79
+ conv_dim=conv_dim,
80
+ ssm_state_size=state_size,
81
+ num_heads=num_heads,
82
+ head_dim=head_dim,
83
+ state_size=state_size,
84
+ conv_kernel=conv_kernel,
85
+ )
86
+
87
+
88
+ @dataclass(kw_only=True, frozen=True)
89
+ class Mamba2StateDType:
90
+ conv: torch.dtype
91
+ temporal: torch.dtype
92
+
93
+
94
+ CONV_DTYPE = torch.bfloat16
95
+
96
+
97
+ def mamba2_state_dtype() -> Mamba2StateDType:
98
+ dtype_map = {
99
+ "float32": torch.float32,
100
+ "bfloat16": torch.bfloat16,
101
+ }
102
+ ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
103
+ return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype)
104
+
105
+
106
+ @dataclass(kw_only=True, frozen=True)
107
+ class Mamba2CacheParams:
108
+ shape: Mamba2StateShape
109
+ dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
110
+ layers: list[int]
111
+
112
+ @property
113
+ def mamba_cache_per_req(self) -> int:
114
+ return (
115
+ int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize
116
+ + int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
117
+ ) * len(self.layers)
@@ -17,7 +17,7 @@ import logging
17
17
  import math
18
18
  import os
19
19
  from enum import Enum, IntEnum, auto
20
- from typing import List, Optional, Set, Union
20
+ from typing import Any, Dict, List, Optional, Set, Union
21
21
 
22
22
  import torch
23
23
  from transformers import PretrainedConfig
@@ -85,17 +85,21 @@ class ModelConfig:
85
85
  enable_multimodal: Optional[bool] = None,
86
86
  dtype: str = "auto",
87
87
  quantization: Optional[str] = None,
88
+ modelopt_quant: Optional[Union[str, Dict]] = None,
88
89
  override_config_file: Optional[str] = None,
89
90
  is_draft_model: bool = False,
90
91
  hybrid_kvcache_ratio: Optional[float] = None,
91
92
  model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
93
+ sampling_defaults: str = "openai",
92
94
  ) -> None:
93
95
  # Parse args
94
96
  self.model_path = model_path
95
97
  self.revision = revision
96
98
  self.quantization = quantization
99
+ self.modelopt_quant = modelopt_quant
97
100
  self.is_draft_model = is_draft_model
98
101
  self.model_impl = model_impl
102
+ self.sampling_defaults = sampling_defaults
99
103
 
100
104
  # Get hf config
101
105
  self._maybe_pull_model_tokenizer_from_remote()
@@ -209,8 +213,10 @@ class ModelConfig:
209
213
  enable_multimodal=server_args.enable_multimodal,
210
214
  dtype=server_args.dtype,
211
215
  quantization=server_args.quantization,
216
+ modelopt_quant=server_args.modelopt_quant,
212
217
  hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
213
218
  model_impl=server_args.model_impl,
219
+ sampling_defaults=server_args.sampling_defaults,
214
220
  **kwargs,
215
221
  )
216
222
 
@@ -477,54 +483,52 @@ class ModelConfig:
477
483
  # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
478
484
  # example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
479
485
  is_local = os.path.exists(self.model_path)
480
- modelopt_quant_config = {"quant_method": "modelopt"}
481
486
  if not is_local:
482
487
  import huggingface_hub
483
488
 
484
489
  try:
485
- from huggingface_hub import HfApi
490
+ from huggingface_hub import HfApi, hf_hub_download
486
491
 
487
492
  hf_api = HfApi()
488
-
489
- def check_hf_quant_config():
490
- return hf_api.file_exists(
491
- self.model_path, "hf_quant_config.json"
493
+ if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
494
+ # Download and parse the quantization config for remote models
495
+ quant_config_file = hf_hub_download(
496
+ repo_id=self.model_path,
497
+ filename="hf_quant_config.json",
498
+ revision=self.revision,
492
499
  )
493
-
494
- # Retry HF API call up to 3 times
495
- file_exists = retry(
496
- check_hf_quant_config,
497
- max_retry=2,
498
- initial_delay=1.0,
499
- max_delay=5.0,
500
- )
501
-
502
- if file_exists:
503
- quant_cfg = modelopt_quant_config
504
-
500
+ with open(quant_config_file) as f:
501
+ quant_config_dict = json.load(f)
502
+ quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
505
503
  except huggingface_hub.errors.OfflineModeIsEnabled:
506
504
  logger.warning(
507
505
  "Offline mode is enabled, skipping hf_quant_config.json check"
508
506
  )
509
- except Exception as e:
510
- logger.warning(
511
- f"Failed to check hf_quant_config.json: {self.model_path} {e}"
512
- )
513
-
507
+ pass
514
508
  elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
515
509
  quant_config_file = os.path.join(
516
510
  self.model_path, "hf_quant_config.json"
517
511
  )
518
512
  with open(quant_config_file) as f:
519
513
  quant_config_dict = json.load(f)
520
- json_quant_configs = quant_config_dict["quantization"]
521
- quant_algo = json_quant_configs.get("quant_algo", None)
522
- if quant_algo == "MIXED_PRECISION":
523
- quant_cfg = {"quant_method": "w4afp8"}
524
- else:
525
- quant_cfg = modelopt_quant_config
514
+ quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
526
515
  return quant_cfg
527
516
 
517
+ def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> dict:
518
+ """Parse ModelOpt quantization config and return the appropriate quant_method."""
519
+ json_quant_configs = quant_config_dict["quantization"]
520
+ quant_algo = json_quant_configs.get("quant_algo", None)
521
+
522
+ if quant_algo == "MIXED_PRECISION":
523
+ return {"quant_method": "w4afp8"}
524
+ elif quant_algo and ("FP4" in quant_algo or "NVFP4" in quant_algo):
525
+ return {"quant_method": "modelopt_fp4"}
526
+ elif quant_algo and "FP8" in quant_algo:
527
+ return {"quant_method": "modelopt_fp8"}
528
+ else:
529
+ # Default to FP8 for backward compatibility
530
+ return {"quant_method": "modelopt_fp8"}
531
+
528
532
  # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
529
533
  def _verify_quantization(self) -> None:
530
534
  supported_quantization = [*QUANTIZATION_METHODS]
@@ -543,7 +547,8 @@ class ModelConfig:
543
547
  optimized_quantization_methods = [
544
548
  "fp8",
545
549
  "marlin",
546
- "modelopt",
550
+ "modelopt_fp8",
551
+ "modelopt_fp4",
547
552
  "gptq_marlin_24",
548
553
  "gptq_marlin",
549
554
  "awq_marlin",
@@ -657,6 +662,38 @@ class ModelConfig:
657
662
  eos_ids = eos_ids | generation_eos_ids
658
663
  return eos_ids
659
664
 
665
+ def get_default_sampling_params(self) -> dict[str, Any]:
666
+ """
667
+ Get default sampling parameters from the model's generation config.
668
+
669
+ This method returns non-default sampling parameters from the model's
670
+ generation_config.json when sampling_defaults is set to "model".
671
+
672
+ Returns:
673
+ A dictionary containing the non-default sampling parameters.
674
+ """
675
+ if self.sampling_defaults != "model":
676
+ return {}
677
+
678
+ if self.hf_generation_config is None:
679
+ return {}
680
+
681
+ config = self.hf_generation_config.to_dict()
682
+
683
+ available_params = [
684
+ "repetition_penalty",
685
+ "temperature",
686
+ "top_k",
687
+ "top_p",
688
+ "min_p",
689
+ ]
690
+
691
+ default_sampling_params = {
692
+ p: config.get(p) for p in available_params if config.get(p) is not None
693
+ }
694
+
695
+ return default_sampling_params
696
+
660
697
  def _maybe_pull_model_tokenizer_from_remote(self) -> None:
661
698
  """
662
699
  Pull the model config files to a temporary
@@ -0,0 +1,286 @@
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/nemotron_h.py
15
+
16
+ """NemotronH model configuration"""
17
+
18
+ import regex as re
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+ from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
23
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ MAMBA = "M"
28
+ ATTENTION = "*"
29
+ MLP = "-"
30
+
31
+
32
+ class NemotronHConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a
35
+ [`NemotronHModel`]. It is used to instantiate a NemotronH model according
36
+ to the specified arguments, defining the model architecture. Instantiating
37
+ a configuration with the defaults will yield a similar configuration to
38
+ that of the NemotronH-v0.1 model.
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 131072):
41
+ Vocabulary size of the NemotronH model. Defines the number of
42
+ different tokens that can be represented by the `inputs_ids`
43
+ passed when calling [`NemotronHModel`]
44
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
45
+ Whether the model's input and output word embeddings should be
46
+ tied. Note that this is only relevant if the model has an output
47
+ word embedding layer.
48
+ hidden_size (`int`, *optional*, defaults to 4096):
49
+ Dimension of the hidden representations.
50
+ intermediate_size (`int`, *optional*, defaults to 21504):
51
+ Dimension of the MLP representations.
52
+ num_hidden_layers (`int`, *optional*, defaults to 52):
53
+ Number of hidden layers in the Transformer encoder.
54
+ hybrid_override_pattern (`str`, *optional*, defaults to
55
+ `"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`):
56
+ The pattern of the hybrid model. The pattern is a string of
57
+ characters where each character represents
58
+ M: Mamba2, *: Attention, -: MLP
59
+ num_attention_heads (`int`, *optional*, defaults to 32):
60
+ Number of attention heads for each attention layer in the
61
+ Transformer encoder.
62
+ attention_head_dim (`int`, *optional*, defaults to 128):
63
+ Dimension of each attention head.
64
+ num_key_value_heads (`int`, *optional*, defaults to 8):
65
+ This is the number of key_value heads that should be used to
66
+ implement Grouped Query Attention. If
67
+ `num_key_value_heads=num_attention_heads`, the model will use
68
+ Multi Head Attention (MHA), if `num_key_value_heads=1` the model
69
+ will use Multi Query Attention (MQA) otherwise GQA is used.
70
+ mlp_hidden_act (`str`, *optional*, defaults to "relu2"):
71
+ The non-linear activation function in the MLP layers.
72
+ attention_bias (`bool`, *optional*, defaults to `False`):
73
+ Whether to use bias in attention layers.
74
+ mlp_bias (`bool`, *optional*, defaults to `False`):
75
+ Whether to use bias in MLP layers.
76
+ use_bias (`bool`, *optional*, defaults to `False`):
77
+ Whether to use bias in the model.
78
+ initializer_range (`float`, *optional*, defaults to 0.02):
79
+ The standard deviation of the truncated_normal_initializer for
80
+ initializing all weight matrices.
81
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
82
+ The epsilon used by the layer normalization layers.
83
+ residual_in_fp32 (`bool`, *optional*, defaults to `False`):
84
+ Whether or not residuals should be in `float32`. If set to `False`
85
+ residuals will keep the same `dtype` as the rest of the model.
86
+ use_cache (`bool`, *optional*, defaults to `True`):
87
+ Whether or not the model should return the last key/values
88
+ attentions (not used by all models). Only relevant if
89
+ `config.is_decoder=True`.
90
+ num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
91
+ Number of prompt logits to calculate during generation. If `None`,
92
+ all logits will be calculated. If an integer value, only last
93
+ `num_logits_to_keep` logits will be calculated.
94
+ pad_token_id (`int`, *optional*, defaults to 0):
95
+ The id of the padding token.
96
+ bos_token_id (`int`, *optional*, defaults to 1):
97
+ The id of the "beginning-of-sequence" token.
98
+ eos_token_id (`int`, *optional*, defaults to 2):
99
+ The id of the "end-of-sequence" token.
100
+ sliding_window (`int`, *optional*, defaults to None):
101
+ Sliding window attention window size.
102
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
103
+ The maximum sequence length that this model might ever be used
104
+ with.
105
+ attention_dropout (`float`, *optional*, defaults to 0.0):
106
+ The dropout ratio for the attention probabilities.
107
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
108
+ The dropout ratio for the hidden states.
109
+ use_mamba_kernels (`bool`, *optional*, defaults to `True`):
110
+ Flag indicating whether or not to use the fast mamba kernels.
111
+ These are available only if `mamba-ssm` and `causal-conv1d`
112
+ are installed, and the mamba modules are running on a CUDA device.
113
+ ssm_state_size (`int`, *optional*, defaults to 128):
114
+ The dimension of the mamba state space latents.
115
+ mamba_num_heads (`int`, *optional*, defaults to 128):
116
+ Number of heads in Mamba layers.
117
+ mamba_n_groups (`int`, *optional*, defaults to 8):
118
+ Number of groups in Mamba layers.
119
+ mamba_head_dim (`int`, *optional*, defaults to 64):
120
+ Dimension of each Mamba head.
121
+ mamba_d_conv (`int`, *optional*, defaults to 4):
122
+ The size of the mamba convolution kernel.
123
+ mamba_expand (`int`, *optional*, defaults to 2):
124
+ Expanding factor used to determine the mamba intermediate size.
125
+ mamba_hidden_act (`str`, *optional*, defaults to "silu"):
126
+ The non-linear activation function in the Mamba layers.
127
+ mamba_dt_min (`float`, *optional*, defaults to 0.001):
128
+ Minimum value for the time step in Mamba.
129
+ mamba_dt_max (`float`, *optional*, defaults to 0.1):
130
+ Maximum value for the time step in Mamba.
131
+ mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))):
132
+ Limits for the time step in Mamba.
133
+ mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4):
134
+ Floor value for time step initialization in Mamba.
135
+ mamba_conv_bias (`bool`, *optional*, defaults to `True`):
136
+ Whether to use bias in the convolution layer of the mamba mixer
137
+ block.
138
+ mamba_proj_bias (`bool`, *optional*, defaults to `False`):
139
+ Whether to use bias in the input and output projections of the
140
+ mamba mixer block.
141
+ mamba_chunk_size (`int`, *optional*, defaults to 256):
142
+ Size of chunks for Mamba processing.
143
+ rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
144
+ Whether to rescale the pre-normalization residual connections.
145
+ """
146
+
147
+ model_type = "nemotron_h"
148
+ keys_to_ignore_at_inference = ["past_key_values"]
149
+
150
+ def __init__(
151
+ self,
152
+ vocab_size=131072,
153
+ tie_word_embeddings=False,
154
+ hidden_size=4096,
155
+ intermediate_size=21504,
156
+ num_hidden_layers=52,
157
+ hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
158
+ num_attention_heads=32,
159
+ head_dim=128,
160
+ num_key_value_heads=8, # nemo: num_query_groups
161
+ mlp_hidden_act="relu2",
162
+ attention_bias=False,
163
+ mlp_bias=False,
164
+ use_bias=False,
165
+ initializer_range=0.02, # nemo: init_method_std
166
+ layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon
167
+ residual_in_fp32=False, # Megatron Core default value
168
+ use_cache=True,
169
+ num_logits_to_keep=1,
170
+ pad_token_id=0,
171
+ bos_token_id=1,
172
+ eos_token_id=2,
173
+ sliding_window=None,
174
+ max_position_embeddings=4096,
175
+ attention_dropout=0.0,
176
+ hidden_dropout=0.0, # * ADDED
177
+ use_mamba_kernels=True,
178
+ ssm_state_size=128, # mamba_state_size
179
+ mamba_num_heads=128,
180
+ mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads
181
+ mamba_head_dim=64,
182
+ mamba_d_conv=4,
183
+ mamba_expand=2,
184
+ mamba_hidden_act="silu",
185
+ mamba_dt_min=0.001,
186
+ mamba_dt_max=0.1,
187
+ mamba_dt_limit=(0.0, float("inf")),
188
+ mamba_dt_init_floor=1e-4,
189
+ mamba_conv_bias=True,
190
+ mamba_proj_bias=False,
191
+ mamba_chunk_size=256,
192
+ rescale_prenorm_residual=True,
193
+ **kwargs,
194
+ ):
195
+ self.vocab_size = vocab_size
196
+ self.tie_word_embeddings = tie_word_embeddings
197
+ self.hidden_size = hidden_size
198
+ self.intermediate_size = intermediate_size
199
+ self.num_hidden_layers = num_hidden_layers
200
+ self.hybrid_override_pattern = hybrid_override_pattern
201
+ self.num_attention_heads = num_attention_heads
202
+ self.head_dim = head_dim
203
+ self.sliding_window = sliding_window
204
+ self.max_position_embeddings = max_position_embeddings
205
+ self.attention_dropout = attention_dropout
206
+ self.hidden_dropout = hidden_dropout
207
+
208
+ # Validate hybrid_override_pattern
209
+ # M: Mamba2, *: Attention, -: MLP
210
+ assert len(self.hybrid_override_pattern) == self.num_hidden_layers, (
211
+ "hybrid_override_pattern must have same length as " "num_hidden_layers"
212
+ )
213
+ assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), (
214
+ "hybrid_override_pattern must only contain characters " "'M', '*', or '-'"
215
+ )
216
+
217
+ # for backward compatibility
218
+ if num_key_value_heads is None:
219
+ num_key_value_heads = num_attention_heads
220
+
221
+ self.num_key_value_heads = num_key_value_heads
222
+ self.mlp_hidden_act = mlp_hidden_act
223
+ self.attention_bias = attention_bias
224
+ self.mlp_bias = mlp_bias
225
+ self.use_bias = use_bias
226
+ self.initializer_range = initializer_range
227
+ self.layer_norm_epsilon = layer_norm_epsilon
228
+ self.residual_in_fp32 = residual_in_fp32
229
+
230
+ self.use_cache = use_cache
231
+ self.num_logits_to_keep = num_logits_to_keep
232
+
233
+ self.use_mamba_kernels = use_mamba_kernels
234
+ self.mamba_n_groups = mamba_n_groups
235
+ self.mamba_head_dim = mamba_head_dim
236
+ self.ssm_state_size = ssm_state_size
237
+ self.mamba_num_heads = mamba_num_heads
238
+ self.conv_kernel = mamba_d_conv
239
+ self.expand = mamba_expand
240
+ self.mamba_hidden_act = mamba_hidden_act
241
+ self.time_step_min = mamba_dt_min
242
+ self.time_step_max = mamba_dt_max
243
+ self.time_step_limit = mamba_dt_limit
244
+ self.time_step_floor = mamba_dt_init_floor
245
+ self.use_conv_bias = mamba_conv_bias
246
+ self.mamba_proj_bias = mamba_proj_bias
247
+ self.mamba_chunk_size = mamba_chunk_size
248
+ self.rescale_prenorm_residual = rescale_prenorm_residual
249
+
250
+ super().__init__(
251
+ pad_token_id=pad_token_id,
252
+ bos_token_id=bos_token_id,
253
+ eos_token_id=eos_token_id,
254
+ tie_word_embeddings=tie_word_embeddings,
255
+ **kwargs,
256
+ )
257
+
258
+ @property
259
+ def mamba_layer_ids(self):
260
+ return [
261
+ i
262
+ for i in range(self.num_hidden_layers)
263
+ if self.hybrid_override_pattern[i] == MAMBA
264
+ ]
265
+
266
+ @property
267
+ def full_attention_layer_ids(self):
268
+ return [
269
+ i
270
+ for i in range(self.num_hidden_layers)
271
+ if self.hybrid_override_pattern[i] == ATTENTION
272
+ ]
273
+
274
+ @property
275
+ def mamba2_cache_params(self) -> Mamba2CacheParams:
276
+ shape = Mamba2StateShape.create(
277
+ tp_world_size=get_attention_tp_size(),
278
+ intermediate_size=self.mamba_num_heads * self.mamba_head_dim,
279
+ n_groups=self.n_groups,
280
+ num_heads=self.mamba_num_heads,
281
+ head_dim=self.mamba_head_dim,
282
+ state_size=self.ssm_state_size,
283
+ conv_kernel=self.conv_kernel,
284
+ )
285
+
286
+ return Mamba2CacheParams(shape=shape, layers=self.mamba_layer_ids)