sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +16 -7
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +8 -1
- sglang/srt/layers/attention/flashinfer_backend.py +21 -5
- sglang/srt/layers/linear.py +89 -47
- sglang/srt/layers/logits_processor.py +6 -6
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +439 -0
- sglang/srt/layers/quantization/__init__.py +5 -2
- sglang/srt/layers/quantization/fp8.py +107 -53
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/vocab_parallel_embedding.py +16 -3
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/configure_logging.py +43 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/detokenizer_manager.py +0 -2
- sglang/srt/managers/io_struct.py +29 -13
- sglang/srt/managers/schedule_batch.py +7 -1
- sglang/srt/managers/scheduler.py +58 -15
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +109 -45
- sglang/srt/mem_cache/memory_pool.py +313 -53
- sglang/srt/metrics/collector.py +32 -35
- sglang/srt/model_executor/cuda_graph_runner.py +14 -7
- sglang/srt/model_executor/forward_batch_info.py +20 -15
- sglang/srt/model_executor/model_runner.py +53 -10
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/grok.py +25 -16
- sglang/srt/models/llama.py +46 -4
- sglang/srt/models/qwen2.py +11 -0
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +15 -5
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +125 -69
- sglang/srt/server_args.py +39 -19
- sglang/srt/speculative/eagle_utils.py +93 -85
- sglang/srt/speculative/eagle_worker.py +48 -33
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +61 -5
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -50,10 +50,12 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
50
50
|
from sglang.srt.model_loader import get_model
|
51
51
|
from sglang.srt.server_args import ServerArgs
|
52
52
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
53
|
+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
53
54
|
from sglang.srt.utils import (
|
54
55
|
enable_show_time_cost,
|
55
56
|
get_available_gpu_memory,
|
56
57
|
init_custom_process_group,
|
58
|
+
is_cuda,
|
57
59
|
is_hip,
|
58
60
|
monkey_patch_vllm_gguf_config,
|
59
61
|
monkey_patch_vllm_p2p_access_check,
|
@@ -89,6 +91,7 @@ class ModelRunner:
|
|
89
91
|
self.is_draft_worker = is_draft_worker
|
90
92
|
self.is_generation = model_config.is_generation
|
91
93
|
self.is_multimodal = model_config.is_multimodal
|
94
|
+
self.should_log = tp_rank == 0
|
92
95
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
93
96
|
server_args.speculative_algorithm
|
94
97
|
)
|
@@ -117,15 +120,21 @@ class ModelRunner:
|
|
117
120
|
|
118
121
|
if self.is_multimodal:
|
119
122
|
self.mem_fraction_static *= 0.95
|
123
|
+
logger.info(
|
124
|
+
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
125
|
+
f"because this is a multimodal model."
|
126
|
+
)
|
127
|
+
|
120
128
|
if self.model_config.hf_config.architectures == [
|
121
129
|
"MllamaForConditionalGeneration"
|
122
130
|
]:
|
123
131
|
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
124
132
|
server_args.chunked_prefill_size = -1
|
125
|
-
|
133
|
+
|
126
134
|
if self.model_config.hf_config.architectures == [
|
127
135
|
"Qwen2VLForConditionalGeneration"
|
128
136
|
]:
|
137
|
+
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
129
138
|
logger.info(
|
130
139
|
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
|
131
140
|
)
|
@@ -158,6 +167,10 @@ class ModelRunner:
|
|
158
167
|
# Get memory before model loading
|
159
168
|
min_per_gpu_memory = self.init_torch_distributed()
|
160
169
|
|
170
|
+
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
171
|
+
enable=self.server_args.enable_memory_saver
|
172
|
+
)
|
173
|
+
|
161
174
|
# Load the model
|
162
175
|
self.sampler = Sampler()
|
163
176
|
self.load_model()
|
@@ -198,7 +211,7 @@ class ModelRunner:
|
|
198
211
|
if self.device == "cuda":
|
199
212
|
backend = "nccl"
|
200
213
|
elif self.device == "xpu":
|
201
|
-
# TODO(liangan1):Just use gloo to bypass the initilization fail
|
214
|
+
# TODO(liangan1): Just use gloo to bypass the initilization fail
|
202
215
|
# Need to use xccl for xpu backend in the future
|
203
216
|
backend = "gloo"
|
204
217
|
elif self.device == "hpu":
|
@@ -264,11 +277,35 @@ class ModelRunner:
|
|
264
277
|
monkey_patch_vllm_gguf_config()
|
265
278
|
|
266
279
|
# Load the model
|
267
|
-
self.
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
280
|
+
with self.memory_saver_adapter.region():
|
281
|
+
self.model = get_model(
|
282
|
+
model_config=self.model_config,
|
283
|
+
load_config=self.load_config,
|
284
|
+
device_config=DeviceConfig(self.device),
|
285
|
+
)
|
286
|
+
|
287
|
+
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
288
|
+
if self.server_args.quantization_param_path is not None:
|
289
|
+
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
290
|
+
self.model.load_kv_cache_scales(
|
291
|
+
self.server_args.quantization_param_path
|
292
|
+
)
|
293
|
+
logger.info(
|
294
|
+
"Loaded KV cache scaling factors from %s",
|
295
|
+
self.server_args.quantization_param_path,
|
296
|
+
)
|
297
|
+
else:
|
298
|
+
raise RuntimeError(
|
299
|
+
"Using FP8 KV cache and scaling factors provided but "
|
300
|
+
"model %s does not support loading scaling factors.",
|
301
|
+
self.model.__class__,
|
302
|
+
)
|
303
|
+
else:
|
304
|
+
logger.warning(
|
305
|
+
"Using FP8 KV cache but no scaling factors "
|
306
|
+
"provided. Defaulting to scaling factors of 1.0. "
|
307
|
+
"This may lead to less accurate results!"
|
308
|
+
)
|
272
309
|
|
273
310
|
# Parse other args
|
274
311
|
self.sliding_window_size = (
|
@@ -386,7 +423,7 @@ class ModelRunner:
|
|
386
423
|
|
387
424
|
logger.info(
|
388
425
|
f"init custom process group: master_address={master_address}, master_port={master_port}, "
|
389
|
-
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
426
|
+
f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
390
427
|
)
|
391
428
|
|
392
429
|
try:
|
@@ -509,6 +546,9 @@ class ModelRunner:
|
|
509
546
|
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
510
547
|
else:
|
511
548
|
self.kv_cache_dtype = torch.float8_e5m2
|
549
|
+
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
|
550
|
+
if is_cuda():
|
551
|
+
self.kv_cache_dtype = torch.float8_e4m3fn
|
512
552
|
else:
|
513
553
|
raise ValueError(
|
514
554
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
@@ -556,6 +596,7 @@ class ModelRunner:
|
|
556
596
|
max_context_len=self.model_config.context_len + 4,
|
557
597
|
device=self.device,
|
558
598
|
use_records=False,
|
599
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
559
600
|
)
|
560
601
|
if (
|
561
602
|
self.model_config.attention_arch == AttentionArch.MLA
|
@@ -568,6 +609,7 @@ class ModelRunner:
|
|
568
609
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
569
610
|
layer_num=self.model_config.num_hidden_layers,
|
570
611
|
device=self.device,
|
612
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
571
613
|
)
|
572
614
|
elif self.server_args.enable_double_sparsity:
|
573
615
|
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
@@ -578,6 +620,7 @@ class ModelRunner:
|
|
578
620
|
layer_num=self.model_config.num_hidden_layers,
|
579
621
|
device=self.device,
|
580
622
|
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
623
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
581
624
|
)
|
582
625
|
else:
|
583
626
|
self.token_to_kv_pool = MHATokenToKVPool(
|
@@ -587,6 +630,7 @@ class ModelRunner:
|
|
587
630
|
head_dim=self.model_config.head_dim,
|
588
631
|
layer_num=self.model_config.num_hidden_layers,
|
589
632
|
device=self.device,
|
633
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
590
634
|
)
|
591
635
|
logger.info(
|
592
636
|
f"Memory pool end. "
|
@@ -627,7 +671,6 @@ class ModelRunner:
|
|
627
671
|
)
|
628
672
|
|
629
673
|
def init_double_sparsity_channel_config(self, selected_channel):
|
630
|
-
|
631
674
|
selected_channel = "." + selected_channel + "_proj"
|
632
675
|
self.sorted_channels = []
|
633
676
|
# load channel config
|
@@ -718,7 +761,7 @@ class ModelRunner:
|
|
718
761
|
elif forward_batch.forward_mode.is_idle():
|
719
762
|
return self.forward_idle(forward_batch)
|
720
763
|
else:
|
721
|
-
raise ValueError(f"
|
764
|
+
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
722
765
|
|
723
766
|
def sample(
|
724
767
|
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
sglang/srt/models/chatglm.py
CHANGED
@@ -23,8 +23,8 @@ from torch import nn
|
|
23
23
|
from torch.nn import LayerNorm
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
25
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
|
-
from vllm.transformers_utils.configs import ChatGLMConfig
|
27
26
|
|
27
|
+
from sglang.srt.configs import ChatGLMConfig
|
28
28
|
from sglang.srt.layers.activation import SiluAndMul
|
29
29
|
from sglang.srt.layers.layernorm import RMSNorm
|
30
30
|
from sglang.srt.layers.linear import (
|
sglang/srt/models/dbrx.py
CHANGED
@@ -25,8 +25,8 @@ from vllm.distributed import (
|
|
25
25
|
tensor_model_parallel_all_reduce,
|
26
26
|
)
|
27
27
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
28
|
-
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
29
28
|
|
29
|
+
from sglang.srt.configs import DbrxConfig
|
30
30
|
from sglang.srt.layers.linear import (
|
31
31
|
QKVParallelLinear,
|
32
32
|
ReplicatedLinear,
|
sglang/srt/models/grok.py
CHANGED
@@ -57,6 +57,7 @@ class Grok1MLP(nn.Module):
|
|
57
57
|
quant_config: Optional[QuantizationConfig] = None,
|
58
58
|
prefix: str = "",
|
59
59
|
reduce_results=True,
|
60
|
+
use_presharded_weights: bool = False,
|
60
61
|
) -> None:
|
61
62
|
super().__init__()
|
62
63
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -65,6 +66,7 @@ class Grok1MLP(nn.Module):
|
|
65
66
|
bias=False,
|
66
67
|
quant_config=quant_config,
|
67
68
|
prefix=f"{prefix}.gate_up_proj",
|
69
|
+
use_presharded_weights=use_presharded_weights,
|
68
70
|
)
|
69
71
|
self.down_proj = RowParallelLinear(
|
70
72
|
intermediate_size,
|
@@ -73,6 +75,7 @@ class Grok1MLP(nn.Module):
|
|
73
75
|
quant_config=quant_config,
|
74
76
|
prefix=f"{prefix}.down_proj",
|
75
77
|
reduce_results=reduce_results,
|
78
|
+
use_presharded_weights=use_presharded_weights,
|
76
79
|
)
|
77
80
|
self.act_fn = GeluAndMul(approximate="tanh")
|
78
81
|
|
@@ -103,6 +106,7 @@ class Grok1MoE(nn.Module):
|
|
103
106
|
quant_config: Optional[QuantizationConfig] = None,
|
104
107
|
tp_size: Optional[int] = None,
|
105
108
|
reduce_results=True,
|
109
|
+
use_presharded_weights: bool = False,
|
106
110
|
):
|
107
111
|
super().__init__()
|
108
112
|
self.hidden_size = hidden_size
|
@@ -129,6 +133,7 @@ class Grok1MoE(nn.Module):
|
|
129
133
|
renormalize=False,
|
130
134
|
quant_config=quant_config,
|
131
135
|
tp_size=tp_size,
|
136
|
+
use_presharded_weights=use_presharded_weights,
|
132
137
|
)
|
133
138
|
|
134
139
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
@@ -156,6 +161,7 @@ class Grok1Attention(nn.Module):
|
|
156
161
|
max_position: int = 4096 * 32,
|
157
162
|
rope_theta: float = 10000,
|
158
163
|
quant_config: Optional[QuantizationConfig] = None,
|
164
|
+
reduce_results: bool = True,
|
159
165
|
) -> None:
|
160
166
|
super().__init__()
|
161
167
|
self.config = config
|
@@ -194,6 +200,7 @@ class Grok1Attention(nn.Module):
|
|
194
200
|
hidden_size,
|
195
201
|
bias=False,
|
196
202
|
quant_config=quant_config,
|
203
|
+
reduce_results=reduce_results,
|
197
204
|
)
|
198
205
|
self.rotary_emb = get_rope(
|
199
206
|
self.head_dim,
|
@@ -234,10 +241,12 @@ class Grok1DecoderLayer(nn.Module):
|
|
234
241
|
config: PretrainedConfig,
|
235
242
|
layer_id: int = 0,
|
236
243
|
quant_config: Optional[QuantizationConfig] = None,
|
244
|
+
use_presharded_weights: bool = False,
|
237
245
|
) -> None:
|
238
246
|
super().__init__()
|
239
247
|
self.num_experts = config.num_local_experts
|
240
248
|
self.hidden_size = config.hidden_size
|
249
|
+
self.layer_id = layer_id
|
241
250
|
|
242
251
|
rope_theta = getattr(config, "rope_theta", 10000)
|
243
252
|
self.self_attn = Grok1Attention(
|
@@ -262,6 +271,7 @@ class Grok1DecoderLayer(nn.Module):
|
|
262
271
|
),
|
263
272
|
quant_config=quant_config,
|
264
273
|
reduce_results=True,
|
274
|
+
use_presharded_weights=use_presharded_weights,
|
265
275
|
)
|
266
276
|
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
267
277
|
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
@@ -299,6 +309,7 @@ class Grok1Model(nn.Module):
|
|
299
309
|
self,
|
300
310
|
config: PretrainedConfig,
|
301
311
|
quant_config: Optional[QuantizationConfig] = None,
|
312
|
+
use_presharded_weights: bool = False,
|
302
313
|
) -> None:
|
303
314
|
super().__init__()
|
304
315
|
self.config = config
|
@@ -311,7 +322,12 @@ class Grok1Model(nn.Module):
|
|
311
322
|
)
|
312
323
|
self.layers = nn.ModuleList(
|
313
324
|
[
|
314
|
-
Grok1DecoderLayer(
|
325
|
+
Grok1DecoderLayer(
|
326
|
+
config,
|
327
|
+
i,
|
328
|
+
quant_config=quant_config,
|
329
|
+
use_presharded_weights=use_presharded_weights,
|
330
|
+
)
|
315
331
|
for i in range(config.num_hidden_layers)
|
316
332
|
]
|
317
333
|
)
|
@@ -347,11 +363,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
347
363
|
super().__init__()
|
348
364
|
self.config = config
|
349
365
|
self.quant_config = quant_config
|
350
|
-
self.model = Grok1Model(config, quant_config=quant_config)
|
351
|
-
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
352
|
-
self.logits_processor = LogitsProcessor(config)
|
353
366
|
|
354
|
-
# Monkey patch _prepare_weights to load pre-sharded weights
|
355
367
|
if (
|
356
368
|
self.config.num_local_experts > 0
|
357
369
|
and get_tensor_model_parallel_world_size() > 1
|
@@ -361,6 +373,14 @@ class Grok1ForCausalLM(nn.Module):
|
|
361
373
|
else:
|
362
374
|
self.use_presharded_weights = False
|
363
375
|
|
376
|
+
self.model = Grok1Model(
|
377
|
+
config,
|
378
|
+
quant_config=quant_config,
|
379
|
+
use_presharded_weights=self.use_presharded_weights,
|
380
|
+
)
|
381
|
+
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
382
|
+
self.logits_processor = LogitsProcessor(config)
|
383
|
+
|
364
384
|
def forward(
|
365
385
|
self,
|
366
386
|
input_ids: torch.Tensor,
|
@@ -376,10 +396,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
376
396
|
def load_weights(
|
377
397
|
self,
|
378
398
|
weights: Iterable[Tuple[str, torch.Tensor]],
|
379
|
-
use_presharded_weights: bool | None = None,
|
380
399
|
):
|
381
|
-
if use_presharded_weights is None:
|
382
|
-
use_presharded_weights = self.use_presharded_weights
|
383
400
|
num_experts = self.config.num_local_experts
|
384
401
|
|
385
402
|
stacked_params_mapping = [
|
@@ -435,20 +452,12 @@ class Grok1ForCausalLM(nn.Module):
|
|
435
452
|
continue
|
436
453
|
name = name.replace(weight_name, param_name)
|
437
454
|
|
438
|
-
if use_presharded_weights:
|
439
|
-
extra_kwargs = {
|
440
|
-
"use_presharded_weights": use_presharded_weights
|
441
|
-
}
|
442
|
-
else:
|
443
|
-
extra_kwargs = {}
|
444
|
-
|
445
455
|
load_weight_wrapper(
|
446
456
|
name,
|
447
457
|
loaded_weight,
|
448
458
|
name,
|
449
459
|
shard_id=shard_id,
|
450
460
|
expert_id=expert_id,
|
451
|
-
**extra_kwargs,
|
452
461
|
)
|
453
462
|
break
|
454
463
|
else:
|
sglang/srt/models/llama.py
CHANGED
@@ -22,8 +22,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
24
|
from transformers import LlamaConfig
|
25
|
-
from vllm.distributed import
|
25
|
+
from vllm.distributed import (
|
26
|
+
get_tensor_model_parallel_rank,
|
27
|
+
get_tensor_model_parallel_world_size,
|
28
|
+
)
|
26
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
+
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
|
27
31
|
|
28
32
|
from sglang.srt.layers.activation import SiluAndMul
|
29
33
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -100,6 +104,7 @@ class LlamaAttention(nn.Module):
|
|
100
104
|
max_position_embeddings: int = 8192,
|
101
105
|
quant_config: Optional[QuantizationConfig] = None,
|
102
106
|
prefix: str = "",
|
107
|
+
bias: bool = False,
|
103
108
|
) -> None:
|
104
109
|
super().__init__()
|
105
110
|
self.hidden_size = hidden_size
|
@@ -132,14 +137,14 @@ class LlamaAttention(nn.Module):
|
|
132
137
|
self.head_dim,
|
133
138
|
self.total_num_heads,
|
134
139
|
self.total_num_kv_heads,
|
135
|
-
bias=
|
140
|
+
bias=bias,
|
136
141
|
quant_config=quant_config,
|
137
142
|
prefix=f"{prefix}.qkv_proj",
|
138
143
|
)
|
139
144
|
self.o_proj = RowParallelLinear(
|
140
145
|
self.total_num_heads * self.head_dim,
|
141
146
|
hidden_size,
|
142
|
-
bias=
|
147
|
+
bias=bias,
|
143
148
|
quant_config=quant_config,
|
144
149
|
prefix=f"{prefix}.o_proj",
|
145
150
|
)
|
@@ -194,6 +199,11 @@ class LlamaDecoderLayer(nn.Module):
|
|
194
199
|
)
|
195
200
|
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
196
201
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
202
|
+
# Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied with attention_bias
|
203
|
+
# Support internlm/internlm-7b with bias
|
204
|
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
205
|
+
config, "bias", False
|
206
|
+
)
|
197
207
|
self.self_attn = LlamaAttention(
|
198
208
|
config=config,
|
199
209
|
hidden_size=self.hidden_size,
|
@@ -206,6 +216,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
206
216
|
max_position_embeddings=max_position_embeddings,
|
207
217
|
quant_config=quant_config,
|
208
218
|
prefix=f"{prefix}.self_attn",
|
219
|
+
bias=attention_bias,
|
209
220
|
)
|
210
221
|
self.mlp = LlamaMLP(
|
211
222
|
hidden_size=self.hidden_size,
|
@@ -292,6 +303,30 @@ class LlamaModel(nn.Module):
|
|
292
303
|
hidden_states, _ = self.norm(hidden_states, residual)
|
293
304
|
return hidden_states
|
294
305
|
|
306
|
+
# If this function is called, it should always initialize KV cache scale
|
307
|
+
# factors (or else raise an exception). Thus, handled exceptions should
|
308
|
+
# make sure to leave KV cache scale factors in a known good (dummy) state
|
309
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
310
|
+
tp_size = get_tensor_model_parallel_world_size()
|
311
|
+
tp_rank = get_tensor_model_parallel_rank()
|
312
|
+
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
313
|
+
quantization_param_path,
|
314
|
+
tp_rank,
|
315
|
+
tp_size,
|
316
|
+
self.config.num_hidden_layers,
|
317
|
+
self.config.__class__.model_type,
|
318
|
+
):
|
319
|
+
if not isinstance(self.layers[layer_idx], nn.Identity):
|
320
|
+
layer_self_attn = self.layers[layer_idx].self_attn
|
321
|
+
|
322
|
+
if hasattr(layer_self_attn.attn, "k_scale"):
|
323
|
+
layer_self_attn.attn.k_scale = scaling_factor
|
324
|
+
layer_self_attn.attn.v_scale = scaling_factor
|
325
|
+
else:
|
326
|
+
raise RuntimeError(
|
327
|
+
"Self attention has no KV cache scaling " "factor attribute!"
|
328
|
+
)
|
329
|
+
|
295
330
|
|
296
331
|
class LlamaForCausalLM(nn.Module):
|
297
332
|
|
@@ -527,9 +562,16 @@ class LlamaForCausalLM(nn.Module):
|
|
527
562
|
torch.cuda.empty_cache()
|
528
563
|
torch.cuda.synchronize()
|
529
564
|
|
565
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
566
|
+
self.model.load_kv_cache_scales(quantization_param_path)
|
567
|
+
|
530
568
|
|
531
569
|
class Phi3ForCausalLM(LlamaForCausalLM):
|
532
570
|
pass
|
533
571
|
|
534
572
|
|
535
|
-
|
573
|
+
class InternLM3ForCausalLM(LlamaForCausalLM):
|
574
|
+
pass
|
575
|
+
|
576
|
+
|
577
|
+
EntryClass = [LlamaForCausalLM, Phi3ForCausalLM, InternLM3ForCausalLM]
|
sglang/srt/models/qwen2.py
CHANGED
@@ -362,5 +362,16 @@ class Qwen2ForCausalLM(nn.Module):
|
|
362
362
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
363
363
|
weight_loader(param, loaded_weight)
|
364
364
|
|
365
|
+
def get_embed_and_head(self):
|
366
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
367
|
+
|
368
|
+
def set_embed_and_head(self, embed, head):
|
369
|
+
del self.model.embed_tokens.weight
|
370
|
+
del self.lm_head.weight
|
371
|
+
self.model.embed_tokens.weight = embed
|
372
|
+
self.lm_head.weight = head
|
373
|
+
torch.cuda.empty_cache()
|
374
|
+
torch.cuda.synchronize()
|
375
|
+
|
365
376
|
|
366
377
|
EntryClass = Qwen2ForCausalLM
|
@@ -0,0 +1,131 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
# Adapted from
|
17
|
+
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
|
18
|
+
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
|
19
|
+
|
20
|
+
from typing import Iterable, Optional, Tuple
|
21
|
+
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
|
25
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
26
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
27
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
28
|
+
ParallelLMHead,
|
29
|
+
VocabParallelEmbedding,
|
30
|
+
)
|
31
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
32
|
+
from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM
|
33
|
+
|
34
|
+
Qwen2Config = None
|
35
|
+
|
36
|
+
|
37
|
+
class Qwen2DecoderLayer(Qwen2DecoderLayer):
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
config: Qwen2Config,
|
41
|
+
layer_id: int = 0,
|
42
|
+
quant_config: Optional[QuantizationConfig] = None,
|
43
|
+
prefix: str = "",
|
44
|
+
) -> None:
|
45
|
+
super().__init__(config, layer_id, quant_config)
|
46
|
+
|
47
|
+
# Skip the input_layernorm
|
48
|
+
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
|
49
|
+
if layer_id == 0:
|
50
|
+
del self.input_layernorm
|
51
|
+
setattr(self, "input_layernorm", lambda x: x)
|
52
|
+
|
53
|
+
|
54
|
+
class Qwen2Model(nn.Module):
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
config: Qwen2Config,
|
58
|
+
quant_config: Optional[QuantizationConfig] = None,
|
59
|
+
) -> None:
|
60
|
+
super().__init__()
|
61
|
+
self.config = config
|
62
|
+
self.vocab_size = config.vocab_size
|
63
|
+
self.embed_tokens = VocabParallelEmbedding(
|
64
|
+
config.vocab_size,
|
65
|
+
config.hidden_size,
|
66
|
+
)
|
67
|
+
self.layers = nn.ModuleList(
|
68
|
+
[
|
69
|
+
Qwen2DecoderLayer(
|
70
|
+
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
|
71
|
+
)
|
72
|
+
for i in range(config.num_hidden_layers)
|
73
|
+
]
|
74
|
+
)
|
75
|
+
self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size)
|
76
|
+
|
77
|
+
def forward(
|
78
|
+
self,
|
79
|
+
input_ids: torch.Tensor,
|
80
|
+
positions: torch.Tensor,
|
81
|
+
forward_batch: ForwardBatch,
|
82
|
+
input_embeds: torch.Tensor = None,
|
83
|
+
) -> torch.Tensor:
|
84
|
+
if input_embeds is None:
|
85
|
+
hidden_states = self.embed_tokens(input_ids)
|
86
|
+
else:
|
87
|
+
hidden_states = input_embeds
|
88
|
+
|
89
|
+
hidden_states = self.fc(
|
90
|
+
torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1)
|
91
|
+
)
|
92
|
+
|
93
|
+
residual = None
|
94
|
+
for i in range(len(self.layers)):
|
95
|
+
layer = self.layers[i]
|
96
|
+
hidden_states, residual = layer(
|
97
|
+
positions,
|
98
|
+
hidden_states,
|
99
|
+
forward_batch,
|
100
|
+
residual,
|
101
|
+
)
|
102
|
+
return hidden_states + residual
|
103
|
+
|
104
|
+
|
105
|
+
class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
|
106
|
+
def __init__(
|
107
|
+
self,
|
108
|
+
config: Qwen2Config,
|
109
|
+
quant_config: Optional[QuantizationConfig] = None,
|
110
|
+
cache_config=None,
|
111
|
+
) -> None:
|
112
|
+
nn.Module.__init__(self)
|
113
|
+
self.config = config
|
114
|
+
self.quant_config = quant_config
|
115
|
+
self.model = Qwen2Model(config, quant_config=quant_config)
|
116
|
+
if self.config.tie_word_embeddings:
|
117
|
+
self.lm_head = self.model.embed_tokens
|
118
|
+
else:
|
119
|
+
self.lm_head = ParallelLMHead(
|
120
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
121
|
+
)
|
122
|
+
self.logits_processor = LogitsProcessor(config)
|
123
|
+
|
124
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
125
|
+
for name, loaded_weight in weights:
|
126
|
+
if "lm_head" not in name:
|
127
|
+
name = "model." + name
|
128
|
+
super().load_weights([(name, loaded_weight)])
|
129
|
+
|
130
|
+
|
131
|
+
EntryClass = [Qwen2ForCausalLMEagle]
|
@@ -3,6 +3,11 @@ from typing import List
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
+
from sglang.srt.utils import is_cuda_available
|
7
|
+
|
8
|
+
is_cuda = is_cuda_available()
|
9
|
+
if is_cuda:
|
10
|
+
from sgl_kernel import sampling_scaling_penalties
|
6
11
|
|
7
12
|
|
8
13
|
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
@@ -56,11 +61,16 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|
56
61
|
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
57
62
|
|
58
63
|
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
+
if is_cuda:
|
65
|
+
return sampling_scaling_penalties(
|
66
|
+
logits, self.cumulated_repetition_penalties
|
67
|
+
)
|
68
|
+
else:
|
69
|
+
return torch.where(
|
70
|
+
logits > 0,
|
71
|
+
logits / self.cumulated_repetition_penalties,
|
72
|
+
logits * self.cumulated_repetition_penalties,
|
73
|
+
)
|
64
74
|
|
65
75
|
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
66
76
|
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
@@ -7,6 +7,12 @@ from typing import TYPE_CHECKING, Callable, List, Optional
|
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
10
|
+
from sglang.srt.utils import is_cuda_available
|
11
|
+
|
12
|
+
is_cuda = is_cuda_available()
|
13
|
+
if is_cuda:
|
14
|
+
from sgl_kernel import sampling_scaling_penalties
|
15
|
+
|
10
16
|
import sglang.srt.sampling.penaltylib as penaltylib
|
11
17
|
|
12
18
|
logger = logging.getLogger(__name__)
|
@@ -232,6 +238,7 @@ class SamplingBatchInfo:
|
|
232
238
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
233
239
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
234
240
|
)
|
241
|
+
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
235
242
|
|
236
243
|
def apply_logits_bias(self, logits: torch.Tensor):
|
237
244
|
# Apply logit_bias
|
@@ -244,11 +251,14 @@ class SamplingBatchInfo:
|
|
244
251
|
|
245
252
|
# repetition
|
246
253
|
if self.scaling_penalties is not None:
|
247
|
-
|
248
|
-
logits
|
249
|
-
|
250
|
-
logits
|
251
|
-
|
254
|
+
if is_cuda:
|
255
|
+
logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
|
256
|
+
else:
|
257
|
+
logits[:] = torch.where(
|
258
|
+
logits > 0,
|
259
|
+
logits / self.scaling_penalties,
|
260
|
+
logits * self.scaling_penalties,
|
261
|
+
)
|
252
262
|
|
253
263
|
# Apply regex vocab_mask
|
254
264
|
if self.vocab_mask is not None:
|