sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/srt/configs/model_config.py +37 -5
  4. sglang/srt/constrained/base_grammar_backend.py +26 -5
  5. sglang/srt/constrained/llguidance_backend.py +1 -0
  6. sglang/srt/constrained/outlines_backend.py +1 -0
  7. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  8. sglang/srt/constrained/xgrammar_backend.py +1 -0
  9. sglang/srt/disaggregation/base/__init__.py +8 -0
  10. sglang/srt/disaggregation/base/conn.py +113 -0
  11. sglang/srt/disaggregation/decode.py +18 -5
  12. sglang/srt/disaggregation/mini_lb.py +53 -122
  13. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  14. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  16. sglang/srt/disaggregation/prefill.py +43 -19
  17. sglang/srt/disaggregation/utils.py +31 -0
  18. sglang/srt/entrypoints/EngineBase.py +53 -0
  19. sglang/srt/entrypoints/engine.py +36 -8
  20. sglang/srt/entrypoints/http_server.py +37 -8
  21. sglang/srt/entrypoints/http_server_engine.py +142 -0
  22. sglang/srt/entrypoints/verl_engine.py +37 -10
  23. sglang/srt/hf_transformers_utils.py +4 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +330 -200
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  26. sglang/srt/layers/attention/vision.py +1 -1
  27. sglang/srt/layers/dp_attention.py +2 -4
  28. sglang/srt/layers/elementwise.py +15 -2
  29. sglang/srt/layers/linear.py +1 -0
  30. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
  38. sglang/srt/layers/moe/router.py +7 -1
  39. sglang/srt/layers/moe/topk.py +37 -16
  40. sglang/srt/layers/quantization/__init__.py +12 -5
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  42. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  43. sglang/srt/layers/quantization/fp8.py +25 -13
  44. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  45. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  46. sglang/srt/layers/quantization/kv_cache.py +43 -52
  47. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  48. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  49. sglang/srt/layers/quantization/w8a8_int8.py +1 -0
  50. sglang/srt/layers/radix_attention.py +13 -1
  51. sglang/srt/layers/rotary_embedding.py +12 -1
  52. sglang/srt/managers/io_struct.py +254 -97
  53. sglang/srt/managers/mm_utils.py +3 -2
  54. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  55. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  56. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  57. sglang/srt/managers/schedule_batch.py +62 -21
  58. sglang/srt/managers/scheduler.py +71 -14
  59. sglang/srt/managers/tokenizer_manager.py +17 -3
  60. sglang/srt/managers/tp_worker.py +1 -0
  61. sglang/srt/mem_cache/memory_pool.py +14 -1
  62. sglang/srt/metrics/collector.py +9 -0
  63. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  64. sglang/srt/model_executor/forward_batch_info.py +234 -15
  65. sglang/srt/model_executor/model_runner.py +48 -9
  66. sglang/srt/model_loader/loader.py +31 -4
  67. sglang/srt/model_loader/weight_utils.py +4 -2
  68. sglang/srt/models/baichuan.py +2 -0
  69. sglang/srt/models/chatglm.py +1 -0
  70. sglang/srt/models/commandr.py +1 -0
  71. sglang/srt/models/dbrx.py +1 -0
  72. sglang/srt/models/deepseek.py +1 -0
  73. sglang/srt/models/deepseek_v2.py +248 -61
  74. sglang/srt/models/exaone.py +1 -0
  75. sglang/srt/models/gemma.py +1 -0
  76. sglang/srt/models/gemma2.py +1 -0
  77. sglang/srt/models/gemma3_causal.py +1 -0
  78. sglang/srt/models/gpt2.py +1 -0
  79. sglang/srt/models/gpt_bigcode.py +1 -0
  80. sglang/srt/models/granite.py +1 -0
  81. sglang/srt/models/grok.py +1 -0
  82. sglang/srt/models/internlm2.py +1 -0
  83. sglang/srt/models/llama.py +1 -0
  84. sglang/srt/models/llama4.py +101 -34
  85. sglang/srt/models/minicpm.py +1 -0
  86. sglang/srt/models/minicpm3.py +2 -0
  87. sglang/srt/models/mixtral.py +1 -0
  88. sglang/srt/models/mixtral_quant.py +1 -0
  89. sglang/srt/models/mllama.py +51 -8
  90. sglang/srt/models/mllama4.py +102 -29
  91. sglang/srt/models/olmo.py +1 -0
  92. sglang/srt/models/olmo2.py +1 -0
  93. sglang/srt/models/olmoe.py +1 -0
  94. sglang/srt/models/phi3_small.py +1 -0
  95. sglang/srt/models/qwen.py +1 -0
  96. sglang/srt/models/qwen2.py +1 -0
  97. sglang/srt/models/qwen2_5_vl.py +35 -70
  98. sglang/srt/models/qwen2_moe.py +1 -0
  99. sglang/srt/models/qwen2_vl.py +27 -25
  100. sglang/srt/models/stablelm.py +1 -0
  101. sglang/srt/models/xverse.py +1 -0
  102. sglang/srt/models/xverse_moe.py +1 -0
  103. sglang/srt/openai_api/adapter.py +4 -1
  104. sglang/srt/patch_torch.py +11 -0
  105. sglang/srt/server_args.py +34 -0
  106. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  107. sglang/srt/speculative/eagle_utils.py +1 -11
  108. sglang/srt/speculative/eagle_worker.py +6 -2
  109. sglang/srt/utils.py +120 -9
  110. sglang/test/attention/test_flashattn_backend.py +259 -221
  111. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  112. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  113. sglang/test/test_block_fp8.py +57 -0
  114. sglang/test/test_utils.py +19 -8
  115. sglang/version.py +1 -1
  116. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  117. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
  118. sglang/srt/disaggregation/conn.py +0 -81
  119. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  120. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  121. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,6 @@ from dataclasses import dataclass
