sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,335 @@
1
+ # Adapted from qwen2.py
2
+
3
+ from functools import partial
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from sglang.srt.distributed import (
10
+ get_tensor_model_parallel_rank,
11
+ get_tensor_model_parallel_world_size,
12
+ split_tensor_along_last_dim,
13
+ tensor_model_parallel_all_gather,
14
+ )
15
+ from sglang.srt.layers.layernorm import RMSNorm
16
+ from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
17
+ from sglang.srt.layers.logits_processor import LogitsProcessor
18
+ from sglang.srt.layers.pooler import Pooler, PoolingType
19
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
20
+ from sglang.srt.layers.radix_attention import RadixAttention
21
+ from sglang.srt.layers.rotary_embedding import get_rope
22
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
23
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
25
+ from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
26
+ from sglang.srt.models.qwen2 import Qwen2Model
27
+ from sglang.srt.utils import add_prefix
28
+
29
+ Qwen3Config = None
30
+
31
+
32
+ class Qwen3Attention(nn.Module):
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ num_heads: int,
37
+ num_kv_heads: int,
38
+ layer_id: int = 0,
39
+ rope_theta: float = 1000000,
40
+ rope_scaling: Optional[Dict[str, Any]] = None,
41
+ head_dim: Optional[int] = None,
42
+ max_position_embeddings: int = 32768,
43
+ quant_config: Optional[QuantizationConfig] = None,
44
+ rms_norm_eps: float = None,
45
+ attention_bias: bool = False,
46
+ prefix: str = "",
47
+ ) -> None:
48
+ super().__init__()
49
+ self.hidden_size = hidden_size
50
+ self.tp_size = get_tensor_model_parallel_world_size()
51
+ self.total_num_heads = num_heads
52
+ assert self.total_num_heads % self.tp_size == 0
53
+ self.num_heads = self.total_num_heads // self.tp_size
54
+ self.total_num_kv_heads = num_kv_heads
55
+ if self.total_num_kv_heads >= self.tp_size:
56
+ # Number of KV heads is greater than TP size, so we partition
57
+ # the KV heads across multiple tensor parallel GPUs.
58
+ assert self.total_num_kv_heads % self.tp_size == 0
59
+ else:
60
+ # Number of KV heads is less than TP size, so we replicate
61
+ # the KV heads across multiple tensor parallel GPUs.
62
+ assert self.tp_size % self.total_num_kv_heads == 0
63
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
64
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
65
+ self.q_size = self.num_heads * self.head_dim
66
+ self.kv_size = self.num_kv_heads * self.head_dim
67
+ self.scaling = self.head_dim**-0.5
68
+ self.rope_theta = rope_theta
69
+ self.max_position_embeddings = max_position_embeddings
70
+ self.tp_rank = get_tensor_model_parallel_rank()
71
+
72
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
73
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
74
+
75
+ self.qkv_proj = QKVParallelLinear(
76
+ hidden_size,
77
+ self.head_dim,
78
+ self.total_num_heads,
79
+ self.total_num_kv_heads,
80
+ bias=attention_bias,
81
+ quant_config=quant_config,
82
+ prefix=add_prefix("qkv_proj", prefix),
83
+ )
84
+ self.o_proj = RowParallelLinear(
85
+ self.total_num_heads * self.head_dim,
86
+ hidden_size,
87
+ bias=attention_bias,
88
+ quant_config=quant_config,
89
+ prefix=add_prefix("o_proj", prefix),
90
+ )
91
+
92
+ self.rotary_emb = get_rope(
93
+ self.head_dim,
94
+ rotary_dim=self.head_dim,
95
+ max_position=max_position_embeddings,
96
+ base=rope_theta,
97
+ rope_scaling=rope_scaling,
98
+ )
99
+ self.attn = RadixAttention(
100
+ self.num_heads,
101
+ self.head_dim,
102
+ self.scaling,
103
+ num_kv_heads=self.num_kv_heads,
104
+ layer_id=layer_id,
105
+ prefix=add_prefix("attn", prefix),
106
+ )
107
+
108
+ def _apply_qk_norm(
109
+ self, q: torch.Tensor, k: torch.Tensor
110
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
111
+ q_by_head = q.reshape(-1, self.head_dim)
112
+ q_by_head = self.q_norm(q_by_head)
113
+ q = q_by_head.view(q.shape)
114
+ k_by_head = k.reshape(-1, self.head_dim)
115
+ k_by_head = self.k_norm(k_by_head)
116
+ k = k_by_head.view(k.shape)
117
+ return q, k
118
+
119
+ def forward(
120
+ self,
121
+ positions: torch.Tensor,
122
+ hidden_states: torch.Tensor,
123
+ forward_batch: ForwardBatch,
124
+ ) -> torch.Tensor:
125
+ qkv, _ = self.qkv_proj(hidden_states)
126
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
127
+ q, k = self._apply_qk_norm(q, k)
128
+ q, k = self.rotary_emb(positions, q, k)
129
+ attn_output = self.attn(q, k, v, forward_batch)
130
+ output, _ = self.o_proj(attn_output)
131
+ return output
132
+
133
+
134
+ class Qwen3DecoderLayer(nn.Module):
135
+ def __init__(
136
+ self,
137
+ config: Qwen3Config,
138
+ layer_id: int = 0,
139
+ quant_config: Optional[QuantizationConfig] = None,
140
+ prefix: str = "",
141
+ ) -> None:
142
+ super().__init__()
143
+ self.hidden_size = config.hidden_size
144
+ rope_theta = getattr(config, "rope_theta", 1000000)
145
+ rope_scaling = getattr(config, "rope_scaling", None)
146
+ max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
147
+ head_dim = getattr(config, "head_dim", None)
148
+ self.self_attn = Qwen3Attention(
149
+ hidden_size=self.hidden_size,
150
+ num_heads=config.num_attention_heads,
151
+ num_kv_heads=config.num_key_value_heads,
152
+ layer_id=layer_id,
153
+ rope_theta=rope_theta,
154
+ rope_scaling=rope_scaling,
155
+ head_dim=head_dim,
156
+ max_position_embeddings=max_position_embeddings,
157
+ quant_config=quant_config,
158
+ rms_norm_eps=config.rms_norm_eps,
159
+ attention_bias=config.attention_bias,
160
+ prefix=add_prefix("self_attn", prefix),
161
+ )
162
+ self.mlp = Qwen3MLP(
163
+ hidden_size=self.hidden_size,
164
+ intermediate_size=config.intermediate_size,
165
+ hidden_act=config.hidden_act,
166
+ quant_config=quant_config,
167
+ prefix=add_prefix("mlp", prefix),
168
+ )
169
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
170
+ self.post_attention_layernorm = RMSNorm(
171
+ config.hidden_size, eps=config.rms_norm_eps
172
+ )
173
+
174
+ def forward(
175
+ self,
176
+ positions: torch.Tensor,
177
+ hidden_states: torch.Tensor,
178
+ forward_batch: ForwardBatch,
179
+ residual: Optional[torch.Tensor],
180
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
181
+ # Self Attention
182
+ if residual is None:
183
+ residual = hidden_states
184
+ hidden_states = self.input_layernorm(hidden_states)
185
+ else:
186
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
187
+ hidden_states = self.self_attn(
188
+ positions=positions,
189
+ hidden_states=hidden_states,
190
+ forward_batch=forward_batch,
191
+ )
192
+
193
+ # Fully Connected
194
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
195
+ hidden_states = self.mlp(hidden_states)
196
+ return hidden_states, residual
197
+
198
+
199
+ class Qwen3Model(Qwen2Model):
200
+ def __init__(
201
+ self,
202
+ config: Qwen3Config,
203
+ quant_config: Optional[QuantizationConfig] = None,
204
+ prefix: str = "",
205
+ ) -> None:
206
+ super().__init__(
207
+ config=config,
208
+ quant_config=quant_config,
209
+ prefix=prefix,
210
+ decoder_layer_type=Qwen3DecoderLayer,
211
+ )
212
+
213
+
214
+ class Qwen3ForCausalLM(nn.Module):
215
+ # BitandBytes specific attributes
216
+ default_bitsandbytes_target_modules = [
217
+ ".gate_proj.",
218
+ ".down_proj.",
219
+ ".up_proj.",
220
+ ".q_proj.",
221
+ ".k_proj.",
222
+ ".v_proj.",
223
+ ".o_proj.",
224
+ ]
225
+ bitsandbytes_stacked_params_mapping = {
226
+ # shard_name, weight_name, index
227
+ "q_proj": ("qkv_proj", 0),
228
+ "k_proj": ("qkv_proj", 1),
229
+ "v_proj": ("qkv_proj", 2),
230
+ "gate_proj": ("gate_up_proj", 0),
231
+ "up_proj": ("gate_up_proj", 1),
232
+ }
233
+
234
+ def __init__(
235
+ self,
236
+ config: Qwen3Config,
237
+ quant_config: Optional[QuantizationConfig] = None,
238
+ prefix: str = "",
239
+ ) -> None:
240
+ super().__init__()
241
+ self.config = config
242
+ self.quant_config = quant_config
243
+ self.model = Qwen3Model(
244
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
245
+ )
246
+ if config.tie_word_embeddings:
247
+ self.lm_head = self.model.embed_tokens
248
+ else:
249
+ self.lm_head = ParallelLMHead(
250
+ config.vocab_size,
251
+ config.hidden_size,
252
+ quant_config=quant_config,
253
+ prefix=add_prefix("lm_head", prefix),
254
+ )
255
+ self.logits_processor = LogitsProcessor(config)
256
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
257
+
258
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
259
+ return self.model.get_input_embeddings(input_ids)
260
+
261
+ @torch.no_grad()
262
+ def forward(
263
+ self,
264
+ input_ids: torch.Tensor,
265
+ positions: torch.Tensor,
266
+ forward_batch: ForwardBatch,
267
+ input_embeds: torch.Tensor = None,
268
+ get_embedding: bool = False,
269
+ ) -> torch.Tensor:
270
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
271
+ if not get_embedding:
272
+ return self.logits_processor(
273
+ input_ids, hidden_states, self.lm_head, forward_batch
274
+ )
275
+ else:
276
+ return self.pooler(hidden_states, forward_batch)
277
+
278
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
279
+ stacked_params_mapping = [
280
+ # (param_name, shard_name, shard_id)
281
+ ("qkv_proj", "q_proj", "q"),
282
+ ("qkv_proj", "k_proj", "k"),
283
+ ("qkv_proj", "v_proj", "v"),
284
+ ("gate_up_proj", "gate_proj", 0),
285
+ ("gate_up_proj", "up_proj", 1),
286
+ ]
287
+
288
+ params_dict = dict(self.named_parameters())
289
+ for name, loaded_weight in weights:
290
+ if "rotary_emb.inv_freq" in name or "projector" in name:
291
+ continue
292
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
293
+ # Models trained using ColossalAI may include these tensors in
294
+ # the checkpoint. Skip them.
295
+ continue
296
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
297
+ continue
298
+ if name.startswith("model.vision_tower") and name not in params_dict:
299
+ continue
300
+
301
+ for param_name, weight_name, shard_id in stacked_params_mapping:
302
+ if weight_name not in name:
303
+ continue
304
+ name = name.replace(weight_name, param_name)
305
+ # Skip loading extra bias for GPTQ models.
306
+ if name.endswith(".bias") and name not in params_dict:
307
+ continue
308
+ param = params_dict[name]
309
+ weight_loader = param.weight_loader
310
+ weight_loader(param, loaded_weight, shard_id)
311
+ break
312
+ else:
313
+ # Skip loading extra bias for GPTQ models.
314
+ if name.endswith(".bias") and name not in params_dict:
315
+ continue
316
+ param = params_dict[name]
317
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
318
+ weight_loader(param, loaded_weight)
319
+
320
+ def get_embed_and_head(self):
321
+ return self.model.embed_tokens.weight, self.lm_head.weight
322
+
323
+ def set_embed_and_head(self, embed, head):
324
+ del self.model.embed_tokens.weight
325
+ del self.lm_head.weight
326
+ self.model.embed_tokens.weight = embed
327
+ self.lm_head.weight = head
328
+ torch.cuda.empty_cache()
329
+ torch.cuda.synchronize()
330
+
331
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
332
+ self.model.load_kv_cache_scales(quantization_param_path)
333
+
334
+
335
+ EntryClass = Qwen3ForCausalLM