sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,6 @@ from dataclasses import dataclass
33
33
  from enum import IntEnum, auto
34
34
  from typing import TYPE_CHECKING, List, Optional, Union
35
35
 
36
- import numpy as np
37
36
  import torch
38
37
  import triton
39
38
  import triton.language as tl
@@ -72,14 +71,14 @@ class ForwardMode(IntEnum):
72
71
  DUMMY_FIRST = auto()
73
72
 
74
73
  def is_prefill(self):
75
- return self == ForwardMode.PREFILL
74
+ return self.is_extend()
76
75
 
77
76
  def is_extend(self):
78
77
  return (
79
78
  self == ForwardMode.EXTEND
80
79
  or self == ForwardMode.MIXED
81
80
  or self == ForwardMode.DRAFT_EXTEND
82
- or self == self.TARGET_VERIFY
81
+ or self == ForwardMode.TARGET_VERIFY
83
82
  )
84
83
 
85
84
  def is_decode(self):
@@ -97,6 +96,13 @@ class ForwardMode(IntEnum):
97
96
  def is_draft_extend(self):
98
97
  return self == ForwardMode.DRAFT_EXTEND
99
98
 
99
+ def is_extend_or_draft_extend_or_mixed(self):
100
+ return (
101
+ self == ForwardMode.EXTEND
102
+ or self == ForwardMode.DRAFT_EXTEND
103
+ or self == ForwardMode.MIXED
104
+ )
105
+
100
106
  def is_cuda_graph(self):
101
107
  return (
102
108
  self == ForwardMode.DECODE
@@ -104,9 +110,6 @@ class ForwardMode(IntEnum):
104
110
  or self == ForwardMode.IDLE
105
111
  )
106
112
 
107
- def is_extend_or_draft_extend(self):
108
- return self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND
109
-
110
113
  def is_dummy_first(self):
111
114
  return self == ForwardMode.DUMMY_FIRST
112
115
 
@@ -178,6 +181,28 @@ class ForwardBatch:
178
181
  extend_logprob_start_lens_cpu: Optional[List[int]] = None
179
182
  extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
180
183
 
184
+ # For MLA chunked prefix cache used in chunked prefill
185
+ # Tell attention backend whether the kv cache needs to be attended in current pass
186
+ attn_attend_prefix_cache: Optional[bool] = None
187
+ # Number of prefix cache chunks
188
+ num_prefix_chunks: Optional[int] = None
189
+ # Index of current chunk, used by attention backend
190
+ prefix_chunk_idx: Optional[int] = None
191
+ # Maximum number of tokens in each chunk per sequence. Computed from maximum chunk capacity
192
+ prefix_chunk_len: Optional[int] = None
193
+ # Start positions of prefix cache for each chunk, (num_prefix_chunks, batch_size)
194
+ prefix_chunk_starts: Optional[torch.Tensor] = None
195
+ # Lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size)
196
+ prefix_chunk_seq_lens: Optional[torch.Tensor] = None
197
+ # Accumulated lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size + 1)
198
+ prefix_chunk_cu_seq_lens: Optional[torch.Tensor] = None
199
+ # Max lengths of prefix cache for each chunk, (num_prefix_chunks,)
200
+ prefix_chunk_max_seq_lens: Optional[List[int]] = None
201
+ # Number of tokens in each prefix cache chunk, (num_prefix_chunks,)
202
+ prefix_chunk_num_tokens: Optional[List[int]] = None
203
+ # KV Indices for each chunk
204
+ prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
205
+
181
206
  # For multimodal
182
207
  mm_inputs: Optional[List[MultimodalInputs]] = None
183
208
 
@@ -399,13 +424,13 @@ class ForwardBatch:
399
424
  )
400
425
  elif self.forward_mode.is_extend():
401
426
  extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
402
- for i, multimodal_inputs in enumerate(batch.multimodal_inputs):
427
+ for i, mm_input in enumerate(batch.multimodal_inputs):
403
428
  extend_start_loc, extend_seq_len, extend_prefix_len = (
404
429
  extend_start_loc_cpu[i],
405
430
  batch.extend_seq_lens[i],
406
431
  batch.extend_prefix_lens[i],
407
432
  )
408
- if multimodal_inputs is None:
433
+ if mm_input is None:
409
434
  # text only
