sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/bench_serving.py +18 -1
  2. sglang/lang/interpreter.py +71 -1
  3. sglang/lang/ir.py +2 -0
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/chatglm.py +78 -0
  6. sglang/srt/configs/dbrx.py +279 -0
  7. sglang/srt/configs/model_config.py +16 -7
  8. sglang/srt/hf_transformers_utils.py +9 -14
  9. sglang/srt/layers/attention/__init__.py +8 -1
  10. sglang/srt/layers/attention/flashinfer_backend.py +21 -5
  11. sglang/srt/layers/linear.py +89 -47
  12. sglang/srt/layers/logits_processor.py +6 -6
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
  15. sglang/srt/layers/moe/topk.py +4 -2
  16. sglang/srt/layers/parameter.py +439 -0
  17. sglang/srt/layers/quantization/__init__.py +5 -2
  18. sglang/srt/layers/quantization/fp8.py +107 -53
  19. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  20. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  21. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  22. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  23. sglang/srt/layers/radix_attention.py +2 -0
  24. sglang/srt/layers/vocab_parallel_embedding.py +16 -3
  25. sglang/srt/managers/cache_controller.py +307 -0
  26. sglang/srt/managers/configure_logging.py +43 -0
  27. sglang/srt/managers/data_parallel_controller.py +2 -0
  28. sglang/srt/managers/detokenizer_manager.py +0 -2
  29. sglang/srt/managers/io_struct.py +29 -13
  30. sglang/srt/managers/schedule_batch.py +7 -1
  31. sglang/srt/managers/scheduler.py +58 -15
  32. sglang/srt/managers/session_controller.py +1 -1
  33. sglang/srt/managers/tokenizer_manager.py +109 -45
  34. sglang/srt/mem_cache/memory_pool.py +313 -53
  35. sglang/srt/metrics/collector.py +32 -35
  36. sglang/srt/model_executor/cuda_graph_runner.py +14 -7
  37. sglang/srt/model_executor/forward_batch_info.py +20 -15
  38. sglang/srt/model_executor/model_runner.py +53 -10
  39. sglang/srt/models/chatglm.py +1 -1
  40. sglang/srt/models/dbrx.py +1 -1
  41. sglang/srt/models/grok.py +25 -16
  42. sglang/srt/models/llama.py +46 -4
  43. sglang/srt/models/qwen2.py +11 -0
  44. sglang/srt/models/qwen2_eagle.py +131 -0
  45. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  46. sglang/srt/sampling/sampling_batch_info.py +15 -5
  47. sglang/srt/sampling/sampling_params.py +1 -1
  48. sglang/srt/server.py +125 -69
  49. sglang/srt/server_args.py +39 -19
  50. sglang/srt/speculative/eagle_utils.py +93 -85
  51. sglang/srt/speculative/eagle_worker.py +48 -33
  52. sglang/srt/torch_memory_saver_adapter.py +59 -0
  53. sglang/srt/utils.py +61 -5
  54. sglang/test/test_programs.py +23 -1
  55. sglang/test/test_utils.py +36 -7
  56. sglang/version.py +1 -1
  57. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
  58. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
  59. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
  60. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
  61. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py CHANGED
@@ -514,6 +514,8 @@ class BenchmarkMetrics:
514
514
  p99_itl_ms: float
515
515
  mean_e2e_latency_ms: float
516
516
  median_e2e_latency_ms: float
517
+ std_e2e_latency_ms: float
518
+ p99_e2e_latency_ms: float
517
519
 
518
520
 
519
521
  SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
