sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +234 -74
  4. sglang/check_env.py +25 -2
  5. sglang/global_config.py +0 -1
  6. sglang/lang/backend/base_backend.py +3 -1
  7. sglang/lang/backend/openai.py +8 -3
  8. sglang/lang/backend/runtime_endpoint.py +46 -40
  9. sglang/lang/choices.py +164 -0
  10. sglang/lang/interpreter.py +6 -13
  11. sglang/lang/ir.py +11 -2
  12. sglang/srt/hf_transformers_utils.py +2 -2
  13. sglang/srt/layers/extend_attention.py +59 -7
  14. sglang/srt/layers/logits_processor.py +1 -1
  15. sglang/srt/layers/radix_attention.py +24 -14
  16. sglang/srt/layers/token_attention.py +28 -2
  17. sglang/srt/managers/io_struct.py +9 -4
  18. sglang/srt/managers/schedule_batch.py +98 -323
  19. sglang/srt/managers/tokenizer_manager.py +34 -16
  20. sglang/srt/managers/tp_worker.py +20 -22
  21. sglang/srt/mem_cache/memory_pool.py +74 -38
  22. sglang/srt/model_config.py +11 -0
  23. sglang/srt/model_executor/cuda_graph_runner.py +3 -3
  24. sglang/srt/model_executor/forward_batch_info.py +256 -0
  25. sglang/srt/model_executor/model_runner.py +51 -26
  26. sglang/srt/models/chatglm.py +1 -1
  27. sglang/srt/models/commandr.py +1 -1
  28. sglang/srt/models/dbrx.py +1 -1
  29. sglang/srt/models/deepseek.py +1 -1
  30. sglang/srt/models/deepseek_v2.py +199 -17
  31. sglang/srt/models/gemma.py +1 -1
  32. sglang/srt/models/gemma2.py +1 -1
  33. sglang/srt/models/gpt_bigcode.py +1 -1
  34. sglang/srt/models/grok.py +1 -1
  35. sglang/srt/models/internlm2.py +1 -1
  36. sglang/srt/models/llama2.py +1 -1
  37. sglang/srt/models/llama_classification.py +1 -1
  38. sglang/srt/models/llava.py +1 -2
  39. sglang/srt/models/llavavid.py +1 -2
  40. sglang/srt/models/minicpm.py +1 -1
  41. sglang/srt/models/mixtral.py +1 -1
  42. sglang/srt/models/mixtral_quant.py +1 -1
  43. sglang/srt/models/qwen.py +1 -1
  44. sglang/srt/models/qwen2.py +1 -1
  45. sglang/srt/models/qwen2_moe.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/openai_api/adapter.py +151 -29
  48. sglang/srt/openai_api/protocol.py +7 -1
  49. sglang/srt/server.py +111 -84
  50. sglang/srt/server_args.py +12 -2
  51. sglang/srt/utils.py +25 -20
  52. sglang/test/run_eval.py +21 -10
  53. sglang/test/runners.py +237 -0
  54. sglang/test/simple_eval_common.py +12 -12
  55. sglang/test/simple_eval_gpqa.py +92 -0
  56. sglang/test/simple_eval_humaneval.py +5 -5
  57. sglang/test/simple_eval_math.py +72 -0
  58. sglang/test/test_utils.py +95 -14
  59. sglang/utils.py +15 -37
  60. sglang/version.py +1 -1
  61. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
  62. sglang-0.2.11.dist-info/RECORD +102 -0
  63. sglang-0.2.9.post1.dist-info/RECORD +0 -97
  64. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
  65. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
  66. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,24 @@
1
1
  import json
2
2
  from typing import List, Optional
3
3
 
4
- import numpy as np
5
-
6
4
  from sglang.global_config import global_config
7
5
  from sglang.lang.backend.base_backend import BaseBackend
8
6
  from sglang.lang.chat_template import get_chat_template_by_model_path