33
33
  from enum import IntEnum, auto
34
34
  from typing import TYPE_CHECKING, List, Optional, Union
35
35
 
36
- import numpy as np
37
36
  import torch
38
37
  import triton
39
38
  import triton.language as tl
@@ -72,14 +71,14 @@ class ForwardMode(IntEnum):
72
71
  DUMMY_FIRST = auto()
73
72
 
74
73
  def is_prefill(self):
75
- return self == ForwardMode.PREFILL
74
+ return self.is_extend()
76
75
 
77
76
  def is_extend(self):
78
77
  return (
79
78
  self == ForwardMode.EXTEND
80
79
  or self == ForwardMode.MIXED
81
80
  or self == ForwardMode.DRAFT_EXTEND
82
- or self == self.TARGET_VERIFY
81
+ or self == ForwardMode.TARGET_VERIFY
83
82
  )
84
83
 
85
84
  def is_decode(self):
@@ -97,6 +96,13 @@ class ForwardMode(IntEnum):
97
96
  def is_draft_extend(self):
98
97
  return self == ForwardMode.DRAFT_EXTEND
99
98
 
99
+ def is_extend_or_draft_extend_or_mixed(self):
100
+ return (
101
+ self == ForwardMode.EXTEND
102
+ or self == ForwardMode.DRAFT_EXTEND
103
+ or self == ForwardMode.MIXED
104
+ )
105
+
100
106
  def is_cuda_graph(self):
101
107
  return (
102
108
  self == ForwardMode.DECODE
@@ -104,9 +110,6 @@ class ForwardMode(IntEnum):
104
110
  or self == ForwardMode.IDLE
105
111
  )
106
112
 
107
- def is_extend_or_draft_extend(self):
108
- return self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND
109
-
110
113
  def is_dummy_first(self):
111
114
  return self == ForwardMode.DUMMY_FIRST
112
115
 
@@ -178,6 +181,28 @@ class ForwardBatch:
178
181
  extend_logprob_start_lens_cpu: Optional[List[int]] = None
179
182
  extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
180
183
 
