sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_offline_throughput.py +18 -6
- sglang/bench_one_batch.py +13 -0
- sglang/bench_serving.py +8 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/constrained/outlines_backend.py +5 -0
- sglang/srt/constrained/xgrammar_backend.py +9 -6
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +22 -5
- sglang/srt/layers/attention/torch_native_backend.py +22 -8
- sglang/srt/layers/attention/triton_backend.py +38 -33
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +665 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
- sglang/srt/layers/fused_moe_triton/layer.py +1 -1
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/quantization/__init__.py +2 -47
- sglang/srt/layers/quantization/fp8.py +607 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +11 -2
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/torchao_utils.py +58 -45
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +39 -24
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +236 -197
- sglang/srt/managers/tokenizer_manager.py +99 -58
- sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -11
- sglang/srt/model_executor/model_runner.py +24 -9
- sglang/srt/model_parallel.py +67 -10
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/deepseek_v2.py +87 -7
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +72 -13
- sglang/srt/models/llama.py +22 -5
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +12 -9
- sglang/srt/models/phi3_small.py +0 -5
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +0 -5
- sglang/srt/models/torch_native_llama.py +0 -5
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +4 -4
- sglang/srt/server_args.py +62 -13
- sglang/srt/utils.py +57 -10
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py
CHANGED
@@ -25,9 +25,11 @@ from transformers import PretrainedConfig
|
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
27
27
|
|
28
|
+
from sglang.srt.layers.activation import GeluAndMul
|
28
29
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
29
30
|
from sglang.srt.layers.layernorm import RMSNorm
|
30
31
|
from sglang.srt.layers.linear import (
|
32
|
+
MergedColumnParallelLinear,
|
31
33
|
QKVParallelLinear,
|
32
34
|
ReplicatedLinear,
|
33
35
|
RowParallelLinear,
|
@@ -35,17 +37,48 @@ from sglang.srt.layers.linear import (
|
|
35
37
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
36
38
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
37
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
-
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
39
40
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
40
41
|
ParallelLMHead,
|
41
42
|
VocabParallelEmbedding,
|
42
43
|
)
|
43
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
44
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
45
|
-
from sglang.srt.model_loader.loader import DefaultModelLoader
|
46
45
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
47
46
|
|
48
47
|
|
48
|
+
class Grok1MLP(nn.Module):
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
hidden_size: int,
|
52
|
+
intermediate_size: int,
|
53
|
+
quant_config: Optional[QuantizationConfig] = None,
|
54
|
+
prefix: str = "",
|
55
|
+
reduce_results=True,
|
56
|
+
) -> None:
|
57
|
+
super().__init__()
|
58
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
59
|
+
hidden_size,
|
60
|
+
[intermediate_size] * 2,
|
61
|
+
bias=False,
|
62
|
+
quant_config=quant_config,
|
63
|
+
prefix=f"{prefix}.gate_up_proj",
|
64
|
+
)
|
65
|
+
self.down_proj = RowParallelLinear(
|
66
|
+
intermediate_size,
|
67
|
+
hidden_size,
|
68
|
+
bias=False,
|
69
|
+
quant_config=quant_config,
|
70
|
+
prefix=f"{prefix}.down_proj",
|
71
|
+
reduce_results=reduce_results,
|
72
|
+
)
|
73
|
+
self.act_fn = GeluAndMul(approximate="tanh")
|
74
|
+
|
75
|
+
def forward(self, x):
|
76
|
+
gate_up, _ = self.gate_up_proj(x)
|
77
|
+
x = self.act_fn(gate_up)
|
78
|
+
x, _ = self.down_proj(x)
|
79
|
+
return x
|
80
|
+
|
81
|
+
|
49
82
|
class Grok1MoE(nn.Module):
|
50
83
|
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
|
51
84
|
across all ranks.
|
@@ -57,6 +90,7 @@ class Grok1MoE(nn.Module):
|
|
57
90
|
|
58
91
|
def __init__(
|
59
92
|
self,
|
93
|
+
config: PretrainedConfig,
|
60
94
|
num_experts: int,
|
61
95
|
top_k: int,
|
62
96
|
hidden_size: int,
|
@@ -64,6 +98,7 @@ class Grok1MoE(nn.Module):
|
|
64
98
|
params_dtype: Optional[torch.dtype] = None,
|
65
99
|
quant_config: Optional[QuantizationConfig] = None,
|
66
100
|
tp_size: Optional[int] = None,
|
101
|
+
reduce_results=True,
|
67
102
|
):
|
68
103
|
super().__init__()
|
69
104
|
self.hidden_size = hidden_size
|
@@ -77,13 +112,16 @@ class Grok1MoE(nn.Module):
|
|
77
112
|
quant_config=None,
|
78
113
|
)
|
79
114
|
|
115
|
+
self.router_logit_softcapping = getattr(
|
116
|
+
config, "router_logit_softcapping", 30.0
|
117
|
+
)
|
80
118
|
self.experts = FusedMoE(
|
81
119
|
num_experts=num_experts,
|
82
120
|
top_k=top_k,
|
83
121
|
hidden_size=hidden_size,
|
84
122
|
intermediate_size=intermediate_size,
|
85
123
|
params_dtype=params_dtype,
|
86
|
-
reduce_results=
|
124
|
+
reduce_results=reduce_results,
|
87
125
|
renormalize=False,
|
88
126
|
quant_config=quant_config,
|
89
127
|
tp_size=tp_size,
|
@@ -93,9 +131,12 @@ class Grok1MoE(nn.Module):
|
|
93
131
|
# NOTE: hidden_states can have either 1D or 2D shape.
|
94
132
|
orig_shape = hidden_states.shape
|
95
133
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
134
|
+
|
96
135
|
# router_logits: (num_tokens, n_experts)
|
97
136
|
router_logits, _ = self.gate(hidden_states)
|
98
137
|
router_logits = 30.0 * F.tanh(router_logits / 30.0)
|
138
|
+
|
139
|
+
# need to assert self.gate.quant_method is unquantized
|
99
140
|
final_hidden_states = self.experts(hidden_states, router_logits)
|
100
141
|
return final_hidden_states.view(orig_shape)
|
101
142
|
|
@@ -103,16 +144,18 @@ class Grok1MoE(nn.Module):
|
|
103
144
|
class Grok1Attention(nn.Module):
|
104
145
|
def __init__(
|
105
146
|
self,
|
147
|
+
config: PretrainedConfig,
|
106
148
|
hidden_size: int,
|
107
149
|
num_heads: int,
|
108
150
|
num_kv_heads: int,
|
109
151
|
layer_id: int = 0,
|
110
152
|
max_position: int = 4096 * 32,
|
111
153
|
rope_theta: float = 10000,
|
112
|
-
logit_cap: float = 30,
|
113
154
|
quant_config: Optional[QuantizationConfig] = None,
|
114
155
|
) -> None:
|
115
156
|
super().__init__()
|
157
|
+
self.config = config
|
158
|
+
self.layer_id = layer_id
|
116
159
|
self.hidden_size = hidden_size
|
117
160
|
tp_size = get_tensor_model_parallel_world_size()
|
118
161
|
self.total_num_heads = num_heads
|
@@ -128,7 +171,7 @@ class Grok1Attention(nn.Module):
|
|
128
171
|
# the KV heads across multiple tensor parallel GPUs.
|
129
172
|
assert tp_size % self.total_num_kv_heads == 0
|
130
173
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
131
|
-
self.head_dim = 128
|
174
|
+
self.head_dim = getattr(config, "head_dim", 128)
|
132
175
|
self.q_size = self.num_heads * self.head_dim
|
133
176
|
self.kv_size = self.num_kv_heads * self.head_dim
|
134
177
|
self.scaling = self.head_dim**-0.5
|
@@ -142,7 +185,6 @@ class Grok1Attention(nn.Module):
|
|
142
185
|
bias=False,
|
143
186
|
quant_config=quant_config,
|
144
187
|
)
|
145
|
-
|
146
188
|
self.o_proj = RowParallelLinear(
|
147
189
|
self.total_num_heads * self.head_dim,
|
148
190
|
hidden_size,
|
@@ -156,6 +198,9 @@ class Grok1Attention(nn.Module):
|
|
156
198
|
base=int(self.rope_theta),
|
157
199
|
is_neox_style=True,
|
158
200
|
)
|
201
|
+
|
202
|
+
logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
|
203
|
+
|
159
204
|
self.attn = RadixAttention(
|
160
205
|
self.num_heads,
|
161
206
|
self.head_dim,
|
@@ -164,7 +209,6 @@ class Grok1Attention(nn.Module):
|
|
164
209
|
layer_id=layer_id,
|
165
210
|
logit_cap=logit_cap,
|
166
211
|
)
|
167
|
-
# TODO(lianmin): load logit cap from config
|
168
212
|
|
169
213
|
def forward(
|
170
214
|
self,
|
@@ -188,10 +232,12 @@ class Grok1DecoderLayer(nn.Module):
|
|
188
232
|
quant_config: Optional[QuantizationConfig] = None,
|
189
233
|
) -> None:
|
190
234
|
super().__init__()
|
235
|
+
self.num_experts = config.num_local_experts
|
191
236
|
self.hidden_size = config.hidden_size
|
192
237
|
|
193
238
|
rope_theta = getattr(config, "rope_theta", 10000)
|
194
239
|
self.self_attn = Grok1Attention(
|
240
|
+
config=config,
|
195
241
|
hidden_size=self.hidden_size,
|
196
242
|
num_heads=config.num_attention_heads,
|
197
243
|
max_position=config.max_position_embeddings,
|
@@ -201,11 +247,17 @@ class Grok1DecoderLayer(nn.Module):
|
|
201
247
|
quant_config=quant_config,
|
202
248
|
)
|
203
249
|
self.block_sparse_moe = Grok1MoE(
|
250
|
+
config=config,
|
204
251
|
num_experts=config.num_local_experts,
|
205
252
|
top_k=config.num_experts_per_tok,
|
206
253
|
hidden_size=config.hidden_size,
|
207
|
-
intermediate_size=
|
254
|
+
intermediate_size=getattr(
|
255
|
+
config,
|
256
|
+
"moe_intermediate_size",
|
257
|
+
getattr(config, "intermediate_size", None),
|
258
|
+
),
|
208
259
|
quant_config=quant_config,
|
260
|
+
reduce_results=True,
|
209
261
|
)
|
210
262
|
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
211
263
|
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
@@ -286,11 +338,11 @@ class Grok1ForCausalLM(nn.Module):
|
|
286
338
|
self,
|
287
339
|
config: PretrainedConfig,
|
288
340
|
quant_config: Optional[QuantizationConfig] = None,
|
341
|
+
cache_config=None,
|
289
342
|
) -> None:
|
290
343
|
super().__init__()
|
291
344
|
self.config = config
|
292
345
|
self.quant_config = quant_config
|
293
|
-
self.torchao_config = global_server_args_dict["torchao_config"]
|
294
346
|
self.model = Grok1Model(config, quant_config=quant_config)
|
295
347
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
296
348
|
self.logits_processor = LogitsProcessor(config)
|
@@ -313,6 +365,8 @@ class Grok1ForCausalLM(nn.Module):
|
|
313
365
|
("qkv_proj", "q_proj", "q"),
|
314
366
|
("qkv_proj", "k_proj", "k"),
|
315
367
|
("qkv_proj", "v_proj", "v"),
|
368
|
+
("gate_up_proj", "gate_proj", 0),
|
369
|
+
("gate_up_proj", "up_proj", 1),
|
316
370
|
]
|
317
371
|
|
318
372
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
@@ -348,6 +402,11 @@ class Grok1ForCausalLM(nn.Module):
|
|
348
402
|
continue
|
349
403
|
name = name.replace(weight_name, param_name)
|
350
404
|
|
405
|
+
if (
|
406
|
+
name.endswith(".bias") or name.endswith("_bias")
|
407
|
+
) and name not in params_dict:
|
408
|
+
continue
|
409
|
+
|
351
410
|
param = params_dict[name]
|
352
411
|
weight_loader = param.weight_loader
|
353
412
|
weight_loader(
|
@@ -360,7 +419,9 @@ class Grok1ForCausalLM(nn.Module):
|
|
360
419
|
break
|
361
420
|
else:
|
362
421
|
# Skip loading extra bias for GPTQ models.
|
363
|
-
if
|
422
|
+
if (
|
423
|
+
name.endswith(".bias") or name.endswith("_bias")
|
424
|
+
) and name not in params_dict:
|
364
425
|
continue
|
365
426
|
# Skip loading kv_scale from ckpts towards new design.
|
366
427
|
if name.endswith(".kv_scale") and name not in params_dict:
|
@@ -374,8 +435,6 @@ class Grok1ForCausalLM(nn.Module):
|
|
374
435
|
)
|
375
436
|
weight_loader(param, loaded_weight)
|
376
437
|
|
377
|
-
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
378
|
-
|
379
438
|
|
380
439
|
class Grok1ModelForCausalLM(Grok1ForCausalLM):
|
381
440
|
"""An alias for backward-compatbility."""
|
sglang/srt/models/llama.py
CHANGED
@@ -36,12 +36,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
|
|
36
36
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
37
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
-
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
40
39
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
41
40
|
ParallelLMHead,
|
42
41
|
VocabParallelEmbedding,
|
43
42
|
)
|
44
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
45
43
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
46
44
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
47
45
|
from sglang.srt.utils import make_layers
|
@@ -296,6 +294,28 @@ class LlamaModel(nn.Module):
|
|
296
294
|
|
297
295
|
|
298
296
|
class LlamaForCausalLM(nn.Module):
|
297
|
+
|
298
|
+
# BitandBytes specific attributes
|
299
|
+
default_bitsandbytes_target_modules = [
|
300
|
+
".gate_proj.",
|
301
|
+
".down_proj.",
|
302
|
+
".up_proj.",
|
303
|
+
".q_proj.",
|
304
|
+
".k_proj.",
|
305
|
+
".v_proj.",
|
306
|
+
".o_proj.",
|
307
|
+
]
|
308
|
+
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
309
|
+
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
310
|
+
bitsandbytes_stacked_params_mapping = {
|
311
|
+
# shard_name, weight_name, index
|
312
|
+
"q_proj": ("qkv_proj", 0),
|
313
|
+
"k_proj": ("qkv_proj", 1),
|
314
|
+
"v_proj": ("qkv_proj", 2),
|
315
|
+
"gate_proj": ("gate_up_proj", 0),
|
316
|
+
"up_proj": ("gate_up_proj", 1),
|
317
|
+
}
|
318
|
+
|
299
319
|
def __init__(
|
300
320
|
self,
|
301
321
|
config: LlamaConfig,
|
@@ -304,7 +324,6 @@ class LlamaForCausalLM(nn.Module):
|
|
304
324
|
super().__init__()
|
305
325
|
self.config = config
|
306
326
|
self.quant_config = quant_config
|
307
|
-
self.torchao_config = global_server_args_dict["torchao_config"]
|
308
327
|
self.model = LlamaModel(config, quant_config=quant_config)
|
309
328
|
# Llama 3.2 1B Insturct set tie_word_embeddings to True
|
310
329
|
# Llama 3.1 8B Insturct set tie_word_embeddings to False
|
@@ -424,8 +443,6 @@ class LlamaForCausalLM(nn.Module):
|
|
424
443
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
425
444
|
weight_loader(param, loaded_weight)
|
426
445
|
|
427
|
-
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
428
|
-
|
429
446
|
def get_weights_by_name(
|
430
447
|
self, name: str, truncate_size: int = 100, tp_size: int = 1
|
431
448
|
) -> Optional[torch.Tensor]:
|
@@ -18,7 +18,7 @@ import torch
|
|
18
18
|
from torch import nn
|
19
19
|
from transformers import LlamaConfig
|
20
20
|
|
21
|
-
from sglang.srt.layers.
|
21
|
+
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
22
22
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
23
23
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
24
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
@@ -33,14 +33,13 @@ class LlamaForClassification(nn.Module):
|
|
33
33
|
) -> None:
|
34
34
|
super().__init__()
|
35
35
|
self.config = config
|
36
|
-
self.torchao_config = None
|
37
36
|
self.quant_config = quant_config
|
38
37
|
self.model = LlamaModel(config, quant_config=quant_config)
|
39
38
|
|
40
39
|
self.classification_head = nn.Linear(
|
41
40
|
config.hidden_size, config.classification_out_size, bias=False
|
42
41
|
)
|
43
|
-
self.
|
42
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
|
44
43
|
|
45
44
|
@torch.no_grad()
|
46
45
|
def forward(
|
@@ -49,28 +48,17 @@ class LlamaForClassification(nn.Module):
|
|
49
48
|
positions: torch.Tensor,
|
50
49
|
forward_batch: ForwardBatch,
|
51
50
|
input_embeds: torch.Tensor = None,
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
if scores.shape[0] != forward_batch.batch_size:
|
59
|
-
print("Warning: the EOS tokens are missing in some sentences.")
|
60
|
-
scores = torch.ones(
|
61
|
-
(forward_batch.batch_size, self.config.classification_out_size)
|
62
|
-
).to(input_ids.device)
|
51
|
+
get_embedding: bool = True,
|
52
|
+
) -> EmbeddingPoolerOutput:
|
53
|
+
assert (
|
54
|
+
get_embedding
|
55
|
+
), "LlamaForClassification is only used for embedding. Please add --is-embedding when you launch the server."
|
63
56
|
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
normalized_prompt_logprobs=scores,
|
68
|
-
input_token_logprobs=torch.ones_like(input_ids),
|
69
|
-
input_top_logprobs=None,
|
70
|
-
output_top_logprobs=None,
|
71
|
-
)
|
57
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
58
|
+
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
|
59
|
+
scores = self.classification_head(last_token_hidden)
|
72
60
|
|
73
|
-
return
|
61
|
+
return EmbeddingPoolerOutput(scores)
|
74
62
|
|
75
63
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
76
64
|
params_dict = dict(self.named_parameters())
|
@@ -21,7 +21,6 @@ from transformers import LlamaConfig
|
|
21
21
|
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
22
22
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
23
23
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
|
-
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
25
24
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
26
25
|
|
27
26
|
|
@@ -33,7 +32,6 @@ class LlamaForSequenceClassification(nn.Module):
|
|
33
32
|
) -> None:
|
34
33
|
super().__init__()
|
35
34
|
self.config = config
|
36
|
-
self.torchao_config = None
|
37
35
|
self.quant_config = quant_config
|
38
36
|
self.num_labels = config.num_labels
|
39
37
|
self.model = LlamaModel(config, quant_config=quant_config)
|
sglang/srt/models/llava.py
CHANGED
@@ -57,6 +57,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
57
57
|
else:
|
58
58
|
image_aspect_ratio = "anyres"
|
59
59
|
offset_list = []
|
60
|
+
image_inputs.image_pad_len = []
|
60
61
|
for image_idx, image_s in enumerate(image_sizes):
|
61
62
|
if len(image_sizes) > 16:
|
62
63
|
# 2x2 pooling with stride 2
|
@@ -103,6 +104,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
103
104
|
+ input_ids[offset + 1 :]
|
104
105
|
)
|
105
106
|
offset_list.append(offset)
|
107
|
+
image_inputs.image_pad_len.append(new_image_feature_len)
|
106
108
|
|
107
109
|
image_inputs.image_offsets = offset_list
|
108
110
|
return input_ids
|
@@ -134,6 +136,14 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
134
136
|
image_inputs = forward_batch.image_inputs
|
135
137
|
|
136
138
|
if forward_batch.forward_mode.is_extend():
|
139
|
+
# Clamp input ids. This is because the input_ids for the image tokens are
|
140
|
+
# filled with the hash values of the image for the prefix matching in the radix attention.
|
141
|
+
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
142
|
+
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
143
|
+
|
144
|
+
# Embed text inputs
|
145
|
+
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
146
|
+
|
137
147
|
# Got List[List[str]] extend it to List[str]
|
138
148
|
# The length of the List should be equal to batch size
|
139
149
|
modalities_list = []
|
@@ -142,18 +152,12 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
142
152
|
if im and im.modalities is not None:
|
143
153
|
modalities_list.extend(im.modalities)
|
144
154
|
if im and im.image_offsets:
|
145
|
-
max_image_offset.append(
|
155
|
+
max_image_offset.append(
|
156
|
+
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
|
157
|
+
)
|
146
158
|
else:
|
147
159
|
max_image_offset.append(-1)
|
148
160
|
|
149
|
-
# Clamp input ids. This is because the input_ids for the image tokens are
|
150
|
-
# filled with the hash values of the image for the prefix matching in the radix attention.
|
151
|
-
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
152
|
-
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
153
|
-
|
154
|
-
# Embed text inputs
|
155
|
-
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
156
|
-
|
157
161
|
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
|
158
162
|
need_vision = start_positions <= np.array(max_image_offset)
|
159
163
|
|
@@ -350,6 +354,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
350
354
|
|
351
355
|
# Fill in the placeholder for the image
|
352
356
|
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
357
|
+
extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
|
353
358
|
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
354
359
|
pt = 0
|
355
360
|
for i in range(bs):
|
@@ -357,18 +362,36 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
357
362
|
continue
|
358
363
|
|
359
364
|
start_idx = extend_start_loc_cpu[i]
|
365
|
+
seq_len = extend_seq_lens[i]
|
360
366
|
prefix_len = prefix_lens_cpu[i]
|
361
367
|
|
362
368
|
# Multiple images
|
363
|
-
for
|
364
|
-
|
369
|
+
for image_idx, image_offset in enumerate(
|
370
|
+
image_inputs[i].image_offsets
|
371
|
+
):
|
372
|
+
if (
|
373
|
+
image_offset + image_inputs[i].image_pad_len[image_idx]
|
374
|
+
<= prefix_len
|
375
|
+
):
|
365
376
|
continue
|
377
|
+
if image_offset >= prefix_len + seq_len:
|
378
|
+
break
|
366
379
|
|
367
|
-
tmp_image_feature = image_features[pt][
|
380
|
+
tmp_image_feature = image_features[pt][image_idx]
|
368
381
|
pad_len = tmp_image_feature.shape[0]
|
369
382
|
|
370
|
-
|
371
|
-
|
383
|
+
input_offset = image_offset - prefix_len
|
384
|
+
left_idx = start_idx + input_offset
|
385
|
+
right_idx = left_idx + pad_len
|
386
|
+
assert right_idx > start_idx
|
387
|
+
if input_offset < 0:
|
388
|
+
left_idx = start_idx
|
389
|
+
tmp_image_feature = tmp_image_feature[-input_offset:]
|
390
|
+
if right_idx > start_idx + seq_len:
|
391
|
+
tmp_image_feature = tmp_image_feature[
|
392
|
+
: start_idx + seq_len - right_idx
|
393
|
+
]
|
394
|
+
right_idx = start_idx + seq_len
|
372
395
|
try:
|
373
396
|
input_embeds[left_idx:right_idx] = tmp_image_feature
|
374
397
|
except RuntimeError as e:
|
sglang/srt/models/mixtral.py
CHANGED
@@ -21,9 +21,13 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import MixtralConfig
|
24
|
-
from vllm.distributed import
|
24
|
+
from vllm.distributed import (
|
25
|
+
get_tensor_model_parallel_world_size,
|
26
|
+
tensor_model_parallel_all_reduce,
|
27
|
+
)
|
25
28
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
29
|
|
30
|
+
from sglang.srt.layers.ep_moe.layer import EPMoE
|
27
31
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
28
32
|
from sglang.srt.layers.layernorm import RMSNorm
|
29
33
|
from sglang.srt.layers.linear import (
|
@@ -34,7 +38,6 @@ from sglang.srt.layers.linear import (
|
|
34
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
35
39
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
37
|
-
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
38
41
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
39
42
|
ParallelLMHead,
|
40
43
|
VocabParallelEmbedding,
|
@@ -65,6 +68,7 @@ class MixtralMoE(nn.Module):
|
|
65
68
|
prefix: str = "",
|
66
69
|
):
|
67
70
|
super().__init__()
|
71
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
68
72
|
self.hidden_size = hidden_size
|
69
73
|
|
70
74
|
# Gate always runs at half / full precision for now.
|
@@ -76,14 +80,13 @@ class MixtralMoE(nn.Module):
|
|
76
80
|
quant_config=None,
|
77
81
|
prefix=f"{prefix}.gate",
|
78
82
|
)
|
79
|
-
|
80
|
-
self.experts =
|
83
|
+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
84
|
+
self.experts = MoEImpl(
|
81
85
|
num_experts=num_experts,
|
82
86
|
top_k=top_k,
|
83
87
|
hidden_size=hidden_size,
|
84
88
|
intermediate_size=intermediate_size,
|
85
89
|
params_dtype=params_dtype,
|
86
|
-
reduce_results=True,
|
87
90
|
renormalize=True,
|
88
91
|
quant_config=quant_config,
|
89
92
|
tp_size=tp_size,
|
@@ -97,6 +100,8 @@ class MixtralMoE(nn.Module):
|
|
97
100
|
# router_logits: (num_tokens, n_experts)
|
98
101
|
router_logits, _ = self.gate(hidden_states)
|
99
102
|
final_hidden_states = self.experts(hidden_states, router_logits)
|
103
|
+
if self.tp_size > 1:
|
104
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
100
105
|
return final_hidden_states.view(orig_shape)
|
101
106
|
|
102
107
|
|
@@ -295,7 +300,6 @@ class MixtralForCausalLM(nn.Module):
|
|
295
300
|
super().__init__()
|
296
301
|
self.config = config
|
297
302
|
self.quant_config = quant_config
|
298
|
-
self.torchao_config = global_server_args_dict["torchao_config"]
|
299
303
|
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
|
300
304
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
301
305
|
self.logits_processor = LogitsProcessor(config)
|
@@ -322,7 +326,8 @@ class MixtralForCausalLM(nn.Module):
|
|
322
326
|
|
323
327
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
324
328
|
# (param_name, weight_name, expert_id, shard_id)
|
325
|
-
|
329
|
+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
330
|
+
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
326
331
|
ckpt_gate_proj_name="w1",
|
327
332
|
ckpt_down_proj_name="w2",
|
328
333
|
ckpt_up_proj_name="w3",
|
@@ -387,7 +392,5 @@ class MixtralForCausalLM(nn.Module):
|
|
387
392
|
)
|
388
393
|
weight_loader(param, loaded_weight)
|
389
394
|
|
390
|
-
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
391
|
-
|
392
395
|
|
393
396
|
EntryClass = MixtralForCausalLM
|
sglang/srt/models/phi3_small.py
CHANGED
@@ -17,13 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
|
|
17
17
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
18
18
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
19
19
|
from sglang.srt.layers.radix_attention import RadixAttention
|
20
|
-
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
21
20
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
22
21
|
DEFAULT_VOCAB_PADDING_SIZE,
|
23
22
|
ParallelLMHead,
|
24
23
|
VocabParallelEmbedding,
|
25
24
|
)
|
26
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
27
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
28
26
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
29
27
|
from sglang.srt.utils import make_layers
|
@@ -348,7 +346,6 @@ class Phi3SmallForCausalLM(nn.Module):
|
|
348
346
|
quant_config=quant_config,
|
349
347
|
prefix="model",
|
350
348
|
)
|
351
|
-
self.torchao_config = global_server_args_dict["torchao_config"]
|
352
349
|
self.vocab_size = config.vocab_size
|
353
350
|
self.mup_width_multiplier = config.mup_width_multiplier
|
354
351
|
self.lm_head = ParallelLMHead(
|
@@ -441,7 +438,5 @@ class Phi3SmallForCausalLM(nn.Module):
|
|
441
438
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
442
439
|
weight_loader(param, loaded_weight)
|
443
440
|
|
444
|
-
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
445
|
-
|
446
441
|
|
447
442
|
EntryClass = Phi3SmallForCausalLM
|
sglang/srt/models/qwen2.py
CHANGED
@@ -267,6 +267,26 @@ class Qwen2Model(nn.Module):
|
|
267
267
|
|
268
268
|
|
269
269
|
class Qwen2ForCausalLM(nn.Module):
|
270
|
+
|
271
|
+
# BitandBytes specific attributes
|
272
|
+
default_bitsandbytes_target_modules = [
|
273
|
+
".gate_proj.",
|
274
|
+
".down_proj.",
|
275
|
+
".up_proj.",
|
276
|
+
".q_proj.",
|
277
|
+
".k_proj.",
|
278
|
+
".v_proj.",
|
279
|
+
".o_proj.",
|
280
|
+
]
|
281
|
+
bitsandbytes_stacked_params_mapping = {
|
282
|
+
# shard_name, weight_name, index
|
283
|
+
"q_proj": ("qkv_proj", 0),
|
284
|
+
"k_proj": ("qkv_proj", 1),
|
285
|
+
"v_proj": ("qkv_proj", 2),
|
286
|
+
"gate_proj": ("gate_up_proj", 0),
|
287
|
+
"up_proj": ("gate_up_proj", 1),
|
288
|
+
}
|
289
|
+
|
270
290
|
def __init__(
|
271
291
|
self,
|
272
292
|
config: Qwen2Config,
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -40,12 +40,10 @@ from sglang.srt.layers.linear import (
|
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
44
43
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
45
44
|
ParallelLMHead,
|
46
45
|
VocabParallelEmbedding,
|
47
46
|
)
|
48
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49
47
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
50
48
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
51
49
|
|
@@ -352,7 +350,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
352
350
|
super().__init__()
|
353
351
|
self.config = config
|
354
352
|
self.quant_config = quant_config
|
355
|
-
self.torchao_config = global_server_args_dict["torchao_config"]
|
356
353
|
self.model = Qwen2MoeModel(config, quant_config)
|
357
354
|
self.lm_head = ParallelLMHead(
|
358
355
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
@@ -445,7 +442,5 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
445
442
|
)
|
446
443
|
weight_loader(param, loaded_weight)
|
447
444
|
|
448
|
-
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
449
|
-
|
450
445
|
|
451
446
|
EntryClass = Qwen2MoeForCausalLM
|
@@ -58,12 +58,10 @@ from sglang.srt.layers.layernorm import RMSNorm
|
|
58
58
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
59
59
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
60
60
|
from sglang.srt.layers.radix_attention import RadixAttention
|
61
|
-
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
62
61
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
63
62
|
ParallelLMHead,
|
64
63
|
VocabParallelEmbedding,
|
65
64
|
)
|
66
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
67
65
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
68
66
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
69
67
|
|
@@ -392,7 +390,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
392
390
|
super().__init__()
|
393
391
|
self.config = config
|
394
392
|
self.quant_config = quant_config
|
395
|
-
self.torchao_config = global_server_args_dict["torchao_config"]
|
396
393
|
self.supports_torch_tp = True
|
397
394
|
self.model = LlamaModel(config, quant_config=quant_config)
|
398
395
|
if self.config.tie_word_embeddings:
|
@@ -503,8 +500,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
503
500
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
504
501
|
weight_loader(param, loaded_weight)
|
505
502
|
|
506
|
-
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
507
|
-
|
508
503
|
|
509
504
|
class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
|
510
505
|
pass
|