sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +30 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/backend/runtime_endpoint.py +18 -14
  6. sglang/bench_latency.py +317 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +41 -6
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +6 -2
  11. sglang/lang/ir.py +74 -28
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +14 -6
  15. sglang/srt/constrained/fsm_cache.py +6 -3
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +2 -0
  19. sglang/srt/hf_transformers_utils.py +68 -9
  20. sglang/srt/layers/extend_attention.py +2 -1
  21. sglang/srt/layers/fused_moe.py +280 -169
  22. sglang/srt/layers/logits_processor.py +106 -42
  23. sglang/srt/layers/radix_attention.py +53 -29
  24. sglang/srt/layers/token_attention.py +4 -1
  25. sglang/srt/managers/controller/dp_worker.py +6 -3
  26. sglang/srt/managers/controller/infer_batch.py +144 -69
  27. sglang/srt/managers/controller/manager_multi.py +5 -5
  28. sglang/srt/managers/controller/manager_single.py +9 -4
  29. sglang/srt/managers/controller/model_runner.py +167 -55
  30. sglang/srt/managers/controller/radix_cache.py +4 -0
  31. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  32. sglang/srt/managers/controller/tp_worker.py +156 -134
  33. sglang/srt/managers/detokenizer_manager.py +19 -21
  34. sglang/srt/managers/io_struct.py +11 -5
  35. sglang/srt/managers/tokenizer_manager.py +16 -14
  36. sglang/srt/model_config.py +89 -4
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +2 -2
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/gemma.py +5 -1
  41. sglang/srt/models/gemma2.py +436 -0
  42. sglang/srt/models/grok.py +204 -137
  43. sglang/srt/models/llama2.py +12 -5
  44. sglang/srt/models/llama_classification.py +107 -0
  45. sglang/srt/models/llava.py +11 -8
  46. sglang/srt/models/llavavid.py +1 -1
  47. sglang/srt/models/minicpm.py +373 -0
  48. sglang/srt/models/mixtral.py +164 -115
  49. sglang/srt/models/mixtral_quant.py +0 -1
  50. sglang/srt/models/qwen.py +1 -1
  51. sglang/srt/models/qwen2.py +1 -1
  52. sglang/srt/models/qwen2_moe.py +454 -0
  53. sglang/srt/models/stablelm.py +1 -1
  54. sglang/srt/models/yivl.py +2 -2
  55. sglang/srt/openai_api_adapter.py +35 -25
  56. sglang/srt/openai_protocol.py +2 -2
  57. sglang/srt/server.py +69 -19
  58. sglang/srt/server_args.py +76 -43
  59. sglang/srt/utils.py +177 -35
  60. sglang/test/test_programs.py +28 -10
  61. sglang/utils.py +4 -3
  62. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
  63. sglang-0.1.19.dist-info/RECORD +81 -0
  64. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
  65. sglang/srt/managers/router/infer_batch.py +0 -596
  66. sglang/srt/managers/router/manager.py +0 -82
  67. sglang/srt/managers/router/model_rpc.py +0 -818
  68. sglang/srt/managers/router/model_runner.py +0 -445
  69. sglang/srt/managers/router/radix_cache.py +0 -267
  70. sglang/srt/managers/router/scheduler.py +0 -59
  71. sglang-0.1.17.dist-info/RECORD +0 -81
  72. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  73. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
33
  from vllm.model_executor.utils import set_weight_attrs
34
34
  from vllm.utils import print_warning_once
35
35
 
36
-
37
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
38
37
  from sglang.srt.layers.radix_attention import RadixAttention
39
38
  from sglang.srt.managers.controller.model_runner import InputMetadata
40
39
 
41
40
 
42
-
43
41
  class MixtralMoE(nn.Module):