@@ -563,7 +565,7 @@ def sample_sharegpt_requests(
563
565
  raise ValueError("output_len too small")
564
566
 
565
567
  # Download sharegpt if necessary
566
- if not os.path.isfile(dataset_path):
568
+ if not os.path.isfile(dataset_path) and dataset_path == "":
567
569
  dataset_path = download_and_cache_file(SHAREGPT_URL)
568
570
 
569
571
  # Load the dataset.
@@ -873,6 +875,8 @@ def calculate_metrics(
873
875
  p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
874
876
  mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
875
877
  median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
878
+ std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
879
+ p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,
876
880
  )
877
881
 
878
882
  return metrics, output_lens
@@ -1064,8 +1068,21 @@ async def benchmark(
1064
1068
  "total_output_tokens_retokenized": metrics.total_output_retokenized,
1065
1069
  "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
1066
1070
  "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
1071
+ "std_e2e_latency_ms": metrics.std_e2e_latency_ms,
1072
+ "p99_e2e_latency_ms": metrics.p99_e2e_latency_ms,
1073
+ "mean_ttft_ms": metrics.mean_ttft_ms,
1067
1074
  "median_ttft_ms": metrics.median_ttft_ms,
1075
+ "std_ttft_ms": metrics.std_ttft_ms,
1076
+ "p99_ttft_ms": metrics.p99_ttft_ms,
1077
+ "mean_tpot_ms": metrics.mean_tpot_ms,
1078
+ "median_tpot_ms": metrics.median_tpot_ms,
1079
+ "std_tpot_ms": metrics.std_tpot_ms,
1080
+ "p99_tpot_ms": metrics.p99_tpot_ms,
1081
+ "mean_itl_ms": metrics.mean_itl_ms,
1068
1082
  "median_itl_ms": metrics.median_itl_ms,
1083
+ "std_itl_ms": metrics.std_itl_ms,
1084
+ "p99_itl_ms": metrics.p99_itl_ms,
1085
+ "input_throughput": metrics.input_throughput,
1069
1086
  "output_throughput": metrics.output_throughput,
1070
1087
  "sharegpt_output_len": args.sharegpt_output_len,
1071
1088
  "random_input_len": args.random_input_len,
@@ -96,6 +96,7 @@ def run_program_batch(
96
96
  default_sampling_para,
97
97
  num_threads,
98
98
  progress_bar,
99
+ generator_style=False,
99
100
  ):
100
101
  if hasattr(backend, "endpoint"):
101
102
  backend = backend.endpoint
@@ -109,6 +110,17 @@ def run_program_batch(
109
110
  num_threads = max(96, multiprocessing.cpu_count() * 16)
110
111
  num_threads = min(num_threads, len(batch_arguments))
111
112
 
113
+ if generator_style:
114
+ return _run_program_batch_generator(
115
+ program,
116
+ backend,
117
+ batch_arguments,
118
+ default_sampling_para,
119
+ num_threads,
120
+ progress_bar,
121
+ )
122
+
123
+ # Original code path when generator_style=False
112
124
  if num_threads == 1:
113
125
  rets = []
114
126
  if progress_bar:
@@ -168,6 +180,64 @@ def run_program_batch(
168
180
  return rets
169
181
 
170
182
 
183
+ def _run_program_batch_generator(
184
+ program,
185
+ backend,
186
+ batch_arguments,
187
+ default_sampling_para,
188
+ num_threads,
189
+ progress_bar,
190
+ ):
191
+ """Helper function that yields results one by one using chunking to avoid overwhelming ThreadPoolExecutor."""
192
+ if num_threads == 1:
193
+ iterator = tqdm.tqdm(batch_arguments) if progress_bar else batch_arguments
194
+ for arguments in iterator:
195
+ yield run_program(
196
+ program,
197
+ backend,
198
+ (),
199
+ arguments,
200
+ default_sampling_para,
201
+ False,
202
+ True,
203
+ )
204
+ else:
205
+ pbar = tqdm.tqdm(total=len(batch_arguments)) if progress_bar else None
206
+
207
+ # Process in chunks to avoid overwhelming ThreadPoolExecutor
208
+ # Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks
209
+ # so we will never reach "yield" until all tasks are done
210
+ chunk_size = 200
211
+
212
+ with ThreadPoolExecutor(num_threads) as executor:
213
+ for chunk_start in range(0, len(batch_arguments), chunk_size):
214
+ chunk_end = min(chunk_start + chunk_size, len(batch_arguments))
215
+ chunk_futures = []
216
+
217
+ # Submit chunk of tasks
218
+ for i in range(chunk_start, chunk_end):
219
+ future = executor.submit(
220
+ run_program,
221
+ program,
222
+ backend,
223
+ (),
224
+ batch_arguments[i],
225
+ default_sampling_para,
226
+ False,
227
+ True,
228
+ )
229
+ if pbar:
230
+ future.add_done_callback(lambda _: pbar.update())
231
+ chunk_futures.append(future)
232
+
233
+ # Yield results from this chunk as they complete
234
+ for future in chunk_futures:
235
+ yield future.result()
236
+
237
+ if pbar:
238
+ pbar.close()
239
+
240
+
171
241
  def cache_program(program, backend):
172
242
  from sglang.lang.tracer import extract_prefix_by_tracing
173
243
 
@@ -277,7 +347,7 @@ class StreamExecutor:
277
347
  size: int = 1,
278
348
  position_ids_offset: Optional[List[int]] = None,
279
349
  ):
280
- if size > 1:
350
+ if size > 1 and str(self.text_):
281
351
  self.submit(SglCommitLazy())
282
352
 
283
353
  self.sync()
sglang/lang/ir.py CHANGED
@@ -227,6 +227,7 @@ class SglFunction:
227
227
  backend=None,
228
228
  num_threads: Union[str, int] = "auto",
229
229
  progress_bar: bool = False,
230
+ generator_style: bool = False,
230
231
  ):
231
232
  from sglang.lang.interpreter import run_program_batch
232
233
 
@@ -277,6 +278,7 @@ class SglFunction:
277
278
  default_sampling_para,
278
279
  num_threads,
279
280
  progress_bar,
281
+ generator_style=generator_style,
280
282
  )
281
283
 
282
284
  def trace(self, *, backend=None, **kwargs):
@@ -1,3 +1,5 @@
1
+ from sglang.srt.configs.chatglm import ChatGLMConfig
2
+ from sglang.srt.configs.dbrx import DbrxConfig
1
3
  from sglang.srt.configs.exaone import ExaoneConfig
2
4
  from sglang.srt.configs.qwen2vl import Qwen2VLConfig, Qwen2VLVisionConfig
3
5
 
@@ -5,4 +7,6 @@ __all__ = [
5
7
  "ExaoneConfig",
6
8
  "Qwen2VLConfig",
7
9
  "Qwen2VLVisionConfig",
10
+ "ChatGLMConfig",
11
+ "DbrxConfig",
8
12
  ]
@@ -0,0 +1,78 @@
1
+ # Adapted from
2
+ # https://github.com/THUDM/ChatGLM2-6B
3
+ # https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/chatglm.py
4
+
5
+ # ChatGLM2 and ChatGLM3 share the same config.
6
+ # ChatGLM4 is officially supported by Huggingface
7
+ # transformers >= 4.46.0 is required
8
+ # https://huggingface.co/docs/transformers/en/model_doc/glm
9
+ from transformers import PretrainedConfig
10
+
11
+
12
+ class ChatGLMConfig(PretrainedConfig):
13
+ model_type = "chatglm"
14
+ attribute_map = {
15
+ "num_hidden_layers": "num_layers",
16
+ "n_head_kv": "multi_query_group_num",
17
+ }
18
+
19
+ def __init__(
20
+ self,
21
+ num_layers=28,
22
+ padded_vocab_size=65024,
23
+ hidden_size=4096,
24
+ ffn_hidden_size=13696,
25
+ kv_channels=128,
26
+ num_attention_heads=32,
27
+ seq_length=2048,
28
+ hidden_dropout=0.0,
29
+ attention_dropout=0.0,
30
+ layernorm_epsilon=1e-5,
31
+ rmsnorm=True,
32
+ apply_residual_connection_post_layernorm=False,
33
+ post_layer_norm=True,
34
+ add_bias_linear=False,
35
+ add_qkv_bias=False,
36
+ interleaved_qkv=False,
37
+ bias_dropout_fusion=True,
38
+ multi_query_attention=False,
39
+ multi_query_group_num=1,
40
+ apply_query_key_layer_scaling=True,
41
+ attention_softmax_in_fp32=True,
42
+ fp32_residual_connection=False,
43
+ quantization_bit=0,
44
+ pre_seq_len=None,
45
+ prefix_projection=False,
46
+ **kwargs
47
+ ):
48
+ self.num_layers = num_layers
49
+ self.vocab_size = padded_vocab_size
50
+ self.padded_vocab_size = padded_vocab_size
51
+ self.hidden_size = hidden_size
52
+ self.ffn_hidden_size = ffn_hidden_size
53
+ self.kv_channels = kv_channels
54
+ self.num_attention_heads = num_attention_heads
55
+ self.seq_length = seq_length
56
+ # It is to be compatible with long lora.
57
+ self.max_position_embeddings = seq_length
58
+ self.hidden_dropout = hidden_dropout
59
+ self.attention_dropout = attention_dropout
60
+ self.layernorm_epsilon = layernorm_epsilon
61
+ self.rmsnorm = rmsnorm
62
+ self.apply_residual_connection_post_layernorm = (
63
+ apply_residual_connection_post_layernorm
64
+ )
65
+ self.post_layer_norm = post_layer_norm
66
+ self.add_bias_linear = add_bias_linear
67
+ self.add_qkv_bias = add_qkv_bias
68
+ self.bias_dropout_fusion = bias_dropout_fusion
69
+ self.multi_query_attention = multi_query_attention
70
+ self.multi_query_group_num = multi_query_group_num
71
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
72
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
73
+ self.fp32_residual_connection = fp32_residual_connection
74
+ self.quantization_bit = quantization_bit
75
+ self.pre_seq_len = pre_seq_len
76
+ self.prefix_projection = prefix_projection
77
+ self.interleaved_qkv = interleaved_qkv
78
+ super().__init__(**kwargs)
@@ -0,0 +1,279 @@
1
+ # Adapted from
2
+ # https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py
3
+ # https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/dbrx.py
4
+ """Dbrx configuration."""
5
+
6
+ from typing import Any, Optional
7
+
8
+ from transformers.configuration_utils import PretrainedConfig
9
+ from transformers.utils import logging
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+ DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
14
+
15
+
16
+ class DbrxAttentionConfig(PretrainedConfig):
17
+ """Configuration class for Dbrx Attention.
18
+
19
+ [`DbrxAttention`] class. It is used to instantiate attention layers
20
+ according to the specified arguments, defining the layers architecture.
21
+
22
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
23
+ documentation from [`PretrainedConfig`] for more information.
24
+
25
+ Args:
26
+ attn_pdrop (`float`, *optional*, defaults to 0.0):
27
+ The dropout probability for the attention layers.
28
+ clip_qkv (`float`, *optional*, defaults to None):
29
+ If not `None`, clip the queries, keys, and values in the attention layer to this value.
30
+ kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
31
+ rope_theta (float): The base frequency for rope.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ attn_pdrop: float = 0,
37
+ clip_qkv: Optional[float] = None,
38
+ kv_n_heads: int = 1,
39
+ rope_theta: float = 10000.0,
40
+ **kwargs: Any,
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.attn_pdrop = attn_pdrop
44
+ self.clip_qkv = clip_qkv
45
+ self.kv_n_heads = kv_n_heads
46
+ self.rope_theta = rope_theta
47
+
48
+ for k in ["model_type"]:
49
+ if k in kwargs:
50
+ kwargs.pop(k)
51
+ if len(kwargs) != 0:
52
+ raise ValueError(f"Found unknown {kwargs=}")
53
+
54
+ @classmethod
55
+ def from_pretrained(
56
+ cls, pretrained_model_name_or_path: str, **kwargs: Any
57
+ ) -> "PretrainedConfig":
58
+ cls._set_token_in_kwargs(kwargs)
59
+
60
+ config_dict, kwargs = cls.get_config_dict(
61
+ pretrained_model_name_or_path, **kwargs
62
+ )
63
+
64
+ if config_dict.get("model_type") == "dbrx":
65
+ config_dict = config_dict["attn_config"]
66
+
67
+ if (
68
+ "model_type" in config_dict
69
+ and hasattr(cls, "model_type")
70
+ and config_dict["model_type"] != cls.model_type
71
+ ):
72
+ logger.warning(
73
+ "You are using a model of type %s to instantiate a model of "
74
+ "type %s. This is not supported for all configurations of "
75
+ "models and can yield errors.",
76
+ config_dict["model_type"],
77
+ cls.model_type,
78
+ )
79
+
80
+ return cls.from_dict(config_dict, **kwargs)
81
+
82
+
83
+ class DbrxFFNConfig(PretrainedConfig):
84
+ """Configuration class for Dbrx FFN.
85
+
86
+ [`DbrxFFN`] class. It is used to instantiate feedforward layers according to
87
+ the specified arguments, defining the layers architecture.
88
+
89
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
90
+ documentation from [`PretrainedConfig`] for more information.
91
+
92
+ Args:
93
+ ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.
94
+ The dict should have a key 'name' with the value being the name of
95
+ the activation function along with any additional keyword arguments.
96
+ ffn_hidden_size (int, optional): The hidden size of the feedforward network.
97
+ moe_num_experts (int, optional): The number of experts in the mixture of experts layer.
98
+ moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.
99
+ moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.
100
+ moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.
101
+ moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.
102
+ uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.
103
+ This should only be used for benchmarking purposes.
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ ffn_act_fn: Optional[dict] = None,
109
+ ffn_hidden_size: int = 3584,
110
+ moe_num_experts: int = 4,
111
+ moe_top_k: int = 1,
112
+ moe_jitter_eps: Optional[float] = None,
113
+ moe_loss_weight: float = 0.01,
114
+ moe_normalize_expert_weights: Optional[float] = 1,
115
+ uniform_expert_assignment: bool = False,
116
+ **kwargs: Any,
117
+ ):
118
+ super().__init__()
119
+ if ffn_act_fn is None:
120
+ ffn_act_fn = {"name": "silu"}
121
+ self.ffn_act_fn = ffn_act_fn
122
+ self.ffn_hidden_size = ffn_hidden_size
123
+ self.moe_num_experts = moe_num_experts
124
+ self.moe_top_k = moe_top_k
125
+ self.moe_jitter_eps = moe_jitter_eps
126
+ self.moe_loss_weight = moe_loss_weight
127
+ self.moe_normalize_expert_weights = moe_normalize_expert_weights
128
+ self.uniform_expert_assignment = uniform_expert_assignment
129
+
130
+ for k in ["model_type"]:
131
+ if k in kwargs:
132
+ kwargs.pop(k)
133
+ if len(kwargs) != 0:
134
+ raise ValueError(f"Found unknown {kwargs=}")
135
+
136
+ @classmethod
137
+ def from_pretrained(
138
+ cls, pretrained_model_name_or_path: str, **kwargs: Any
139
+ ) -> "PretrainedConfig":
140
+ cls._set_token_in_kwargs(kwargs)
141
+
142
+ config_dict, kwargs = cls.get_config_dict(
143
+ pretrained_model_name_or_path, **kwargs
144
+ )
145
+
146
+ if config_dict.get("model_type") == "dbrx":
147
+ config_dict = config_dict["ffn_config"]
148
+
149
+ if (
150
+ "model_type" in config_dict
151
+ and hasattr(cls, "model_type")
152
+ and config_dict["model_type"] != cls.model_type
153
+ ):
154
+ logger.warning(
155
+ "You are using a model of type %s to instantiate a model of "
156
+ "type %s. This is not supported for all "
157
+ "configurations of models and can yield errors.",
158
+ config_dict["model_type"],
159
+ cls.model_type,
160
+ )
161
+
162
+ return cls.from_dict(config_dict, **kwargs)
163
+
164
+
165
+ class DbrxConfig(PretrainedConfig):
166
+ """Configuration class for Dbrx.
167
+
168
+ [`DbrxModel`]. It is used to instantiate a Dbrx model according to the
169
+ specified arguments, defining the model architecture.
170
+
171
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
172
+ documentation from [`PretrainedConfig`] for more information.
173
+
174
+
175
+ Args:
176
+ d_model (`int`, *optional*, defaults to 6144):
177
+ Dimensionality of the embeddings and hidden states.
178
+ n_heads (`int`, *optional*, defaults to 48):
179
+ Number of attention heads for each attention layer in the Transformer encoder.
180
+ n_layers (`int`, *optional*, defaults to 40):
181
+ Number of hidden layers in the Transformer encoder.
182
+ max_seq_len (`int`, *optional*, defaults to 32768):
183
+ The maximum sequence length of the model.
184
+ vocab_size (`int`, *optional*, defaults to 100352):
185
+ Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
186
+ the `inputs_ids` passed when calling [`DbrxModel`].
187
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
188
+ The dropout probability applied to the attention output before combining with residual.
189
+ emb_pdrop (`float`, *optional*, defaults to 0.0):
190
+ The dropout probability for the embedding layer.
191
+ attn_config (`dict`, *optional*):
192
+ A dictionary used to configure the model's attention module.
193
+ ffn_config (`dict`, *optional*):
194
+ A dictionary used to configure the model's FFN module.
195
+ use_cache (`bool`, *optional*, defaults to `False`):
196
+ Whether or not the model should return the last key/values attentions (not used by all models).
197
+ initializer_range (`float`, *optional*, defaults to 0.02):
198
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
199
+ output_router_logits (`bool`, *optional*, defaults to `False`):
200
+ Whether or not the router logits should be returned by the model. Enabling this will also
201
+ allow the model to output the auxiliary loss. See [here]() for more details
202
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
203
+ The aux loss factor for the total loss.
204
+
205
+
206
+ Example:
207
+ ```python
208
+ >>> from transformers import DbrxConfig, DbrxModel
209
+
210
+ >>> # Initializing a Dbrx configuration
211
+ >>> configuration = DbrxConfig()
212
+
213
+ >>> # Initializing a model (with random weights) from the configuration
214
+ >>> model = DbrxModel(configuration)
215
+
216
+ >>> # Accessing the model configuration
217
+ >>> configuration = model.config
218
+ ```
219
+ """
220
+
221
+ model_type = "dbrx"
222
+ attribute_map = {
223
+ "num_attention_heads": "n_heads",
224
+ "hidden_size": "d_model",
225
+ "num_hidden_layers": "n_layers",
226
+ "max_position_embeddings": "max_seq_len",
227
+ }
228
+
229
+ def __init__(
230
+ self,
231
+ d_model: int = 2048,
232
+ n_heads: int = 16,
233
+ n_layers: int = 24,
234
+ max_seq_len: int = 2048,
235
+ vocab_size: int = 32000,
236
+ resid_pdrop: float = 0.0,
237
+ emb_pdrop: float = 0.0,
238
+ attn_config: Optional[DbrxAttentionConfig] = None,
239
+ ffn_config: Optional[DbrxFFNConfig] = None,
240
+ use_cache: bool = True,
241
+ initializer_range: float = 0.02,
242
+ output_router_logits: bool = False,
243
+ router_aux_loss_coef: float = 0.05,
244
+ **kwargs: Any,
245
+ ):
246
+ if attn_config is None:
247
+ self.attn_config = DbrxAttentionConfig()
248
+ elif isinstance(attn_config, dict):
249
+ self.attn_config = DbrxAttentionConfig(**attn_config)
250
+ else:
251
+ self.attn_config = attn_config
252
+
253
+ if ffn_config is None:
254
+ self.ffn_config = DbrxFFNConfig()
255
+ elif isinstance(ffn_config, dict):
256
+ self.ffn_config = DbrxFFNConfig(**ffn_config)
257
+ else:
258
+ self.ffn_config = ffn_config
259
+
260
+ self.d_model = d_model
261
+ self.n_heads = n_heads
262
+ self.n_layers = n_layers
263
+ self.max_seq_len = max_seq_len
264
+ self.vocab_size = vocab_size
265
+ self.resid_pdrop = resid_pdrop
266
+ self.emb_pdrop = emb_pdrop
267
+ self.use_cache = use_cache
268
+ self.initializer_range = initializer_range
269
+ self.output_router_logits = output_router_logits
270
+ self.router_aux_loss_coef = router_aux_loss_coef
271
+
272
+ tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
273
+ if tie_word_embeddings:
274
+ raise ValueError("tie_word_embeddings is not supported for Dbrx models.")
275
+
276
+ super().__init__(
277
+ tie_word_embeddings=tie_word_embeddings,
278
+ **kwargs,
279
+ )
@@ -128,7 +128,7 @@ class ModelConfig:
128
128
  self.num_hidden_layers = self.hf_text_config.num_hidden_layers
