sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +7 -1
- sglang/bench_latency.py +9 -6
- sglang/bench_serving.py +46 -22
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +4 -2
- sglang/lang/ir.py +16 -7
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/layers/activation.py +32 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +9 -2
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +7 -2
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +40 -16
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +115 -97
- sglang/srt/managers/tokenizer_manager.py +194 -112
- sglang/srt/managers/tp_worker.py +290 -359
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +71 -25
- sglang/srt/model_executor/forward_batch_info.py +293 -156
- sglang/srt/model_executor/model_runner.py +77 -57
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/deepseek.py +2 -2
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +11 -6
- sglang/srt/models/grok.py +50 -396
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/mixtral.py +56 -254
- sglang/srt/models/mixtral_quant.py +1 -4
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +2 -13
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +187 -48
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -8
- sglang/srt/server.py +91 -29
- sglang/srt/server_args.py +32 -19
- sglang/srt/utils.py +32 -15
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +81 -73
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +36 -7
- sglang/test/test_utils.py +24 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
- sglang-0.2.13.dist-info/RECORD +112 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,7 @@ from vllm.distributed import (
|
|
38
38
|
init_distributed_environment,
|
39
39
|
initialize_model_parallel,
|
40
40
|
)
|
41
|
+
from vllm.model_executor.model_loader import get_model
|
41
42
|
from vllm.model_executor.models import ModelRegistry
|
42
43
|
|
43
44
|
from sglang.global_config import global_config
|
@@ -52,7 +53,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
|
|
52
53
|
from sglang.srt.server_args import ServerArgs
|
53
54
|
from sglang.srt.utils import (
|
54
55
|
get_available_gpu_memory,
|
55
|
-
|
56
|
+
is_generation_model,
|
57
|
+
is_llama3_405b_fp8_head_16,
|
56
58
|
is_multimodal_model,
|
57
59
|
monkey_patch_vllm_dummy_weight_loader,
|
58
60
|
monkey_patch_vllm_p2p_access_check,
|
@@ -130,10 +132,12 @@ class ModelRunner:
|
|
130
132
|
server_args.max_total_tokens,
|
131
133
|
)
|
132
134
|
self.init_cublas()
|
133
|
-
self.
|
135
|
+
self.init_flashinfer()
|
134
136
|
|
135
|
-
|
136
|
-
|
137
|
+
if self.is_generation:
|
138
|
+
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
|
139
|
+
# Capture cuda graphs
|
140
|
+
self.init_cuda_graphs()
|
137
141
|
|
138
142
|
def load_model(self):
|
139
143
|
logger.info(
|
@@ -155,7 +159,7 @@ class ModelRunner:
|
|
155
159
|
skip_tokenizer_init=True,
|
156
160
|
)
|
157
161
|
|
158
|
-
if
|
162
|
+
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
|
159
163
|
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
160
164
|
self.model_config.hf_config.num_key_value_heads = 8
|
161
165
|
vllm_model_config.hf_config.num_key_value_heads = 8
|
@@ -165,15 +169,6 @@ class ModelRunner:
|
|
165
169
|
if self.model_config.model_overide_args is not None:
|
166
170
|
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
167
171
|
|
168
|
-
if (
|
169
|
-
self.server_args.efficient_weight_load
|
170
|
-
and "llama" in self.server_args.model_path.lower()
|
171
|
-
and self.server_args.quantization == "fp8"
|
172
|
-
):
|
173
|
-
from sglang.srt.model_loader.model_loader import get_model
|
174
|
-
else:
|
175
|
-
from vllm.model_executor.model_loader import get_model
|
176
|
-
|
177
172
|
self.model = get_model(
|
178
173
|
model_config=vllm_model_config,
|
179
174
|
device_config=device_config,
|
@@ -184,6 +179,15 @@ class ModelRunner:
|
|
184
179
|
scheduler_config=None,
|
185
180
|
cache_config=None,
|
186
181
|
)
|
182
|
+
self.sliding_window_size = (
|
183
|
+
self.model.get_window_size()
|
184
|
+
if hasattr(self.model, "get_window_size")
|
185
|
+
else None
|
186
|
+
)
|
187
|
+
self.is_generation = is_generation_model(
|
188
|
+
self.model_config.hf_config.architectures
|
189
|
+
)
|
190
|
+
|
187
191
|
logger.info(
|
188
192
|
f"[gpu={self.gpu_id}] Load weight end. "
|
189
193
|
f"type={type(self.model).__name__}, "
|
@@ -287,8 +291,11 @@ class ModelRunner:
|
|
287
291
|
c = a @ b
|
288
292
|
return c
|
289
293
|
|
290
|
-
def
|
294
|
+
def init_flashinfer(self):
|
291
295
|
if self.server_args.disable_flashinfer:
|
296
|
+
assert (
|
297
|
+
self.sliding_window_size is None
|
298
|
+
), "turn on flashinfer to support window attention"
|
292
299
|
self.flashinfer_prefill_wrapper_ragged = None
|
293
300
|
self.flashinfer_prefill_wrapper_paged = None
|
294
301
|
self.flashinfer_decode_wrapper = None
|
@@ -302,20 +309,47 @@ class ModelRunner:
|
|
302
309
|
else:
|
303
310
|
use_tensor_cores = False
|
304
311
|
|
305
|
-
self.
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
312
|
+
if self.sliding_window_size is None:
|
313
|
+
self.flashinfer_workspace_buffer = torch.empty(
|
314
|
+
global_config.flashinfer_workspace_size,
|
315
|
+
dtype=torch.uint8,
|
316
|
+
device="cuda",
|
317
|
+
)
|
318
|
+
self.flashinfer_prefill_wrapper_ragged = (
|
319
|
+
BatchPrefillWithRaggedKVCacheWrapper(
|
320
|
+
self.flashinfer_workspace_buffer, "NHD"
|
321
|
+
)
|
322
|
+
)
|
323
|
+
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
324
|
+
self.flashinfer_workspace_buffer, "NHD"
|
325
|
+
)
|
326
|
+
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
327
|
+
self.flashinfer_workspace_buffer,
|
328
|
+
"NHD",
|
329
|
+
use_tensor_cores=use_tensor_cores,
|
330
|
+
)
|
331
|
+
else:
|
332
|
+
self.flashinfer_workspace_buffer = torch.empty(
|
333
|
+
global_config.flashinfer_workspace_size,
|
334
|
+
dtype=torch.uint8,
|
335
|
+
device="cuda",
|
336
|
+
)
|
337
|
+
self.flashinfer_prefill_wrapper_ragged = None
|
338
|
+
self.flashinfer_prefill_wrapper_paged = []
|
339
|
+
self.flashinfer_decode_wrapper = []
|
340
|
+
for i in range(2):
|
341
|
+
self.flashinfer_prefill_wrapper_paged.append(
|
342
|
+
BatchPrefillWithPagedKVCacheWrapper(
|
343
|
+
self.flashinfer_workspace_buffer, "NHD"
|
344
|
+
)
|
345
|
+
)
|
346
|
+
self.flashinfer_decode_wrapper.append(
|
347
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
348
|
+
self.flashinfer_workspace_buffer,
|
349
|
+
"NHD",
|
350
|
+
use_tensor_cores=use_tensor_cores,
|
351
|
+
)
|
352
|
+
)
|
319
353
|
|
320
354
|
def init_cuda_graphs(self):
|
321
355
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
@@ -350,33 +384,22 @@ class ModelRunner:
|
|
350
384
|
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
351
385
|
return self.cuda_graph_runner.replay(batch)
|
352
386
|
|
353
|
-
input_metadata = InputMetadata.
|
387
|
+
input_metadata = InputMetadata.from_schedule_batch(
|
354
388
|
self,
|
355
|
-
|
356
|
-
|
357
|
-
seq_lens=batch.seq_lens,
|
358
|
-
prefix_lens=batch.prefix_lens,
|
359
|
-
position_ids_offsets=batch.position_ids_offsets,
|
360
|
-
out_cache_loc=batch.out_cache_loc,
|
361
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
362
|
-
return_logprob=batch.return_logprob,
|
389
|
+
batch,
|
390
|
+
ForwardMode.DECODE,
|
363
391
|
)
|
392
|
+
|
364
393
|
return self.model.forward(
|
365
394
|
batch.input_ids, input_metadata.positions, input_metadata
|
366
395
|
)
|
367
396
|
|
368
397
|
@torch.inference_mode()
|
369
398
|
def forward_extend(self, batch: ScheduleBatch):
|
370
|
-
input_metadata = InputMetadata.
|
399
|
+
input_metadata = InputMetadata.from_schedule_batch(
|
371
400
|
self,
|
401
|
+
batch,
|
372
402
|
forward_mode=ForwardMode.EXTEND,
|
373
|
-
req_pool_indices=batch.req_pool_indices,
|
374
|
-
seq_lens=batch.seq_lens,
|
375
|
-
prefix_lens=batch.prefix_lens,
|
376
|
-
position_ids_offsets=batch.position_ids_offsets,
|
377
|
-
out_cache_loc=batch.out_cache_loc,
|
378
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
379
|
-
return_logprob=batch.return_logprob,
|
380
403
|
)
|
381
404
|
return self.model.forward(
|
382
405
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -384,24 +407,18 @@ class ModelRunner:
|
|
384
407
|
|
385
408
|
@torch.inference_mode()
|
386
409
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
387
|
-
input_metadata = InputMetadata.
|
410
|
+
input_metadata = InputMetadata.from_schedule_batch(
|
388
411
|
self,
|
412
|
+
batch,
|
389
413
|
forward_mode=ForwardMode.EXTEND,
|
390
|
-
req_pool_indices=batch.req_pool_indices,
|
391
|
-
seq_lens=batch.seq_lens,
|
392
|
-
prefix_lens=batch.prefix_lens,
|
393
|
-
position_ids_offsets=batch.position_ids_offsets,
|
394
|
-
out_cache_loc=batch.out_cache_loc,
|
395
|
-
return_logprob=batch.return_logprob,
|
396
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
397
414
|
)
|
398
415
|
return self.model.forward(
|
399
416
|
batch.input_ids,
|
400
417
|
input_metadata.positions,
|
401
418
|
input_metadata,
|
402
|
-
|
403
|
-
|
404
|
-
|
419
|
+
input_metadata.pixel_values,
|
420
|
+
input_metadata.image_sizes,
|
421
|
+
input_metadata.image_offsets,
|
405
422
|
)
|
406
423
|
|
407
424
|
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
@@ -429,8 +446,10 @@ def import_model_classes():
|
|
429
446
|
entry, list
|
430
447
|
): # To support multiple model classes in one module
|
431
448
|
for tmp in entry:
|
449
|
+
assert tmp.__name__ not in model_arch_name_to_cls
|
432
450
|
model_arch_name_to_cls[tmp.__name__] = tmp
|
433
451
|
else:
|
452
|
+
assert entry.__name__ not in model_arch_name_to_cls
|
434
453
|
model_arch_name_to_cls[entry.__name__] = entry
|
435
454
|
|
436
455
|
# compat: some models such as chatglm has incorrect class set in config.json
|
@@ -440,6 +459,7 @@ def import_model_classes():
|
|
440
459
|
):
|
441
460
|
for remap in module.EntryClassRemapping:
|
442
461
|
if isinstance(remap, tuple) and len(remap) == 2:
|
462
|
+
assert remap[0] not in model_arch_name_to_cls
|
443
463
|
model_arch_name_to_cls[remap[0]] = remap[1]
|
444
464
|
|
445
465
|
return model_arch_name_to_cls
|
sglang/srt/models/chatglm.py
CHANGED
@@ -24,8 +24,6 @@ from torch import nn
|
|
24
24
|
from torch.nn import LayerNorm
|
25
25
|
from vllm.config import CacheConfig
|
26
26
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
27
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
28
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
29
27
|
from vllm.model_executor.layers.linear import (
|
30
28
|
MergedColumnParallelLinear,
|
31
29
|
QKVParallelLinear,
|
@@ -43,6 +41,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
43
41
|
from vllm.sequence import SamplerOutput
|
44
42
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
45
43
|
|
44
|
+
from sglang.srt.layers.activation import SiluAndMul
|
45
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
48
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
sglang/srt/models/commandr.py
CHANGED
@@ -50,7 +50,6 @@ from vllm.distributed import (
|
|
50
50
|
get_tensor_model_parallel_rank,
|
51
51
|
get_tensor_model_parallel_world_size,
|
52
52
|
)
|
53
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
54
53
|
from vllm.model_executor.layers.linear import (
|
55
54
|
MergedColumnParallelLinear,
|
56
55
|
QKVParallelLinear,
|
@@ -62,6 +61,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
|
|
62
61
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
63
62
|
from vllm.model_executor.utils import set_weight_attrs
|
64
63
|
|
64
|
+
from sglang.srt.layers.activation import SiluAndMul
|
65
65
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
66
66
|
from sglang.srt.layers.radix_attention import RadixAttention
|
67
67
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
sglang/srt/models/deepseek.py
CHANGED
@@ -27,9 +27,7 @@ from vllm.distributed import (
|
|
27
27
|
get_tensor_model_parallel_world_size,
|
28
28
|
tensor_model_parallel_all_reduce,
|
29
29
|
)
|
30
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
31
30
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
32
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
33
31
|
from vllm.model_executor.layers.linear import (
|
34
32
|
MergedColumnParallelLinear,
|
35
33
|
QKVParallelLinear,
|
@@ -44,6 +42,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
44
42
|
)
|
45
43
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
46
44
|
|
45
|
+
from sglang.srt.layers.activation import SiluAndMul
|
46
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
47
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -26,9 +26,7 @@ from vllm.distributed import (
|
|
26
26
|
get_tensor_model_parallel_world_size,
|
27
27
|
tensor_model_parallel_all_reduce,
|
28
28
|
)
|
29
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
30
29
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
31
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
32
30
|
from vllm.model_executor.layers.linear import (
|
33
31
|
ColumnParallelLinear,
|
34
32
|
MergedColumnParallelLinear,
|
@@ -43,6 +41,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
43
41
|
)
|
44
42
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
45
43
|
|
44
|
+
from sglang.srt.layers.activation import SiluAndMul
|
45
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -445,11 +445,12 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
445
445
|
q_nope_out = q_input[..., : self.kv_lora_rank]
|
446
446
|
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
|
447
447
|
|
448
|
-
|
449
|
-
|
450
|
-
v_input =
|
451
|
-
|
448
|
+
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
449
|
+
v_input = latent_cache[..., : self.kv_lora_rank]
|
450
|
+
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
|
451
|
+
k_input = latent_cache.unsqueeze(1)
|
452
452
|
k_input[..., : self.kv_lora_rank] = v_input
|
453
|
+
k_pe = k_input[..., self.kv_lora_rank :]
|
453
454
|
|
454
455
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
455
456
|
q_input[..., self.kv_lora_rank :] = q_pe
|
sglang/srt/models/gemma.py
CHANGED
@@ -24,7 +24,6 @@ from transformers import PretrainedConfig
|
|
24
24
|
from vllm.config import CacheConfig, LoRAConfig
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
26
|
from vllm.model_executor.layers.activation import GeluAndMul
|
27
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
28
27
|
from vllm.model_executor.layers.linear import (
|
29
28
|
MergedColumnParallelLinear,
|
30
29
|
QKVParallelLinear,
|
@@ -35,6 +34,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
35
34
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
36
35
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
37
36
|
|
37
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
40
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
sglang/srt/models/gemma2.py
CHANGED
@@ -38,13 +38,18 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|
38
38
|
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
39
39
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
40
40
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
41
|
-
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
42
41
|
|
43
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
44
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
46
45
|
|
47
46
|
|
47
|
+
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
48
|
+
# SGLang assumes exclusive
|
49
|
+
def get_window_size(config):
|
50
|
+
return config.sliding_window - 1
|
51
|
+
|
52
|
+
|
48
53
|
class GemmaRMSNorm(CustomOp):
|
49
54
|
"""RMS normalization for Gemma.
|
50
55
|
|
@@ -201,17 +206,14 @@ class Gemma2Attention(nn.Module):
|
|
201
206
|
dtype=torch.get_default_dtype(),
|
202
207
|
)
|
203
208
|
|
204
|
-
|
205
|
-
# odd layer, vLLM currently ignores it and uses global attention for
|
206
|
-
# all layers.
|
207
|
-
use_sliding_window = layer_idx % 2 == 1 and config.sliding_window is not None
|
208
|
-
del use_sliding_window # Unused.
|
209
|
+
use_sliding_window = layer_idx % 2 == 0 and hasattr(config, "sliding_window")
|
209
210
|
self.attn = RadixAttention(
|
210
211
|
self.num_heads,
|
211
212
|
self.head_dim,
|
212
213
|
self.scaling,
|
213
214
|
num_kv_heads=self.num_kv_heads,
|
214
215
|
layer_id=layer_idx,
|
216
|
+
sliding_window_size=get_window_size(config) if use_sliding_window else None,
|
215
217
|
logit_cap=self.config.attn_logit_softcapping,
|
216
218
|
)
|
217
219
|
|
@@ -404,6 +406,9 @@ class Gemma2ForCausalLM(nn.Module):
|
|
404
406
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
405
407
|
)
|
406
408
|
|
409
|
+
def get_window_size(self):
|
410
|
+
return get_window_size(self.config)
|
411
|
+
|
407
412
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
408
413
|
stacked_params_mapping = [
|
409
414
|
# (param_name, shard_name, shard_id)
|