410
435
  mrope_positions = [
411
436
  [
@@ -416,23 +441,58 @@ class ForwardBatch:
416
441
  ]
417
442
  ] * 3
418
443
  else:
444
+ image_grid_thws_list = [
445
+ item.image_grid_thws
446
+ for item in mm_input.mm_items
447
+ if item.image_grid_thws is not None
448
+ ]
449
+ image_grid_thw = (
450
+ None
451
+ if len(image_grid_thws_list) == 0
452
+ else torch.cat(image_grid_thws_list, dim=0)
453
+ )
454
+
455
+ video_grid_thws_list = [
456
+ item.video_grid_thws
457
+ for item in mm_input.mm_items
458
+ if item.video_grid_thws is not None
459
+ ]
460
+ video_grid_thw = (
461
+ None
462
+ if len(video_grid_thws_list) == 0
463
+ else torch.cat(video_grid_thws_list, dim=0)
464
+ )
465
+
466
+ second_per_grid_ts_list = [
467
+ item.second_per_grid_ts
468
+ for item in mm_input.mm_items
469
+ if item.second_per_grid_ts is not None
470
+ ]
471
+ second_per_grid_ts = (
472
+ None
473
+ if len(second_per_grid_ts_list) == 0
474
+ else torch.cat(second_per_grid_ts_list, dim=0)
475
+ )
476
+
419
477
  # TODO: current qwen2-vl do not support radix cache since mrope position calculation
420
478
  mrope_positions, mrope_position_delta = (
421
479
  MRotaryEmbedding.get_input_positions(
422
480
  input_tokens=self.input_ids[
423
481
  extend_start_loc : extend_start_loc + extend_seq_len
424
- ],
425
- image_grid_thw=multimodal_inputs.image_grid_thws,
426
- video_grid_thw=multimodal_inputs.video_grid_thws,
427
- image_token_id=multimodal_inputs.im_token_id,
428
- video_token_id=multimodal_inputs.video_token_id,
482
+ ].tolist(),
483
+ image_grid_thw=image_grid_thw,
484
+ video_grid_thw=video_grid_thw,
485
+ image_token_id=hf_config.image_token_id,
486
+ video_token_id=hf_config.video_token_id,
429
487
  vision_start_token_id=hf_config.vision_start_token_id,
430
488
  vision_end_token_id=hf_config.vision_end_token_id,
431
489
  spatial_merge_size=hf_config.vision_config.spatial_merge_size,
432
490
  context_len=0,
433
491
  seq_len=len(self.input_ids),
434
- second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
435
- tokens_per_second=hf_config.vision_config.tokens_per_second,
492
+ second_per_grid_ts=second_per_grid_ts,
493
+ tokens_per_second=getattr(
494
+ hf_config.vision_config, "tokens_per_second", None
495
+ ),
436
496
  )
437
497
  )
438
498
  batch.multimodal_inputs[i].mrope_position_delta = (
@@ -446,6 +506,128 @@ class ForwardBatch:
446
506
  )
447
507
  self.mrope_positions = self.mrope_positions.to(torch.int64)
448
508
 