7
+ from sglang.lang.choices import (
8
+ ChoicesDecision,
9
+ ChoicesSamplingMethod,
10
+ token_length_normalized,
11
+ )
9
12
  from sglang.lang.interpreter import StreamExecutor
10
13
  from sglang.lang.ir import SglSamplingParams
11
14
  from sglang.utils import http_request
12
15
 
13
16
 
14
17
  class RuntimeEndpoint(BaseBackend):
18
+
15
19
  def __init__(
16
20
  self,
17
21
  base_url: str,
18
- auth_token: Optional[str] = None,
19
22
  api_key: Optional[str] = None,
20
23
  verify: Optional[str] = None,
21
24
  ):
@@ -23,13 +26,11 @@ class RuntimeEndpoint(BaseBackend):
23
26
  self.support_concate_and_append = True
24
27
 
25
28
  self.base_url = base_url
26
- self.auth_token = auth_token
27
29
  self.api_key = api_key
28
30
  self.verify = verify
29
31
 
30
32
  res = http_request(
31
33
  self.base_url + "/get_model_info",
32
- auth_token=self.auth_token,
33
34
  api_key=self.api_key,
34
35
  verify=self.verify,
35
36
  )
@@ -46,7 +47,7 @@ class RuntimeEndpoint(BaseBackend):
46
47
  def flush_cache(self):
47
48
  res = http_request(
48
49
  self.base_url + "/flush_cache",
49
- auth_token=self.auth_token,
50
+ api_key=self.api_key,
50
51
  verify=self.verify,
51
52
  )
52
53
  self._assert_success(res)
@@ -54,7 +55,7 @@ class RuntimeEndpoint(BaseBackend):
54
55
  def get_server_args(self):
55
56
  res = http_request(
56
57
  self.base_url + "/get_server_args",
57
- auth_token=self.auth_token,
58
+ api_key=self.api_key,
58
59
  verify=self.verify,
59
60
  )
60
61
  self._assert_success(res)
@@ -67,7 +68,6 @@ class RuntimeEndpoint(BaseBackend):
67
68
  res = http_request(
68
69
  self.base_url + "/generate",
69
70
  json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
70
- auth_token=self.auth_token,
71
71
  api_key=self.api_key,
72
72
  verify=self.verify,
73
73
  )
@@ -79,7 +79,6 @@ class RuntimeEndpoint(BaseBackend):
79
79
  res = http_request(
80
80
  self.base_url + "/generate",
81
81
  json=data,
82
- auth_token=self.auth_token,
83
82
  api_key=self.api_key,
84
83
  verify=self.verify,
85
84
  )
@@ -91,7 +90,6 @@ class RuntimeEndpoint(BaseBackend):
91
90
  res = http_request(
92
91
  self.base_url + "/generate",
93
92
  json=data,
94
- auth_token=self.auth_token,
95
93
  api_key=self.api_key,
96
94
  verify=self.verify,
97
95
  )
@@ -139,7 +137,6 @@ class RuntimeEndpoint(BaseBackend):
139
137
  res = http_request(
140
138
  self.base_url + "/generate",
141
139
  json=data,
142
- auth_token=self.auth_token,
143
140
  api_key=self.api_key,
144
141
  verify=self.verify,
145
142
  )
@@ -193,7 +190,6 @@ class RuntimeEndpoint(BaseBackend):
193
190
  self.base_url + "/generate",
194
191
  json=data,
195
192
  stream=True,
196
- auth_token=self.auth_token,
197
193
  api_key=self.api_key,
198
194
  verify=self.verify,
199
195
  )
@@ -216,21 +212,14 @@ class RuntimeEndpoint(BaseBackend):
216
212
  s: StreamExecutor,
217
213
  choices: List[str],
218
214
  temperature: float,
219
- ):
215
+ choices_method: ChoicesSamplingMethod,
216
+ ) -> ChoicesDecision:
220
217
  assert temperature <= 1e-5
221
218
 