184
+ # For MLA chunked prefix cache used in chunked prefill
185
+ # Tell attention backend whether the kv cache needs to be attended in current pass
186
+ attn_attend_prefix_cache: Optional[bool] = None
187
+ # Number of prefix cache chunks
188
+ num_prefix_chunks: Optional[int] = None
189
+ # Index of current chunk, used by attention backend
190
+ prefix_chunk_idx: Optional[int] = None
191
+ # Maximum number of tokens in each chunk per sequence. Computed from maximum chunk capacity
192
+ prefix_chunk_len: Optional[int] = None
193
+ # Start positions of prefix cache for each chunk, (num_prefix_chunks, batch_size)
194
+ prefix_chunk_starts: Optional[torch.Tensor] = None
195
+ # Lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size)
196
+ prefix_chunk_seq_lens: Optional[torch.Tensor] = None
197
+ # Accumulated lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size + 1)
198
+ prefix_chunk_cu_seq_lens: Optional[torch.Tensor] = None
199
+ # Max lengths of prefix cache for each chunk, (num_prefix_chunks,)
200
+ prefix_chunk_max_seq_lens: Optional[List[int]] = None
201
+ # Number of tokens in each prefix cache chunk, (num_prefix_chunks,)
202
+ prefix_chunk_num_tokens: Optional[List[int]] = None
203
+ # KV Indices for each chunk
204
+ prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
205
+
181
206
  # For multimodal
182
207
  mm_inputs: Optional[List[MultimodalInputs]] = None
183
208
 
@@ -399,13 +424,13 @@ class ForwardBatch:
399
424
  )
400
425
  elif self.forward_mode.is_extend():
401
426
  extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
402
- for i, multimodal_inputs in enumerate(batch.multimodal_inputs):
427
+ for i, mm_input in enumerate(batch.multimodal_inputs):
403
428
  extend_start_loc, extend_seq_len, extend_prefix_len = (
404
429
  extend_start_loc_cpu[i],
405
430
  batch.extend_seq_lens[i],
406
431
  batch.extend_prefix_lens[i],
407
432
  )
408
- if multimodal_inputs is None:
433
+ if mm_input is None:
409
434
  # text only
410
435
  mrope_positions = [
411
436
  [
@@ -416,23 +441,58 @@ class ForwardBatch:
416
441
  ]
417
442
  ] * 3
418
443
  else:
444
+ image_grid_thws_list = [
445
+ item.image_grid_thws
446
+ for item in mm_input.mm_items
447
+ if item.image_grid_thws is not None
448
+ ]
449
+ image_grid_thw = (
450
+ None
451
+ if len(image_grid_thws_list) == 0
452
+ else torch.cat(image_grid_thws_list, dim=0)
453
+ )
454
+
455
+ video_grid_thws_list = [
456
+ item.video_grid_thws
457
+ for item in mm_input.mm_items
458
+ if item.video_grid_thws is not None
459
+ ]
460
+ video_grid_thw = (
461
+ None
462
+ if len(video_grid_thws_list) == 0
463
+ else torch.cat(video_grid_thws_list, dim=0)
464
+ )
465
+
466
+ second_per_grid_ts_list = [
467
+ item.second_per_grid_ts
468
+ for item in mm_input.mm_items
469
+ if item.second_per_grid_ts is not None
470
+ ]
471
+ second_per_grid_ts = (
472
+ None
473
+ if len(second_per_grid_ts_list) == 0
474
+ else torch.cat(second_per_grid_ts_list, dim=0)
475
+ )
476
+
419
477
  # TODO: current qwen2-vl do not support radix cache since mrope position calculation
420
478
  mrope_positions, mrope_position_delta = (
421
479
  MRotaryEmbedding.get_input_positions(
422
480
  input_tokens=self.input_ids[
423
481
  extend_start_loc : extend_start_loc + extend_seq_len
424
- ],
425
- image_grid_thw=multimodal_inputs.image_grid_thws,
426
- video_grid_thw=multimodal_inputs.video_grid_thws,
427
- image_token_id=multimodal_inputs.im_token_id,
428
- video_token_id=multimodal_inputs.video_token_id,
482
+ ].tolist(),
483
+ image_grid_thw=image_grid_thw,
484
+ video_grid_thw=video_grid_thw,
485
+ image_token_id=hf_config.image_token_id,
486
+ video_token_id=hf_config.video_token_id,
429
487
  vision_start_token_id=hf_config.vision_start_token_id,
430
488
  vision_end_token_id=hf_config.vision_end_token_id,
431
489
  spatial_merge_size=hf_config.vision_config.spatial_merge_size,
432
490
  context_len=0,
433
491
  seq_len=len(self.input_ids),
434
- second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
435
- tokens_per_second=hf_config.vision_config.tokens_per_second,
492
+ second_per_grid_ts=second_per_grid_ts,
493
+ tokens_per_second=getattr(
494
+ hf_config.vision_config, "tokens_per_second", None
495
+ ),
436
496
  )
