sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,223 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/utils.py
2
+ # Copyright 2023 The vLLM team.
3
+ # Adapted from
4
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
5
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
6
+ import dataclasses
7
+ import logging
8
+ import os
9
+ import pickle
10
+ import time
11
+ from collections import deque
12
+ from typing import Any, Deque, Dict, Optional, Sequence, Tuple
13
+
14
+ import torch
15
+ from torch.distributed import TCPStore
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def ensure_divisibility(numerator, denominator):
21
+ """Ensure that numerator is divisible by the denominator."""
22
+ assert numerator % denominator == 0, "{} is not divisible by {}".format(
23
+ numerator, denominator
24
+ )
25
+
26
+
27
+ def divide(numerator, denominator):
28
+ """Ensure that numerator is divisible by the denominator and return
29
+ the division value."""
30
+ ensure_divisibility(numerator, denominator)
31
+ return numerator // denominator
32
+
33
+
34
+ def split_tensor_along_last_dim(
35
+ tensor: torch.Tensor,
36
+ num_partitions: int,
37
+ contiguous_split_chunks: bool = False,
38
+ ) -> Sequence[torch.Tensor]:
39
+ """Split a tensor along its last dimension.
40
+
41
+ Arguments:
42
+ tensor: input tensor.
43
+ num_partitions: number of partitions to split the tensor
44
+ contiguous_split_chunks: If True, make each chunk contiguous
45
+ in memory.
46
+
47
+ Returns:
48
+ A list of Tensors
49
+ """
50
+ # Get the size and dimension.
51
+ last_dim = tensor.dim() - 1
52
+ last_dim_size = divide(tensor.size()[last_dim], num_partitions)
53
+ # Split.
54
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
55
+ # NOTE: torch.split does not create contiguous tensors by default.
56
+ if contiguous_split_chunks:
57
+ return tuple(chunk.contiguous() for chunk in tensor_list)
58
+
59
+ return tensor_list
60
+
61
+
62
+ def get_pp_indices(
63
+ num_hidden_layers: int, pp_rank: int, pp_size: int
64
+ ) -> Tuple[int, int]:
65
+ """Try to evenly distribute layers across partitions.
66
+ If the number of layers is not divisible by the number of partitions,
67
+ the last partition will have the remaining layers.
68
+ """
69
+ # partition_list_str can be set to None in sglang
70
+ partition_list_str = os.getenv("SGLANG_PP_LAYER_PARTITION", None)
71
+ if partition_list_str is not None:
72
+ try:
73
+ partitions = [int(layer) for layer in partition_list_str.split(",")]
74
+ except ValueError as err:
75
+ raise ValueError(
76
+ "Invalid partition string: {}".format(partition_list_str)
77
+ ) from err
78
+ if len(partitions) != pp_size:
79
+ raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
80
+ if sum(partitions) != num_hidden_layers:
81
+ raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
82
+ start_layer = sum(partitions[:pp_rank])
83
+ end_layer = start_layer + partitions[pp_rank]
84
+ else:
85
+ layers_per_partition = num_hidden_layers // pp_size
86
+ start_layer = pp_rank * layers_per_partition
87
+ end_layer = start_layer + layers_per_partition
88
+
89
+ if pp_rank == pp_size - 1:
90
+ end_layer = num_hidden_layers
91
+
92
+ return (start_layer, end_layer)
93
+
94
+
95
+ @dataclasses.dataclass
96
+ class StatelessProcessGroup:
97
+ """A dataclass to hold a metadata store, and the rank, world_size of the
98
+ group. Only use it to communicate metadata between processes.
99
+ For data-plane communication, create NCCL-related objects.
100
+ """
101
+
102
+ rank: int
103
+ world_size: int
104
+ store: torch._C._distributed_c10d.Store
105
+ data_expiration_seconds: int = 3600 # 1 hour
106
+
107
+ # dst rank -> counter
108
+ send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
109
+ # src rank -> counter
110
+ recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
111
+ broadcast_send_counter: int = 0
112
+ broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
113
+
114
+ # A deque to store the data entries, with key and timestamp.
115
+ entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque)
116
+
117
+ def __post_init__(self):
118
+ assert self.rank < self.world_size
119
+ self.send_dst_counter = {i: 0 for i in range(self.world_size)}
120
+ self.recv_src_counter = {i: 0 for i in range(self.world_size)}
121
+ self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
122
+
123
+ def send_obj(self, obj: Any, dst: int):
124
+ """Send an object to a destination rank."""
125
+ self.expire_data()
126
+ key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
127
+ self.store.set(key, pickle.dumps(obj))
128
+ self.send_dst_counter[dst] += 1
129
+ self.entries.append((key, time.time()))
130
+
131
+ def expire_data(self):
132
+ """Expire data that is older than `data_expiration_seconds` seconds."""
133
+ while self.entries:
134
+ # check the oldest entry
135
+ key, timestamp = self.entries[0]
136
+ if time.time() - timestamp > self.data_expiration_seconds:
137
+ self.store.delete_key(key)
138
+ self.entries.popleft()
139
+ else:
140
+ break
141
+
142
+ def recv_obj(self, src: int) -> Any:
143
+ """Receive an object from a source rank."""
144
+ obj = pickle.loads(
145
+ self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}")
146
+ )
147
+ self.recv_src_counter[src] += 1
148
+ return obj
149
+
150
+ def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
151
+ """Broadcast an object from a source rank to all other ranks.
152
+ It does not clean up after all ranks have received the object.
153
+ Use it for limited times, e.g., for initialization.
154
+ """
155
+ if self.rank == src:
156
+ self.expire_data()
157
+ key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}"
158
+ self.store.set(key, pickle.dumps(obj))
159
+ self.broadcast_send_counter += 1
160
+ self.entries.append((key, time.time()))
161
+ return obj
162
+ else:
163
+ key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}"
164
+ recv_obj = pickle.loads(self.store.get(key))
165
+ self.broadcast_recv_src_counter[src] += 1
166
+ return recv_obj
167
+
168
+ def all_gather_obj(self, obj: Any) -> list[Any]:
169
+ """All gather an object from all ranks."""
170
+ gathered_objs = []
171
+ for i in range(self.world_size):
172
+ if i == self.rank:
173
+ gathered_objs.append(obj)
174
+ self.broadcast_obj(obj, src=self.rank)
175
+ else:
176
+ recv_obj = self.broadcast_obj(None, src=i)
177
+ gathered_objs.append(recv_obj)
178
+ return gathered_objs
179
+
180
+ def barrier(self):
181
+ """A barrier to synchronize all ranks."""
182
+ for i in range(self.world_size):
183
+ if i == self.rank:
184
+ self.broadcast_obj(None, src=self.rank)
185
+ else:
186
+ self.broadcast_obj(None, src=i)
187
+
188
+ @staticmethod
189
+ def create(
190
+ host: str,
191
+ port: int,
192
+ rank: int,
193
+ world_size: int,
194
+ data_expiration_seconds: int = 3600,
195
+ ) -> "StatelessProcessGroup":
196
+ """A replacement for `torch.distributed.init_process_group` that does not
197
+ pollute the global state.
198
+
199
+ If we have process A and process B called `torch.distributed.init_process_group`
200
+ to form a group, and then we want to form another group with process A, B, C,
201
+ D, it is not possible in PyTorch, because process A and process B have already
202
+ formed a group, and process C and process D cannot join that group. This
203
+ function is a workaround for this issue.
204
+
205
+ `torch.distributed.init_process_group` is a global call, while this function
206
+ is a stateless call. It will return a `StatelessProcessGroup` object that can be
207
+ used for exchanging metadata. With this function, process A and process B
208
+ can call `StatelessProcessGroup.create` to form a group, and then process A, B,
209
+ C, and D can call `StatelessProcessGroup.create` to form another group.
210
+ """ # noqa
211
+ store = TCPStore(
212
+ host_name=host,
213
+ port=port,
214
+ world_size=world_size,
215
+ is_master=(rank == 0),
216
+ )
217
+
218
+ return StatelessProcessGroup(
219
+ rank=rank,
220
+ world_size=world_size,
221
+ store=store,
222
+ data_expiration_seconds=data_expiration_seconds,
223
+ )
@@ -16,6 +16,7 @@
16
16
  import contextlib