222
219
  # Cache common prefix
223
220
  data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
224
- self._add_images(s, data)
225
- res = http_request(
226
- self.base_url + "/generate",
227
- json=data,
228
- auth_token=self.auth_token,
229
- api_key=self.api_key,
230
- verify=self.verify,
231
- )
232
- self._assert_success(res)
233
- prompt_len = res.json()["meta_info"]["prompt_tokens"]
221
+ obj = self._generate_http_request(s, data)
222
+ prompt_len = obj["meta_info"]["prompt_tokens"]
234
223
 
235
224
  # Compute logprob
236
225
  data = {
@@ -239,40 +228,57 @@ class RuntimeEndpoint(BaseBackend):
239
228
  "return_logprob": True,
240
229
  "logprob_start_len": max(prompt_len - 2, 0),
241
230
  }
242
- self._add_images(s, data)
243
- res = http_request(
244
- self.base_url + "/generate",
245
- json=data,
246
- auth_token=self.auth_token,
247
- api_key=self.api_key,
248
- verify=self.verify,
249
- )
250
- self._assert_success(res)
251
- obj = res.json()
231
+ obj = self._generate_http_request(s, data)
232
+
252
233
  normalized_prompt_logprobs = [
253
234
  r["meta_info"]["normalized_prompt_logprob"] for r in obj
254
235
  ]
255
- decision = choices[np.argmax(normalized_prompt_logprobs)]
256
236
  input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
257
237
  output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
258
238
 