129
129
  self.vocab_size = self.hf_text_config.vocab_size
130
130
 
131
- # Veirfy quantization
131
+ # Verify quantization
132
132
  self._verify_quantization()
133
133
 
134
134
  # Cache attributes
@@ -223,7 +223,11 @@ class ModelConfig:
223
223
  "compressed_tensors",
224
224
  "compressed-tensors",
225
225
  "experts_int8",
226
+ "w8a8_int8",
226
227
  ]
228
+ compatible_quantization_methods = {
229
+ "w8a8_int8": ["compressed-tensors", "compressed_tensors"]
230
+ }
227
231
  if self.quantization is not None:
228
232
  self.quantization = self.quantization.lower()
229
233
 
@@ -247,12 +251,17 @@ class ModelConfig:
247
251
  if self.quantization is None:
248
252
  self.quantization = quant_method
249
253
  elif self.quantization != quant_method:
250
- raise ValueError(
251
- "Quantization method specified in the model config "
252
- f"({quant_method}) does not match the quantization "
253
- f"method specified in the `quantization` argument "
254
- f"({self.quantization})."
255
- )
254
+ if (
255
+ self.quantization not in compatible_quantization_methods
256
+ or quant_method
257
+ not in compatible_quantization_methods[self.quantization]
258
+ ):
259
+ raise ValueError(
260
+ "Quantization method specified in the model config "
261
+ f"({quant_method}) does not match the quantization "
262
+ f"method specified in the `quantization` argument "
263
+ f"({self.quantization})."
264
+ )
256
265
 
