sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_offline_throughput.py +18 -6
  3. sglang/bench_one_batch.py +13 -0
  4. sglang/bench_serving.py +8 -1
  5. sglang/check_env.py +140 -48
  6. sglang/lang/backend/runtime_endpoint.py +1 -0
  7. sglang/lang/chat_template.py +32 -0
  8. sglang/llama3_eval.py +316 -0
  9. sglang/srt/constrained/outlines_backend.py +5 -0
  10. sglang/srt/constrained/xgrammar_backend.py +9 -6
  11. sglang/srt/layers/attention/__init__.py +5 -2
  12. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  13. sglang/srt/layers/attention/flashinfer_backend.py +22 -5
  14. sglang/srt/layers/attention/torch_native_backend.py +22 -8
  15. sglang/srt/layers/attention/triton_backend.py +38 -33
  16. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  17. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  18. sglang/srt/layers/ep_moe/__init__.py +0 -0
  19. sglang/srt/layers/ep_moe/kernels.py +349 -0
  20. sglang/srt/layers/ep_moe/layer.py +665 -0
  21. sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
  22. sglang/srt/layers/fused_moe_triton/layer.py +1 -1
  23. sglang/srt/layers/logits_processor.py +133 -95
  24. sglang/srt/layers/quantization/__init__.py +2 -47
  25. sglang/srt/layers/quantization/fp8.py +607 -0
  26. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  27. sglang/srt/layers/radix_attention.py +11 -2
  28. sglang/srt/layers/sampler.py +29 -5
  29. sglang/srt/layers/torchao_utils.py +58 -45
  30. sglang/srt/managers/detokenizer_manager.py +37 -17
  31. sglang/srt/managers/io_struct.py +39 -10
  32. sglang/srt/managers/schedule_batch.py +39 -24
  33. sglang/srt/managers/schedule_policy.py +64 -5
  34. sglang/srt/managers/scheduler.py +236 -197
  35. sglang/srt/managers/tokenizer_manager.py +99 -58
  36. sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
  37. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  38. sglang/srt/mem_cache/chunk_cache.py +2 -2
  39. sglang/srt/mem_cache/memory_pool.py +5 -1
  40. sglang/srt/mem_cache/radix_cache.py +12 -2
  41. sglang/srt/model_executor/cuda_graph_runner.py +39 -11
  42. sglang/srt/model_executor/model_runner.py +24 -9
  43. sglang/srt/model_parallel.py +67 -10
  44. sglang/srt/models/commandr.py +2 -2
  45. sglang/srt/models/deepseek_v2.py +87 -7
  46. sglang/srt/models/gemma2.py +34 -0
  47. sglang/srt/models/gemma2_reward.py +0 -1
  48. sglang/srt/models/granite.py +517 -0
  49. sglang/srt/models/grok.py +72 -13
  50. sglang/srt/models/llama.py +22 -5
  51. sglang/srt/models/llama_classification.py +11 -23
  52. sglang/srt/models/llama_reward.py +0 -2
  53. sglang/srt/models/llava.py +37 -14
  54. sglang/srt/models/mixtral.py +12 -9
  55. sglang/srt/models/phi3_small.py +0 -5
  56. sglang/srt/models/qwen2.py +20 -0
  57. sglang/srt/models/qwen2_moe.py +0 -5
  58. sglang/srt/models/torch_native_llama.py +0 -5
  59. sglang/srt/openai_api/adapter.py +4 -0
  60. sglang/srt/openai_api/protocol.py +9 -4
  61. sglang/srt/sampling/sampling_batch_info.py +9 -8
  62. sglang/srt/server.py +4 -4
  63. sglang/srt/server_args.py +62 -13
  64. sglang/srt/utils.py +57 -10
  65. sglang/test/test_utils.py +3 -2
  66. sglang/utils.py +10 -3
  67. sglang/version.py +1 -1
  68. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
  69. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
  70. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
  71. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
  72. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py CHANGED
@@ -25,9 +25,11 @@ from transformers import PretrainedConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
27
 
28
+ from sglang.srt.layers.activation import GeluAndMul
28
29
  from sglang.srt.layers.fused_moe_triton import FusedMoE