259
- return (
260
- decision,
261
- normalized_prompt_logprobs,
262
- input_token_logprobs,
263
- output_token_logprobs,
239
+ # Compute unconditional logprobs if required
240
+ if choices_method.requires_unconditional_logprobs:
241
+ input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
242
+ data = {
243
+ "input_ids": input_ids,
244
+ "sampling_params": {"max_new_tokens": 0},
245
+ "return_logprob": True,
246
+ }
247
+ obj = self._generate_http_request(s, data)
248
+ unconditional_token_logprobs = [
249
+ r["meta_info"]["input_token_logprobs"] for r in obj
250
+ ]
251
+ else:
252
+ unconditional_token_logprobs = None
253
+
254
+ return choices_method(
255
+ choices=choices,
256
+ normalized_prompt_logprobs=normalized_prompt_logprobs,
257
+ input_token_logprobs=input_token_logprobs,
258
+ output_token_logprobs=output_token_logprobs,
259
+ unconditional_token_logprobs=unconditional_token_logprobs,
264
260
  )
265
261
 
266
262
  def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
267
263
  res = http_request(
268
264
  self.base_url + "/concate_and_append_request",
269
265
  json={"src_rids": src_rids, "dst_rid": dst_rid},
270
- auth_token=self.auth_token,
271
266
  api_key=self.api_key,
272
267
  verify=self.verify,
273
268
  )
274
269
  self._assert_success(res)
275
270
 
271
+ def _generate_http_request(self, s: StreamExecutor, data):
272
+ self._add_images(s, data)
273
+ res = http_request(
274
+ self.base_url + "/generate",
275
+ json=data,
276
+ api_key=self.api_key,
277
+ verify=self.verify,
278
+ )
279
+ self._assert_success(res)
280
+ return res.json()
281
+
276
282
  def _add_images(self, s: StreamExecutor, data):
277
283
  if s.images_:
278
284
  assert len(s.images_) == 1, "Only support one image."
sglang/lang/choices.py ADDED
@@ -0,0 +1,164 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import numpy as np
6
+
7
+
8
+ @dataclass
9
+ class ChoicesDecision:
10
+ decision: str
11
+ meta_info: Optional[Dict[str, Any]] = None
12
+
13
+
14
+ class ChoicesSamplingMethod(ABC):
15
+
16
+ @property
17
+ def requires_unconditional_logprobs(self) -> bool:
18
+ return False
19
+
20
+ @abstractmethod
21
+ def __call__(
22
+ self,
23
+ *,
24
+ choices: List[str],
25
+ normalized_prompt_logprobs: List[float],
26
+ input_token_logprobs: List[List[Any]],
27
+ output_token_logprobs: List[List[Any]],
28
+ unconditional_token_logprobs: Optional[List[List[Any]]] = None,
29
+ ) -> ChoicesDecision: ...
30
+
31
+
32
+ class TokenLengthNormalized(ChoicesSamplingMethod):
33
+
34
+ def __call__(
35
+ self,
36
+ *,
37
+ choices: List[str],
38
+ normalized_prompt_logprobs: List[float],
39
+ input_token_logprobs: List[List[Any]],
40
+ output_token_logprobs: List[List[Any]],
41
+ unconditional_token_logprobs: Optional[List[List[Any]]] = None,
42
+ ) -> ChoicesDecision:
43
+ """Select the option with the highest token length normalized prompt logprob."""
44
+ best_choice = choices[np.argmax(normalized_prompt_logprobs)]
45
+ meta_info = {
46
+ "normalized_prompt_logprobs": normalized_prompt_logprobs,
47
+ "input_token_logprobs": input_token_logprobs,
48
+ "output_token_logprobs": output_token_logprobs,
49
+ }
50
+ return ChoicesDecision(decision=best_choice, meta_info=meta_info)
51
+
52
+
53
+ token_length_normalized = TokenLengthNormalized()
54
+
55
+
56
+ class GreedyTokenSelection(ChoicesSamplingMethod):
57
+
58
+ def __call__(
59
+ self,
60
+ *,
61
+ choices: List[str],
62
+ normalized_prompt_logprobs: List[float],
63
+ input_token_logprobs: List[List[Any]],
64
+ output_token_logprobs: List[List[Any]],
65
+ unconditional_token_logprobs: Optional[List[List[Any]]] = None,
66
+ ) -> ChoicesDecision:
67
+ """Select the option based on greedy logprob selection. For overlapping options
68
+ where one option is a subset of a longer option, extend the shorter option using
69
+ its average logprob for comparison against the longer option."""
70
+
71
+ num_options = len(choices)
72
+ max_tokens = max(len(option) for option in input_token_logprobs)
73
+ logprob_matrix = self._build_logprob_matrix(
74
+ input_token_logprobs, max_tokens, num_options
75
+ )
76
+ remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens)
77
+
78
+ best_choice = choices[remaining[0]]
79
+ meta_info = {
80
+ "normalized_prompt_logprobs": normalized_prompt_logprobs,
81
+ "input_token_logprobs": input_token_logprobs,
82
+ "output_token_logprobs": output_token_logprobs,
83
+ "greedy_logprob_matrix": logprob_matrix.tolist(),
84
+ }
85
+ return ChoicesDecision(decision=best_choice, meta_info=meta_info)
86
+
87
+ def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options):
88
+ logprob_matrix = np.zeros((num_options, max_tokens))
89
+ for i, option in enumerate(input_token_logprobs):
90
+ actual_logprobs = [token[0] for token in option]
91
+ avg_logprob = np.mean(actual_logprobs)
92
+ logprob_matrix[i, : len(option)] = actual_logprobs
93
+ if len(option) < max_tokens:
94
+ logprob_matrix[i, len(option) :] = avg_logprob
95
+ return logprob_matrix
96
+
97
+ def _greedy_selection(self, logprob_matrix, num_options, max_tokens):
98
+ remaining = np.arange(num_options)
99
+ for j in range(max_tokens):
100
+ max_logprob = np.max(logprob_matrix[remaining, j])
101
+ remaining = remaining[logprob_matrix[remaining, j] == max_logprob]
102
+ if len(remaining) == 1:
103
+ break
104
+ return remaining
105
+
106
+
107
+ greedy_token_selection = GreedyTokenSelection()
108
+
109
+
110
+ class UnconditionalLikelihoodNormalized(ChoicesSamplingMethod):
111
+
112
+ @property
113
+ def requires_unconditional_logprobs(self) -> bool:
114
+ return True
115
+
116
+ def __call__(
117
+ self,
118
+ *,
119
+ choices: List[str],
120
+ normalized_prompt_logprobs: List[float],
121
+ input_token_logprobs: List[List[Any]],
122
+ output_token_logprobs: List[List[Any]],
123
+ unconditional_token_logprobs: Optional[List[List[Any]]] = None,
124
+ ) -> ChoicesDecision:
125
+ """Select the option with the highest average token logprob once normalized by
126
+ the unconditional token logprobs.
127
+
128
+ The first unconditional token logprob is assumed to be None. If so, it is
129
+ replaced with 0 for the purposes of normalization."""
130
+
131
+ if unconditional_token_logprobs is None:
132
+ raise ValueError(
133
+ "Unconditional token logprobs are required for this method."
134
+ )
135
+
136
+ normalized_unconditional_prompt_logprobs = self._normalize_logprobs(
137
+ input_token_logprobs, unconditional_token_logprobs
138
+ )
139
+
140
+ best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)]
141
+ meta_info = {
142
+ "normalized_prompt_logprobs": normalized_prompt_logprobs,
143
+ "input_token_logprobs": input_token_logprobs,
144
+ "output_token_logprobs": output_token_logprobs,
145
+ "unconditional_token_logprobs": unconditional_token_logprobs,
146
+ "normalized_unconditional_prompt_logprobs": normalized_unconditional_prompt_logprobs,
147
+ }
148
+ return ChoicesDecision(decision=best_choice, meta_info=meta_info)
149
+
150
+ def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs):
151
+ normalized_unconditional_prompt_logprobs = []
152
+ for inputs, unconditionals in zip(
153
+ input_token_logprobs, unconditional_token_logprobs
154
+ ):
155
+ inputs_logprobs = np.array([token[0] for token in inputs])
156
+ unconditionals_logprobs = np.array([token[0] for token in unconditionals])
157
+ unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0
158
+ normalized_unconditional_prompt_logprobs.append(
159
+ float(np.mean(inputs_logprobs - unconditionals_logprobs))
160
+ )
161
+ return normalized_unconditional_prompt_logprobs
162
+
163
+
164
+ unconditional_likelihood_normalized = UnconditionalLikelihoodNormalized()
@@ -538,24 +538,17 @@ class StreamExecutor:
538
538
  self.stream_var_event[name].set()