257
266
  if self.quantization is not None:
258
267
  if self.quantization not in supported_quantization:
@@ -30,20 +30,15 @@ from transformers import (
30
30
  )
31
31
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
32
32
 
33
- try:
34
- from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
35
-
36
- from sglang.srt.configs import ExaoneConfig, Qwen2VLConfig
37
-
38
- _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
39
- ChatGLMConfig.model_type: ChatGLMConfig,
40
- DbrxConfig.model_type: DbrxConfig,
41
- ExaoneConfig.model_type: ExaoneConfig,
42
- Qwen2VLConfig.model_type: Qwen2VLConfig,
43
- }
44
- except ImportError:
45
- # We want this file to run without vllm dependency
46
- _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {}
33
+ from sglang.srt.configs import ChatGLMConfig, DbrxConfig, ExaoneConfig, Qwen2VLConfig
34
+
35
+ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
36
+ ChatGLMConfig.model_type: ChatGLMConfig,
37
+ DbrxConfig.model_type: DbrxConfig,
38
+ ExaoneConfig.model_type: ExaoneConfig,
39
+ Qwen2VLConfig.model_type: Qwen2VLConfig,
40
+ }
41
+
47
42
 
48
43
  for name, cls in _CONFIG_REGISTRY.items():
49
44
  with contextlib.suppress(ValueError):
@@ -66,7 +66,14 @@ class AttentionBackend(ABC):
66
66
  if forward_batch.forward_mode.is_decode():
67
67
  return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
68
68
  else:
69
- return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache)
69
+ return self.forward_extend(
70
+ q,
71
+ k,
72
+ v,
73
+ layer,
74
+ forward_batch,
75
+ save_kv_cache,
76
+ )
70
77
 
71
78
  def forward_decode(
72
79
  self,