17
17
  import os
18
18
  import warnings
19
+ from pathlib import Path
19
20
  from typing import Dict, Optional, Type, Union
20
21
 
21
22
  from huggingface_hub import snapshot_download
@@ -27,6 +28,7 @@ from transformers import (
27
28
  PreTrainedTokenizer,
28
29
  PreTrainedTokenizerFast,
29
30
  )
31
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
30
32
 
31
33
  try:
32
34
  from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
@@ -60,15 +62,31 @@ def get_config(
60
62
  trust_remote_code: bool,
61
63
  revision: Optional[str] = None,
62
64
  model_override_args: Optional[dict] = None,
65
+ **kwargs,
63
66
  ):
67
+ is_gguf = check_gguf_file(model)
68
+ if is_gguf:
69
+ kwargs["gguf_file"] = model
70
+ model = Path(model).parent
71
+
64
72
  config = AutoConfig.from_pretrained(
65
- model, trust_remote_code=trust_remote_code, revision=revision
73
+ model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
66
74
  )
67
75
  if config.model_type in _CONFIG_REGISTRY:
68
76
  config_class = _CONFIG_REGISTRY[config.model_type]
69
77
  config = config_class.from_pretrained(model, revision=revision)
78
+ # NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
79
+ setattr(config, "_name_or_path", model)
70
80
  if model_override_args:
71
81
  config.update(model_override_args)
82
+
83
+ # Special architecture mapping check for GGUF models
84
+ if is_gguf:
85
+ if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
86
+ raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
87
+ model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
88
+ config.update({"architectures": [model_type]})
89
+
72
90
  return config
73
91
 
74
92
 
@@ -123,6 +141,11 @@ def get_tokenizer(
123
141
  raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
124
142
  kwargs["use_fast"] = False
125
143
 
144
+ is_gguf = check_gguf_file(tokenizer_name)
145
+ if is_gguf:
146
+ kwargs["gguf_file"] = tokenizer_name
147
+ tokenizer_name = Path(tokenizer_name).parent
148
+
126
149
  try:
127
150
  tokenizer = AutoTokenizer.from_pretrained(
128
151
  tokenizer_name,
@@ -195,3 +218,16 @@ def attach_additional_stop_token_ids(tokenizer):
195
218
  )
196
219
  else:
197
220
  tokenizer.additional_stop_token_ids = None
221
+
222
+
223
+ def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
224
+ """Check if the file is a GGUF model."""
225
+ model = Path(model)
226
+ if not model.is_file():
227
+ return False
228
+ elif model.suffix == ".gguf":
229
+ return True
230
+
231
+ with open(model, "rb") as f:
232
+ header = f.read(4)
233
+ return header == b"GGUF"
@@ -52,12 +52,13 @@ class AttentionBackend(ABC):
52
52
  v: torch.Tensor,
53
53
  layer: RadixAttention,
54
54
  forward_batch: ForwardBatch,
55
+ save_kv_cache: bool = True,
55
56
  ):
56
57
  """Run forward on an attention layer."""
57
58
  if forward_batch.forward_mode.is_decode():
58
- return self.forward_decode(q, k, v, layer, forward_batch)
59
+ return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
59
60
  else:
60
- return self.forward_extend(q, k, v, layer, forward_batch)
61
+ return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache)
61
62
 
62
63
  def forward_decode(
63
64
  self,
@@ -66,6 +67,7 @@ class AttentionBackend(ABC):
66
67
  v: torch.Tensor,
67
68
  layer: RadixAttention,
68
69
  forward_batch: ForwardBatch,
70
+ save_kv_cache: bool = True,
69
71
  ):
70
72
  """Run a forward for decode."""
71
73
  raise NotImplementedError()
@@ -77,6 +79,7 @@ class AttentionBackend(ABC):
77
79
  v: torch.Tensor,
78
80
  layer: RadixAttention,
79
81
  forward_batch: ForwardBatch,
82
+ save_kv_cache: bool = True,
80
83
  ):
81
84
  """Run a forward for extend."""
82
85
  raise NotImplementedError()
@@ -165,7 +165,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
165
165
  return 1
166
166
 
167
167
  def forward_extend(
168
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
168
+ self,
169
+ q,
170
+ k,
171
+ v,
172
+ layer: RadixAttention,
173
+ forward_batch: ForwardBatch,
174
+ save_kv_cache=True,
169
175
  ):
170
176
  # TODO: reuse the buffer across layers
171
177
  if layer.qk_head_dim != layer.v_head_dim:
@@ -181,9 +187,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
181
187
  .expand(k.shape[0], -1, -1),
182
188
  )
183
189
 
184
- forward_batch.token_to_kv_pool.set_kv_buffer(
185
- layer, forward_batch.out_cache_loc, k, v, k_label
186
- )
190
+ if save_kv_cache:
191
+ forward_batch.token_to_kv_pool.set_kv_buffer(
192
+ layer, forward_batch.out_cache_loc, k, v, k_label
193
+ )
187
194
 