437
497
  )
438
498
  batch.multimodal_inputs[i].mrope_position_delta = (
@@ -446,6 +506,128 @@ class ForwardBatch:
446
506
  )
447
507
  self.mrope_positions = self.mrope_positions.to(torch.int64)
448
508
 
509
+ def get_max_chunk_capacity(self):
510
+ # Maximum number of tokens in each chunk
511
+ # TODO: Should be changed to a better value, maybe passed through server args
512
+ return 128 * 1024
513
+
514
+ def set_prefix_chunk_idx(self, idx: int):
515
+ self.prefix_chunk_idx = idx
516
+
517
+ def set_attn_attend_prefix_cache(self, attn_attend_prefix_cache: bool):
518
+ self.attn_attend_prefix_cache = attn_attend_prefix_cache
519
+
520
+ def prepare_chunked_kv_indices(self, device: torch.device):
521
+ self.prefix_chunk_kv_indices = []
522
+ for idx in range(self.num_prefix_chunks):
523
+ chunk_starts = self.prefix_chunk_starts[idx]
524
+ chunk_seq_lens = self.prefix_chunk_seq_lens[idx]
525
+ chunk_cu_seq_lens = self.prefix_chunk_cu_seq_lens[idx]
526
+ num_chunk_tokens = self.prefix_chunk_num_tokens[idx]
527
+
528
+ chunk_kv_indices = torch.empty(
529
+ num_chunk_tokens, dtype=torch.int32, device=device
530
+ )
531
+
532
+ create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
533
+ self.req_to_token_pool.req_to_token,
534
+ self.req_pool_indices,
535
+ chunk_starts,
536
+ chunk_seq_lens,
537
+ chunk_cu_seq_lens,
538
+ chunk_kv_indices,
539
+ self.req_to_token_pool.req_to_token.shape[1],
540
+ )
541
+ self.prefix_chunk_kv_indices.append(chunk_kv_indices)
542
+
543
+ # Here we suppose the length of each chunk is equal
544
+ # For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
545
+ # num_prefix_chunks = cdiv(1024, 256) = 4
546
+ # prefix_chunk_starts = [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512], [768, 768, 768, 768]]
547
+ # prefix_chunk_ends = [[256, 256, 256, 256], [256, 512, 512, 512], [256, 512, 768, 768], [256, 512, 768, 1024]]
548
+ # prefix_chunk_seq_lens = [[256, 256, 256, 256], [0, 256, 256, 256], [0, 0, 256, 256], [0, 0, 0, 256]]
549
+ # TODO: Implement a better way to allocate chunk lengths that uses memory spaces more efficiently.
550
+ def get_prefix_chunk_seq_lens(
551
+ self, prefix_lens: torch.Tensor, num_prefix_chunks: int, prefix_chunk_len: int
552
+ ):
553
+ device = prefix_lens.device
554
+ prefix_chunk_starts = (
555
+ torch.arange(num_prefix_chunks, device=device, dtype=torch.int32)
556
+ .unsqueeze(1)
557
+ .expand(-1, self.batch_size)
558
+ * prefix_chunk_len
559
+ )
560
+ prefix_chunk_ends = torch.min(
561
+ prefix_lens.unsqueeze(0),
562
+ prefix_chunk_starts + prefix_chunk_len,
563
+ ).to(torch.int32)
564
+
565
+ prefix_chunk_seq_lens = (
566
+ (prefix_chunk_ends - prefix_chunk_starts).clamp(min=0).to(torch.int32)
567
+ )
568
+
569
+ return prefix_chunk_starts, prefix_chunk_seq_lens
570
+
571
+ # Called before each attention module if using chunked kv cache for prefill
572
+ # Some of the codes are adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
573
+ def prepare_chunked_prefix_cache_info(self, device: torch.device):
574
+
575
+ from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
576
+
577
+ assert isinstance(
578
+ self.token_to_kv_pool, MLATokenToKVPool
579
+ ), "Currently chunked prefix cache can only be used by Deepseek models"
580
+
581
+ if self.prefix_chunk_len is not None:
582
+ # Chunked kv cache info already prepared by prior modules
583
+ return
584
+
585
+ self.prefix_chunk_idx = -1
586
+
587
+ # chunk_capacity is the maximum number of tokens in each chunk
588
+ chunk_capacity = self.get_max_chunk_capacity()
589
+ self.prefix_chunk_len = chunk_capacity // self.batch_size
590
+
591
+ self.num_prefix_chunks = (
592
+ max(self.extend_prefix_lens_cpu) + self.prefix_chunk_len - 1
593
+ ) // self.prefix_chunk_len
594
+
595
+ # Here we compute chunk lens twice to avoid stream sync, once on gpu and once on cpu.
596
+ prefix_chunk_starts_cuda, prefix_chunk_seq_lens_cuda = (
597
+ self.get_prefix_chunk_seq_lens(
598
+ self.extend_prefix_lens,
599
+ self.num_prefix_chunks,
600
+ self.prefix_chunk_len,
601
+ )
602
+ )
603
+ _, prefix_chunk_seq_lens_cpu = self.get_prefix_chunk_seq_lens(
604
+ torch.tensor(self.extend_prefix_lens_cpu),
605
+ self.num_prefix_chunks,
606
+ self.prefix_chunk_len,
607
+ )
608
+ self.prefix_chunk_starts = prefix_chunk_starts_cuda
609
+ self.prefix_chunk_seq_lens = prefix_chunk_seq_lens_cuda
610
+
611
+ # Metadata for attention backend
612
+ self.prefix_chunk_cu_seq_lens = torch.zeros(
613
+ self.num_prefix_chunks,
614
+ self.batch_size + 1,
615
+ device=device,
616
+ dtype=torch.int32,
617
+ )
618
+ self.prefix_chunk_cu_seq_lens[:, 1:] = prefix_chunk_seq_lens_cuda.cumsum(
619
+ dim=1
620
+ ).to(torch.int32)
621
+ self.prefix_chunk_max_seq_lens = prefix_chunk_seq_lens_cpu.max(
622
+ dim=1
623
+ ).values.tolist()
624
+
625
+ self.prefix_chunk_num_tokens = prefix_chunk_seq_lens_cpu.sum(dim=1).tolist()
626
+ assert max(self.prefix_chunk_num_tokens) <= self.get_max_chunk_capacity()
627
+
628
+ # Precompute the kv indices for each chunk
629
+ self.prepare_chunked_kv_indices(device)
630
+
449
631
 