539
539
 
540
540
  def _execute_select(self, expr: SglSelect):
541
- (
542
- decision,
543
- normalized_prompt_logprobs,
544
- input_token_logprobs,
545
- output_token_logprobs,
546
- ) = self.backend.select(self, expr.choices, expr.temperature)
541
+ choices_decision = self.backend.select(
542
+ self, expr.choices, expr.temperature, expr.choices_method
543
+ )
547
544
  if expr.name is not None:
548
545
  name = expr.name
549
- self.variables[name] = decision
550
- self.meta_info[name] = {
551
- "normalized_prompt_logprobs": normalized_prompt_logprobs,
552
- "input_token_logprobs": input_token_logprobs,
553
- "output_token_logprobs": output_token_logprobs,
554
- }
546
+ self.variables[name] = choices_decision.decision
547
+ self.meta_info[name] = choices_decision.meta_info
555
548
  self.variable_event[name].set()
556
549
  if self.stream_var_event:
557
550
  self.stream_var_event[name].set()
558
- self.text_ += decision
551
+ self.text_ += choices_decision.decision
559
552
 
560
553
  def _execute_variable(self, expr: SglVariable):
561
554
  src_executor = expr.source_stream_executor
sglang/lang/ir.py CHANGED
@@ -6,6 +6,7 @@ import warnings
6
6
  from typing import List, Optional, Union
7
7
 
