sglang 0.3.5.post2__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +48 -20
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +71 -1
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/outlines_backend.py +15 -2
  8. sglang/srt/constrained/xgrammar_backend.py +22 -14
  9. sglang/srt/layers/activation.py +3 -0
  10. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  11. sglang/srt/layers/attention/triton_backend.py +9 -7
  12. sglang/srt/layers/custom_op_util.py +26 -0
  13. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  14. sglang/srt/layers/layernorm.py +4 -0
  15. sglang/srt/layers/logits_processor.py +10 -10
  16. sglang/srt/layers/sampler.py +4 -8
  17. sglang/srt/layers/torchao_utils.py +2 -0
  18. sglang/srt/managers/data_parallel_controller.py +74 -9
  19. sglang/srt/managers/detokenizer_manager.py +1 -0
  20. sglang/srt/managers/io_struct.py +27 -0
  21. sglang/srt/managers/schedule_batch.py +104 -38
  22. sglang/srt/managers/schedule_policy.py +5 -1
  23. sglang/srt/managers/scheduler.py +204 -54
  24. sglang/srt/managers/session_controller.py +62 -0
  25. sglang/srt/managers/tokenizer_manager.py +38 -0
  26. sglang/srt/managers/tp_worker.py +12 -1
  27. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  28. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  29. sglang/srt/model_executor/forward_batch_info.py +109 -15
  30. sglang/srt/model_executor/model_runner.py +99 -43
  31. sglang/srt/model_parallel.py +98 -0
  32. sglang/srt/models/deepseek_v2.py +147 -44
  33. sglang/srt/models/gemma2.py +9 -8
  34. sglang/srt/models/llava.py +1 -1
  35. sglang/srt/models/llavavid.py +1 -1
  36. sglang/srt/models/olmo.py +3 -3
  37. sglang/srt/models/phi3_small.py +447 -0
  38. sglang/srt/models/qwen2_vl.py +13 -6
  39. sglang/srt/models/torch_native_llama.py +94 -78
  40. sglang/srt/openai_api/adapter.py +6 -2
  41. sglang/srt/openai_api/protocol.py +1 -1
  42. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  43. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  44. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  45. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  47. sglang/srt/sampling/sampling_batch_info.py +58 -57
  48. sglang/srt/sampling/sampling_params.py +1 -1
  49. sglang/srt/server.py +27 -1
  50. sglang/srt/server_args.py +78 -62
  51. sglang/srt/utils.py +71 -52
  52. sglang/test/runners.py +25 -6
  53. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  54. sglang/test/test_utils.py +30 -19
  55. sglang/version.py +1 -1
  56. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  57. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/RECORD +60 -55
  58. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  59. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  60. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -81,9 +81,22 @@ class OutlinesGrammar(BaseGrammarObject):
81
81
  ):
82
82
  self.state = next_state
83
83
 
84
- def fill_vocab_mask(self, vocab_mask: torch.Tensor):
84
+ def allocate_vocab_mask(
85
+ self, vocab_size: int, batch_size: int, device
86
+ ) -> torch.Tensor:
87
+ return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
88
+
89
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
90
+ tokens = torch.tensor(
91
+ self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
92
+ ).to(vocab_mask.device, non_blocking=True)
93
+ vocab_mask = vocab_mask[idx]
85
94
  vocab_mask.fill_(1)
86
- vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
95
+ vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
96
+
97
+ @staticmethod
98
+ def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
99
+ logits.masked_fill_(vocab_mask, float("-inf"))
87
100
 
88
101
  def copy(self):
89
102
  return OutlinesGrammar(self.guide, self.jump_forward_map)
@@ -21,7 +21,12 @@ from typing import List, Tuple
21
21
  import torch
22
22
 
23
23
  try:
24
- from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
24
+ from xgrammar import (
25
+ CachedGrammarCompiler,
26
+ CompiledGrammar,
27
+ GrammarMatcher,
28
+ TokenizerInfo,
29
+ )
25
30
 
26
31
  import_error = None
27
32
  except ImportError as e:
@@ -80,19 +85,23 @@ class XGrammarGrammar(BaseGrammarObject):
80
85
  for i in range(k, len(new_output_ids)):
81
86
  assert self.matcher.accept_token(new_output_ids[i])
82
87
 
83
- def fill_vocab_mask(self, vocab_mask: torch.Tensor):
84
- # Note that this bitmask is a bitset, not bool
85
- bitmask = self.matcher.get_next_token_bitmask()
86
- # Mask the tokens that are not allowed
87
- vocab_mask[
88
- self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
89
- ] = 1
88
+ def allocate_vocab_mask(
89
+ self, vocab_size: int, batch_size: int, device
90
+ ) -> torch.Tensor:
91
+ return self.matcher.allocate_token_bitmask(vocab_size, batch_size)
92
+
93
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
94
+ self.matcher.fill_next_token_bitmask(vocab_mask, idx)
95
+
96
+ @staticmethod
97
+ def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
98
+ GrammarMatcher.apply_token_bitmask_inplace(logits, vocab_mask)
90
99
 
91
100
  def copy(self):
92
101
  matcher = GrammarMatcher(
93
102
  self.ctx,
94
103
  max_rollback_tokens=MAX_ROLLBACK_TOKENS,
95
- mask_vocab_size=self.vocab_size,
104
+ vocab_size=self.vocab_size,
96
105
  )
97
106
  return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
98
107
 
@@ -112,7 +121,8 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
112
121
  self.grammar_cache = None
113
122
  return
114
123
 
115
- self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
124
+ tokenizer_info = TokenizerInfo.from_huggingface(tokenizer)
125
+ self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info)
116
126
  self.vocab_size = vocab_size
117
127
 
118
128
  def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
@@ -122,9 +132,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
122
132
  key_type, key_string = key
123
133
  if key_type == "json":
124
134
  try:
125
- ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(
126
- key_string
127
- )
135
+ ctx = self.grammar_cache.compile_json_schema_grammar(schema=key_string)
128
136
  except RuntimeError as e:
129
137
  logging.warning(
130
138
  f"Skip invalid json_schema: json_schema={key_string}, {e=}"
@@ -141,7 +149,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
141
149
  matcher = GrammarMatcher(
142
150
  ctx,
143
151
  max_rollback_tokens=MAX_ROLLBACK_TOKENS,
144
- mask_vocab_size=self.vocab_size,
152
+ vocab_size=self.vocab_size,
145
153
  )
146
154
  return XGrammarGrammar(matcher, self.vocab_size, ctx)
147
155
 
@@ -32,12 +32,14 @@ from vllm.distributed import (
32
32
  )
33
33
  from vllm.model_executor.custom_op import CustomOp
34
34
 
35
+ from sglang.srt.layers.custom_op_util import register_custom_op
35
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
37
  from sglang.srt.utils import set_weight_attrs
37
38
 
38
39
  logger = logging.getLogger(__name__)
39
40
 
40
41
 
42
+ @register_custom_op("sglang_silu_and_mul")
41
43
  class SiluAndMul(CustomOp):
42
44
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
43
45
  d = x.shape[-1] // 2
@@ -51,6 +53,7 @@ class SiluAndMul(CustomOp):
51
53
  return out
52
54
 
53
55
 
56
+ @register_custom_op("sglang_gelu_and_mul")
54
57
  class GeluAndMul(CustomOp):
55
58
  def __init__(self, approximate="tanh"):
56
59
  super().__init__()
@@ -8,7 +8,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
8
8
  """
9
9
 
10
10
  from enum import Enum, auto
11
- from typing import TYPE_CHECKING
11
+ from typing import TYPE_CHECKING, List
12
12
 
13
13
  import torch
14
14
  import triton
@@ -136,15 +136,17 @@ class FlashInferAttnBackend(AttentionBackend):
136
136
  prefix_lens = forward_batch.extend_prefix_lens
137
137
 
138
138
  # Some heuristics to check whether to use ragged forward
139
- use_ragged = False
140
139
  if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
141
140
  use_ragged = True
142
-
143
- extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
141
+ extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
142
+ else:
143
+ use_ragged = False
144
+ extend_no_prefix = False
144
145
 
145
146
  self.indices_updater_prefill.update(
146
147
  forward_batch.req_pool_indices,
147
148
  forward_batch.seq_lens,
149
+ forward_batch.seq_lens_sum,
148
150
  prefix_lens,
149
151
  use_ragged=use_ragged,
150
152
  encoder_lens=forward_batch.encoder_lens,
@@ -314,7 +316,6 @@ class FlashInferIndicesUpdaterDecode:
314
316
  self.head_dim = model_runner.model_config.head_dim
315
317
  self.data_type = model_runner.kv_cache_dtype
316
318
  self.q_data_type = model_runner.dtype
317
- self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
318
319
  self.sliding_window_size = model_runner.sliding_window_size
319
320
 
320
321
  self.attn_backend = attn_backend
@@ -335,7 +336,12 @@ class FlashInferIndicesUpdaterDecode:
335
336
  self.update = self.update_single_wrapper
336
337
 
337
338
  def update(
338
- self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
339
+ self,
340
+ req_pool_indices: torch.Tensor,
341
+ seq_lens: torch.Tensor,
342
+ seq_lens_sum: int,
343
+ decode_wrappers: List,
344
+ encoder_lens: torch.Tensor,
339
345
  ):
340
346
  # Keep the signature for type checking. It will be assigned during runtime.
341
347
  raise NotImplementedError()
@@ -345,8 +351,8 @@ class FlashInferIndicesUpdaterDecode:
345
351
  req_pool_indices: torch.Tensor,
346
352
  seq_lens: torch.Tensor,
347
353
  seq_lens_sum: int,
348
- decode_wrappers=None,
349
- encoder_lens=None,
354
+ decode_wrappers: List,
355
+ encoder_lens: torch.Tensor,
350
356
  ):
351
357
  decode_wrappers = decode_wrappers or self.decode_wrappers
352
358
  self.call_begin_forward(
@@ -363,8 +369,8 @@ class FlashInferIndicesUpdaterDecode:
363
369
  req_pool_indices: torch.Tensor,
364
370
  seq_lens: torch.Tensor,
365
371
  seq_lens_sum: int,
366
- decode_wrappers=None,
367
- encoder_lens=None,
372
+ decode_wrappers: List,
373
+ encoder_lens: torch.Tensor,
368
374
  ):
369
375
  decode_wrappers = decode_wrappers or self.decode_wrappers
370
376
 
@@ -394,11 +400,11 @@ class FlashInferIndicesUpdaterDecode:
394
400
 
395
401
  def update_cross_attention(
396
402
  self,
397
- req_pool_indices,
398
- seq_lens,
399
- seq_lens_sum,
400
- decode_wrappers=None,
401
- encoder_lens=None,
403
+ req_pool_indices: torch.Tensor,
404
+ seq_lens: torch.Tensor,
405
+ seq_lens_sum: int,
406
+ decode_wrappers: List,
407
+ encoder_lens: torch.Tensor,
402
408
  ):
403
409
  decode_wrappers = decode_wrappers or self.decode_wrappers
404
410
 
@@ -425,11 +431,11 @@ class FlashInferIndicesUpdaterDecode:
425
431
  def call_begin_forward(
426
432
  self,
427
433
  wrapper,
428
- req_pool_indices,
429
- paged_kernel_lens,
430
- paged_kernel_lens_sum,
431
- kv_indptr,
432
- kv_start_idx,
434
+ req_pool_indices: torch.Tensor,
435
+ paged_kernel_lens: torch.Tensor,
436
+ paged_kernel_lens_sum: int,
437
+ kv_indptr: torch.Tensor,
438
+ kv_start_idx: torch.Tensor,
433
439
  ):
434
440
  bs = len(req_pool_indices)
435
441
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
@@ -445,7 +451,7 @@ class FlashInferIndicesUpdaterDecode:
445
451
  kv_indptr,
446
452
  kv_start_idx,
447
453
  kv_indices,
448
- self.max_context_len,
454
+ self.req_to_token.shape[1],
449
455
  )
450
456
 
451
457
  wrapper.end_forward()
@@ -474,7 +480,6 @@ class FlashInferIndicesUpdaterPrefill:
474
480
  self.head_dim = model_runner.model_config.head_dim
475
481
  self.data_type = model_runner.kv_cache_dtype
476
482
  self.q_data_type = model_runner.dtype
477
- self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
478
483
  self.sliding_window_size = model_runner.sliding_window_size
479
484
 
480
485
  self.attn_backend = attn_backend
@@ -496,23 +501,40 @@ class FlashInferIndicesUpdaterPrefill:
496
501
  assert self.attn_backend.num_wrappers == 1
497
502
  self.update = self.update_single_wrapper
498
503
 
499
- def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
504
+ def update(
505
+ self,
506
+ req_pool_indices: torch.Tnesor,
507
+ seq_lens: torch.Tensor,
508
+ seq_lens_sum: int,
509
+ prefix_lens: torch.Tensor,
510
+ use_ragged: bool,
511
+ encoder_lens: torch.Tensor,
512
+ ):
500
513
  # Keep the signature for type checking. It will be assigned during runtime.
501
514
  raise NotImplementedError()
502
515
 
503
516
  def update_single_wrapper(
504
- self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
517
+ self,
518
+ req_pool_indices: torch.Tnesor,
519
+ seq_lens: torch.Tensor,
520
+ seq_lens_sum: int,
521
+ prefix_lens: torch.Tensor,
522
+ use_ragged: bool,
523
+ encoder_lens: torch.Tensor,
505
524
  ):
506
525
  if use_ragged:
507
526
  paged_kernel_lens = prefix_lens
527
+ paged_kernel_lens_sum = paged_kernel_lens.sum().item()
508
528
  else:
509
529
  paged_kernel_lens = seq_lens
530
+ paged_kernel_lens_sum = seq_lens_sum
510
531
 
511
532
  self.call_begin_forward(
512
533
  self.wrapper_ragged,
513
534
  self.wrappers_paged[0],
514
535
  req_pool_indices,
515
536
  paged_kernel_lens,
537
+ paged_kernel_lens_sum,
516
538
  seq_lens,
517
539
  prefix_lens,
518
540
  None,
@@ -522,7 +544,13 @@ class FlashInferIndicesUpdaterPrefill:
522
544
  )
523
545
 
524
546
  def update_sliding_window(
525
- self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
547
+ self,
548
+ req_pool_indices: torch.Tensor,
549
+ seq_lens: torch.Tensor,
550
+ seq_lens_sum: int,
551
+ prefix_lens: torch.Tensor,
552
+ use_ragged: bool,
553
+ encoder_lens: torch.Tensor,
526
554
  ):
527
555
  for wrapper_id in range(2):
528
556
  if wrapper_id == 0:
@@ -531,9 +559,12 @@ class FlashInferIndicesUpdaterPrefill:
531
559
  seq_lens,
532
560
  torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
533
561
  )
562
+ paged_kernel_lens_sum = paged_kernel_lens.sum().item()
534
563
  else:
535
564
  # full attention
536
565
  paged_kernel_lens = seq_lens
566
+ paged_kernel_lens_sum = seq_lens_sum
567
+
537
568
  kv_start_idx = seq_lens - paged_kernel_lens
538
569
 
539
570
  self.call_begin_forward(
@@ -541,6 +572,7 @@ class FlashInferIndicesUpdaterPrefill:
541
572
  self.wrappers_paged[wrapper_id],
542
573
  req_pool_indices,
543
574
  paged_kernel_lens,
575
+ paged_kernel_lens_sum,
544
576
  seq_lens,
545
577
  prefix_lens,
546
578
  kv_start_idx,
@@ -550,23 +582,32 @@ class FlashInferIndicesUpdaterPrefill:
550
582
  )
551
583
 
552
584
  def update_cross_attention(
553
- self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
585
+ self,
586
+ req_pool_indices: torch.Tensor,
587
+ seq_lens: torch.Tensor,
588
+ seq_lens_sum: int,
589
+ prefix_lens: torch.Tensor,
590
+ use_ragged: bool,
591
+ encoder_lens: torch.Tensor,
554
592
  ):
555
593
  for wrapper_id in range(2):
556
594
  if wrapper_id == 0:
557
595
  # normal attention
558
596
  paged_kernel_lens = seq_lens
559
597
  kv_start_idx = encoder_lens
598
+ paged_kernel_lens_sum = seq_lens_sum
560
599
  else:
561
600
  # cross attention
562
601
  paged_kernel_lens = encoder_lens
563
602
  kv_start_idx = torch.zeros_like(encoder_lens)
603
+ paged_kernel_lens_sum = paged_kernel_lens.sum().item()
564
604
 
565
605
  self.call_begin_forward(
566
606
  self.wrapper_ragged,
567
607
  self.wrappers_paged[wrapper_id],
568
608
  req_pool_indices,
569
609
  paged_kernel_lens,
610
+ paged_kernel_lens_sum,
570
611
  seq_lens,
571
612
  prefix_lens,
572
613
  kv_start_idx,
@@ -579,19 +620,22 @@ class FlashInferIndicesUpdaterPrefill:
579
620
  self,
580
621
  wrapper_ragged,
581
622
  wrapper_paged,
582
- req_pool_indices,
583
- paged_kernel_lens,
584
- seq_lens,
585
- prefix_lens,
586
- kv_start_idx,
587
- kv_indptr,
588
- qo_indptr,
589
- use_ragged,
623
+ req_pool_indices: torch.Tensor,
624
+ paged_kernel_lens: torch.Tensor,
625
+ paged_kernel_lens_sum: int,
626
+ seq_lens: torch.Tensor,
627
+ prefix_lens: torch.Tensor,
628
+ kv_start_idx: torch.Tensor,
629
+ kv_indptr: torch.Tensor,
630
+ qo_indptr: torch.Tensor,
631
+ use_ragged: bool,
590
632
  ):
591
633
  bs = len(req_pool_indices)
592
634
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
593
635
  kv_indptr = kv_indptr[: bs + 1]
594
- kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
636
+ kv_indices = torch.empty(
637
+ paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
638
+ )
595
639
  create_flashinfer_kv_indices_triton[(bs,)](
596
640
  self.req_to_token,
597
641
  req_pool_indices,
@@ -599,7 +643,7 @@ class FlashInferIndicesUpdaterPrefill:
599
643
  kv_indptr,
600
644
  kv_start_idx,
601
645
  kv_indices,
602
- self.max_context_len,
646
+ self.req_to_token.shape[1],
603
647
  )
604
648
 
605
649
  qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
@@ -638,10 +682,11 @@ def create_flashinfer_kv_indices_triton(
638
682
  kv_indptr,
639
683
  kv_start_idx,
640
684
  kv_indices_ptr,
641
- max_context_len: tl.constexpr,
685
+ req_to_token_ptr_stride: tl.constexpr,
642
686
  ):
643
687
  BLOCK_SIZE: tl.constexpr = 512
644
688
  pid = tl.program_id(axis=0)
689
+
645
690
  req_pool_index = tl.load(req_pool_indices_ptr + pid)
646
691
  kv_indices_offset = tl.load(kv_indptr + pid)
647
692
 
@@ -652,15 +697,15 @@ def create_flashinfer_kv_indices_triton(
652
697
  kv_end = kv_start
653
698
  kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
654
699
 
655
- req_to_token_ptr += req_pool_index * max_context_len
656
- kv_indices_ptr += kv_indices_offset
657
-
658
- ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
659
- st_offset = tl.arange(0, BLOCK_SIZE)
660
700
  num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
661
- for _ in range(num_loop):
662
- mask = ld_offset < kv_end
663
- data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
664
- tl.store(kv_indices_ptr + st_offset, data, mask=mask)
665
- ld_offset += BLOCK_SIZE
666
- st_offset += BLOCK_SIZE
701
+ for i in range(num_loop):
702
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
703
+ mask = offset < kv_end - kv_start
704
+ data = tl.load(
705
+ req_to_token_ptr
706
+ + req_pool_index * req_to_token_ptr_stride
707
+ + kv_start
708
+ + offset,
709
+ mask=mask,
710
+ )
711
+ tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING
4
4
 
5
5
  import torch
6
- import torch.nn as nn
7
6
 
8
7
  from sglang.srt.layers.attention import AttentionBackend
9
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -28,9 +27,13 @@ class TritonAttnBackend(AttentionBackend):
28
27
 
29
28
  self.decode_attention_fwd = decode_attention_fwd
30
29
  self.extend_attention_fwd = extend_attention_fwd
31
- self.num_head = (
32
- model_runner.model_config.num_attention_heads // model_runner.tp_size
33
- )
30
+
31
+ if model_runner.server_args.enable_dp_attention:
32
+ self.num_head = model_runner.model_config.num_attention_heads
33
+ else:
34
+ self.num_head = (
35
+ model_runner.model_config.num_attention_heads // model_runner.tp_size
36
+ )
34
37
 
35
38
  if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
36
39
  self.reduce_dtype = torch.float32
@@ -50,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
50
53
  start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
51
54
  start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
52
55
 
53
- total_num_tokens = torch.sum(forward_batch.seq_lens).item()
56
+ total_num_tokens = forward_batch.seq_lens_sum
54
57
  attn_logits = torch.empty(
55
58
  (self.num_head, total_num_tokens),
56
59
  dtype=self.reduce_dtype,
@@ -61,8 +64,7 @@ class TritonAttnBackend(AttentionBackend):
61
64
  max_extend_len = None
62
65
  else:
63
66
  start_loc = attn_logits = max_seq_len = None
64
- prefix_lens = forward_batch.extend_prefix_lens
65
- max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
67
+ max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
66
68
 
67
69
  self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
68
70
 
@@ -0,0 +1,26 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ from vllm.model_executor.custom_op import CustomOp
17
+
18
+
19
+ def register_custom_op(op_name):
20
+ def decorator(cls):
21
+ if hasattr(CustomOp, "register"):
22
+ return CustomOp.register(op_name)(cls)
23
+ else:
24
+ return cls
25
+
26
+ return decorator
@@ -250,9 +250,12 @@ def invoke_fused_moe_kernel(
250
250
  assert topk_weights.stride(1) == 1
251
251
  assert sorted_token_ids.stride(0) == 1
252
252
 
253
+ padded_size = padding_size
253
254
  if not use_fp8:
254
255
  assert A_scale is None
255
256
  assert B_scale is None
257
+ # MOE_PADDING FP8 only
258
+ padded_size = 0
256
259
  else:
257
260
  A, A_scale = ops.scaled_fp8_quant(A, A_scale)
258
261
  assert B_scale is not None
@@ -262,7 +265,7 @@ def invoke_fused_moe_kernel(
262
265
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
263
266
  )
264
267
 
265
- K = B.shape[2] - padding_size
268
+ K = B.shape[2] - padded_size
266
269
  if K % config["BLOCK_SIZE_K"] == 0:
267
270
  even_ks = True
268
271
  else:
@@ -279,7 +282,7 @@ def invoke_fused_moe_kernel(
279
282
  expert_ids,
280
283
  num_tokens_post_padded,
281
284
  B.shape[1],
282
- B.shape[2] - padding_size,
285
+ B.shape[2] - padded_size,
283
286
  sorted_token_ids.shape[0],
284
287
  topk_ids.numel(),
285
288
  A.stride(0),
@@ -480,8 +483,12 @@ def fused_experts(
480
483
  a1_scale: Optional[torch.Tensor] = None,
481
484
  a2_scale: Optional[torch.Tensor] = None,
482
485
  ):
486
+ padded_size = padding_size
487
+ if not use_fp8:
488
+ # MOE_PADDING FP8 only
489
+ padded_size = 0
483
490
  # Check constraints.
484
- assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch"
491
+ assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
485
492
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
486
493
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
487
494
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -498,7 +505,7 @@ def fused_experts(
498
505
  get_config_func = functools.partial(
499
506
  try_get_optimal_moe_config,
500
507
  w1.shape,
501
- (w2.shape[0], w2.shape[1], w2.shape[2] - padding_size),
508
+ (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
502
509
  topk_ids.shape[1],
503
510
  "float8" if use_fp8 else None,
504
511
  override_config=override_config,
@@ -33,9 +33,12 @@ if is_flashinfer_available():
33
33
 
34
34
  from vllm.model_executor.custom_op import CustomOp
35
35
 
36
+ from sglang.srt.layers.custom_op_util import register_custom_op
37
+
36
38
  logger = logging.getLogger(__name__)
37
39
 
38
40
 
41
+ @register_custom_op("sglang_rmsnorm")
39
42
  class RMSNorm(CustomOp):
40
43
  def __init__(
41
44
  self,
@@ -78,6 +81,7 @@ class RMSNorm(CustomOp):
78
81
  return x, residual
79
82
 
80
83
 
84
+ @register_custom_op("sglang_gemma_rmsnorm")
81
85
  class GemmaRMSNorm(CustomOp):
82
86
  def __init__(
83
87
  self,
@@ -62,21 +62,21 @@ class LogitsMetadata:
62
62
 
63
63
  @classmethod
64
64
  def from_forward_batch(cls, forward_batch: ForwardBatch):
65
+ extend_logprob_pruned_lens_cpu = None
66
+
65
67
  if forward_batch.return_logprob:
66
68
  return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
69
+ if forward_batch.forward_mode.is_extend():
70
+ extend_logprob_pruned_lens_cpu = [
71
+ extend_len - start_len
72
+ for extend_len, start_len in zip(
73
+ forward_batch.extend_seq_lens_cpu,
74
+ forward_batch.extend_logprob_start_lens_cpu,
75
+ )
76
+ ]
67
77
  else:
68
78
  return_top_logprob = False
69
79
 
70
- if forward_batch.forward_mode.is_extend():
71
- extend_logprob_pruned_lens_cpu = [
72
- extend_len - start_len
73
- for extend_len, start_len in zip(
74
- forward_batch.extend_seq_lens,
75
- forward_batch.extend_logprob_start_lens_cpu,
76
- )
77
- ]
78
- else:
79
- extend_logprob_pruned_lens_cpu = None
80
80
  return cls(
81
81
  forward_mode=forward_batch.forward_mode,
82
82
  top_logprobs_nums=forward_batch.top_logprobs_nums,
@@ -1,5 +1,4 @@
1
1
  import logging
2
- import os
3
2
  from typing import Union
4
3
 
5
4
  import torch
@@ -8,7 +7,7 @@ from torch import nn
8
7
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
9
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
10
9
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
11
- from sglang.srt.utils import is_flashinfer_available
10
+ from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
12
11
 
13
12
  if is_flashinfer_available():
14
13
  from flashinfer.sampling import (
@@ -19,17 +18,13 @@ if is_flashinfer_available():
19
18
  )
20
19
 
21
20
 
22
- # Crash on warning if we are running CI tests
23
- crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
24
-
25
-
26
21
  logger = logging.getLogger(__name__)
27
22
 
28
23
 
29
24
  class Sampler(nn.Module):
30
25
  def __init__(self):
31
26
  super().__init__()
32
- self.use_nan_detectioin = not global_server_args_dict["disable_nan_detection"]
27
+ self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
33
28
 
34
29
  def forward(
35
30
  self,
@@ -46,7 +41,8 @@ class Sampler(nn.Module):
46
41
  logits = torch.where(
47
42
  torch.isnan(logits), torch.full_like(logits, -1e5), logits
48
43
  )
49
- exit(1) if crash_on_warning else None
44
+ if crash_on_warnings():
45
+ raise ValueError("Detected errors during sampling! NaN in the logits.")
50
46
 
51
47
  if sampling_info.is_all_greedy:
52
48
  # Use torch.argmax if all requests use greedy sampling
@@ -62,6 +62,8 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
62
62
  granularity=GRANULARITY_MAP[granularity]
63
63
  ),
64
64
  )
65
+ else:
66
+ raise ValueError(f"Unexpected config: {torchao_config}")
65
67
 
66
68
  return dummy_linear.weight
67
69