450
632
  def compute_position_triton(
451
633
  extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
@@ -523,3 +705,40 @@ def compute_position_torch(
523
705
  @torch.compile(dynamic=True, backend=get_compiler_backend())
524
706
  def clamp_position(seq_lens):
525
707
  return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
708
+
709
+
710
+ @triton.jit
711
+ def create_chunked_prefix_cache_kv_indices(
712
+ req_to_token_ptr, # (max_batch, max_context_len,)
713
+ req_pool_indices_ptr, # (batch_size,)
714
+ chunk_start_idx_ptr, # (batch_size,)
715
+ chunk_seq_lens_ptr, # (batch_size,)
716
+ chunk_cu_seq_lens_ptr, # (batch_size + 1,)
717
+ chunk_kv_indices_ptr, # (num_chunk_tokens,)
718
+ req_to_token_ptr_stride: tl.constexpr,
719
+ ):
720
+ BLOCK_SIZE: tl.constexpr = 512
721
+ pid = tl.program_id(axis=0)
722
+
723
+ # find the req pool idx, this is for batch to token
724
+ req_pool_index = tl.load(req_pool_indices_ptr + pid)
725
+ chunk_kv_indices_offset = tl.load(chunk_cu_seq_lens_ptr + pid)
726
+
727
+ # get the token positions of current chunk
728
+ chunk_start_pos = tl.load(chunk_start_idx_ptr + pid).to(tl.int32)
729
+ chunk_seq_len = tl.load(chunk_seq_lens_ptr + pid).to(tl.int32)
730
+
731
+ num_loop = tl.cdiv(chunk_seq_len, BLOCK_SIZE)
732
+ for i in range(num_loop):
733
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
734
+ mask = offset < chunk_seq_len
735
+ data = tl.load(
736
+ req_to_token_ptr
737
+ + req_pool_index * req_to_token_ptr_stride
738
+ + chunk_start_pos
739
+ + offset,
740
+ mask=mask,
741
+ )
742
+ tl.store(
743
+ chunk_kv_indices_ptr + chunk_kv_indices_offset + offset, data, mask=mask
744
+ )
@@ -75,8 +75,11 @@ from sglang.srt.utils import (
75
75
  get_available_gpu_memory,
76
76
  init_custom_process_group,
77
77
  is_cuda,
78
+ is_fa3_default_architecture,
78
79
  is_flashinfer_available,
79
80
  is_hip,
81
+ is_hopper_with_cuda_12_3,
82
+ is_no_spec_infer_or_topk_one,
80
83
  monkey_patch_p2p_access_check,
81
84
  monkey_patch_vllm_gguf_config,
82
85
  set_cpu_offload_max_bytes,
@@ -164,6 +167,7 @@ class ModelRunner:
164
167
  "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
165
168
  "n_share_experts_fusion": server_args.n_share_experts_fusion,
166
169
  "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
170
+ "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
167
171
  "use_mla_backend": self.use_mla_backend,
168
172
  }
169
173
  )
@@ -236,11 +240,23 @@ class ModelRunner:
236
240
  elif server_args.attention_backend is None:
237
241
  # By default, use flashinfer for non-mla attention and triton for mla attention
238
242
  if not self.use_mla_backend:
239
- server_args.attention_backend = (
240
- "flashinfer" if is_flashinfer_available() else "triton"
241
- )
243
+ if (
244
+ is_hopper_with_cuda_12_3()
245
+ and is_no_spec_infer_or_topk_one(server_args)
246
+ and is_fa3_default_architecture(self.model_config.hf_config)
247
+ ):
248
+ server_args.attention_backend = "fa3"
249
+ else:
250
+ server_args.attention_backend = (
251
+ "flashinfer" if is_flashinfer_available() else "triton"
252
+ )
242
253
  else:
243
- server_args.attention_backend = "triton"
254
+ if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
255
+ server_args
256
+ ):
257
+ server_args.attention_backend = "fa3"
258
+ else:
259
+ server_args.attention_backend = "triton"
244
260
  logger.info(
245
261
  f"Attention backend not set. Use {server_args.attention_backend} backend by default."
246
262
  )
