sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +6 -25
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +104 -71
  31. sglang/srt/managers/tokenizer_manager.py +17 -8
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +58 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +117 -131
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +1 -5
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +1 -5
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/llama.py +51 -5
  49. sglang/srt/models/llama_classification.py +1 -20
  50. sglang/srt/models/llava.py +30 -5
  51. sglang/srt/models/llavavid.py +2 -2
  52. sglang/srt/models/minicpm.py +1 -5
  53. sglang/srt/models/minicpm3.py +665 -0
  54. sglang/srt/models/mixtral.py +6 -5
  55. sglang/srt/models/mixtral_quant.py +1 -5
  56. sglang/srt/models/qwen.py +1 -5
  57. sglang/srt/models/qwen2.py +1 -5
  58. sglang/srt/models/qwen2_moe.py +6 -5
  59. sglang/srt/models/stablelm.py +1 -5
  60. sglang/srt/models/xverse.py +375 -0
  61. sglang/srt/models/xverse_moe.py +445 -0
  62. sglang/srt/openai_api/adapter.py +65 -46
  63. sglang/srt/openai_api/protocol.py +11 -3
  64. sglang/srt/sampling/sampling_batch_info.py +57 -44
  65. sglang/srt/server.py +24 -14
  66. sglang/srt/server_args.py +130 -28
  67. sglang/srt/utils.py +12 -0
  68. sglang/test/few_shot_gsm8k.py +132 -0
  69. sglang/test/runners.py +114 -22
  70. sglang/test/test_programs.py +7 -5
  71. sglang/test/test_utils.py +85 -1
  72. sglang/utils.py +32 -37
  73. sglang/version.py +1 -1
  74. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
  75. sglang-0.3.1.dist-info/RECORD +129 -0
  76. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  77. sglang-0.3.0.dist-info/RECORD +0 -118
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  79. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py CHANGED
@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
- from sglang.srt.layers.sampler import Sampler
50
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
51
50
 
52
51
 
@@ -298,7 +297,6 @@ class Grok1ForCausalLM(nn.Module):
298
297
  self.model = Grok1Model(config, quant_config=quant_config)
299
298
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
300
299
  self.logits_processor = LogitsProcessor(config)
301
- self.sampler = Sampler()
302
300
 
303
301
  # Monkey patch _prepare_weights to load pre-sharded weights
304
302
  setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
@@ -315,11 +313,9 @@ class Grok1ForCausalLM(nn.Module):
315
313
  input_embeds: torch.Tensor = None,
316
314
  ) -> torch.Tensor:
