sglang 0.3.5.post2__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +474 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +71 -1
- sglang/check_env.py +3 -6
- sglang/srt/constrained/outlines_backend.py +15 -2
- sglang/srt/constrained/xgrammar_backend.py +22 -14
- sglang/srt/layers/activation.py +3 -0
- sglang/srt/layers/attention/flashinfer_backend.py +93 -48
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/custom_op_util.py +26 -0
- sglang/srt/layers/fused_moe/fused_moe.py +11 -4
- sglang/srt/layers/layernorm.py +4 -0
- sglang/srt/layers/logits_processor.py +10 -10
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +74 -9
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/schedule_batch.py +104 -38
- sglang/srt/managers/schedule_policy.py +5 -1
- sglang/srt/managers/scheduler.py +204 -54
- sglang/srt/managers/session_controller.py +62 -0
- sglang/srt/managers/tokenizer_manager.py +38 -0
- sglang/srt/managers/tp_worker.py +12 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
- sglang/srt/model_executor/cuda_graph_runner.py +43 -6
- sglang/srt/model_executor/forward_batch_info.py +109 -15
- sglang/srt/model_executor/model_runner.py +99 -43
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/deepseek_v2.py +147 -44
- sglang/srt/models/gemma2.py +9 -8
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/torch_native_llama.py +94 -78
- sglang/srt/openai_api/adapter.py +6 -2
- sglang/srt/openai_api/protocol.py +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +58 -57
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +27 -1
- sglang/srt/server_args.py +78 -62
- sglang/srt/utils.py +71 -52
- sglang/test/runners.py +25 -6
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +30 -19
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/RECORD +60 -55
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -81,9 +81,22 @@ class OutlinesGrammar(BaseGrammarObject):
|
|
81
81
|
):
|
82
82
|
self.state = next_state
|
83
83
|
|
84
|
-
def
|
84
|
+
def allocate_vocab_mask(
|
85
|
+
self, vocab_size: int, batch_size: int, device
|
86
|
+
) -> torch.Tensor:
|
87
|
+
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
88
|
+
|
89
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
90
|
+
tokens = torch.tensor(
|
91
|
+
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
92
|
+
).to(vocab_mask.device, non_blocking=True)
|
93
|
+
vocab_mask = vocab_mask[idx]
|
85
94
|
vocab_mask.fill_(1)
|
86
|
-
vocab_mask
|
95
|
+
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
|
96
|
+
|
97
|
+
@staticmethod
|
98
|
+
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
|
99
|
+
logits.masked_fill_(vocab_mask, float("-inf"))
|
87
100
|
|
88
101
|
def copy(self):
|
89
102
|
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
@@ -21,7 +21,12 @@ from typing import List, Tuple
|
|
21
21
|
import torch
|
22
22
|
|
23
23
|
try:
|
24
|
-
from xgrammar import
|
24
|
+
from xgrammar import (
|
25
|
+
CachedGrammarCompiler,
|
26
|
+
CompiledGrammar,
|
27
|
+
GrammarMatcher,
|
28
|
+
TokenizerInfo,
|
29
|
+
)
|
25
30
|
|
26
31
|
import_error = None
|
27
32
|
except ImportError as e:
|
@@ -80,19 +85,23 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
80
85
|
for i in range(k, len(new_output_ids)):
|
81
86
|
assert self.matcher.accept_token(new_output_ids[i])
|
82
87
|
|
83
|
-
def
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
88
|
+
def allocate_vocab_mask(
|
89
|
+
self, vocab_size: int, batch_size: int, device
|
90
|
+
) -> torch.Tensor:
|
91
|
+
return self.matcher.allocate_token_bitmask(vocab_size, batch_size)
|
92
|
+
|
93
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
94
|
+
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
95
|
+
|
96
|
+
@staticmethod
|
97
|
+
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
98
|
+
GrammarMatcher.apply_token_bitmask_inplace(logits, vocab_mask)
|
90
99
|
|
91
100
|
def copy(self):
|
92
101
|
matcher = GrammarMatcher(
|
93
102
|
self.ctx,
|
94
103
|
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
95
|
-
|
104
|
+
vocab_size=self.vocab_size,
|
96
105
|
)
|
97
106
|
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
98
107
|
|
@@ -112,7 +121,8 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
112
121
|
self.grammar_cache = None
|
113
122
|
return
|
114
123
|
|
115
|
-
|
124
|
+
tokenizer_info = TokenizerInfo.from_huggingface(tokenizer)
|
125
|
+
self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info)
|
116
126
|
self.vocab_size = vocab_size
|
117
127
|
|
118
128
|
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
@@ -122,9 +132,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
122
132
|
key_type, key_string = key
|
123
133
|
if key_type == "json":
|
124
134
|
try:
|
125
|
-
ctx = self.grammar_cache.
|
126
|
-
key_string
|
127
|
-
)
|
135
|
+
ctx = self.grammar_cache.compile_json_schema_grammar(schema=key_string)
|
128
136
|
except RuntimeError as e:
|
129
137
|
logging.warning(
|
130
138
|
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
@@ -141,7 +149,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
141
149
|
matcher = GrammarMatcher(
|
142
150
|
ctx,
|
143
151
|
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
144
|
-
|
152
|
+
vocab_size=self.vocab_size,
|
145
153
|
)
|
146
154
|
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
147
155
|
|
sglang/srt/layers/activation.py
CHANGED
@@ -32,12 +32,14 @@ from vllm.distributed import (
|
|
32
32
|
)
|
33
33
|
from vllm.model_executor.custom_op import CustomOp
|
34
34
|
|
35
|
+
from sglang.srt.layers.custom_op_util import register_custom_op
|
35
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
37
|
from sglang.srt.utils import set_weight_attrs
|
37
38
|
|
38
39
|
logger = logging.getLogger(__name__)
|
39
40
|
|
40
41
|
|
42
|
+
@register_custom_op("sglang_silu_and_mul")
|
41
43
|
class SiluAndMul(CustomOp):
|
42
44
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
43
45
|
d = x.shape[-1] // 2
|
@@ -51,6 +53,7 @@ class SiluAndMul(CustomOp):
|
|
51
53
|
return out
|
52
54
|
|
53
55
|
|
56
|
+
@register_custom_op("sglang_gelu_and_mul")
|
54
57
|
class GeluAndMul(CustomOp):
|
55
58
|
def __init__(self, approximate="tanh"):
|
56
59
|
super().__init__()
|
@@ -8,7 +8,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
|
|
8
8
|
"""
|
9
9
|
|
10
10
|
from enum import Enum, auto
|
11
|
-
from typing import TYPE_CHECKING
|
11
|
+
from typing import TYPE_CHECKING, List
|
12
12
|
|
13
13
|
import torch
|
14
14
|
import triton
|
@@ -136,15 +136,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
136
136
|
prefix_lens = forward_batch.extend_prefix_lens
|
137
137
|
|
138
138
|
# Some heuristics to check whether to use ragged forward
|
139
|
-
use_ragged = False
|
140
139
|
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
|
141
140
|
use_ragged = True
|
142
|
-
|
143
|
-
|
141
|
+
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
142
|
+
else:
|
143
|
+
use_ragged = False
|
144
|
+
extend_no_prefix = False
|
144
145
|
|
145
146
|
self.indices_updater_prefill.update(
|
146
147
|
forward_batch.req_pool_indices,
|
147
148
|
forward_batch.seq_lens,
|
149
|
+
forward_batch.seq_lens_sum,
|
148
150
|
prefix_lens,
|
149
151
|
use_ragged=use_ragged,
|
150
152
|
encoder_lens=forward_batch.encoder_lens,
|
@@ -314,7 +316,6 @@ class FlashInferIndicesUpdaterDecode:
|
|
314
316
|
self.head_dim = model_runner.model_config.head_dim
|
315
317
|
self.data_type = model_runner.kv_cache_dtype
|
316
318
|
self.q_data_type = model_runner.dtype
|
317
|
-
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
318
319
|
self.sliding_window_size = model_runner.sliding_window_size
|
319
320
|
|
320
321
|
self.attn_backend = attn_backend
|
@@ -335,7 +336,12 @@ class FlashInferIndicesUpdaterDecode:
|
|
335
336
|
self.update = self.update_single_wrapper
|
336
337
|
|
337
338
|
def update(
|
338
|
-
self,
|
339
|
+
self,
|
340
|
+
req_pool_indices: torch.Tensor,
|
341
|
+
seq_lens: torch.Tensor,
|
342
|
+
seq_lens_sum: int,
|
343
|
+
decode_wrappers: List,
|
344
|
+
encoder_lens: torch.Tensor,
|
339
345
|
):
|
340
346
|
# Keep the signature for type checking. It will be assigned during runtime.
|
341
347
|
raise NotImplementedError()
|
@@ -345,8 +351,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
345
351
|
req_pool_indices: torch.Tensor,
|
346
352
|
seq_lens: torch.Tensor,
|
347
353
|
seq_lens_sum: int,
|
348
|
-
decode_wrappers
|
349
|
-
encoder_lens
|
354
|
+
decode_wrappers: List,
|
355
|
+
encoder_lens: torch.Tensor,
|
350
356
|
):
|
351
357
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
352
358
|
self.call_begin_forward(
|
@@ -363,8 +369,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
363
369
|
req_pool_indices: torch.Tensor,
|
364
370
|
seq_lens: torch.Tensor,
|
365
371
|
seq_lens_sum: int,
|
366
|
-
decode_wrappers
|
367
|
-
encoder_lens
|
372
|
+
decode_wrappers: List,
|
373
|
+
encoder_lens: torch.Tensor,
|
368
374
|
):
|
369
375
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
370
376
|
|
@@ -394,11 +400,11 @@ class FlashInferIndicesUpdaterDecode:
|
|
394
400
|
|
395
401
|
def update_cross_attention(
|
396
402
|
self,
|
397
|
-
req_pool_indices,
|
398
|
-
seq_lens,
|
399
|
-
seq_lens_sum,
|
400
|
-
decode_wrappers
|
401
|
-
encoder_lens
|
403
|
+
req_pool_indices: torch.Tensor,
|
404
|
+
seq_lens: torch.Tensor,
|
405
|
+
seq_lens_sum: int,
|
406
|
+
decode_wrappers: List,
|
407
|
+
encoder_lens: torch.Tensor,
|
402
408
|
):
|
403
409
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
404
410
|
|
@@ -425,11 +431,11 @@ class FlashInferIndicesUpdaterDecode:
|
|
425
431
|
def call_begin_forward(
|
426
432
|
self,
|
427
433
|
wrapper,
|
428
|
-
req_pool_indices,
|
429
|
-
paged_kernel_lens,
|
430
|
-
paged_kernel_lens_sum,
|
431
|
-
kv_indptr,
|
432
|
-
kv_start_idx,
|
434
|
+
req_pool_indices: torch.Tensor,
|
435
|
+
paged_kernel_lens: torch.Tensor,
|
436
|
+
paged_kernel_lens_sum: int,
|
437
|
+
kv_indptr: torch.Tensor,
|
438
|
+
kv_start_idx: torch.Tensor,
|
433
439
|
):
|
434
440
|
bs = len(req_pool_indices)
|
435
441
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
@@ -445,7 +451,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
445
451
|
kv_indptr,
|
446
452
|
kv_start_idx,
|
447
453
|
kv_indices,
|
448
|
-
self.
|
454
|
+
self.req_to_token.shape[1],
|
449
455
|
)
|
450
456
|
|
451
457
|
wrapper.end_forward()
|
@@ -474,7 +480,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|
474
480
|
self.head_dim = model_runner.model_config.head_dim
|
475
481
|
self.data_type = model_runner.kv_cache_dtype
|
476
482
|
self.q_data_type = model_runner.dtype
|
477
|
-
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
478
483
|
self.sliding_window_size = model_runner.sliding_window_size
|
479
484
|
|
480
485
|
self.attn_backend = attn_backend
|
@@ -496,23 +501,40 @@ class FlashInferIndicesUpdaterPrefill:
|
|
496
501
|
assert self.attn_backend.num_wrappers == 1
|
497
502
|
self.update = self.update_single_wrapper
|
498
503
|
|
499
|
-
def update(
|
504
|
+
def update(
|
505
|
+
self,
|
506
|
+
req_pool_indices: torch.Tnesor,
|
507
|
+
seq_lens: torch.Tensor,
|
508
|
+
seq_lens_sum: int,
|
509
|
+
prefix_lens: torch.Tensor,
|
510
|
+
use_ragged: bool,
|
511
|
+
encoder_lens: torch.Tensor,
|
512
|
+
):
|
500
513
|
# Keep the signature for type checking. It will be assigned during runtime.
|
501
514
|
raise NotImplementedError()
|
502
515
|
|
503
516
|
def update_single_wrapper(
|
504
|
-
self,
|
517
|
+
self,
|
518
|
+
req_pool_indices: torch.Tnesor,
|
519
|
+
seq_lens: torch.Tensor,
|
520
|
+
seq_lens_sum: int,
|
521
|
+
prefix_lens: torch.Tensor,
|
522
|
+
use_ragged: bool,
|
523
|
+
encoder_lens: torch.Tensor,
|
505
524
|
):
|
506
525
|
if use_ragged:
|
507
526
|
paged_kernel_lens = prefix_lens
|
527
|
+
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
508
528
|
else:
|
509
529
|
paged_kernel_lens = seq_lens
|
530
|
+
paged_kernel_lens_sum = seq_lens_sum
|
510
531
|
|
511
532
|
self.call_begin_forward(
|
512
533
|
self.wrapper_ragged,
|
513
534
|
self.wrappers_paged[0],
|
514
535
|
req_pool_indices,
|
515
536
|
paged_kernel_lens,
|
537
|
+
paged_kernel_lens_sum,
|
516
538
|
seq_lens,
|
517
539
|
prefix_lens,
|
518
540
|
None,
|
@@ -522,7 +544,13 @@ class FlashInferIndicesUpdaterPrefill:
|
|
522
544
|
)
|
523
545
|
|
524
546
|
def update_sliding_window(
|
525
|
-
self,
|
547
|
+
self,
|
548
|
+
req_pool_indices: torch.Tensor,
|
549
|
+
seq_lens: torch.Tensor,
|
550
|
+
seq_lens_sum: int,
|
551
|
+
prefix_lens: torch.Tensor,
|
552
|
+
use_ragged: bool,
|
553
|
+
encoder_lens: torch.Tensor,
|
526
554
|
):
|
527
555
|
for wrapper_id in range(2):
|
528
556
|
if wrapper_id == 0:
|
@@ -531,9 +559,12 @@ class FlashInferIndicesUpdaterPrefill:
|
|
531
559
|
seq_lens,
|
532
560
|
torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
|
533
561
|
)
|
562
|
+
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
534
563
|
else:
|
535
564
|
# full attention
|
536
565
|
paged_kernel_lens = seq_lens
|
566
|
+
paged_kernel_lens_sum = seq_lens_sum
|
567
|
+
|
537
568
|
kv_start_idx = seq_lens - paged_kernel_lens
|
538
569
|
|
539
570
|
self.call_begin_forward(
|
@@ -541,6 +572,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
541
572
|
self.wrappers_paged[wrapper_id],
|
542
573
|
req_pool_indices,
|
543
574
|
paged_kernel_lens,
|
575
|
+
paged_kernel_lens_sum,
|
544
576
|
seq_lens,
|
545
577
|
prefix_lens,
|
546
578
|
kv_start_idx,
|
@@ -550,23 +582,32 @@ class FlashInferIndicesUpdaterPrefill:
|
|
550
582
|
)
|
551
583
|
|
552
584
|
def update_cross_attention(
|
553
|
-
self,
|
585
|
+
self,
|
586
|
+
req_pool_indices: torch.Tensor,
|
587
|
+
seq_lens: torch.Tensor,
|
588
|
+
seq_lens_sum: int,
|
589
|
+
prefix_lens: torch.Tensor,
|
590
|
+
use_ragged: bool,
|
591
|
+
encoder_lens: torch.Tensor,
|
554
592
|
):
|
555
593
|
for wrapper_id in range(2):
|
556
594
|
if wrapper_id == 0:
|
557
595
|
# normal attention
|
558
596
|
paged_kernel_lens = seq_lens
|
559
597
|
kv_start_idx = encoder_lens
|
598
|
+
paged_kernel_lens_sum = seq_lens_sum
|
560
599
|
else:
|
561
600
|
# cross attention
|
562
601
|
paged_kernel_lens = encoder_lens
|
563
602
|
kv_start_idx = torch.zeros_like(encoder_lens)
|
603
|
+
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
564
604
|
|
565
605
|
self.call_begin_forward(
|
566
606
|
self.wrapper_ragged,
|
567
607
|
self.wrappers_paged[wrapper_id],
|
568
608
|
req_pool_indices,
|
569
609
|
paged_kernel_lens,
|
610
|
+
paged_kernel_lens_sum,
|
570
611
|
seq_lens,
|
571
612
|
prefix_lens,
|
572
613
|
kv_start_idx,
|
@@ -579,19 +620,22 @@ class FlashInferIndicesUpdaterPrefill:
|
|
579
620
|
self,
|
580
621
|
wrapper_ragged,
|
581
622
|
wrapper_paged,
|
582
|
-
req_pool_indices,
|
583
|
-
paged_kernel_lens,
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
623
|
+
req_pool_indices: torch.Tensor,
|
624
|
+
paged_kernel_lens: torch.Tensor,
|
625
|
+
paged_kernel_lens_sum: int,
|
626
|
+
seq_lens: torch.Tensor,
|
627
|
+
prefix_lens: torch.Tensor,
|
628
|
+
kv_start_idx: torch.Tensor,
|
629
|
+
kv_indptr: torch.Tensor,
|
630
|
+
qo_indptr: torch.Tensor,
|
631
|
+
use_ragged: bool,
|
590
632
|
):
|
591
633
|
bs = len(req_pool_indices)
|
592
634
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
593
635
|
kv_indptr = kv_indptr[: bs + 1]
|
594
|
-
kv_indices = torch.empty(
|
636
|
+
kv_indices = torch.empty(
|
637
|
+
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
638
|
+
)
|
595
639
|
create_flashinfer_kv_indices_triton[(bs,)](
|
596
640
|
self.req_to_token,
|
597
641
|
req_pool_indices,
|
@@ -599,7 +643,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
599
643
|
kv_indptr,
|
600
644
|
kv_start_idx,
|
601
645
|
kv_indices,
|
602
|
-
self.
|
646
|
+
self.req_to_token.shape[1],
|
603
647
|
)
|
604
648
|
|
605
649
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
@@ -638,10 +682,11 @@ def create_flashinfer_kv_indices_triton(
|
|
638
682
|
kv_indptr,
|
639
683
|
kv_start_idx,
|
640
684
|
kv_indices_ptr,
|
641
|
-
|
685
|
+
req_to_token_ptr_stride: tl.constexpr,
|
642
686
|
):
|
643
687
|
BLOCK_SIZE: tl.constexpr = 512
|
644
688
|
pid = tl.program_id(axis=0)
|
689
|
+
|
645
690
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
646
691
|
kv_indices_offset = tl.load(kv_indptr + pid)
|
647
692
|
|
@@ -652,15 +697,15 @@ def create_flashinfer_kv_indices_triton(
|
|
652
697
|
kv_end = kv_start
|
653
698
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
654
699
|
|
655
|
-
req_to_token_ptr += req_pool_index * max_context_len
|
656
|
-
kv_indices_ptr += kv_indices_offset
|
657
|
-
|
658
|
-
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
659
|
-
st_offset = tl.arange(0, BLOCK_SIZE)
|
660
700
|
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
661
|
-
for
|
662
|
-
|
663
|
-
|
664
|
-
tl.
|
665
|
-
|
666
|
-
|
701
|
+
for i in range(num_loop):
|
702
|
+
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
703
|
+
mask = offset < kv_end - kv_start
|
704
|
+
data = tl.load(
|
705
|
+
req_to_token_ptr
|
706
|
+
+ req_pool_index * req_to_token_ptr_stride
|
707
|
+
+ kv_start
|
708
|
+
+ offset,
|
709
|
+
mask=mask,
|
710
|
+
)
|
711
|
+
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
5
5
|
import torch
|
6
|
-
import torch.nn as nn
|
7
6
|
|
8
7
|
from sglang.srt.layers.attention import AttentionBackend
|
9
8
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -28,9 +27,13 @@ class TritonAttnBackend(AttentionBackend):
|
|
28
27
|
|
29
28
|
self.decode_attention_fwd = decode_attention_fwd
|
30
29
|
self.extend_attention_fwd = extend_attention_fwd
|
31
|
-
|
32
|
-
|
33
|
-
|
30
|
+
|
31
|
+
if model_runner.server_args.enable_dp_attention:
|
32
|
+
self.num_head = model_runner.model_config.num_attention_heads
|
33
|
+
else:
|
34
|
+
self.num_head = (
|
35
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
36
|
+
)
|
34
37
|
|
35
38
|
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
36
39
|
self.reduce_dtype = torch.float32
|
@@ -50,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
50
53
|
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
|
51
54
|
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
|
52
55
|
|
53
|
-
total_num_tokens =
|
56
|
+
total_num_tokens = forward_batch.seq_lens_sum
|
54
57
|
attn_logits = torch.empty(
|
55
58
|
(self.num_head, total_num_tokens),
|
56
59
|
dtype=self.reduce_dtype,
|
@@ -61,8 +64,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
61
64
|
max_extend_len = None
|
62
65
|
else:
|
63
66
|
start_loc = attn_logits = max_seq_len = None
|
64
|
-
|
65
|
-
max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
|
67
|
+
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
66
68
|
|
67
69
|
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
|
68
70
|
|
@@ -0,0 +1,26 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
from vllm.model_executor.custom_op import CustomOp
|
17
|
+
|
18
|
+
|
19
|
+
def register_custom_op(op_name):
|
20
|
+
def decorator(cls):
|
21
|
+
if hasattr(CustomOp, "register"):
|
22
|
+
return CustomOp.register(op_name)(cls)
|
23
|
+
else:
|
24
|
+
return cls
|
25
|
+
|
26
|
+
return decorator
|
@@ -250,9 +250,12 @@ def invoke_fused_moe_kernel(
|
|
250
250
|
assert topk_weights.stride(1) == 1
|
251
251
|
assert sorted_token_ids.stride(0) == 1
|
252
252
|
|
253
|
+
padded_size = padding_size
|
253
254
|
if not use_fp8:
|
254
255
|
assert A_scale is None
|
255
256
|
assert B_scale is None
|
257
|
+
# MOE_PADDING FP8 only
|
258
|
+
padded_size = 0
|
256
259
|
else:
|
257
260
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
258
261
|
assert B_scale is not None
|
@@ -262,7 +265,7 @@ def invoke_fused_moe_kernel(
|
|
262
265
|
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
263
266
|
)
|
264
267
|
|
265
|
-
K = B.shape[2] -
|
268
|
+
K = B.shape[2] - padded_size
|
266
269
|
if K % config["BLOCK_SIZE_K"] == 0:
|
267
270
|
even_ks = True
|
268
271
|
else:
|
@@ -279,7 +282,7 @@ def invoke_fused_moe_kernel(
|
|
279
282
|
expert_ids,
|
280
283
|
num_tokens_post_padded,
|
281
284
|
B.shape[1],
|
282
|
-
B.shape[2] -
|
285
|
+
B.shape[2] - padded_size,
|
283
286
|
sorted_token_ids.shape[0],
|
284
287
|
topk_ids.numel(),
|
285
288
|
A.stride(0),
|
@@ -480,8 +483,12 @@ def fused_experts(
|
|
480
483
|
a1_scale: Optional[torch.Tensor] = None,
|
481
484
|
a2_scale: Optional[torch.Tensor] = None,
|
482
485
|
):
|
486
|
+
padded_size = padding_size
|
487
|
+
if not use_fp8:
|
488
|
+
# MOE_PADDING FP8 only
|
489
|
+
padded_size = 0
|
483
490
|
# Check constraints.
|
484
|
-
assert hidden_states.shape[1] == w1.shape[2] -
|
491
|
+
assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
|
485
492
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
486
493
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
487
494
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
@@ -498,7 +505,7 @@ def fused_experts(
|
|
498
505
|
get_config_func = functools.partial(
|
499
506
|
try_get_optimal_moe_config,
|
500
507
|
w1.shape,
|
501
|
-
(w2.shape[0], w2.shape[1], w2.shape[2] -
|
508
|
+
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
|
502
509
|
topk_ids.shape[1],
|
503
510
|
"float8" if use_fp8 else None,
|
504
511
|
override_config=override_config,
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -33,9 +33,12 @@ if is_flashinfer_available():
|
|
33
33
|
|
34
34
|
from vllm.model_executor.custom_op import CustomOp
|
35
35
|
|
36
|
+
from sglang.srt.layers.custom_op_util import register_custom_op
|
37
|
+
|
36
38
|
logger = logging.getLogger(__name__)
|
37
39
|
|
38
40
|
|
41
|
+
@register_custom_op("sglang_rmsnorm")
|
39
42
|
class RMSNorm(CustomOp):
|
40
43
|
def __init__(
|
41
44
|
self,
|
@@ -78,6 +81,7 @@ class RMSNorm(CustomOp):
|
|
78
81
|
return x, residual
|
79
82
|
|
80
83
|
|
84
|
+
@register_custom_op("sglang_gemma_rmsnorm")
|
81
85
|
class GemmaRMSNorm(CustomOp):
|
82
86
|
def __init__(
|
83
87
|
self,
|
@@ -62,21 +62,21 @@ class LogitsMetadata:
|
|
62
62
|
|
63
63
|
@classmethod
|
64
64
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
65
|
+
extend_logprob_pruned_lens_cpu = None
|
66
|
+
|
65
67
|
if forward_batch.return_logprob:
|
66
68
|
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
69
|
+
if forward_batch.forward_mode.is_extend():
|
70
|
+
extend_logprob_pruned_lens_cpu = [
|
71
|
+
extend_len - start_len
|
72
|
+
for extend_len, start_len in zip(
|
73
|
+
forward_batch.extend_seq_lens_cpu,
|
74
|
+
forward_batch.extend_logprob_start_lens_cpu,
|
75
|
+
)
|
76
|
+
]
|
67
77
|
else:
|
68
78
|
return_top_logprob = False
|
69
79
|
|
70
|
-
if forward_batch.forward_mode.is_extend():
|
71
|
-
extend_logprob_pruned_lens_cpu = [
|
72
|
-
extend_len - start_len
|
73
|
-
for extend_len, start_len in zip(
|
74
|
-
forward_batch.extend_seq_lens,
|
75
|
-
forward_batch.extend_logprob_start_lens_cpu,
|
76
|
-
)
|
77
|
-
]
|
78
|
-
else:
|
79
|
-
extend_logprob_pruned_lens_cpu = None
|
80
80
|
return cls(
|
81
81
|
forward_mode=forward_batch.forward_mode,
|
82
82
|
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
import logging
|
2
|
-
import os
|
3
2
|
from typing import Union
|
4
3
|
|
5
4
|
import torch
|
@@ -8,7 +7,7 @@ from torch import nn
|
|
8
7
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
9
8
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
10
9
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
11
|
-
from sglang.srt.utils import is_flashinfer_available
|
10
|
+
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
|
12
11
|
|
13
12
|
if is_flashinfer_available():
|
14
13
|
from flashinfer.sampling import (
|
@@ -19,17 +18,13 @@ if is_flashinfer_available():
|
|
19
18
|
)
|
20
19
|
|
21
20
|
|
22
|
-
# Crash on warning if we are running CI tests
|
23
|
-
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
24
|
-
|
25
|
-
|
26
21
|
logger = logging.getLogger(__name__)
|
27
22
|
|
28
23
|
|
29
24
|
class Sampler(nn.Module):
|
30
25
|
def __init__(self):
|
31
26
|
super().__init__()
|
32
|
-
self.use_nan_detectioin =
|
27
|
+
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
|
33
28
|
|
34
29
|
def forward(
|
35
30
|
self,
|
@@ -46,7 +41,8 @@ class Sampler(nn.Module):
|
|
46
41
|
logits = torch.where(
|
47
42
|
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
48
43
|
)
|
49
|
-
|
44
|
+
if crash_on_warnings():
|
45
|
+
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
50
46
|
|
51
47
|
if sampling_info.is_all_greedy:
|
52
48
|
# Use torch.argmax if all requests use greedy sampling
|