509
+ def get_max_chunk_capacity(self):
510
+ # Maximum number of tokens in each chunk
511
+ # TODO: Should be changed to a better value, maybe passed through server args
512
+ return 128 * 1024
513
+
514
+ def set_prefix_chunk_idx(self, idx: int):
515
+ self.prefix_chunk_idx = idx
516
+
517
+ def set_attn_attend_prefix_cache(self, attn_attend_prefix_cache: bool):
518
+ self.attn_attend_prefix_cache = attn_attend_prefix_cache
519
+
520
+ def prepare_chunked_kv_indices(self, device: torch.device):
521
+ self.prefix_chunk_kv_indices = []
522
+ for idx in range(self.num_prefix_chunks):
523
+ chunk_starts = self.prefix_chunk_starts[idx]
524
+ chunk_seq_lens = self.prefix_chunk_seq_lens[idx]
525
+ chunk_cu_seq_lens = self.prefix_chunk_cu_seq_lens[idx]
526
+ num_chunk_tokens = self.prefix_chunk_num_tokens[idx]
527
+
528
+ chunk_kv_indices = torch.empty(
529
+ num_chunk_tokens, dtype=torch.int32, device=device
530
+ )
531
+
532
+ create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
533
+ self.req_to_token_pool.req_to_token,
534
+ self.req_pool_indices,
535
+ chunk_starts,
536
+ chunk_seq_lens,
537
+ chunk_cu_seq_lens,
538
+ chunk_kv_indices,
539
+ self.req_to_token_pool.req_to_token.shape[1],
540
+ )
541
+ self.prefix_chunk_kv_indices.append(chunk_kv_indices)
542
+
543
+ # Here we suppose the length of each chunk is equal
544
+ # For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
545
+ # num_prefix_chunks = cdiv(1024, 256) = 4
546
+ # prefix_chunk_starts = [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512], [768, 768, 768, 768]]
547
+ # prefix_chunk_ends = [[256, 256, 256, 256], [256, 512, 512, 512], [256, 512, 768, 768], [256, 512, 768, 1024]]
548
+ # prefix_chunk_seq_lens = [[256, 256, 256, 256], [0, 256, 256, 256], [0, 0, 256, 256], [0, 0, 0, 256]]
549
+ # TODO: Implement a better way to allocate chunk lengths that uses memory spaces more efficiently.
550
+ def get_prefix_chunk_seq_lens(
551
+ self, prefix_lens: torch.Tensor, num_prefix_chunks: int, prefix_chunk_len: int
552
+ ):
553
+ device = prefix_lens.device
554
+ prefix_chunk_starts = (
555
+ torch.arange(num_prefix_chunks, device=device, dtype=torch.int32)
556
+ .unsqueeze(1)
557
+ .expand(-1, self.batch_size)
558
+ * prefix_chunk_len
559
+ )
560
+ prefix_chunk_ends = torch.min(
561
+ prefix_lens.unsqueeze(0),
562
+ prefix_chunk_starts + prefix_chunk_len,
563
+ ).to(torch.int32)
564
+
565
+ prefix_chunk_seq_lens = (
566
+ (prefix_chunk_ends - prefix_chunk_starts).clamp(min=0).to(torch.int32)
567
+ )
568
+
569
+ return prefix_chunk_starts, prefix_chunk_seq_lens
570
+
571
+ # Called before each attention module if using chunked kv cache for prefill
572
+ # Some of the codes are adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
573
+ def prepare_chunked_prefix_cache_info(self, device: torch.device):
574
+
575
+ from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
576
+
577
+ assert isinstance(
578
+ self.token_to_kv_pool, MLATokenToKVPool
579
+ ), "Currently chunked prefix cache can only be used by Deepseek models"
580
+
581
+ if self.prefix_chunk_len is not None:
582
+ # Chunked kv cache info already prepared by prior modules
583
+ return
584
+
585
+ self.prefix_chunk_idx = -1
586
+
587
+ # chunk_capacity is the maximum number of tokens in each chunk
588
+ chunk_capacity = self.get_max_chunk_capacity()
589
+ self.prefix_chunk_len = chunk_capacity // self.batch_size
590
+
591
+ self.num_prefix_chunks = (
592
+ max(self.extend_prefix_lens_cpu) + self.prefix_chunk_len - 1
593
+ ) // self.prefix_chunk_len
594
+
595
+ # Here we compute chunk lens twice to avoid stream sync, once on gpu and once on cpu.
596
+ prefix_chunk_starts_cuda, prefix_chunk_seq_lens_cuda = (
597
+ self.get_prefix_chunk_seq_lens(
598
+ self.extend_prefix_lens,
599
+ self.num_prefix_chunks,
600
+ self.prefix_chunk_len,
601
+ )
602
+ )
603
+ _, prefix_chunk_seq_lens_cpu = self.get_prefix_chunk_seq_lens(
604
+ torch.tensor(self.extend_prefix_lens_cpu),
605
+ self.num_prefix_chunks,
606
+ self.prefix_chunk_len,
607
+ )
608
+ self.prefix_chunk_starts = prefix_chunk_starts_cuda
609
+ self.prefix_chunk_seq_lens = prefix_chunk_seq_lens_cuda
610
+
611
+ # Metadata for attention backend
612
+ self.prefix_chunk_cu_seq_lens = torch.zeros(
613
+ self.num_prefix_chunks,
614
+ self.batch_size + 1,
615
+ device=device,
616
+ dtype=torch.int32,
617
+ )
618
+ self.prefix_chunk_cu_seq_lens[:, 1:] = prefix_chunk_seq_lens_cuda.cumsum(
619
+ dim=1
620
+ ).to(torch.int32)
621
+ self.prefix_chunk_max_seq_lens = prefix_chunk_seq_lens_cpu.max(
622
+ dim=1
623
+ ).values.tolist()
624
+
625
+ self.prefix_chunk_num_tokens = prefix_chunk_seq_lens_cpu.sum(dim=1).tolist()
626
+ assert max(self.prefix_chunk_num_tokens) <= self.get_max_chunk_capacity()
627
+
628
+ # Precompute the kv indices for each chunk
629
+ self.prepare_chunked_kv_indices(device)
630
+
449
631
 