@@ -258,6 +274,16 @@ class ModelRunner:
258
274
  else:
259
275
  raise ValueError(f"MLA optimization not supported on CPU.")
260
276
 
277
+ if (
278
+ server_args.attention_backend == "fa3"
279
+ and server_args.kv_cache_dtype == "fp8_e5m2"
280
+ ):
281
+ logger.warning(
282
+ "FlashAttention3 only supports fp8_e4m3 if using FP8; "
283
+ "Setting attention backend to triton."
284
+ )
285
+ server_args.attention_backend = "triton"
286
+
261
287
  if server_args.enable_double_sparsity:
262
288
  logger.info(
263
289
  "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
@@ -276,7 +302,6 @@ class ModelRunner:
276
302
  f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
277
303
  f"because this is a multimodal model."
278
304
  )
279
-
280
305
  logger.info(
281
306
  "Automatically turn off --chunked-prefill-size for multimodal model."
282
307
  )
@@ -294,6 +319,16 @@ class ModelRunner:
294
319
  if server_args.enable_deepep_moe:
295
320
  logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
296
321
 
322
+ if not self.use_mla_backend:
323
+ logger.info("Disable chunked prefix cache for non-MLA backend.")
324
+ server_args.disable_chunked_prefix_cache = True
325
+ elif self.page_size > 1:
326
+ logger.info("Disable chunked prefix cache when page size > 1.")
327
+ server_args.disable_chunked_prefix_cache = True
328
+
329
+ if not server_args.disable_chunked_prefix_cache:
330
+ logger.info("Chunked prefix cache is turned on.")
331
+
297
332
  def init_torch_distributed(self):
298
333
  logger.info("Init torch distributed begin.")
