sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +330 -200
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +12 -5
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +25 -13
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +1 -0
- sglang/srt/layers/radix_attention.py +13 -1
- sglang/srt/layers/rotary_embedding.py +12 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +48 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +1 -0
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,6 @@ from dataclasses import dataclass
|
|
33
33
|
from enum import IntEnum, auto
|
34
34
|
from typing import TYPE_CHECKING, List, Optional, Union
|
35
35
|
|
36
|
-
import numpy as np
|
37
36
|
import torch
|
38
37
|
import triton
|
39
38
|
import triton.language as tl
|
@@ -72,14 +71,14 @@ class ForwardMode(IntEnum):
|
|
72
71
|
DUMMY_FIRST = auto()
|
73
72
|
|
74
73
|
def is_prefill(self):
|
75
|
-
return self
|
74
|
+
return self.is_extend()
|
76
75
|
|
77
76
|
def is_extend(self):
|
78
77
|
return (
|
79
78
|
self == ForwardMode.EXTEND
|
80
79
|
or self == ForwardMode.MIXED
|
81
80
|
or self == ForwardMode.DRAFT_EXTEND
|
82
|
-
or self ==
|
81
|
+
or self == ForwardMode.TARGET_VERIFY
|
83
82
|
)
|
84
83
|
|
85
84
|
def is_decode(self):
|
@@ -97,6 +96,13 @@ class ForwardMode(IntEnum):
|
|
97
96
|
def is_draft_extend(self):
|
98
97
|
return self == ForwardMode.DRAFT_EXTEND
|
99
98
|
|
99
|
+
def is_extend_or_draft_extend_or_mixed(self):
|
100
|
+
return (
|
101
|
+
self == ForwardMode.EXTEND
|
102
|
+
or self == ForwardMode.DRAFT_EXTEND
|
103
|
+
or self == ForwardMode.MIXED
|
104
|
+
)
|
105
|
+
|
100
106
|
def is_cuda_graph(self):
|
101
107
|
return (
|
102
108
|
self == ForwardMode.DECODE
|
@@ -104,9 +110,6 @@ class ForwardMode(IntEnum):
|
|
104
110
|
or self == ForwardMode.IDLE
|
105
111
|
)
|
106
112
|
|
107
|
-
def is_extend_or_draft_extend(self):
|
108
|
-
return self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND
|
109
|
-
|
110
113
|
def is_dummy_first(self):
|
111
114
|
return self == ForwardMode.DUMMY_FIRST
|
112
115
|
|
@@ -178,6 +181,28 @@ class ForwardBatch:
|
|
178
181
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
179
182
|
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
180
183
|
|
184
|
+
# For MLA chunked prefix cache used in chunked prefill
|
185
|
+
# Tell attention backend whether the kv cache needs to be attended in current pass
|
186
|
+
attn_attend_prefix_cache: Optional[bool] = None
|
187
|
+
# Number of prefix cache chunks
|
188
|
+
num_prefix_chunks: Optional[int] = None
|
189
|
+
# Index of current chunk, used by attention backend
|
190
|
+
prefix_chunk_idx: Optional[int] = None
|
191
|
+
# Maximum number of tokens in each chunk per sequence. Computed from maximum chunk capacity
|
192
|
+
prefix_chunk_len: Optional[int] = None
|
193
|
+
# Start positions of prefix cache for each chunk, (num_prefix_chunks, batch_size)
|
194
|
+
prefix_chunk_starts: Optional[torch.Tensor] = None
|
195
|
+
# Lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size)
|
196
|
+
prefix_chunk_seq_lens: Optional[torch.Tensor] = None
|
197
|
+
# Accumulated lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size + 1)
|
198
|
+
prefix_chunk_cu_seq_lens: Optional[torch.Tensor] = None
|
199
|
+
# Max lengths of prefix cache for each chunk, (num_prefix_chunks,)
|
200
|
+
prefix_chunk_max_seq_lens: Optional[List[int]] = None
|
201
|
+
# Number of tokens in each prefix cache chunk, (num_prefix_chunks,)
|
202
|
+
prefix_chunk_num_tokens: Optional[List[int]] = None
|
203
|
+
# KV Indices for each chunk
|
204
|
+
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
|
205
|
+
|
181
206
|
# For multimodal
|
182
207
|
mm_inputs: Optional[List[MultimodalInputs]] = None
|
183
208
|
|
@@ -399,13 +424,13 @@ class ForwardBatch:
|
|
399
424
|
)
|
400
425
|
elif self.forward_mode.is_extend():
|
401
426
|
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
402
|
-
for i,
|
427
|
+
for i, mm_input in enumerate(batch.multimodal_inputs):
|
403
428
|
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
404
429
|
extend_start_loc_cpu[i],
|
405
430
|
batch.extend_seq_lens[i],
|
406
431
|
batch.extend_prefix_lens[i],
|
407
432
|
)
|
408
|
-
if
|
433
|
+
if mm_input is None:
|
409
434
|
# text only
|
410
435
|
mrope_positions = [
|
411
436
|
[
|
@@ -416,23 +441,58 @@ class ForwardBatch:
|
|
416
441
|
]
|
417
442
|
] * 3
|
418
443
|
else:
|
444
|
+
image_grid_thws_list = [
|
445
|
+
item.image_grid_thws
|
446
|
+
for item in mm_input.mm_items
|
447
|
+
if item.image_grid_thws is not None
|
448
|
+
]
|
449
|
+
image_grid_thw = (
|
450
|
+
None
|
451
|
+
if len(image_grid_thws_list) == 0
|
452
|
+
else torch.cat(image_grid_thws_list, dim=0)
|
453
|
+
)
|
454
|
+
|
455
|
+
video_grid_thws_list = [
|
456
|
+
item.video_grid_thws
|
457
|
+
for item in mm_input.mm_items
|
458
|
+
if item.video_grid_thws is not None
|
459
|
+
]
|
460
|
+
video_grid_thw = (
|
461
|
+
None
|
462
|
+
if len(video_grid_thws_list) == 0
|
463
|
+
else torch.cat(video_grid_thws_list, dim=0)
|
464
|
+
)
|
465
|
+
|
466
|
+
second_per_grid_ts_list = [
|
467
|
+
item.second_per_grid_ts
|
468
|
+
for item in mm_input.mm_items
|
469
|
+
if item.second_per_grid_ts is not None
|
470
|
+
]
|
471
|
+
second_per_grid_ts = (
|
472
|
+
None
|
473
|
+
if len(second_per_grid_ts_list) == 0
|
474
|
+
else torch.cat(second_per_grid_ts_list, dim=0)
|
475
|
+
)
|
476
|
+
|
419
477
|
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
420
478
|
mrope_positions, mrope_position_delta = (
|
421
479
|
MRotaryEmbedding.get_input_positions(
|
422
480
|
input_tokens=self.input_ids[
|
423
481
|
extend_start_loc : extend_start_loc + extend_seq_len
|
424
|
-
],
|
425
|
-
image_grid_thw=
|
426
|
-
video_grid_thw=
|
427
|
-
image_token_id=
|
428
|
-
video_token_id=
|
482
|
+
].tolist(),
|
483
|
+
image_grid_thw=image_grid_thw,
|
484
|
+
video_grid_thw=video_grid_thw,
|
485
|
+
image_token_id=hf_config.image_token_id,
|
486
|
+
video_token_id=hf_config.video_token_id,
|
429
487
|
vision_start_token_id=hf_config.vision_start_token_id,
|
430
488
|
vision_end_token_id=hf_config.vision_end_token_id,
|
431
489
|
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
432
490
|
context_len=0,
|
433
491
|
seq_len=len(self.input_ids),
|
434
|
-
second_per_grid_ts=
|
435
|
-
tokens_per_second=
|
492
|
+
second_per_grid_ts=second_per_grid_ts,
|
493
|
+
tokens_per_second=getattr(
|
494
|
+
hf_config.vision_config, "tokens_per_second", None
|
495
|
+
),
|
436
496
|
)
|
437
497
|
)
|
438
498
|
batch.multimodal_inputs[i].mrope_position_delta = (
|
@@ -446,6 +506,128 @@ class ForwardBatch:
|
|
446
506
|
)
|
447
507
|
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
448
508
|
|
509
|
+
def get_max_chunk_capacity(self):
|
510
|
+
# Maximum number of tokens in each chunk
|
511
|
+
# TODO: Should be changed to a better value, maybe passed through server args
|
512
|
+
return 128 * 1024
|
513
|
+
|
514
|
+
def set_prefix_chunk_idx(self, idx: int):
|
515
|
+
self.prefix_chunk_idx = idx
|
516
|
+
|
517
|
+
def set_attn_attend_prefix_cache(self, attn_attend_prefix_cache: bool):
|
518
|
+
self.attn_attend_prefix_cache = attn_attend_prefix_cache
|
519
|
+
|
520
|
+
def prepare_chunked_kv_indices(self, device: torch.device):
|
521
|
+
self.prefix_chunk_kv_indices = []
|
522
|
+
for idx in range(self.num_prefix_chunks):
|
523
|
+
chunk_starts = self.prefix_chunk_starts[idx]
|
524
|
+
chunk_seq_lens = self.prefix_chunk_seq_lens[idx]
|
525
|
+
chunk_cu_seq_lens = self.prefix_chunk_cu_seq_lens[idx]
|
526
|
+
num_chunk_tokens = self.prefix_chunk_num_tokens[idx]
|
527
|
+
|
528
|
+
chunk_kv_indices = torch.empty(
|
529
|
+
num_chunk_tokens, dtype=torch.int32, device=device
|
530
|
+
)
|
531
|
+
|
532
|
+
create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
|
533
|
+
self.req_to_token_pool.req_to_token,
|
534
|
+
self.req_pool_indices,
|
535
|
+
chunk_starts,
|
536
|
+
chunk_seq_lens,
|
537
|
+
chunk_cu_seq_lens,
|
538
|
+
chunk_kv_indices,
|
539
|
+
self.req_to_token_pool.req_to_token.shape[1],
|
540
|
+
)
|
541
|
+
self.prefix_chunk_kv_indices.append(chunk_kv_indices)
|
542
|
+
|
543
|
+
# Here we suppose the length of each chunk is equal
|
544
|
+
# For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
|
545
|
+
# num_prefix_chunks = cdiv(1024, 256) = 4
|
546
|
+
# prefix_chunk_starts = [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512], [768, 768, 768, 768]]
|
547
|
+
# prefix_chunk_ends = [[256, 256, 256, 256], [256, 512, 512, 512], [256, 512, 768, 768], [256, 512, 768, 1024]]
|
548
|
+
# prefix_chunk_seq_lens = [[256, 256, 256, 256], [0, 256, 256, 256], [0, 0, 256, 256], [0, 0, 0, 256]]
|
549
|
+
# TODO: Implement a better way to allocate chunk lengths that uses memory spaces more efficiently.
|
550
|
+
def get_prefix_chunk_seq_lens(
|
551
|
+
self, prefix_lens: torch.Tensor, num_prefix_chunks: int, prefix_chunk_len: int
|
552
|
+
):
|
553
|
+
device = prefix_lens.device
|
554
|
+
prefix_chunk_starts = (
|
555
|
+
torch.arange(num_prefix_chunks, device=device, dtype=torch.int32)
|
556
|
+
.unsqueeze(1)
|
557
|
+
.expand(-1, self.batch_size)
|
558
|
+
* prefix_chunk_len
|
559
|
+
)
|
560
|
+
prefix_chunk_ends = torch.min(
|
561
|
+
prefix_lens.unsqueeze(0),
|
562
|
+
prefix_chunk_starts + prefix_chunk_len,
|
563
|
+
).to(torch.int32)
|
564
|
+
|
565
|
+
prefix_chunk_seq_lens = (
|
566
|
+
(prefix_chunk_ends - prefix_chunk_starts).clamp(min=0).to(torch.int32)
|
567
|
+
)
|
568
|
+
|
569
|
+
return prefix_chunk_starts, prefix_chunk_seq_lens
|
570
|
+
|
571
|
+
# Called before each attention module if using chunked kv cache for prefill
|
572
|
+
# Some of the codes are adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
|
573
|
+
def prepare_chunked_prefix_cache_info(self, device: torch.device):
|
574
|
+
|
575
|
+
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
576
|
+
|
577
|
+
assert isinstance(
|
578
|
+
self.token_to_kv_pool, MLATokenToKVPool
|
579
|
+
), "Currently chunked prefix cache can only be used by Deepseek models"
|
580
|
+
|
581
|
+
if self.prefix_chunk_len is not None:
|
582
|
+
# Chunked kv cache info already prepared by prior modules
|
583
|
+
return
|
584
|
+
|
585
|
+
self.prefix_chunk_idx = -1
|
586
|
+
|
587
|
+
# chunk_capacity is the maximum number of tokens in each chunk
|
588
|
+
chunk_capacity = self.get_max_chunk_capacity()
|
589
|
+
self.prefix_chunk_len = chunk_capacity // self.batch_size
|
590
|
+
|
591
|
+
self.num_prefix_chunks = (
|
592
|
+
max(self.extend_prefix_lens_cpu) + self.prefix_chunk_len - 1
|
593
|
+
) // self.prefix_chunk_len
|
594
|
+
|
595
|
+
# Here we compute chunk lens twice to avoid stream sync, once on gpu and once on cpu.
|
596
|
+
prefix_chunk_starts_cuda, prefix_chunk_seq_lens_cuda = (
|
597
|
+
self.get_prefix_chunk_seq_lens(
|
598
|
+
self.extend_prefix_lens,
|
599
|
+
self.num_prefix_chunks,
|
600
|
+
self.prefix_chunk_len,
|
601
|
+
)
|
602
|
+
)
|
603
|
+
_, prefix_chunk_seq_lens_cpu = self.get_prefix_chunk_seq_lens(
|
604
|
+
torch.tensor(self.extend_prefix_lens_cpu),
|
605
|
+
self.num_prefix_chunks,
|
606
|
+
self.prefix_chunk_len,
|
607
|
+
)
|
608
|
+
self.prefix_chunk_starts = prefix_chunk_starts_cuda
|
609
|
+
self.prefix_chunk_seq_lens = prefix_chunk_seq_lens_cuda
|
610
|
+
|
611
|
+
# Metadata for attention backend
|
612
|
+
self.prefix_chunk_cu_seq_lens = torch.zeros(
|
613
|
+
self.num_prefix_chunks,
|
614
|
+
self.batch_size + 1,
|
615
|
+
device=device,
|
616
|
+
dtype=torch.int32,
|
617
|
+
)
|
618
|
+
self.prefix_chunk_cu_seq_lens[:, 1:] = prefix_chunk_seq_lens_cuda.cumsum(
|
619
|
+
dim=1
|
620
|
+
).to(torch.int32)
|
621
|
+
self.prefix_chunk_max_seq_lens = prefix_chunk_seq_lens_cpu.max(
|
622
|
+
dim=1
|
623
|
+
).values.tolist()
|
624
|
+
|
625
|
+
self.prefix_chunk_num_tokens = prefix_chunk_seq_lens_cpu.sum(dim=1).tolist()
|
626
|
+
assert max(self.prefix_chunk_num_tokens) <= self.get_max_chunk_capacity()
|
627
|
+
|
628
|
+
# Precompute the kv indices for each chunk
|
629
|
+
self.prepare_chunked_kv_indices(device)
|
630
|
+
|
449
631
|
|
450
632
|
def compute_position_triton(
|
451
633
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
@@ -523,3 +705,40 @@ def compute_position_torch(
|
|
523
705
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
524
706
|
def clamp_position(seq_lens):
|
525
707
|
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
708
|
+
|
709
|
+
|
710
|
+
@triton.jit
|
711
|
+
def create_chunked_prefix_cache_kv_indices(
|
712
|
+
req_to_token_ptr, # (max_batch, max_context_len,)
|
713
|
+
req_pool_indices_ptr, # (batch_size,)
|
714
|
+
chunk_start_idx_ptr, # (batch_size,)
|
715
|
+
chunk_seq_lens_ptr, # (batch_size,)
|
716
|
+
chunk_cu_seq_lens_ptr, # (batch_size + 1,)
|
717
|
+
chunk_kv_indices_ptr, # (num_chunk_tokens,)
|
718
|
+
req_to_token_ptr_stride: tl.constexpr,
|
719
|
+
):
|
720
|
+
BLOCK_SIZE: tl.constexpr = 512
|
721
|
+
pid = tl.program_id(axis=0)
|
722
|
+
|
723
|
+
# find the req pool idx, this is for batch to token
|
724
|
+
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
725
|
+
chunk_kv_indices_offset = tl.load(chunk_cu_seq_lens_ptr + pid)
|
726
|
+
|
727
|
+
# get the token positions of current chunk
|
728
|
+
chunk_start_pos = tl.load(chunk_start_idx_ptr + pid).to(tl.int32)
|
729
|
+
chunk_seq_len = tl.load(chunk_seq_lens_ptr + pid).to(tl.int32)
|
730
|
+
|
731
|
+
num_loop = tl.cdiv(chunk_seq_len, BLOCK_SIZE)
|
732
|
+
for i in range(num_loop):
|
733
|
+
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
734
|
+
mask = offset < chunk_seq_len
|
735
|
+
data = tl.load(
|
736
|
+
req_to_token_ptr
|
737
|
+
+ req_pool_index * req_to_token_ptr_stride
|
738
|
+
+ chunk_start_pos
|
739
|
+
+ offset,
|
740
|
+
mask=mask,
|
741
|
+
)
|
742
|
+
tl.store(
|
743
|
+
chunk_kv_indices_ptr + chunk_kv_indices_offset + offset, data, mask=mask
|
744
|
+
)
|
@@ -75,8 +75,11 @@ from sglang.srt.utils import (
|
|
75
75
|
get_available_gpu_memory,
|
76
76
|
init_custom_process_group,
|
77
77
|
is_cuda,
|
78
|
+
is_fa3_default_architecture,
|
78
79
|
is_flashinfer_available,
|
79
80
|
is_hip,
|
81
|
+
is_hopper_with_cuda_12_3,
|
82
|
+
is_no_spec_infer_or_topk_one,
|
80
83
|
monkey_patch_p2p_access_check,
|
81
84
|
monkey_patch_vllm_gguf_config,
|
82
85
|
set_cpu_offload_max_bytes,
|
@@ -164,6 +167,7 @@ class ModelRunner:
|
|
164
167
|
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
165
168
|
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
166
169
|
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
|
170
|
+
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
|
167
171
|
"use_mla_backend": self.use_mla_backend,
|
168
172
|
}
|
169
173
|
)
|
@@ -236,11 +240,23 @@ class ModelRunner:
|
|
236
240
|
elif server_args.attention_backend is None:
|
237
241
|
# By default, use flashinfer for non-mla attention and triton for mla attention
|
238
242
|
if not self.use_mla_backend:
|
239
|
-
|
240
|
-
|
241
|
-
|
243
|
+
if (
|
244
|
+
is_hopper_with_cuda_12_3()
|
245
|
+
and is_no_spec_infer_or_topk_one(server_args)
|
246
|
+
and is_fa3_default_architecture(self.model_config.hf_config)
|
247
|
+
):
|
248
|
+
server_args.attention_backend = "fa3"
|
249
|
+
else:
|
250
|
+
server_args.attention_backend = (
|
251
|
+
"flashinfer" if is_flashinfer_available() else "triton"
|
252
|
+
)
|
242
253
|
else:
|
243
|
-
|
254
|
+
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
|
255
|
+
server_args
|
256
|
+
):
|
257
|
+
server_args.attention_backend = "fa3"
|
258
|
+
else:
|
259
|
+
server_args.attention_backend = "triton"
|
244
260
|
logger.info(
|
245
261
|
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
|
246
262
|
)
|
@@ -258,6 +274,16 @@ class ModelRunner:
|
|
258
274
|
else:
|
259
275
|
raise ValueError(f"MLA optimization not supported on CPU.")
|
260
276
|
|
277
|
+
if (
|
278
|
+
server_args.attention_backend == "fa3"
|
279
|
+
and server_args.kv_cache_dtype == "fp8_e5m2"
|
280
|
+
):
|
281
|
+
logger.warning(
|
282
|
+
"FlashAttention3 only supports fp8_e4m3 if using FP8; "
|
283
|
+
"Setting attention backend to triton."
|
284
|
+
)
|
285
|
+
server_args.attention_backend = "triton"
|
286
|
+
|
261
287
|
if server_args.enable_double_sparsity:
|
262
288
|
logger.info(
|
263
289
|
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
@@ -276,7 +302,6 @@ class ModelRunner:
|
|
276
302
|
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
277
303
|
f"because this is a multimodal model."
|
278
304
|
)
|
279
|
-
|
280
305
|
logger.info(
|
281
306
|
"Automatically turn off --chunked-prefill-size for multimodal model."
|
282
307
|
)
|
@@ -294,6 +319,16 @@ class ModelRunner:
|
|
294
319
|
if server_args.enable_deepep_moe:
|
295
320
|
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
|
296
321
|
|
322
|
+
if not self.use_mla_backend:
|
323
|
+
logger.info("Disable chunked prefix cache for non-MLA backend.")
|
324
|
+
server_args.disable_chunked_prefix_cache = True
|
325
|
+
elif self.page_size > 1:
|
326
|
+
logger.info("Disable chunked prefix cache when page size > 1.")
|
327
|
+
server_args.disable_chunked_prefix_cache = True
|
328
|
+
|
329
|
+
if not server_args.disable_chunked_prefix_cache:
|
330
|
+
logger.info("Chunked prefix cache is turned on.")
|
331
|
+
|
297
332
|
def init_torch_distributed(self):
|
298
333
|
logger.info("Init torch distributed begin.")
|
299
334
|
|
@@ -885,9 +920,6 @@ class ModelRunner:
|
|
885
920
|
"FlashAttention v3 Backend requires SM>=90. "
|
886
921
|
"Please use `--attention-backend flashinfer`."
|
887
922
|
)
|
888
|
-
logger.warning(
|
889
|
-
"FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
|
890
|
-
)
|
891
923
|
from sglang.srt.layers.attention.flashattention_backend import (
|
892
924
|
FlashAttentionBackend,
|
893
925
|
)
|
@@ -924,6 +956,12 @@ class ModelRunner:
|
|
924
956
|
return
|
925
957
|
|
926
958
|
if self.server_args.disable_cuda_graph:
|
959
|
+
logger.warning(
|
960
|
+
"\n\nCUDA Graph is DISABLED.\n"
|
961
|
+
"This will cause significant performance degradation.\n"
|
962
|
+
"CUDA Graph should almost never be disabled in most usage scenarios.\n"
|
963
|
+
"If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.\n"
|
964
|
+
)
|
927
965
|
return
|
928
966
|
|
929
967
|
tic = time.time()
|
@@ -1060,7 +1098,8 @@ class ModelRunner:
|
|
1060
1098
|
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
|
1061
1099
|
if rope_scaling is None:
|
1062
1100
|
return False
|
1063
|
-
|
1101
|
+
is_mrope_enabled = "mrope_section" in rope_scaling
|
1102
|
+
return is_mrope_enabled
|
1064
1103
|
|
1065
1104
|
def save_remote_model(self, url: str):
|
1066
1105
|
from sglang.srt.model_loader.loader import RemoteModelLoader
|
@@ -108,11 +108,15 @@ logger = logging.getLogger(__name__)
|
|
108
108
|
|
109
109
|
|
110
110
|
def _get_quantization_config(
|
111
|
-
model_config: ModelConfig,
|
111
|
+
model_config: ModelConfig,
|
112
|
+
load_config: LoadConfig,
|
113
|
+
packed_modules_mapping: Dict[str, List[str]],
|
112
114
|
) -> Optional[QuantizationConfig]:
|
113
115
|
"""Get the quantization config."""
|
114
116
|
if model_config.quantization is not None:
|
115
|
-
quant_config = get_quant_config(
|
117
|
+
quant_config = get_quant_config(
|
118
|
+
model_config, load_config, packed_modules_mapping
|
119
|
+
)
|
116
120
|
major, minor = get_device_capability()
|
117
121
|
|
118
122
|
if major is not None and minor is not None:
|
@@ -142,7 +146,10 @@ def _initialize_model(
|
|
142
146
|
) -> nn.Module:
|
143
147
|
"""Initialize a model with the given configurations."""
|
144
148
|
model_class, _ = get_model_architecture(model_config)
|
145
|
-
|
149
|
+
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
150
|
+
quant_config = _get_quantization_config(
|
151
|
+
model_config, load_config, packed_modules_mapping
|
152
|
+
)
|
146
153
|
return model_class(
|
147
154
|
config=model_config.hf_config,
|
148
155
|
quant_config=quant_config,
|
@@ -1064,19 +1071,37 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|
1064
1071
|
|
1065
1072
|
param_dict = dict(model.named_parameters())
|
1066
1073
|
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
|
1074
|
+
model_type = model_config.hf_config.model_type
|
1067
1075
|
for quant_param_name in quant_state_dict:
|
1068
1076
|
non_stacked_param_name = quant_param_name
|
1069
|
-
|
1077
|
+
if model_type == "mllama" and "vision_model" in quant_param_name:
|
1078
|
+
# adapt to VisionAttention
|
1079
|
+
quant_param_name = quant_param_name.replace(
|
1080
|
+
"self_attn.o_proj", "self_attn.proj"
|
1081
|
+
)
|
1070
1082
|
shard_index = 0
|
1071
1083
|
for shard_name, (
|
1072
1084
|
weight_name,
|
1073
1085
|
index,
|
1074
1086
|
) in model.bitsandbytes_stacked_params_mapping.items():
|
1087
|
+
if (
|
1088
|
+
model_type in ["qwen2_vl", "qwen2_5_vl"]
|
1089
|
+
and "visual" in quant_param_name
|
1090
|
+
):
|
1091
|
+
break
|
1075
1092
|
if shard_name in quant_param_name:
|
1076
1093
|
shard_index = index
|
1077
1094
|
quant_param_name = quant_param_name.replace(shard_name, weight_name)
|
1078
1095
|
break
|
1079
1096
|
|
1097
|
+
if (
|
1098
|
+
model_type in ["qwen2_vl", "qwen2_5_vl"]
|
1099
|
+
and "visual" in quant_param_name
|
1100
|
+
):
|
1101
|
+
quant_param_name = quant_param_name.replace(
|
1102
|
+
r"attn.qkv.", r"attn.qkv_proj."
|
1103
|
+
)
|
1104
|
+
|
1080
1105
|
if quant_param_name not in param_dict:
|
1081
1106
|
raise ValueError(
|
1082
1107
|
f"Parameter {quant_param_name} not found in the model."
|
@@ -1104,6 +1129,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|
1104
1129
|
num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
|
1105
1130
|
|
1106
1131
|
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
1132
|
+
# Make torch infer_schema happy(Compatible with vLLM)
|
1133
|
+
offsets = torch.tensor(offsets).cpu()
|
1107
1134
|
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
1108
1135
|
|
1109
1136
|
if load_8bit:
|
@@ -129,7 +129,9 @@ def convert_bin_to_safetensor_file(
|
|
129
129
|
|
130
130
|
# TODO(woosuk): Move this to other place.
|
131
131
|
def get_quant_config(
|
132
|
-
model_config: ModelConfig,
|
132
|
+
model_config: ModelConfig,
|
133
|
+
load_config: LoadConfig,
|
134
|
+
packed_modules_mapping: Dict[str, List[str]],
|
133
135
|
) -> QuantizationConfig:
|
134
136
|
quant_cls = get_quantization_config(model_config.quantization)
|
135
137
|
|
@@ -147,6 +149,7 @@ def get_quant_config(
|
|
147
149
|
# compressed-tensors uses a compressions_config
|
148
150
|
hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
|
149
151
|
if hf_quant_config is not None:
|
152
|
+
hf_quant_config["packed_modules_mapping"] = packed_modules_mapping
|
150
153
|
return quant_cls.from_config(hf_quant_config)
|
151
154
|
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
|
152
155
|
if model_config.quantization == "bitsandbytes":
|
@@ -457,7 +460,6 @@ def pt_weights_iterator(
|
|
457
460
|
state = torch.load(bin_file, map_location="cpu", weights_only=True)
|
458
461
|
yield from state.items()
|
459
462
|
del state
|
460
|
-
torch.cuda.empty_cache()
|
461
463
|
|
462
464
|
|
463
465
|
def get_gguf_extra_tensor_names(
|
sglang/srt/models/baichuan.py
CHANGED
@@ -178,6 +178,7 @@ class BaiChuanAttention(nn.Module):
|
|
178
178
|
scaling,
|
179
179
|
num_kv_heads=self.num_kv_heads,
|
180
180
|
layer_id=layer_id,
|
181
|
+
quant_config=quant_config,
|
181
182
|
prefix=add_prefix("attn", prefix),
|
182
183
|
)
|
183
184
|
else:
|
@@ -194,6 +195,7 @@ class BaiChuanAttention(nn.Module):
|
|
194
195
|
self.scaling,
|
195
196
|
num_kv_heads=self.num_kv_heads,
|
196
197
|
layer_id=layer_id,
|
198
|
+
quant_config=quant_config,
|
197
199
|
prefix=add_prefix("attn", prefix),
|
198
200
|
)
|
199
201
|
|
sglang/srt/models/chatglm.py
CHANGED
sglang/srt/models/commandr.py
CHANGED
sglang/srt/models/dbrx.py
CHANGED