450
632
  def compute_position_triton(
451
633
  extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
@@ -523,3 +705,40 @@ def compute_position_torch(
523
705
  @torch.compile(dynamic=True, backend=get_compiler_backend())
524
706
  def clamp_position(seq_lens):
525
707
  return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
708
+
709
+
710
+ @triton.jit
711
+ def create_chunked_prefix_cache_kv_indices(
712
+ req_to_token_ptr, # (max_batch, max_context_len,)
713
+ req_pool_indices_ptr, # (batch_size,)
714
+ chunk_start_idx_ptr, # (batch_size,)
715
+ chunk_seq_lens_ptr, # (batch_size,)
716
+ chunk_cu_seq_lens_ptr, # (batch_size + 1,)
717
+ chunk_kv_indices_ptr, # (num_chunk_tokens,)
718
+ req_to_token_ptr_stride: tl.constexpr,
719
+ ):
720
+ BLOCK_SIZE: tl.constexpr = 512
721
+ pid = tl.program_id(axis=0)
722
+
723
+ # find the req pool idx, this is for batch to token
724
+ req_pool_index = tl.load(req_pool_indices_ptr + pid)
725
+ chunk_kv_indices_offset = tl.load(chunk_cu_seq_lens_ptr + pid)
726
+
727
+ # get the token positions of current chunk
728
+ chunk_start_pos = tl.load(chunk_start_idx_ptr + pid).to(tl.int32)
729
+ chunk_seq_len = tl.load(chunk_seq_lens_ptr + pid).to(tl.int32)
730
+
731
+ num_loop = tl.cdiv(chunk_seq_len, BLOCK_SIZE)
732
+ for i in range(num_loop):
733
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
734
+ mask = offset < chunk_seq_len
735
+ data = tl.load(
736
+ req_to_token_ptr
737
+ + req_pool_index * req_to_token_ptr_stride
738
+ + chunk_start_pos
739
+ + offset,
740
+ mask=mask,
741
+ )
742
+ tl.store(
743
+ chunk_kv_indices_ptr + chunk_kv_indices_offset + offset, data, mask=mask
744
+ )
@@ -73,10 +73,14 @@ from sglang.srt.utils import (
73
73
  MultiprocessingSerializer,
74
74
  enable_show_time_cost,
75
75
  get_available_gpu_memory,
76
+ get_bool_env_var,
76
77
  init_custom_process_group,
77
78
  is_cuda,
79
+ is_fa3_default_architecture,
78
80
  is_flashinfer_available,
79
81
  is_hip,
82
+ is_hopper_with_cuda_12_3,
83
+ is_no_spec_infer_or_topk_one,
80
84
  monkey_patch_p2p_access_check,
81
85
  monkey_patch_vllm_gguf_config,
82
86
  set_cpu_offload_max_bytes,
@@ -124,10 +128,7 @@ class ModelRunner:
124
128
  self.page_size = server_args.page_size
125
129
  self.req_to_token_pool = req_to_token_pool
126
130
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
127
- self.use_mla_backend = (
128
- self.model_config.attention_arch == AttentionArch.MLA
129
- and not server_args.disable_mla
130
- )
131
+ self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
131
132
  self.attention_chunk_size = model_config.attention_chunk_size
132
133
 
133
134
  # Model-specific adjustment
@@ -136,18 +137,12 @@ class ModelRunner:
136
137
  if server_args.show_time_cost:
137
138
  enable_show_time_cost()
138
139
 
139
- if server_args.disable_outlines_disk_cache:
140
- from outlines.caching import disable_cache
141
-
142
- disable_cache()
143
-
144
140
  # Global vars
145
141
  global_server_args_dict.update(
146
142
  {
147
143
  "attention_backend": server_args.attention_backend,
148
144
  "sampling_backend": server_args.sampling_backend,
149
145
  "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
150
- "disable_mla": server_args.disable_mla,
151
146
  "torchao_config": server_args.torchao_config,
152
147
  "enable_nan_detection": server_args.enable_nan_detection,
153
148
  "enable_dp_attention": server_args.enable_dp_attention,
@@ -157,13 +152,13 @@ class ModelRunner:
157
152
  "device": server_args.device,
158
153
  "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
159
154
  "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
160
- "enable_flashmla": server_args.enable_flashmla,
161
155
  "disable_radix_cache": server_args.disable_radix_cache,
162
156
  "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
157
+ "moe_dense_tp_size": server_args.moe_dense_tp_size,
163
158
  "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
164
159
  "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
165
160
  "n_share_experts_fusion": server_args.n_share_experts_fusion,
166
- "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
161
+ "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
167
162
  "use_mla_backend": self.use_mla_backend,
168
163
  }
169
164
  )
@@ -225,29 +220,38 @@ class ModelRunner:
225
220
  def model_specific_adjustment(self):
226
221
  server_args = self.server_args
227
222
 
228
- if server_args.enable_flashinfer_mla:
229
- # TODO: remove this branch after enable_flashinfer_mla is deprecated
230
- logger.info("MLA optimization is turned on. Use flashinfer backend.")
231
- server_args.attention_backend = "flashinfer"
232
- elif server_args.enable_flashmla:
233
- # TODO: remove this branch after enable_flashmla is deprecated
234
- logger.info("MLA optimization is turned on. Use flashmla decode.")
235
- server_args.attention_backend = "flashmla"
236
- elif server_args.attention_backend is None:
223
+ if server_args.attention_backend is None:
237
224
  # By default, use flashinfer for non-mla attention and triton for mla attention
238
225
  if not self.use_mla_backend:
239
- server_args.attention_backend = (
240
- "flashinfer" if is_flashinfer_available() else "triton"
241
- )
226
+ if (
227
+ is_hopper_with_cuda_12_3()
228
+ and is_no_spec_infer_or_topk_one(server_args)
229
+ and is_fa3_default_architecture(self.model_config.hf_config)
230
+ ):
231
+ server_args.attention_backend = "fa3"
232
+ else:
233
+ server_args.attention_backend = (
234
+ "flashinfer" if is_flashinfer_available() else "triton"
235
+ )
242
236
  else:
243
- server_args.attention_backend = "triton"
237
+ if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
238
+ server_args
239
+ ):
240
+ server_args.attention_backend = "fa3"
241
+ else:
242
+ server_args.attention_backend = "triton"
244
243
  logger.info(
245
244
  f"Attention backend not set. Use {server_args.attention_backend} backend by default."
246
245
  )
247
246
  elif self.use_mla_backend:
248
247
  # TODO: add MLA optimization on CPU
249
248
  if server_args.device != "cpu":
250
- if server_args.attention_backend in ["flashinfer", "fa3", "triton"]:
249
+ if server_args.attention_backend in [
250
+ "flashinfer",
251
+ "fa3",
252
+ "triton",
253
+ "flashmla",
254
+ ]:
251
255
  logger.info(
252
256
  f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
253
257
  )
@@ -258,6 +262,16 @@ class ModelRunner:
258
262
  else:
259
263
  raise ValueError(f"MLA optimization not supported on CPU.")
260
264
 
265
+ if (
266
+ server_args.attention_backend == "fa3"
267
+ and server_args.kv_cache_dtype == "fp8_e5m2"
268
+ ):
269
+ logger.warning(
270
+ "FlashAttention3 only supports fp8_e4m3 if using FP8; "
271
+ "Setting attention backend to triton."
272
+ )
273
+ server_args.attention_backend = "triton"
274
+
261
275
  if server_args.enable_double_sparsity:
262
276
  logger.info(
263
277
  "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
@@ -276,7 +290,6 @@ class ModelRunner:
276
290
  f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
277
291
  f"because this is a multimodal model."
278
292
  )
279
-
280
293
  logger.info(
281
294
  "Automatically turn off --chunked-prefill-size for multimodal model."
282
295
  )
@@ -294,6 +307,15 @@ class ModelRunner:
294
307
  if server_args.enable_deepep_moe:
295
308
  logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
296
309
 
310
+ if not self.use_mla_backend:
311
+ server_args.disable_chunked_prefix_cache = True
312
+ elif self.page_size > 1:
313
+ logger.info("Disable chunked prefix cache when page size > 1.")
314
+ server_args.disable_chunked_prefix_cache = True
315
+
316
+ if not server_args.disable_chunked_prefix_cache:
317
+ logger.info("Chunked prefix cache is turned on.")
318
+
297
319
  def init_torch_distributed(self):
298
320
  logger.info("Init torch distributed begin.")
299
321
 
@@ -352,10 +374,16 @@ class ModelRunner:
352
374
  local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
353
375
  if self.tp_size > 1:
354
376
  if min_per_gpu_memory < local_gpu_memory * 0.9:
355
- raise ValueError(
356
- "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
357
- f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
358
- )
377
+ if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
378
+ logger.warning(
379
+ "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
380
+ f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
381
+ )
382
+ else:
383
+ raise ValueError(
384
+ "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
385
+ f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
386
+ )
359
387
 
360
388
  logger.info(
361
389
  f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
@@ -885,9 +913,6 @@ class ModelRunner:
885
913
  "FlashAttention v3 Backend requires SM>=90. "
886
914
  "Please use `--attention-backend flashinfer`."
887
915
  )
888
- logger.warning(
889
- "FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
890
- )
891
916
  from sglang.srt.layers.attention.flashattention_backend import (
892
917
  FlashAttentionBackend,
893
918
  )
@@ -924,6 +949,12 @@ class ModelRunner:
924
949
  return
925
950
 
926
951
  if self.server_args.disable_cuda_graph:
952
+ logger.warning(
953
+ "\n\nCUDA Graph is DISABLED.\n"
954
+ "This will cause significant performance degradation.\n"
955
+ "CUDA Graph should almost never be disabled in most usage scenarios.\n"
956
+ "If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.\n"
957
+ )
927
958
  return
928
959
 
929
960
  tic = time.time()
@@ -1060,7 +1091,8 @@ class ModelRunner:
1060
1091
  rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
1061
1092
  if rope_scaling is None:
1062
1093
  return False
1063
- return rope_scaling.get("type", None) == "mrope"
1094
+ is_mrope_enabled = "mrope_section" in rope_scaling
1095
+ return is_mrope_enabled
1064
1096
 
1065
1097
  def save_remote_model(self, url: str):
1066
1098
  from sglang.srt.model_loader.loader import RemoteModelLoader
@@ -108,11 +108,15 @@ logger = logging.getLogger(__name__)
108
108
 
109
109
 
110
110
  def _get_quantization_config(
111
- model_config: ModelConfig, load_config: LoadConfig
111
+ model_config: ModelConfig,
112
+ load_config: LoadConfig,
113
+ packed_modules_mapping: Dict[str, List[str]],
112
114
  ) -> Optional[QuantizationConfig]:
113
115
  """Get the quantization config."""
114
116
  if model_config.quantization is not None:
115
- quant_config = get_quant_config(model_config, load_config)
117
+ quant_config = get_quant_config(
118
+ model_config, load_config, packed_modules_mapping
119
+ )
116
120
  major, minor = get_device_capability()
117
121
 
118
122
  if major is not None and minor is not None:
@@ -142,7 +146,10 @@ def _initialize_model(
142
146
  ) -> nn.Module:
143
147
  """Initialize a model with the given configurations."""
144
148
  model_class, _ = get_model_architecture(model_config)
145
- quant_config = _get_quantization_config(model_config, load_config)
149
+ packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
150
+ quant_config = _get_quantization_config(
151
+ model_config, load_config, packed_modules_mapping
152
+ )
146
153
  return model_class(
147
154
  config=model_config.hf_config,
148
155
  quant_config=quant_config,
@@ -1064,19 +1071,37 @@ class BitsAndBytesModelLoader(BaseModelLoader):
1064
1071
 
1065
1072
  param_dict = dict(model.named_parameters())
1066
1073
  stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
1074
+ model_type = model_config.hf_config.model_type
1067
1075
  for quant_param_name in quant_state_dict:
1068
1076
  non_stacked_param_name = quant_param_name
1069
-
1077
+ if model_type == "mllama" and "vision_model" in quant_param_name:
1078
+ # adapt to VisionAttention
1079
+ quant_param_name = quant_param_name.replace(
1080
+ "self_attn.o_proj", "self_attn.proj"
1081
+ )
1070
1082
  shard_index = 0
1071
1083
  for shard_name, (
1072
1084
  weight_name,
1073
1085
  index,
1074
1086
  ) in model.bitsandbytes_stacked_params_mapping.items():
1087
+ if (
1088
+ model_type in ["qwen2_vl", "qwen2_5_vl"]
1089
+ and "visual" in quant_param_name
1090
+ ):
1091
+ break
1075
1092
  if shard_name in quant_param_name:
1076
1093
  shard_index = index
1077
1094
  quant_param_name = quant_param_name.replace(shard_name, weight_name)
1078
1095
  break
1079
1096
 
1097
+ if (
1098
+ model_type in ["qwen2_vl", "qwen2_5_vl"]
1099
+ and "visual" in quant_param_name
1100
+ ):
1101
+ quant_param_name = quant_param_name.replace(
1102
+ r"attn.qkv.", r"attn.qkv_proj."
1103
+ )
1104
+
1080
1105
  if quant_param_name not in param_dict:
1081
1106
  raise ValueError(
1082
1107
  f"Parameter {quant_param_name} not found in the model."
@@ -1104,6 +1129,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
1104
1129
  num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
1105
1130
 
1106
1131
  offsets = np.concatenate(([0], np.cumsum(num_elements)))
1132
+ # Make torch infer_schema happy(Compatible with vLLM)
1133
+ offsets = torch.tensor(offsets).cpu()
1107
1134
  set_weight_attrs(param, {"bnb_shard_offsets": offsets})
1108
1135
 
1109
1136
  if load_8bit:
@@ -129,7 +129,9 @@ def convert_bin_to_safetensor_file(
129
129
 
130
130
  # TODO(woosuk): Move this to other place.
131
131
  def get_quant_config(
132
- model_config: ModelConfig, load_config: LoadConfig
132
+ model_config: ModelConfig,
133
+ load_config: LoadConfig,
134
+ packed_modules_mapping: Dict[str, List[str]],
133
135
  ) -> QuantizationConfig:
134
136
  quant_cls = get_quantization_config(model_config.quantization)
135
137
 
@@ -147,6 +149,7 @@ def get_quant_config(
147
149
  # compressed-tensors uses a compressions_config
148
150
  hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
149
151
  if hf_quant_config is not None:
152
+ hf_quant_config["packed_modules_mapping"] = packed_modules_mapping
150
153
  return quant_cls.from_config(hf_quant_config)
151
154
  # In case of bitsandbytes/QLoRA, get quant config from the adapter model.
152
155
  if model_config.quantization == "bitsandbytes":
@@ -457,7 +460,6 @@ def pt_weights_iterator(
457
460
  state = torch.load(bin_file, map_location="cpu", weights_only=True)
458
461
  yield from state.items()
459
462
  del state
460
- torch.cuda.empty_cache()
461
463
 
462
464
 
463
465
  def get_gguf_extra_tensor_names(
@@ -178,6 +178,7 @@ class BaiChuanAttention(nn.Module):
178
178
  scaling,
179
179
  num_kv_heads=self.num_kv_heads,
180
180
  layer_id=layer_id,
181
+ quant_config=quant_config,
181
182
  prefix=add_prefix("attn", prefix),
182
183
  )
183
184
  else:
@@ -194,6 +195,7 @@ class BaiChuanAttention(nn.Module):
194
195
  self.scaling,
195
196
  num_kv_heads=self.num_kv_heads,
196
197
  layer_id=layer_id,
198
+ quant_config=quant_config,
197
199
  prefix=add_prefix("attn", prefix),
198
200
  )
199
201