8
8
  from sglang.global_config import global_config
9
+ from sglang.lang.choices import ChoicesSamplingMethod
9
10
 
10
11
  REGEX_INT = r"[-+]?[0-9]+"
11
12
  REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
@@ -461,14 +462,22 @@ class SglRoleEnd(SglExpr):
461
462
 
462
463
 
463
464
  class SglSelect(SglExpr):
464
- def __init__(self, name: str, choices: List[str], temperature: float):
465
+
466
+ def __init__(
467
+ self,
468
+ name: str,
469
+ choices: List[str],
470
+ temperature: float,
471
+ choices_method: ChoicesSamplingMethod,
472
+ ):
465
473
  super().__init__()
466
474
  self.name = name
467
475
  self.choices = choices
468
476
  self.temperature = temperature
477
+ self.choices_method = choices_method
469
478
 
470
479
  def __repr__(self):
471
- return f"Select({self.name}, choices={self.choices})"
480
+ return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})"
472
481
 
473
482
 
474
483
  class SglFork(SglExpr):
@@ -19,7 +19,7 @@ import functools
19
19
  import json
20
20
  import os
21
21
  import warnings
22
- from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
22
+ from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union
23
23
 
24
24
  from huggingface_hub import snapshot_download
25
25
  from transformers import (
@@ -259,7 +259,7 @@ class TiktokenTokenizer:
259
259
  Literal["all"], AbstractSet[str]
260
260
  ] = set(), # noqa: B006
261
261
  disallowed_special: Union[Literal["all"], Collection[str]] = "all",
262
- ) -> list[int]:
262
+ ) -> List[int]:
263
263
  if isinstance(allowed_special, set):
264
264
  allowed_special |= self._default_allowed_special