299
334
 
@@ -885,9 +920,6 @@ class ModelRunner:
885
920
  "FlashAttention v3 Backend requires SM>=90. "
886
921
  "Please use `--attention-backend flashinfer`."
887
922
  )
888
- logger.warning(
889
- "FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
890
- )
891
923
  from sglang.srt.layers.attention.flashattention_backend import (
892
924
  FlashAttentionBackend,
893
925
  )
@@ -924,6 +956,12 @@ class ModelRunner:
924
956
  return
925
957
 
926
958
  if self.server_args.disable_cuda_graph:
959
+ logger.warning(
960
+ "\n\nCUDA Graph is DISABLED.\n"
961
+ "This will cause significant performance degradation.\n"
962
+ "CUDA Graph should almost never be disabled in most usage scenarios.\n"
963
+ "If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.\n"
964
+ )
927
965
  return
928
966
 
929
967
  tic = time.time()
@@ -1060,7 +1098,8 @@ class ModelRunner:
1060
1098
  rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
1061
1099
  if rope_scaling is None:
1062
1100
  return False
1063
- return rope_scaling.get("type", None) == "mrope"
1101
+ is_mrope_enabled = "mrope_section" in rope_scaling
1102
+ return is_mrope_enabled
1064
1103
 
1065
1104
  def save_remote_model(self, url: str):
1066
1105
  from sglang.srt.model_loader.loader import RemoteModelLoader
@@ -108,11 +108,15 @@ logger = logging.getLogger(__name__)
108
108
 
109
109
 
110
110
  def _get_quantization_config(
111
- model_config: ModelConfig, load_config: LoadConfig
111
+ model_config: ModelConfig,
112
+ load_config: LoadConfig,
113
+ packed_modules_mapping: Dict[str, List[str]],
112
114
  ) -> Optional[QuantizationConfig]:
113
115
  """Get the quantization config."""
114
116
  if model_config.quantization is not None:
115
- quant_config = get_quant_config(model_config, load_config)
117
+ quant_config = get_quant_config(
118
+ model_config, load_config, packed_modules_mapping
119
+ )
116
120
  major, minor = get_device_capability()
117
121
 
118
122
  if major is not None and minor is not None:
@@ -142,7 +146,10 @@ def _initialize_model(
142
146
  ) -> nn.Module:
143
147
  """Initialize a model with the given configurations."""
144
148
  model_class, _ = get_model_architecture(model_config)