188
195
  (
189
196
  start_loc,
@@ -212,7 +219,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
212
219
  return o
213
220
 
214
221
  def forward_decode(
215
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
222
+ self,
223
+ q,
224
+ k,
225
+ v,
226
+ layer: RadixAttention,
227
+ forward_batch: ForwardBatch,
228
+ save_kv_cache=True,
216
229
  ):
217
230
  # During torch.compile, there is a bug in rotary_emb that causes the
218
231
  # output value to have a 3D tensor shape. This reshapes the output correctly.
@@ -242,9 +255,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
242
255
  .expand(k.shape[0], -1, -1),
243
256
  )
244
257
 
245
- forward_batch.token_to_kv_pool.set_kv_buffer(
246
- layer, forward_batch.out_cache_loc, k, v, k_label
247
- )
258
+ if save_kv_cache:
259
+ forward_batch.token_to_kv_pool.set_kv_buffer(
260
+ layer, forward_batch.out_cache_loc, k, v, k_label
261
+ )
248
262
 
249
263
  # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
250
264
  # and set a minimum value for sparse_decode
@@ -18,7 +18,11 @@ import triton.language as tl
18
18
  from sglang.global_config import global_config
19
19
  from sglang.srt.layers.attention import AttentionBackend
20
20
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
21
- from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
21
+ from sglang.srt.utils import (
22
+ get_bool_env_var,
23
+ is_flashinfer_available,
24
+ should_use_tensor_core,
25
+ )
22
26
 
23
27
  if TYPE_CHECKING:
24
28
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -31,7 +35,6 @@ if is_flashinfer_available():
31
35
  BatchPrefillWithRaggedKVCacheWrapper,
32
36
  )
33
37
  from flashinfer.cascade import merge_state
34
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
35
38
 
36
39
 
37
40
  class WrapperDispatch(Enum):
@@ -45,19 +48,14 @@ class FlashInferAttnBackend(AttentionBackend):
45
48
  def __init__(self, model_runner: ModelRunner):
46
49
  super().__init__()
47
50
 
48
- # Parse constants
49
- if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
50
- self.decode_use_tensor_cores = get_bool_env_var(
51
- "SGLANG_FLASHINFER_USE_TENSOR_CORE"
52
- )
53
- else:
54
- if not _grouped_size_compiled_for_decode_kernels(
55
- model_runner.model_config.num_attention_heads // model_runner.tp_size,
56
- model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
57
- ):
58
- self.decode_use_tensor_cores = True
59
- else:
60
- self.decode_use_tensor_cores = False
51
+ self.decode_use_tensor_cores = should_use_tensor_core(
52
+ kv_cache_dtype=model_runner.kv_cache_dtype,
53
+ num_attention_heads=model_runner.model_config.num_attention_heads
54
+ // model_runner.tp_size,
55
+ num_kv_heads=model_runner.model_config.get_num_kv_heads(
56
+ model_runner.tp_size
57
+ ),
58
+ )
61
59
 
62
60
  self.max_context_len = model_runner.model_config.context_len
63
61
 
@@ -223,7 +221,13 @@ class FlashInferAttnBackend(AttentionBackend):
223
221
  return 0
224
222
 
225
223
  def forward_extend(
226
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
224
+ self,
225
+ q,
226
+ k,
227
+ v,
228
+ layer: RadixAttention,
229
+ forward_batch: ForwardBatch,
230
+ save_kv_cache=True,
227
231
  ):
228
232
  prefill_wrapper_paged = self.prefill_wrappers_paged[
229
233
  self._get_wrapper_idx(layer)
@@ -239,7 +243,8 @@ class FlashInferAttnBackend(AttentionBackend):
239
243
  if not use_ragged:
240
244
  if k is not None:
241
245
  assert v is not None
242
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
246
+ if save_kv_cache:
247
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
243
248
 
244
249
  o = prefill_wrapper_paged.forward(
245
250
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -272,12 +277,19 @@ class FlashInferAttnBackend(AttentionBackend):
272
277
 
273
278
  o, _ = merge_state(o1, s1, o2, s2)
274
279
 
275
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
280
+ if save_kv_cache:
281
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
276
282
 
277
283
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
278
284
 
279
285
  def forward_decode(
280
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
286
+ self,
287
+ q,
288
+ k,
289
+ v,
290
+ layer: RadixAttention,
291
+ forward_batch: ForwardBatch,
292
+ save_kv_cache=True,
281
293
  ):
282
294
  decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
283
295
  cache_loc = (
@@ -288,7 +300,8 @@ class FlashInferAttnBackend(AttentionBackend):
288
300
 
289
301
  if k is not None:
290
302
  assert v is not None
291
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
303
+ if save_kv_cache:
304
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
292
305
 
293
306
  o = decode_wrapper.forward(
294
307
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),