317
315
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
318
- logits_output = self.logits_processor(
316
+ return self.logits_processor(
319
317
  input_ids, hidden_states, self.lm_head.weight, input_metadata
320
318
  )
321
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
322
- return sample_output, logits_output
323
319
 
324
320
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
325
321
  stacked_params_mapping = [
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
40
40
  from sglang.srt.layers.layernorm import RMSNorm
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.layers.sampler import Sampler
44
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
44
 
46
45
 
@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
263
262
  self.model = InternLM2Model(config, quant_config)
264
263
  self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
265
264
  self.logits_processor = LogitsProcessor(config)
266
- self.sampler = Sampler()
267
265
 
268
266
  @torch.no_grad()
269
267
  def forward(
@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
274
272
  input_embeds: torch.Tensor = None,
275
273
  ) -> torch.Tensor:
276
274
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
277
- logits_output = self.logits_processor(
275
+ return self.logits_processor(
278
276
  input_ids, hidden_states, self.output.weight, input_metadata
279
277
  )
280
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
281
- return sample_output, logits_output
282
278
 
283
279
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
284
280
  stacked_params_mapping = [
@@ -41,7 +41,8 @@ from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
- from sglang.srt.layers.sampler import Sampler
44
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_
45
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
45
46
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
47
 
47
48
 
@@ -299,10 +300,10 @@ class LlamaForCausalLM(nn.Module):
299
300
  super().__init__()
300
301
  self.config = config
301
302
  self.quant_config = quant_config
303
+ self.torchao_config = global_server_args_dict["torchao_config"]
302
304
  self.model = LlamaModel(config, quant_config=quant_config)
303
305
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
304
306
  self.logits_processor = LogitsProcessor(config)
305
- self.sampler = Sampler()
306
307
 
307
308
  self.param_dict = dict(self.named_parameters())
308
309
 
@@ -315,11 +316,54 @@ class LlamaForCausalLM(nn.Module):
315
316
  input_embeds: torch.Tensor = None,
316
317
  ) -> LogitsProcessorOutput:
317
318
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
318
- logits_output = self.logits_processor(
319
+ return self.logits_processor(
319
320
  input_ids, hidden_states, self.lm_head.weight, input_metadata
320
321
  )
321
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
322
- return sample_output, logits_output
322
+
323
+ def get_hidden_dim(self, module_name):
324
+ if module_name in ["q_proj", "o_proj", "qkv_proj"]:
325
+ return self.config.hidden_size, self.config.hidden_size
326
+ elif module_name in ["kv_proj"]:
327
+ return self.config.hidden_size, self.config.hidden_size // (
328
+ self.config.num_attention_heads // self.config.num_key_value_heads
329
+ )
330
+ elif module_name == "gate_up_proj":
331
+ return self.config.hidden_size, self.config.intermediate_size
332
+ elif module_name == "down_proj":
333
+ return self.config.intermediate_size, self.config.hidden_size
334
+ else:
335
+ raise NotImplementedError()
336
+
337
+ def get_module_name(self, name):
338
+ params_mapping = {
339
+ "q_proj": "qkv_proj",
340
+ "k_proj": "qkv_proj",
341
+ "v_proj": "qkv_proj",
342
+ "gate_proj": "gate_up_proj",
343
+ "up_proj": "gate_up_proj",
344
+ }
345
+ return params_mapping.get(name, name)
346
+
347
+ def get_module_name_from_weight_name(self, name):
348
+ stacked_params_mapping = [
349
+ # (param_name, shard_name, shard_id, num_shard)
350
+ ("qkv_proj", "q_proj", "q", 3),
351
+ ("qkv_proj", "k_proj", "k", 3),
352
+ ("qkv_proj", "v_proj", "v", 3),
353
+ ("gate_up_proj", "gate_proj", 0, 2),
354
+ ("gate_up_proj", "up_proj", 1, 2),
355
+ ]
356
+ for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
357
+ if weight_name in name:
358
+ return (
359
+ name.replace(weight_name, param_name)[: -len(".weight")],
360
+ num_shard,
361
+ )
362
+ return name[: -len(".weight")], 1
363
+
364
+ def get_num_params(self):
365
+ params_dict = dict(self.named_parameters())
366
+ return len(params_dict)
323
367
 
324
368
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
325
369
  stacked_params_mapping = [
@@ -361,6 +405,8 @@ class LlamaForCausalLM(nn.Module):
361
405
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
362
406
  weight_loader(param, loaded_weight)
363
407
 
408
+ apply_torchao_config_(self, params_dict, set(["proj.weight"]))
409
+
364
410
 
365
411
  class Phi3ForCausalLM(LlamaForCausalLM):
366
412
  pass
@@ -23,7 +23,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
23
23
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
24
 
25
25
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
26
- from sglang.srt.layers.sampler import SampleOutput
27
26
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
28
27
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
29
28
 
@@ -75,25 +74,7 @@ class LlamaForClassification(nn.Module):
75
74
  output_top_logprobs=None,
76
75
  )
77
76
 
78
- # A dummy to make this work
79
- sample_output = SampleOutput(
80
- success=torch.full(
81
- size=(scores.shape[0],),
82
- fill_value=True,
83
- dtype=torch.bool,
84
- ),
85
- probs=torch.full(
86
- size=(scores.shape[0], 1),
87
- fill_value=1.0,
88
- dtype=torch.float16,
89
- ),
90
- batch_next_token_ids=torch.full(
91
- size=(scores.shape[0],),
92
- fill_value=0,
93
- dtype=torch.long,
94
- ),
95
- )
96
- return sample_output, logits_output
77
+ return logits_output
97
78
 
98
79
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
99
80
  params_dict = self.param_dict
@@ -136,8 +136,14 @@ class LlavaBaseForCausalLM(nn.Module):
136
136
  image_sizes: Optional[List[List[int]]] = None,
137
137
  image_offsets: Optional[List[int]] = None,
138
138
  ) -> torch.Tensor:
139
- if input_metadata.forward_mode == ForwardMode.EXTEND:
139
+ if input_metadata.forward_mode.is_extend():
140
140
  bs = input_metadata.batch_size
141
+ # Got List[List[str]] extend it to List[str]
142
+ # The length of the List should be equal to batch size
143
+ modalities_list = []
144
+ for modalities in input_metadata.modalities:
145
+ if modalities is not None:
146
+ modalities_list.extend(modalities)
141
147
 
142
148
  # Embed text inputs
143
149
  input_embeds = self.language_model.model.embed_tokens(input_ids)
@@ -179,11 +185,14 @@ class LlavaBaseForCausalLM(nn.Module):
179
185
  new_image_features = []
180
186
  height = width = self.num_patches_per_side
181
187
  for image_idx, image_feature in enumerate(image_features):
182
- if len(image_sizes[image_idx]) == 1:
188
+ if modalities_list[image_idx] == "image":
183
189
  image_aspect_ratio = (
184
190
  self.config.image_aspect_ratio
185
191
  ) # single image
186
- else:
192
+ elif (
193
+ modalities_list[image_idx] == "multi-images"
194
+ or modalities_list[image_idx] == "video"
195
+ ):
187
196
  image_aspect_ratio = "pad" # multi image
188
197
  # image_aspect_ratio = (
189
198
  # "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
@@ -191,6 +200,7 @@ class LlavaBaseForCausalLM(nn.Module):
191
200
  if (
192
201
  image_feature.shape[0] > 1
193
202
  and "anyres" in image_aspect_ratio
203
+ and modalities_list[image_idx] == "image"
194
204
  ):
195
205
  base_image_feature = image_feature[0]
196
206
  image_feature = image_feature[1:]
@@ -290,7 +300,7 @@ class LlavaBaseForCausalLM(nn.Module):
290
300
  )
291
301
  image_feature = image_feature.unsqueeze(0)
292
302
  else:
293
- if image_feature.shape[0] > 16: # video
303
+ if modalities_list[image_idx] == "video": # video
294
304
  # 2x2 pooling
295
305
  num_of_frames = image_feature.shape[0]
296
306
  image_feature = image_feature.view(
@@ -312,6 +322,21 @@ class LlavaBaseForCausalLM(nn.Module):
312
322
  .transpose(1, 2)
313
323
  .contiguous()
314
324
  ) # N, C, H*W
325
+ if "unpad" in self.mm_patch_merge_type:
326
+ image_feature = torch.cat(
327
+ (
328
+ image_feature,
329
+ # Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
330
+ self.language_model.model.image_newline[
331
+ None, None
332
+ ].expand(
333
+ image_feature.shape[0],
334
+ 1,
335
+ image_feature.shape[-1],
336
+ ),
337
+ ),
338
+ dim=1,
339
+ )
315
340
 
316
341
  new_image_features.append(image_feature)
317
342
  image_features = new_image_features
@@ -350,7 +375,7 @@ class LlavaBaseForCausalLM(nn.Module):
350
375
  return self.language_model(
351
376
  input_ids, positions, input_metadata, input_embeds=input_embeds
352
377
  )
353
- elif input_metadata.forward_mode == ForwardMode.DECODE:
378
+ elif input_metadata.forward_mode.is_decode():
354
379
  return self.language_model(input_ids, positions, input_metadata)
355
380
 
356
381
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -116,7 +116,7 @@ class LlavaVidForCausalLM(nn.Module):
116
116
  image_sizes: Optional[List[List[int]]] = None,
117
117
  image_offsets: Optional[List[int]] = None,
118
118
  ) -> torch.Tensor:
119
- if input_metadata.forward_mode == ForwardMode.EXTEND:
119
+ if input_metadata.forward_mode.is_extend():
120
120
  bs = input_metadata.batch_size
121
121
 
122
122
  # Embed text inputs
@@ -199,7 +199,7 @@ class LlavaVidForCausalLM(nn.Module):
199
199
  return self.language_model(
200
200
  input_ids, positions, input_metadata, input_embeds=input_embeds
201
201
  )
202
- elif input_metadata.forward_mode == ForwardMode.DECODE:
202
+ elif input_metadata.forward_mode.is_decode():
203
203
  return self.language_model(input_ids, positions, input_metadata)
204
204
 
205
205
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.layers.sampler import Sampler
43
42
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
43
 
45
44
 
@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module):
298
297
  self.scale_width = self.config.hidden_size / self.config.dim_model_base
299
298
 
300
299
  self.logits_processor = LogitsProcessor(config)
301
- self.sampler = Sampler()
302
300
 
303
301
  @torch.no_grad()
304
302
  def forward(
@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module):
316
314
  lm_head_weight = self.model.embed_tokens.weight
317
315
  else:
318
316
  lm_head_weight = self.lm_head.weight
319
- logits_output = self.logits_processor(
317
+ return self.logits_processor(
320
318
  input_ids, hidden_states, lm_head_weight, input_metadata
321
319
  )
322
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
323
- return sample_output, logits_output
324
320
 
325
321
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
326
322
  stacked_params_mapping = [