29
30
  from sglang.srt.layers.layernorm import RMSNorm
30
31
  from sglang.srt.layers.linear import (
32
+ MergedColumnParallelLinear,
31
33
  QKVParallelLinear,
32
34
  ReplicatedLinear,
33
35
  RowParallelLinear,
@@ -35,17 +37,48 @@ from sglang.srt.layers.linear import (
35
37
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
39
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
39
40
  from sglang.srt.layers.vocab_parallel_embedding import (
40
41
  ParallelLMHead,
41
42
  VocabParallelEmbedding,
42
43
  )
43
- from sglang.srt.managers.schedule_batch import global_server_args_dict
44
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
- from sglang.srt.model_loader.loader import DefaultModelLoader
46
45
  from sglang.srt.model_loader.weight_utils import default_weight_loader
47
46
 
48
47
 
48
+ class Grok1MLP(nn.Module):
49
+ def __init__(
50
+ self,
51
+ hidden_size: int,
52
+ intermediate_size: int,
53
+ quant_config: Optional[QuantizationConfig] = None,
54
+ prefix: str = "",
55
+ reduce_results=True,
56
+ ) -> None:
57
+ super().__init__()
58
+ self.gate_up_proj = MergedColumnParallelLinear(
59
+ hidden_size,
60
+ [intermediate_size] * 2,
61
+ bias=False,
62
+ quant_config=quant_config,
63
+ prefix=f"{prefix}.gate_up_proj",
64
+ )
65
+ self.down_proj = RowParallelLinear(
66
+ intermediate_size,
67
+ hidden_size,
68
+ bias=False,
69
+ quant_config=quant_config,
70
+ prefix=f"{prefix}.down_proj",
71
+ reduce_results=reduce_results,
72
+ )
73
+ self.act_fn = GeluAndMul(approximate="tanh")
74
+
75
+ def forward(self, x):
76
+ gate_up, _ = self.gate_up_proj(x)
77
+ x = self.act_fn(gate_up)
78
+ x, _ = self.down_proj(x)
79
+ return x
80
+
81
+
49
82
  class Grok1MoE(nn.Module):
50
83
  """A tensor-parallel MoE implementation for Grok1 that shards each expert
51
84
  across all ranks.
@@ -57,6 +90,7 @@ class Grok1MoE(nn.Module):
57
90
 
58
91
  def __init__(
59
92
  self,
93
+ config: PretrainedConfig,
60
94
  num_experts: int,
61
95
  top_k: int,
62
96
  hidden_size: int,
@@ -64,6 +98,7 @@ class Grok1MoE(nn.Module):
64
98
  params_dtype: Optional[torch.dtype] = None,
65
99
  quant_config: Optional[QuantizationConfig] = None,
66
100
  tp_size: Optional[int] = None,
101
+ reduce_results=True,
67
102
  ):
68
103
  super().__init__()
69
104
  self.hidden_size = hidden_size
@@ -77,13 +112,16 @@ class Grok1MoE(nn.Module):
77
112
  quant_config=None,
78
113
  )
79
114
 
115
+ self.router_logit_softcapping = getattr(
116
+ config, "router_logit_softcapping", 30.0
117
+ )
80
118
  self.experts = FusedMoE(
81
119
  num_experts=num_experts,
82
120
  top_k=top_k,
83
121
  hidden_size=hidden_size,
84
122
  intermediate_size=intermediate_size,
85
123
  params_dtype=params_dtype,
86
- reduce_results=True,
124
+ reduce_results=reduce_results,
87
125
  renormalize=False,
88
126
  quant_config=quant_config,
89
127
  tp_size=tp_size,
@@ -93,9 +131,12 @@ class Grok1MoE(nn.Module):
93
131
  # NOTE: hidden_states can have either 1D or 2D shape.
94
132
  orig_shape = hidden_states.shape
95
133
  hidden_states = hidden_states.view(-1, self.hidden_size)
134
+
96
135
  # router_logits: (num_tokens, n_experts)
97
136
  router_logits, _ = self.gate(hidden_states)
98
137
  router_logits = 30.0 * F.tanh(router_logits / 30.0)
138
+
139
+ # need to assert self.gate.quant_method is unquantized
99
140
  final_hidden_states = self.experts(hidden_states, router_logits)
100
141
  return final_hidden_states.view(orig_shape)
101
142
 
@@ -103,16 +144,18 @@ class Grok1MoE(nn.Module):
103
144
  class Grok1Attention(nn.Module):
104
145
  def __init__(
105
146
  self,
147
+ config: PretrainedConfig,
106
148
  hidden_size: int,
107
149
  num_heads: int,
108
150
  num_kv_heads: int,
109
151
  layer_id: int = 0,
110
152
  max_position: int = 4096 * 32,
111
153
  rope_theta: float = 10000,
112
- logit_cap: float = 30,
113
154
  quant_config: Optional[QuantizationConfig] = None,
114
155
  ) -> None:
115
156
  super().__init__()
157
+ self.config = config
158
+ self.layer_id = layer_id
116
159
  self.hidden_size = hidden_size
117
160
  tp_size = get_tensor_model_parallel_world_size()
118
161
  self.total_num_heads = num_heads
@@ -128,7 +171,7 @@ class Grok1Attention(nn.Module):
128
171
  # the KV heads across multiple tensor parallel GPUs.
129
172
  assert tp_size % self.total_num_kv_heads == 0
130
173
  self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
131
- self.head_dim = 128
174
+ self.head_dim = getattr(config, "head_dim", 128)
132
175
  self.q_size = self.num_heads * self.head_dim
133
176
  self.kv_size = self.num_kv_heads * self.head_dim
134
177
  self.scaling = self.head_dim**-0.5
@@ -142,7 +185,6 @@ class Grok1Attention(nn.Module):
142
185
  bias=False,
143
186
  quant_config=quant_config,
144
187
  )
145
-
146
188
  self.o_proj = RowParallelLinear(
147
189
  self.total_num_heads * self.head_dim,
148
190
  hidden_size,
@@ -156,6 +198,9 @@ class Grok1Attention(nn.Module):
156
198
  base=int(self.rope_theta),
157
199
  is_neox_style=True,
158
200
  )
201
+
202
+ logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
203
+
159
204
  self.attn = RadixAttention(
160
205
  self.num_heads,
161
206
  self.head_dim,
@@ -164,7 +209,6 @@ class Grok1Attention(nn.Module):
164
209
  layer_id=layer_id,
165
210
  logit_cap=logit_cap,
166
211
  )
167
- # TODO(lianmin): load logit cap from config
168
212
 
169
213
  def forward(
170
214
  self,
@@ -188,10 +232,12 @@ class Grok1DecoderLayer(nn.Module):
188
232
  quant_config: Optional[QuantizationConfig] = None,
189
233
  ) -> None:
190
234
  super().__init__()
235
+ self.num_experts = config.num_local_experts
191
236
  self.hidden_size = config.hidden_size
192
237
 
193
238
  rope_theta = getattr(config, "rope_theta", 10000)
194
239
  self.self_attn = Grok1Attention(
240
+ config=config,
195
241
  hidden_size=self.hidden_size,
196
242
  num_heads=config.num_attention_heads,
197
243
  max_position=config.max_position_embeddings,
@@ -201,11 +247,17 @@ class Grok1DecoderLayer(nn.Module):
201
247
  quant_config=quant_config,
202
248
  )
203
249
  self.block_sparse_moe = Grok1MoE(
250
+ config=config,
204
251
  num_experts=config.num_local_experts,
205
252
  top_k=config.num_experts_per_tok,
206
253
  hidden_size=config.hidden_size,
207
- intermediate_size=config.intermediate_size,
254
+ intermediate_size=getattr(
255
+ config,
256
+ "moe_intermediate_size",
257
+ getattr(config, "intermediate_size", None),
258
+ ),
208
259
  quant_config=quant_config,
260
+ reduce_results=True,
209
261
  )
210
262
  self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
211
263
  self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -286,11 +338,11 @@ class Grok1ForCausalLM(nn.Module):
286
338
  self,
287
339
  config: PretrainedConfig,
288
340
  quant_config: Optional[QuantizationConfig] = None,
341
+ cache_config=None,
289
342
  ) -> None:
290
343
  super().__init__()
291
344
  self.config = config
292
345
  self.quant_config = quant_config
293
- self.torchao_config = global_server_args_dict["torchao_config"]
294
346
  self.model = Grok1Model(config, quant_config=quant_config)
295
347
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
296
348
  self.logits_processor = LogitsProcessor(config)
@@ -313,6 +365,8 @@ class Grok1ForCausalLM(nn.Module):
313
365
  ("qkv_proj", "q_proj", "q"),
314
366
  ("qkv_proj", "k_proj", "k"),
315
367
  ("qkv_proj", "v_proj", "v"),
368
+ ("gate_up_proj", "gate_proj", 0),
369
+ ("gate_up_proj", "up_proj", 1),
316
370
  ]
317
371
 
318
372
  # Params for weights, fp8 weight scales, fp8 activation scales
@@ -348,6 +402,11 @@ class Grok1ForCausalLM(nn.Module):
348
402
  continue
349
403
  name = name.replace(weight_name, param_name)
350
404
 
405
+ if (
406
+ name.endswith(".bias") or name.endswith("_bias")
407
+ ) and name not in params_dict:
408
+ continue
409
+
351
410
  param = params_dict[name]
352
411
  weight_loader = param.weight_loader
353
412
  weight_loader(
@@ -360,7 +419,9 @@ class Grok1ForCausalLM(nn.Module):
360
419
  break
361
420
  else:
362
421
  # Skip loading extra bias for GPTQ models.
363
- if name.endswith(".bias") and name not in params_dict:
422
+ if (
423
+ name.endswith(".bias") or name.endswith("_bias")
424
+ ) and name not in params_dict:
364
425
  continue
365
426
  # Skip loading kv_scale from ckpts towards new design.
366
427
  if name.endswith(".kv_scale") and name not in params_dict:
@@ -374,8 +435,6 @@ class Grok1ForCausalLM(nn.Module):
374
435
  )
375
436
  weight_loader(param, loaded_weight)
376
437
 
377
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
378
-
379
438
 
380
439
  class Grok1ModelForCausalLM(Grok1ForCausalLM):
381
440
  """An alias for backward-compatbility."""
@@ -36,12 +36,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
36
36
  from sglang.srt.layers.pooler import Pooler, PoolingType
37
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
40
39
  from sglang.srt.layers.vocab_parallel_embedding import (
41
40
  ParallelLMHead,
42
41
  VocabParallelEmbedding,
43
42
  )
44
- from sglang.srt.managers.schedule_batch import global_server_args_dict
45
43
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
46
44
  from sglang.srt.model_loader.weight_utils import default_weight_loader
47
45
  from sglang.srt.utils import make_layers
@@ -296,6 +294,28 @@ class LlamaModel(nn.Module):
296
294
 
297
295
 
298
296
  class LlamaForCausalLM(nn.Module):
297
+
298
+ # BitandBytes specific attributes
299
+ default_bitsandbytes_target_modules = [
300
+ ".gate_proj.",
301
+ ".down_proj.",
302
+ ".up_proj.",
303
+ ".q_proj.",
304
+ ".k_proj.",
305
+ ".v_proj.",
306
+ ".o_proj.",
307
+ ]
308
+ # in TP, these weights are partitioned along the column dimension (dim=-1)
309
+ column_parallel_weights_modules = [".down_proj.", ".o_proj."]
310
+ bitsandbytes_stacked_params_mapping = {
311
+ # shard_name, weight_name, index
312
+ "q_proj": ("qkv_proj", 0),
313
+ "k_proj": ("qkv_proj", 1),
314
+ "v_proj": ("qkv_proj", 2),
315
+ "gate_proj": ("gate_up_proj", 0),
316
+ "up_proj": ("gate_up_proj", 1),
317
+ }
318
+
299
319
  def __init__(
300
320
  self,
301
321
  config: LlamaConfig,
@@ -304,7 +324,6 @@ class LlamaForCausalLM(nn.Module):
304
324
  super().__init__()
305
325
  self.config = config
306
326
  self.quant_config = quant_config
307
- self.torchao_config = global_server_args_dict["torchao_config"]
308
327
  self.model = LlamaModel(config, quant_config=quant_config)
309
328
  # Llama 3.2 1B Insturct set tie_word_embeddings to True
310
329
  # Llama 3.1 8B Insturct set tie_word_embeddings to False
@@ -424,8 +443,6 @@ class LlamaForCausalLM(nn.Module):
424
443
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
425
444
  weight_loader(param, loaded_weight)
426
445
 
427
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
428
-
429
446
  def get_weights_by_name(
430
447
  self, name: str, truncate_size: int = 100, tp_size: int = 1
431
448
  ) -> Optional[torch.Tensor]:
@@ -18,7 +18,7 @@ import torch
18
18
  from torch import nn
19
19
  from transformers import LlamaConfig
20
20
 
21
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
21
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
24
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -33,14 +33,13 @@ class LlamaForClassification(nn.Module):
33
33
  ) -> None:
34
34
  super().__init__()
35
35
  self.config = config
36
- self.torchao_config = None
37
36
  self.quant_config = quant_config
38
37
  self.model = LlamaModel(config, quant_config=quant_config)
39
38
 
40
39
  self.classification_head = nn.Linear(
41
40
  config.hidden_size, config.classification_out_size, bias=False
42
41
  )
43
- self.eos_token_id = config.eos_token_id
42
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
44
43
 
45
44
  @torch.no_grad()
46
45
  def forward(
@@ -49,28 +48,17 @@ class LlamaForClassification(nn.Module):
49
48
  positions: torch.Tensor,
50
49
  forward_batch: ForwardBatch,
51
50
  input_embeds: torch.Tensor = None,
52
- ) -> torch.Tensor:
53
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
54
- is_eos_token = input_ids == self.eos_token_id
55
- hidden_states = hidden_states[is_eos_token]
56
- scores = self.classification_head(hidden_states)
57
-
58
- if scores.shape[0] != forward_batch.batch_size:
59
- print("Warning: the EOS tokens are missing in some sentences.")
60
- scores = torch.ones(
61
- (forward_batch.batch_size, self.config.classification_out_size)
62
- ).to(input_ids.device)
51
+ get_embedding: bool = True,
52
+ ) -> EmbeddingPoolerOutput:
53
+ assert (
54
+ get_embedding
55
+ ), "LlamaForClassification is only used for embedding. Please add --is-embedding when you launch the server."
63
56
 
64
- logits_output = LogitsProcessorOutput(
65
- next_token_logits=scores,
66
- next_token_logprobs=scores,
67
- normalized_prompt_logprobs=scores,
68
- input_token_logprobs=torch.ones_like(input_ids),
69
- input_top_logprobs=None,
70
- output_top_logprobs=None,
71
- )
57
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
58
+ last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
59
+ scores = self.classification_head(last_token_hidden)
72
60
 
73
- return logits_output
61
+ return EmbeddingPoolerOutput(scores)
74
62
 
75
63
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
76
64
  params_dict = dict(self.named_parameters())
@@ -21,7 +21,6 @@ from transformers import LlamaConfig
21
21
  from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
- from sglang.srt.model_loader.weight_utils import default_weight_loader
25
24
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
26
25
 
27
26
 
@@ -33,7 +32,6 @@ class LlamaForSequenceClassification(nn.Module):
33
32
  ) -> None:
34
33
  super().__init__()
35
34
  self.config = config
36
- self.torchao_config = None
37
35
  self.quant_config = quant_config
38
36
  self.num_labels = config.num_labels
39
37
  self.model = LlamaModel(config, quant_config=quant_config)
@@ -57,6 +57,7 @@ class LlavaBaseForCausalLM(nn.Module):
57
57
  else:
58
58
  image_aspect_ratio = "anyres"
59
59
  offset_list = []
60
+ image_inputs.image_pad_len = []
60
61
  for image_idx, image_s in enumerate(image_sizes):
61
62
  if len(image_sizes) > 16:
62
63
  # 2x2 pooling with stride 2
@@ -103,6 +104,7 @@ class LlavaBaseForCausalLM(nn.Module):
103
104
  + input_ids[offset + 1 :]
104
105
  )
105
106
  offset_list.append(offset)
107
+ image_inputs.image_pad_len.append(new_image_feature_len)
106
108
 
107
109
  image_inputs.image_offsets = offset_list
108
110
  return input_ids
@@ -134,6 +136,14 @@ class LlavaBaseForCausalLM(nn.Module):
134
136
  image_inputs = forward_batch.image_inputs
135
137
 
136
138
  if forward_batch.forward_mode.is_extend():
139
+ # Clamp input ids. This is because the input_ids for the image tokens are
140
+ # filled with the hash values of the image for the prefix matching in the radix attention.
141
+ # There values are useless because their embeddings will be replaced by vision embeddings anyway.
142
+ input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
143
+
144
+ # Embed text inputs
145
+ input_embeds = self.language_model.model.embed_tokens(input_ids)
146
+
137
147
  # Got List[List[str]] extend it to List[str]
138
148
  # The length of the List should be equal to batch size
139
149
  modalities_list = []
@@ -142,18 +152,12 @@ class LlavaBaseForCausalLM(nn.Module):
142
152
  if im and im.modalities is not None:
143
153
  modalities_list.extend(im.modalities)
144
154
  if im and im.image_offsets:
145
- max_image_offset.append(max(im.image_offsets))
155
+ max_image_offset.append(
156
+ np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
157
+ )
146
158
  else:
147
159
  max_image_offset.append(-1)
148
160
 
149
- # Clamp input ids. This is because the input_ids for the image tokens are
150
- # filled with the hash values of the image for the prefix matching in the radix attention.
151
- # There values are useless because their embeddings will be replaced by vision embeddings anyway.
152
- input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
153
-
154
- # Embed text inputs
155
- input_embeds = self.language_model.model.embed_tokens(input_ids)
156
-
157
161
  start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
158
162
  need_vision = start_positions <= np.array(max_image_offset)
159
163
 
@@ -350,6 +354,7 @@ class LlavaBaseForCausalLM(nn.Module):
350
354
 
351
355
  # Fill in the placeholder for the image
352
356
  extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
357
+ extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
353
358
  prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
354
359
  pt = 0
355
360
  for i in range(bs):
@@ -357,18 +362,36 @@ class LlavaBaseForCausalLM(nn.Module):
357
362
  continue
358
363
 
359
364
  start_idx = extend_start_loc_cpu[i]
365
+ seq_len = extend_seq_lens[i]
360
366
  prefix_len = prefix_lens_cpu[i]
361
367
 
362
368
  # Multiple images
363
- for j, image_offset in enumerate(image_inputs[i].image_offsets):
364
- if image_offset < prefix_len:
369
+ for image_idx, image_offset in enumerate(
370
+ image_inputs[i].image_offsets
371
+ ):
372
+ if (
373
+ image_offset + image_inputs[i].image_pad_len[image_idx]
374
+ <= prefix_len
375
+ ):
365
376
  continue
377
+ if image_offset >= prefix_len + seq_len:
378
+ break
366
379
 
367
- tmp_image_feature = image_features[pt][j]
380
+ tmp_image_feature = image_features[pt][image_idx]
368
381
  pad_len = tmp_image_feature.shape[0]
369
382
 
370
- left_idx = start_idx + (image_offset - prefix_len)
371
- right_idx = start_idx + (image_offset - prefix_len) + pad_len
383
+ input_offset = image_offset - prefix_len
384
+ left_idx = start_idx + input_offset
385
+ right_idx = left_idx + pad_len
386
+ assert right_idx > start_idx
387
+ if input_offset < 0:
388
+ left_idx = start_idx
389
+ tmp_image_feature = tmp_image_feature[-input_offset:]
390
+ if right_idx > start_idx + seq_len:
391
+ tmp_image_feature = tmp_image_feature[
392
+ : start_idx + seq_len - right_idx
393
+ ]
394
+ right_idx = start_idx + seq_len
372
395
  try:
373
396
  input_embeds[left_idx:right_idx] = tmp_image_feature
374
397
  except RuntimeError as e:
@@ -21,9 +21,13 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import MixtralConfig
24
- from vllm.distributed import get_tensor_model_parallel_world_size
24
+ from vllm.distributed import (
25
+ get_tensor_model_parallel_world_size,
26
+ tensor_model_parallel_all_reduce,
27
+ )
25
28
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
29
 
30
+ from sglang.srt.layers.ep_moe.layer import EPMoE
27
31
  from sglang.srt.layers.fused_moe_triton import FusedMoE
28
32
  from sglang.srt.layers.layernorm import RMSNorm
29
33
  from sglang.srt.layers.linear import (
@@ -34,7 +38,6 @@ from sglang.srt.layers.linear import (
34
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
35
39
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
40
  from sglang.srt.layers.radix_attention import RadixAttention
37
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
38
41
  from sglang.srt.layers.vocab_parallel_embedding import (
39
42
  ParallelLMHead,
40
43
  VocabParallelEmbedding,
@@ -65,6 +68,7 @@ class MixtralMoE(nn.Module):
65
68
  prefix: str = "",
66
69
  ):
67
70
  super().__init__()
71
+ self.tp_size = get_tensor_model_parallel_world_size()
68
72
  self.hidden_size = hidden_size
69
73
 
70
74
  # Gate always runs at half / full precision for now.
@@ -76,14 +80,13 @@ class MixtralMoE(nn.Module):
76
80
  quant_config=None,
77
81
  prefix=f"{prefix}.gate",
78
82
  )
79
-
80
- self.experts = FusedMoE(
83
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
84
+ self.experts = MoEImpl(
81
85
  num_experts=num_experts,
82
86
  top_k=top_k,
83
87
  hidden_size=hidden_size,
84
88
  intermediate_size=intermediate_size,
85
89
  params_dtype=params_dtype,
86
- reduce_results=True,
87
90
  renormalize=True,
88
91
  quant_config=quant_config,
89
92
  tp_size=tp_size,
@@ -97,6 +100,8 @@ class MixtralMoE(nn.Module):
97
100
  # router_logits: (num_tokens, n_experts)
98
101
  router_logits, _ = self.gate(hidden_states)
99
102
  final_hidden_states = self.experts(hidden_states, router_logits)
103
+ if self.tp_size > 1:
104
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
100
105
  return final_hidden_states.view(orig_shape)
101
106
 
102
107
 
@@ -295,7 +300,6 @@ class MixtralForCausalLM(nn.Module):
295
300
  super().__init__()
296
301
  self.config = config
297
302
  self.quant_config = quant_config
298
- self.torchao_config = global_server_args_dict["torchao_config"]
299
303
  self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
300
304
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
301
305
  self.logits_processor = LogitsProcessor(config)
@@ -322,7 +326,8 @@ class MixtralForCausalLM(nn.Module):
322
326
 
323
327
  # Params for weights, fp8 weight scales, fp8 activation scales
324
328
  # (param_name, weight_name, expert_id, shard_id)
325
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
329
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
330
+ expert_params_mapping = MoEImpl.make_expert_params_mapping(
326
331
  ckpt_gate_proj_name="w1",
327
332
  ckpt_down_proj_name="w2",
328
333
  ckpt_up_proj_name="w3",
@@ -387,7 +392,5 @@ class MixtralForCausalLM(nn.Module):
387
392
  )
388
393
  weight_loader(param, loaded_weight)
389
394
 
390
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
391
-
392
395
 
393
396
  EntryClass = MixtralForCausalLM
@@ -17,13 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
17
17
  from sglang.srt.layers.pooler import Pooler, PoolingType
18
18
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
19
  from sglang.srt.layers.radix_attention import RadixAttention
20
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
21
20
  from sglang.srt.layers.vocab_parallel_embedding import (
22
21
  DEFAULT_VOCAB_PADDING_SIZE,
23
22
  ParallelLMHead,
24
23
  VocabParallelEmbedding,
25
24
  )
26
- from sglang.srt.managers.schedule_batch import global_server_args_dict
27
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
26
  from sglang.srt.model_loader.weight_utils import default_weight_loader
29
27
  from sglang.srt.utils import make_layers
@@ -348,7 +346,6 @@ class Phi3SmallForCausalLM(nn.Module):
348
346
  quant_config=quant_config,
349
347
  prefix="model",
350
348
  )
351
- self.torchao_config = global_server_args_dict["torchao_config"]
352
349
  self.vocab_size = config.vocab_size
353
350
  self.mup_width_multiplier = config.mup_width_multiplier
354
351
  self.lm_head = ParallelLMHead(
@@ -441,7 +438,5 @@ class Phi3SmallForCausalLM(nn.Module):
441
438
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
442
439
  weight_loader(param, loaded_weight)
443
440
 
444
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
445
-
446
441
 
447
442
  EntryClass = Phi3SmallForCausalLM
@@ -267,6 +267,26 @@ class Qwen2Model(nn.Module):
267
267
 
268
268
 
269
269
  class Qwen2ForCausalLM(nn.Module):
270
+
271
+ # BitandBytes specific attributes
272
+ default_bitsandbytes_target_modules = [
273
+ ".gate_proj.",
274
+ ".down_proj.",
275
+ ".up_proj.",
276
+ ".q_proj.",
277
+ ".k_proj.",
278
+ ".v_proj.",
279
+ ".o_proj.",
280
+ ]
281
+ bitsandbytes_stacked_params_mapping = {
282
+ # shard_name, weight_name, index
283
+ "q_proj": ("qkv_proj", 0),
284
+ "k_proj": ("qkv_proj", 1),
285
+ "v_proj": ("qkv_proj", 2),
286
+ "gate_proj": ("gate_up_proj", 0),
287
+ "up_proj": ("gate_up_proj", 1),
288
+ }
289
+
270
290
  def __init__(
271
291
  self,
272
292
  config: Qwen2Config,
@@ -40,12 +40,10 @@ from sglang.srt.layers.linear import (
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
44
43
  from sglang.srt.layers.vocab_parallel_embedding import (
45
44
  ParallelLMHead,
46
45
  VocabParallelEmbedding,
47
46
  )
48
- from sglang.srt.managers.schedule_batch import global_server_args_dict
49
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
48
  from sglang.srt.model_loader.weight_utils import default_weight_loader
51
49
 
@@ -352,7 +350,6 @@ class Qwen2MoeForCausalLM(nn.Module):
352
350
  super().__init__()
353
351
  self.config = config
354
352
  self.quant_config = quant_config
355
- self.torchao_config = global_server_args_dict["torchao_config"]
356
353
  self.model = Qwen2MoeModel(config, quant_config)
357
354
  self.lm_head = ParallelLMHead(
358
355
  config.vocab_size, config.hidden_size, quant_config=quant_config
@@ -445,7 +442,5 @@ class Qwen2MoeForCausalLM(nn.Module):
445
442
  )
446
443
  weight_loader(param, loaded_weight)
447
444
 
448
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
449
-
450
445
 
451
446
  EntryClass = Qwen2MoeForCausalLM
@@ -58,12 +58,10 @@ from sglang.srt.layers.layernorm import RMSNorm
58
58
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
59
59
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
60
60
  from sglang.srt.layers.radix_attention import RadixAttention
61
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
62
61
  from sglang.srt.layers.vocab_parallel_embedding import (
63
62
  ParallelLMHead,
64
63
  VocabParallelEmbedding,
65
64
  )
66
- from sglang.srt.managers.schedule_batch import global_server_args_dict
67
65
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
68
66
  from sglang.srt.model_loader.weight_utils import default_weight_loader
69
67
 
@@ -392,7 +390,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
392
390
  super().__init__()
393
391
  self.config = config
394
392
  self.quant_config = quant_config
395
- self.torchao_config = global_server_args_dict["torchao_config"]
396
393
  self.supports_torch_tp = True
397
394
  self.model = LlamaModel(config, quant_config=quant_config)
398
395
  if self.config.tie_word_embeddings:
@@ -503,8 +500,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
503
500
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
504
501
  weight_loader(param, loaded_weight)
505
502
 
506
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
507
-
508
503
 
509
504
  class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
510
505
  pass