145
- quant_config = _get_quantization_config(model_config, load_config)
149
+ packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
150
+ quant_config = _get_quantization_config(
151
+ model_config, load_config, packed_modules_mapping
152
+ )
146
153
  return model_class(
147
154
  config=model_config.hf_config,
148
155
  quant_config=quant_config,
@@ -1064,19 +1071,37 @@ class BitsAndBytesModelLoader(BaseModelLoader):
1064
1071
 
1065
1072
  param_dict = dict(model.named_parameters())
1066
1073
  stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
1074
+ model_type = model_config.hf_config.model_type
1067
1075
  for quant_param_name in quant_state_dict:
1068
1076
  non_stacked_param_name = quant_param_name
1069
-
1077
+ if model_type == "mllama" and "vision_model" in quant_param_name:
1078
+ # adapt to VisionAttention
1079
+ quant_param_name = quant_param_name.replace(
1080
+ "self_attn.o_proj", "self_attn.proj"
1081
+ )
1070
1082
  shard_index = 0
1071
1083
  for shard_name, (
1072
1084
  weight_name,
1073
1085
  index,
1074
1086
  ) in model.bitsandbytes_stacked_params_mapping.items():
1087
+ if (
1088
+ model_type in ["qwen2_vl", "qwen2_5_vl"]
1089
+ and "visual" in quant_param_name
1090
+ ):
1091
+ break
1075
1092
  if shard_name in quant_param_name:
1076
1093
  shard_index = index
1077
1094
  quant_param_name = quant_param_name.replace(shard_name, weight_name)
1078
1095
  break
1079
1096
 
1097
+ if (
1098
+ model_type in ["qwen2_vl", "qwen2_5_vl"]
1099
+ and "visual" in quant_param_name
1100
+ ):
1101
+ quant_param_name = quant_param_name.replace(
1102
+ r"attn.qkv.", r"attn.qkv_proj."
1103
+ )
1104
+
1080
1105
  if quant_param_name not in param_dict:
1081
1106
  raise ValueError(
1082
1107
  f"Parameter {quant_param_name} not found in the model."
@@ -1104,6 +1129,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
1104
1129
  num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
1105
1130
 
1106
1131
  offsets = np.concatenate(([0], np.cumsum(num_elements)))
1132
+ # Make torch infer_schema happy(Compatible with vLLM)
1133
+ offsets = torch.tensor(offsets).cpu()
1107
1134
  set_weight_attrs(param, {"bnb_shard_offsets": offsets})
1108
1135
 
1109
1136
  if load_8bit:
@@ -129,7 +129,9 @@ def convert_bin_to_safetensor_file(
129
129
 
130
130
  # TODO(woosuk): Move this to other place.
131
131
  def get_quant_config(
132
- model_config: ModelConfig, load_config: LoadConfig
132
+ model_config: ModelConfig,
133
+ load_config: LoadConfig,
134
+ packed_modules_mapping: Dict[str, List[str]],
133
135
  ) -> QuantizationConfig:
134
136
  quant_cls = get_quantization_config(model_config.quantization)
135
137
 
@@ -147,6 +149,7 @@ def get_quant_config(
147
149
  # compressed-tensors uses a compressions_config
148
150
  hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
149
151
  if hf_quant_config is not None:
152
+ hf_quant_config["packed_modules_mapping"] = packed_modules_mapping
150
153
  return quant_cls.from_config(hf_quant_config)
151
154
  # In case of bitsandbytes/QLoRA, get quant config from the adapter model.
152
155
  if model_config.quantization == "bitsandbytes":
@@ -457,7 +460,6 @@ def pt_weights_iterator(
457
460
  state = torch.load(bin_file, map_location="cpu", weights_only=True)
458
461
  yield from state.items()
459
462
  del state
460
- torch.cuda.empty_cache()
461
463
 
462
464
 
463
465
  def get_gguf_extra_tensor_names(
@@ -178,6 +178,7 @@ class BaiChuanAttention(nn.Module):
178
178
  scaling,
179
179
  num_kv_heads=self.num_kv_heads,
180
180
  layer_id=layer_id,
181
+ quant_config=quant_config,
181
182
  prefix=add_prefix("attn", prefix),
182
183
  )
183
184
  else:
@@ -194,6 +195,7 @@ class BaiChuanAttention(nn.Module):
194
195
  self.scaling,
195
196
  num_kv_heads=self.num_kv_heads,
196
197
  layer_id=layer_id,
198
+ quant_config=quant_config,
197
199
  prefix=add_prefix("attn", prefix),
198
200
  )
199
201
 
@@ -113,6 +113,7 @@ class GLMAttention(nn.Module):
113
113
  self.scaling,
114
114
  num_kv_heads=self.num_kv_heads,
115
115
  layer_id=layer_id,
116
+ quant_config=quant_config,
116
117
  prefix=add_prefix("attn", prefix),
117
118
  )
118
119
 
@@ -204,6 +204,7 @@ class CohereAttention(nn.Module):
204
204
  self.scaling,
205
205
  num_kv_heads=self.num_kv_heads,
206
206
  layer_id=layer_id,
207
+ quant_config=quant_config,
207
208
  prefix=add_prefix("attn", prefix),
208
209
  )
209
210
  if self.use_qk_norm:
sglang/srt/models/dbrx.py CHANGED
@@ -249,6 +249,7 @@ class DbrxAttention(nn.Module):
249
249
  self.scaling,
250
250
  num_kv_heads=self.num_kv_heads,
251
251
  layer_id=layer_id,
252
+ quant_config=quant_config,
252
253
  prefix=add_prefix("attn", prefix),
253
254
  )
254
255
 
@@ -255,6 +255,7 @@ class DeepseekAttention(nn.Module):
255
255
  self.scaling,
256
256
  num_kv_heads=self.num_kv_heads,
257
257
  layer_id=layer_id,
258
+ quant_config=quant_config,
258
259
  prefix=add_prefix("attn", prefix),
259
260
  )
260
261