265
265
  return tiktoken.Encoding.encode(
@@ -57,6 +57,8 @@ def _fwd_kernel(
57
57
  stride_buf_vh,
58
58
  stride_req_to_tokens_b,
59
59
  BLOCK_DMODEL: tl.constexpr,
60
+ BLOCK_DPE: tl.constexpr,
61
+ BLOCK_DV: tl.constexpr,
60
62
  BLOCK_M: tl.constexpr,
61
63
  BLOCK_N: tl.constexpr,
62
64
  logit_cap: tl.constexpr,
@@ -75,8 +77,10 @@ def _fwd_kernel(
75
77
  cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
76
78
 
77
79
  offs_d = tl.arange(0, BLOCK_DMODEL)
80
+ offs_dv = tl.arange(0, BLOCK_DV)
78
81
  offs_m = tl.arange(0, BLOCK_M)
79
82
  mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
83
+
80
84
  offs_q = (
81
85
  (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
82
86
  * stride_qbs
@@ -85,10 +89,20 @@ def _fwd_kernel(
85
89
  )
86
90
  q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
87
91
 
92
+ if BLOCK_DPE > 0:
93
+ offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
94
+ offs_qpe = (
95
+ (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
96
+ * stride_qbs
97
+ + cur_head * stride_qh
98
+ + offs_dpe[None, :]
99
+ )
100
+ qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
101
+
88
102
  # stage1: compute scores with prefix
89
103
  offs_n = tl.arange(0, BLOCK_N)
90
104
 
91
- acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
105
+ acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
92
106
  deno = tl.zeros([BLOCK_M], dtype=tl.float32)
93
107
  e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
94
108
 
@@ -110,6 +124,18 @@ def _fwd_kernel(
110
124
 
111
125
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
112
126
  qk += tl.dot(q, k)
127
+ if BLOCK_DPE > 0:
128
+ offs_kpe = (
129
+ offs_kv_loc[None, :] * stride_buf_kbs
130
+ + cur_kv_head * stride_buf_kh
131
+ + offs_dpe[:, None]
132
+ )
133
+ kpe = tl.load(
134
+ K_Buffer + offs_kpe,
135
+ mask=mask_n[None, :],
136
+ other=0.0,
137
+ )
138
+ qk += tl.dot(qpe, kpe)
113
139
  qk *= sm_scale
114
140
 
115
141
  if logit_cap > 0:
@@ -125,7 +151,7 @@ def _fwd_kernel(
125
151
  offs_buf_v = (
126
152
  offs_kv_loc[:, None] * stride_buf_vbs
127
153
  + cur_kv_head * stride_buf_vh
128
- + offs_d[None, :]
154
+ + offs_dv[None, :]
129
155
  )
130
156
  v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
131
157
  p = p.to(v.dtype)
@@ -150,6 +176,21 @@ def _fwd_kernel(
150
176
 
151
177
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
152
178
  qk += tl.dot(q, k)
179
+
180
+ if BLOCK_DPE > 0:
181
+ offs_kpe = (
182
+ (cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
183
+ * stride_kbs
184
+ + cur_kv_head * stride_kh
185
+ + offs_dpe[:, None]
186
+ )
187
+ kpe = tl.load(
188
+ K_Extend + offs_kpe,
189
+ mask=mask_n[None, :],
190
+ other=0.0,
191
+ )
192
+ qk += tl.dot(qpe, kpe)
193
+
153
194
  qk *= sm_scale
154
195
 
155
196
  if logit_cap > 0:
@@ -169,7 +210,7 @@ def _fwd_kernel(
169
210
  offs_v = (
170
211
  (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
171
212
  + cur_kv_head * stride_vh
172
- + offs_d[None, :]
213
+ + offs_dv[None, :]
173
214
  )
174
215
  v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
175
216
  p = p.to(v.dtype)
@@ -181,7 +222,7 @@ def _fwd_kernel(
181
222
  (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
182
223
  * stride_obs
183
224
  + cur_head * stride_oh
184
- + offs_d[None, :]
225
+ + offs_dv[None, :]
185
226
  )
186
227
  tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
187
228
 
@@ -217,8 +258,17 @@ def extend_attention_fwd(
217
258
  o_extend.shape[-1],
218
259
  )
219
260
 
220
- assert Lq == Lk and Lk == Lv and Lv == Lo
221
- assert Lq in {16, 32, 64, 128, 256}
261
+ assert Lq == Lk and Lv == Lo
262
+ assert Lq in {16, 32, 64, 128, 256, 576}
263
+ assert Lv in {16, 32, 64, 128, 256, 512}
264
+
265
+ if Lq == 576:
266
+ BLOCK_DMODEL = 512
267
+ BLOCK_DPE = 64
268
+ else:
269
+ BLOCK_DMODEL = Lq
270
+ BLOCK_DPE = 0
271
+ BLOCK_DV = Lv
222
272
 
223
273
  if CUDA_CAPABILITY[0] >= 8:
224
274
  BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
@@ -260,7 +310,9 @@ def extend_attention_fwd(
260
310
  v_buffer.stride(0),
261
311
  v_buffer.stride(1),
262
312
  req_to_tokens.stride(0),
263
- BLOCK_DMODEL=Lq,
313
+ BLOCK_DMODEL=BLOCK_DMODEL,
314
+ BLOCK_DPE=BLOCK_DPE,
315
+ BLOCK_DV=BLOCK_DV,
264
316
  BLOCK_M=BLOCK_M,
265
317
  BLOCK_N=BLOCK_N,
266
318
  num_warps=num_warps,
@@ -25,7 +25,7 @@ from vllm.distributed import (
25
25
  tensor_model_parallel_all_gather,
26
26
  )
27
27
 
28
- from sglang.srt.model_executor.model_runner import ForwardMode, InputMetadata
28
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
29
29
 
30
30
 
31
31
  @dataclasses.dataclass
@@ -22,11 +22,8 @@ from torch import nn
22
22
  from sglang.global_config import global_config
23
23
  from sglang.srt.layers.extend_attention import extend_attention_fwd
24
24
  from sglang.srt.layers.token_attention import token_attention_fwd
25
- from sglang.srt.model_executor.model_runner import (
26
- ForwardMode,
27
- InputMetadata,
28
- global_server_args_dict,
29
- )
25
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
26
+ from sglang.srt.model_executor.model_runner import global_server_args_dict
30
27
 
31
28
 
32
29
  class RadixAttention(nn.Module):
@@ -38,16 +35,22 @@ class RadixAttention(nn.Module):
38
35
  num_kv_heads: int,
39
36
  layer_id: int,
40
37
  logit_cap: int = -1,
38
+ v_head_dim: int = -1,
41
39
  ):
42
40
  super().__init__()
43
41
  self.tp_q_head_num = num_heads
44
42
  self.tp_k_head_num = num_kv_heads
45
43
  self.tp_v_head_num = num_kv_heads
46
44
  self.head_dim = head_dim
45
+ self.qk_head_dim = head_dim
46
+ self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
47
47
  self.scaling = scaling
48
48
  self.layer_id = layer_id
49
49
 
50
- if not global_server_args_dict.get("disable_flashinfer", False):
50
+ if (
51
+ not global_server_args_dict.get("disable_flashinfer", False)
52
+ and self.qk_head_dim == self.v_head_dim
53
+ ):
51
54
  self.extend_forward = self.extend_forward_flashinfer
52
55
  self.decode_forward = self.decode_forward_flashinfer
53
56
  else:
@@ -57,13 +60,17 @@ class RadixAttention(nn.Module):
57
60
  self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
58
61
 
59
62
  def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
60
- o = torch.empty_like(q)
63
+ if self.qk_head_dim != self.v_head_dim:
64
+ o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
65
+ else:
66
+ o = torch.empty_like(q)
67
+
61
68
  self.store_kv_cache(k, v, input_metadata)
62
69
  extend_attention_fwd(
63
- q.view(-1, self.tp_q_head_num, self.head_dim),
70
+ q.view(-1, self.tp_q_head_num, self.qk_head_dim),
64
71
  k.contiguous(),
65
72
  v.contiguous(),
66
- o.view(-1, self.tp_q_head_num, self.head_dim),
73
+ o.view(-1, self.tp_q_head_num, self.v_head_dim),
67
74
  input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
68
75
  input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
69
76
  input_metadata.req_to_token_pool.req_to_token,
@@ -82,14 +89,17 @@ class RadixAttention(nn.Module):
82
89
  return o
83
90
 
84
91
  def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
85
- o = torch.empty_like(q)
92
+ if self.qk_head_dim != self.v_head_dim:
93
+ o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
94
+ else:
95
+ o = torch.empty_like(q)
86
96
  self.store_kv_cache(k, v, input_metadata)
87
97
 
88
98
  token_attention_fwd(
89
- q.view(-1, self.tp_q_head_num, self.head_dim),
99
+ q.view(-1, self.tp_q_head_num, self.qk_head_dim),
90
100
  input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
91
101
  input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
92
- o.view(-1, self.tp_q_head_num, self.head_dim),
102
+ o.view(-1, self.tp_q_head_num, self.v_head_dim),
93
103
  input_metadata.req_to_token_pool.req_to_token,
94
104
  input_metadata.req_pool_indices,
95
105
  input_metadata.triton_start_loc,
@@ -160,8 +170,8 @@ class RadixAttention(nn.Module):
160
170
  return o.view(-1, self.tp_q_head_num * self.head_dim)
161
171
 
162
172
  def forward(self, q, k, v, input_metadata: InputMetadata):
163
- k = k.view(-1, self.tp_k_head_num, self.head_dim)
164
- v = v.view(-1, self.tp_v_head_num, self.head_dim)
173
+ k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
174
+ v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
165
175
 
166
176
  if input_metadata.forward_mode == ForwardMode.EXTEND:
167
177
  return self.extend_forward(q, k, v, input_metadata)