44
42
  """A tensor-parallel MoE implementation for Mixtral that shards each expert
45
43
  across all ranks.
@@ -76,32 +74,46 @@ class MixtralMoE(nn.Module):
76
74
  self.params_dtype = params_dtype
77
75
 
78
76
  # Gate always runs at half / full precision for now.
79
- self.gate = ReplicatedLinear(self.hidden_size,
80
- self.num_total_experts,
81
- bias=False,
82
- params_dtype=self.params_dtype,
83
- quant_config=None)
77
+ self.gate = ReplicatedLinear(
78
+ self.hidden_size,
79
+ self.num_total_experts,
80
+ bias=False,
81
+ params_dtype=self.params_dtype,
82
+ quant_config=None,
83
+ )
84
84
 
85
85
  if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
86
86
  params_dtype = torch.float8_e4m3fn
87
87
 
88
88
  self.w13_weight = nn.Parameter(
89
- torch.empty(self.num_total_experts,
90
- 2 * self.intermediate_size,
91
- self.hidden_size,
92
- dtype=params_dtype))
89
+ torch.empty(
90
+ self.num_total_experts,
91
+ 2 * self.intermediate_size,
92
+ self.hidden_size,
93
+ dtype=params_dtype,
94
+ )
95
+ )
93
96
  self.w2_weight = nn.Parameter(
94
- torch.empty(self.num_total_experts,
95
- self.hidden_size,
96
- self.intermediate_size,
97
- dtype=params_dtype))
98
-
99
- set_weight_attrs(self.w13_weight, {
100
- "weight_loader": self.weight_loader,
101
- })
102
- set_weight_attrs(self.w2_weight, {
103
- "weight_loader": self.weight_loader,
104
- })
97
+ torch.empty(
98
+ self.num_total_experts,
99
+ self.hidden_size,
100
+ self.intermediate_size,
101
+ dtype=params_dtype,
102
+ )
103
+ )
104
+
105
+ set_weight_attrs(
106
+ self.w13_weight,
107
+ {
108
+ "weight_loader": self.weight_loader,
109
+ },
110
+ )
111
+ set_weight_attrs(
112
+ self.w2_weight,
113
+ {
114
+ "weight_loader": self.weight_loader,
115
+ },
116
+ )
105
117
 
106
118
  # Used for fp8.
107
119
  self.w13_scale = None
@@ -111,46 +123,68 @@ class MixtralMoE(nn.Module):
111
123
 
112
124
  if self.use_fp8:
113
125
  # WEIGHT_SCALE (for fp8)
114
- self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
115
- dtype=torch.float32),
116
- requires_grad=False)
117
- self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
118
- dtype=torch.float32),
119
- requires_grad=False)
126
+ self.w13_scale = nn.Parameter(
127
+ torch.ones(self.num_total_experts, dtype=torch.float32),
128
+ requires_grad=False,
129
+ )
130
+ self.w2_scale = nn.Parameter(
131
+ torch.ones(self.num_total_experts, dtype=torch.float32),
132
+ requires_grad=False,
133
+ )
120
134
 
121
135
  # If loading fp8 checkpoint, pass the weight loaders.
122
136
  # If loading an fp16 checkpoint, do not (we will quantize in
123
137
  # process_weights_after_loading()
124
138
  if quant_config.is_checkpoint_fp8_serialized:
125
- set_weight_attrs(self.w13_scale, {
126
- "weight_loader": self.weight_loader,
127
- })
128
- set_weight_attrs(self.w2_scale, {
129
- "weight_loader": self.weight_loader,
130
- })
139
+ set_weight_attrs(
140
+ self.w13_scale,
141
+ {
142
+ "weight_loader": self.weight_loader,
143
+ },
144
+ )
145
+ set_weight_attrs(
146
+ self.w2_scale,
147
+ {
148
+ "weight_loader": self.weight_loader,
149
+ },
150
+ )
131
151
 
132
152
  # ACT_SCALE (for fp8)
133
153
  if quant_config.activation_scheme == "static":
134
154
  if not quant_config.is_checkpoint_fp8_serialized:
135
155
  raise ValueError(
136
156
  "Found static activation scheme for checkpoint that "
137
- "was not serialized fp8.")
138
- self.a13_scale = nn.Parameter(torch.zeros(
139
- self.num_total_experts, dtype=torch.float32),
140
- requires_grad=False)
141
- self.a2_scale = nn.Parameter(torch.zeros(
142
- self.num_total_experts, dtype=torch.float32),
143
- requires_grad=False)
144
-
145
- set_weight_attrs(self.a13_scale, {
146
- "weight_loader": self.weight_loader,
147
- })
148
- set_weight_attrs(self.a2_scale, {
149
- "weight_loader": self.weight_loader,
150
- })
151
-
152
- def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
153
- weight_name: str, expert_id: int):
157
+ "was not serialized fp8."
158
+ )
159
+ self.a13_scale = nn.Parameter(
160
+ torch.zeros(self.num_total_experts, dtype=torch.float32),
161
+ requires_grad=False,
162
+ )
163
+ self.a2_scale = nn.Parameter(
164
+ torch.zeros(self.num_total_experts, dtype=torch.float32),
165
+ requires_grad=False,
166
+ )
167
+
168
+ set_weight_attrs(
169
+ self.a13_scale,
170
+ {
171
+ "weight_loader": self.weight_loader,
172
+ },
173
+ )
174
+ set_weight_attrs(
175
+ self.a2_scale,
176
+ {
177
+ "weight_loader": self.weight_loader,
178
+ },
179
+ )
180
+
181
+ def weight_loader(
182
+ self,
183
+ param: nn.Parameter,
184
+ loaded_weight: torch.Tensor,
185
+ weight_name: str,
186
+ expert_id: int,
187
+ ):
154
188
  tp_rank = get_tensor_model_parallel_rank()
155
189
  param_data = param.data
156
190
  shard_size = self.intermediate_size
@@ -158,8 +192,9 @@ class MixtralMoE(nn.Module):
158
192
  if weight_name.endswith("w1.weight"):
159
193
  param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
160
194
  if weight_name.endswith("w3.weight"):
161
- param_data[expert_id,
162
- shard_size:2 * shard_size, :] = loaded_weight[shard, :]
195
+ param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
196
+ shard, :
197
+ ]
163
198
  if weight_name.endswith("w2.weight"):
164
199
  param_data[expert_id, :, :] = loaded_weight[:, shard]
165
200
  if "act_scale" in weight_name or "weight_scale" in weight_name:
@@ -172,17 +207,17 @@ class MixtralMoE(nn.Module):
172
207
 
173
208
  # If checkpoint is fp16, quantize here.
174
209
  if not self.quant_config.is_checkpoint_fp8_serialized:
175
- w13_weight = torch.empty_like(self.w13_weight.data,
176
- dtype=torch.float8_e4m3fn)
177
- w2_weight = torch.empty_like(self.w2_weight.data,
178
- dtype=torch.float8_e4m3fn)
210
+ w13_weight = torch.empty_like(
211
+ self.w13_weight.data, dtype=torch.float8_e4m3fn
212
+ )
213
+ w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
179
214
  for expert in range(self.num_total_experts):
180
- w13_weight[expert, :, :], self.w13_scale[
181
- expert] = ops.scaled_fp8_quant(
182
- self.w13_weight.data[expert, :, :])
183
- w2_weight[expert, :, :], self.w2_scale[
184
- expert] = ops.scaled_fp8_quant(
185
- self.w2_weight.data[expert, :, :])
215
+ w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
216
+ self.w13_weight.data[expert, :, :]
217
+ )
218
+ w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
219
+ self.w2_weight.data[expert, :, :]
220
+ )
186
221
  self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
187
222
  self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
188
223
 
@@ -193,40 +228,40 @@ class MixtralMoE(nn.Module):
193
228
  if self.a13_scale is None or self.a2_scale is None:
194
229
  raise ValueError(
195
230
  "QuantConfig has static quantization, but found "
196
- "activation scales are None.")
231
+ "activation scales are None."
232
+ )
197
233
 
198
- if (not all_close_1d(self.a13_scale)
199
- or not all_close_1d(self.a2_scale)):
234
+ if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
200
235
  print_warning_once(
201
236
  "Found act_scales that are not equal for fp8 MoE layer. "
202
- "Using the maximum across experts for each layer. ")
237
+ "Using the maximum across experts for each layer. "
238
+ )
203
239
 
204
- self.a13_scale = nn.Parameter(self.a13_scale.max(),
205
- requires_grad=False)
206
- self.a2_scale = nn.Parameter(self.a2_scale.max(),
207
- requires_grad=False)
240
+ self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
241
+ self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
208
242
 
209
243
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
210
244
  num_tokens, hidden_size = hidden_states.shape
211
245
  hidden_states = hidden_states.view(-1, self.hidden_size)
212
246
  # router_logits: (num_tokens, n_experts)
213
247
  router_logits, _ = self.gate(hidden_states)
214
- final_hidden_states = fused_moe(hidden_states,
215
- self.w13_weight,
216
- self.w2_weight,
217
- router_logits,
218
- self.top_k,
219
- renormalize=True,
220
- inplace=True,
221
- use_fp8=self.use_fp8,
222
- w1_scale=self.w13_scale,
223
- w2_scale=self.w2_scale,
224
- a1_scale=self.a13_scale,
225
- a2_scale=self.a2_scale)
248
+ final_hidden_states = fused_moe(
249
+ hidden_states,
250
+ self.w13_weight,
251
+ self.w2_weight,
252
+ router_logits,
253
+ self.top_k,
254
+ renormalize=True,
255
+ inplace=True,
256
+ use_fp8=self.use_fp8,
257
+ w1_scale=self.w13_scale,
258
+ w2_scale=self.w2_scale,
259
+ a1_scale=self.a13_scale,
260
+ a2_scale=self.a2_scale,
261
+ )
226
262
 
227
263
  if self.tp_size > 1:
228
- final_hidden_states = tensor_model_parallel_all_reduce(
229
- final_hidden_states)
264
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
230
265
 
231
266
  return final_hidden_states.view(num_tokens, hidden_size)
232
267
 
@@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module):
335
370
  top_k=config.num_experts_per_tok,
336
371
  hidden_size=config.hidden_size,
337
372
  intermediate_size=config.intermediate_size,
338
- quant_config=quant_config)
373
+ quant_config=quant_config,
374
+ )
339
375
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
340
376
  self.post_attention_layernorm = RMSNorm(
341
377
  config.hidden_size, eps=config.rms_norm_eps
@@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module):
444
480
  ("qkv_proj", "v_proj", "v"),
445
481
  ]
446
482
 
447
- expert_params_mapping = [
448
- # These are the weight scales for the experts
449
- # (param_name, weight_name, expert_id)
450
- ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
451
- f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
452
- for expert_id in range(self.config.num_local_experts)
453
- for weight_name in ["w1", "w2", "w3"]
454
- ] + [
455
- # These are the weights for the experts
456
- # (param_name, weight_name, expert_id)
457
- ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
458
- f"experts.{expert_id}.{weight_name}.weight", expert_id)
459
- for expert_id in range(self.config.num_local_experts)
460
- for weight_name in ["w1", "w2", "w3"]
461
- ] + [
462
- # These are the activation scales for the experts
463
- # (param_name, weight_name, expert_id)
464
- ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
465
- f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
466
- for expert_id in range(self.config.num_local_experts)
467
- for weight_name in ["w1", "w2", "w3"]
468
- ]
483
+ expert_params_mapping = (
484
+ [
485
+ # These are the weight scales for the experts
486
+ # (param_name, weight_name, expert_id)
487
+ (
488
+ "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
489
+ f"experts.{expert_id}.{weight_name}.weight_scale",
490
+ expert_id,
491
+ )
492
+ for expert_id in range(self.config.num_local_experts)
493
+ for weight_name in ["w1", "w2", "w3"]
494
+ ]
495
+ + [
496
+ # These are the weights for the experts
497
+ # (param_name, weight_name, expert_id)
498
+ (
499
+ "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
500
+ f"experts.{expert_id}.{weight_name}.weight",
501
+ expert_id,
502
+ )
503
+ for expert_id in range(self.config.num_local_experts)
504
+ for weight_name in ["w1", "w2", "w3"]
505
+ ]
506
+ + [
507
+ # These are the activation scales for the experts
508
+ # (param_name, weight_name, expert_id)
509
+ (
510
+ "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
511
+ f"experts.{expert_id}.{weight_name}.act_scale",
512
+ expert_id,
513
+ )
514
+ for expert_id in range(self.config.num_local_experts)
515
+ for weight_name in ["w1", "w2", "w3"]
516
+ ]
517
+ )
469
518
 
470
519
  params_dict = dict(self.named_parameters())
471
520
  for name, loaded_weight in weights:
472
521
  if "rotary_emb.inv_freq" in name:
473
522
  continue
474
523
 
475
- for (param_name, weight_name, shard_id) in stacked_params_mapping:
524
+ for param_name, weight_name, shard_id in stacked_params_mapping:
476
525
  if weight_name not in name:
477
526
  continue
478
527
  name = name.replace(weight_name, param_name)
@@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module):
490
539
  name = name.replace(weight_name, param_name)
491
540
  param = params_dict[name]
492
541
  weight_loader = param.weight_loader
493
- weight_loader(param,
494
- loaded_weight,
495
- weight_name,
496
- expert_id=expert_id)
542
+ weight_loader(
543
+ param, loaded_weight, weight_name, expert_id=expert_id
544
+ )
497
545
  break
498
546
  else:
499
547
  # Skip loading extra bias for GPTQ models.
500
548
  if name.endswith(".bias") and name not in params_dict:
501
549
  continue
502
550
  param = params_dict[name]
503
- weight_loader = getattr(param, "weight_loader",
504
- default_weight_loader)
551
+ weight_loader = getattr(
552
+ param, "weight_loader", default_weight_loader
553
+ )
505
554
  weight_loader(param, loaded_weight)
506
555
 
507
556
 
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
28
28
  )
29
29
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
30
 
31
-
32
31
  from sglang.srt.layers.logits_processor import LogitsProcessor
33
32
  from sglang.srt.layers.radix_attention import RadixAttention
34
33
  from sglang.srt.managers.controller.model_runner import InputMetadata
sglang/srt/models/qwen.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Adapted from
2
2
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
3
- from typing import Any, Dict, Optional, Iterable, Tuple
3
+ from typing import Any, Dict, Iterable, Optional, Tuple
4
4
 
5
5
  import torch
6
6
  from torch import nn
@@ -1,7 +1,7 @@
1
1
  # Adapted from llama2.py
2
2
  # Modify details for the adaptation of Qwen2 model.
3
3
  """Inference-only Qwen2 model compatible with HuggingFace weights."""
4
- from typing import Any, Dict, Optional, Tuple, Iterable
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
5
 
6
6
  import torch
7
7
  from torch import nn