sglang 0.1.21__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/vertexai.py +5 -4
  4. sglang/bench.py +627 -0
  5. sglang/bench_latency.py +22 -19
  6. sglang/bench_serving.py +758 -0
  7. sglang/check_env.py +171 -0
  8. sglang/lang/backend/__init__.py +0 -0
  9. sglang/lang/backend/anthropic.py +77 -0
  10. sglang/lang/backend/base_backend.py +80 -0
  11. sglang/lang/backend/litellm.py +90 -0
  12. sglang/lang/backend/openai.py +438 -0
  13. sglang/lang/backend/runtime_endpoint.py +283 -0
  14. sglang/lang/backend/vertexai.py +149 -0
  15. sglang/lang/tracer.py +1 -1
  16. sglang/launch_server.py +1 -1
  17. sglang/launch_server_llavavid.py +1 -4
  18. sglang/srt/conversation.py +1 -1
  19. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  20. sglang/srt/layers/extend_attention.py +0 -39
  21. sglang/srt/layers/linear.py +869 -0
  22. sglang/srt/layers/quantization/__init__.py +49 -0
  23. sglang/srt/layers/quantization/fp8.py +662 -0
  24. sglang/srt/layers/radix_attention.py +31 -5
  25. sglang/srt/layers/token_attention.py +1 -51
  26. sglang/srt/managers/controller/cuda_graph_runner.py +14 -12
  27. sglang/srt/managers/controller/infer_batch.py +47 -49
  28. sglang/srt/managers/controller/manager_multi.py +107 -100
  29. sglang/srt/managers/controller/manager_single.py +76 -96
  30. sglang/srt/managers/controller/model_runner.py +35 -23
  31. sglang/srt/managers/controller/tp_worker.py +127 -138
  32. sglang/srt/managers/detokenizer_manager.py +49 -5
  33. sglang/srt/managers/io_struct.py +36 -17
  34. sglang/srt/managers/tokenizer_manager.py +228 -125
  35. sglang/srt/memory_pool.py +19 -6
  36. sglang/srt/model_loader/model_loader.py +277 -0
  37. sglang/srt/model_loader/utils.py +260 -0
  38. sglang/srt/models/chatglm.py +1 -0
  39. sglang/srt/models/dbrx.py +1 -0
  40. sglang/srt/models/grok.py +1 -0
  41. sglang/srt/models/internlm2.py +317 -0
  42. sglang/srt/models/llama2.py +65 -16
  43. sglang/srt/models/llama_classification.py +1 -0
  44. sglang/srt/models/llava.py +1 -0
  45. sglang/srt/models/llavavid.py +1 -0
  46. sglang/srt/models/minicpm.py +1 -0
  47. sglang/srt/models/mixtral.py +1 -0
  48. sglang/srt/models/mixtral_quant.py +1 -0
  49. sglang/srt/models/qwen.py +1 -0
  50. sglang/srt/models/qwen2.py +6 -0
  51. sglang/srt/models/qwen2_moe.py +7 -4
  52. sglang/srt/models/stablelm.py +1 -0
  53. sglang/srt/openai_api/adapter.py +432 -0
  54. sglang/srt/openai_api/api_adapter.py +432 -0
  55. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  56. sglang/srt/openai_api/openai_protocol.py +207 -0
  57. sglang/srt/openai_api/protocol.py +208 -0
  58. sglang/srt/openai_protocol.py +17 -0
  59. sglang/srt/sampling_params.py +2 -0
  60. sglang/srt/server.py +113 -84
  61. sglang/srt/server_args.py +23 -15
  62. sglang/srt/utils.py +16 -117
  63. sglang/test/test_conversation.py +1 -1
  64. sglang/test/test_openai_protocol.py +1 -1
  65. sglang/test/test_programs.py +1 -1
  66. sglang/test/test_utils.py +2 -2
  67. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -167
  68. sglang-0.1.22.dist-info/RECORD +103 -0
  69. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
  70. sglang-0.1.21.dist-info/RECORD +0 -82
  71. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
  72. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,283 @@
1
+ import json
2
+ from typing import List, Optional
3
+
4
+ import numpy as np
5
+
6
+ from sglang.global_config import global_config
7
+ from sglang.lang.backend.base_backend import BaseBackend
8
+ from sglang.lang.chat_template import get_chat_template_by_model_path
9
+ from sglang.lang.interpreter import StreamExecutor
10
+ from sglang.lang.ir import SglSamplingParams
11
+ from sglang.utils import http_request
12
+
13
+
14
+ class RuntimeEndpoint(BaseBackend):
15
+ def __init__(
16
+ self,
17
+ base_url: str,
18
+ auth_token: Optional[str] = None,
19
+ api_key: Optional[str] = None,
20
+ verify: Optional[str] = None,
21
+ ):
22
+ super().__init__()
23
+ self.support_concate_and_append = True
24
+
25
+ self.base_url = base_url
26
+ self.auth_token = auth_token
27
+ self.api_key = api_key
28
+ self.verify = verify
29
+
30
+ res = http_request(
31
+ self.base_url + "/get_model_info",
32
+ auth_token=self.auth_token,
33
+ api_key=self.api_key,
34
+ verify=self.verify,
35
+ )
36
+ self._assert_success(res)
37
+ self.model_info = res.json()
38
+
39
+ self.chat_template = get_chat_template_by_model_path(
40
+ self.model_info["model_path"]
41
+ )
42
+
43
+ def get_model_name(self):
44
+ return self.model_info["model_path"]
45
+
46
+ def flush_cache(self):
47
+ res = http_request(
48
+ self.base_url + "/flush_cache",
49
+ auth_token=self.auth_token,
50
+ verify=self.verify,
51
+ )
52
+ self._assert_success(res)
53
+
54
+ def get_server_args(self):
55
+ res = http_request(
56
+ self.base_url + "/get_server_args",
57
+ auth_token=self.auth_token,
58
+ verify=self.verify,
59
+ )
60
+ self._assert_success(res)
61
+ return res.json()
62
+
63
+ def get_chat_template(self):
64
+ return self.chat_template
65
+
66
+ def cache_prefix(self, prefix_str: str):
67
+ res = http_request(
68
+ self.base_url + "/generate",
69
+ json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
70
+ auth_token=self.auth_token,
71
+ api_key=self.api_key,
72
+ verify=self.verify,
73
+ )
74
+ self._assert_success(res)
75
+
76
+ def commit_lazy_operations(self, s: StreamExecutor):
77
+ data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
78
+ self._add_images(s, data)
79
+ res = http_request(
80
+ self.base_url + "/generate",
81
+ json=data,
82
+ auth_token=self.auth_token,
83
+ api_key=self.api_key,
84
+ verify=self.verify,
85
+ )
86
+ self._assert_success(res)
87
+
88
+ def fill_image(self, s: StreamExecutor):
89
+ data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
90
+ self._add_images(s, data)
91
+ res = http_request(
92
+ self.base_url + "/generate",
93
+ json=data,
94
+ auth_token=self.auth_token,
95
+ api_key=self.api_key,
96
+ verify=self.verify,
97
+ )
98
+ self._assert_success(res)
99
+
100
+ def generate(
101
+ self,
102
+ s: StreamExecutor,
103
+ sampling_params: SglSamplingParams,
104
+ ):
105
+ if sampling_params.dtype is None:
106
+ data = {
107
+ "text": s.text_,
108
+ "sampling_params": {
109
+ "skip_special_tokens": global_config.skip_special_tokens_in_output,
110
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
111
+ **sampling_params.to_srt_kwargs(),
112
+ },
113
+ }
114
+ elif sampling_params.dtype in [int, "int"]:
115
+ data = {
116
+ "text": s.text_,
117
+ "sampling_params": {
118
+ "skip_special_tokens": global_config.skip_special_tokens_in_output,
119
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
120
+ "dtype": "int",
121
+ **sampling_params.to_srt_kwargs(),
122
+ },
123
+ }
124
+ else:
125
+ raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
126
+
127
+ for item in [
128
+ "return_logprob",
129
+ "logprob_start_len",
130
+ "top_logprobs_num",
131
+ "return_text_in_logprobs",
132
+ ]:
133
+ value = getattr(sampling_params, item, None)
134
+ if value is not None:
135
+ data[item] = value
136
+
137
+ self._add_images(s, data)
138
+
139
+ res = http_request(
140
+ self.base_url + "/generate",
141
+ json=data,
142
+ auth_token=self.auth_token,
143
+ api_key=self.api_key,
144
+ verify=self.verify,
145
+ )
146
+ self._assert_success(res)
147
+
148
+ obj = res.json()
149
+ comp = obj["text"]
150
+ return comp, obj["meta_info"]
151
+
152
+ def generate_stream(
153
+ self,
154
+ s: StreamExecutor,
155
+ sampling_params: SglSamplingParams,
156
+ ):
157
+ if sampling_params.dtype is None:
158
+ data = {
159
+ "text": s.text_,
160
+ "sampling_params": {
161
+ "skip_special_tokens": global_config.skip_special_tokens_in_output,
162
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
163
+ **sampling_params.to_srt_kwargs(),
164
+ },
165
+ }
166
+ elif sampling_params.dtype in [int, "int"]:
167
+ data = {
168
+ "text": s.text_,
169
+ "sampling_params": {
170
+ "skip_special_tokens": global_config.skip_special_tokens_in_output,
171
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
172
+ "dtype": "int",
173
+ **sampling_params.to_srt_kwargs(),
174
+ },
175
+ }
176
+ else:
177
+ raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
178
+
179
+ for item in [
180
+ "return_logprob",
181
+ "logprob_start_len",
182
+ "top_logprobs_num",
183
+ "return_text_in_logprobs",
184
+ ]:
185
+ value = getattr(sampling_params, item, None)
186
+ if value is not None:
187
+ data[item] = value
188
+
189
+ data["stream"] = True
190
+ self._add_images(s, data)
191
+
192
+ res = http_request(
193
+ self.base_url + "/generate",
194
+ json=data,
195
+ stream=True,
196
+ auth_token=self.auth_token,
197
+ api_key=self.api_key,
198
+ verify=self.verify,
199
+ )
200
+ self._assert_success(res)
201
+ pos = 0
202
+
203
+ for chunk in res.iter_lines(decode_unicode=False):
204
+ chunk = chunk.decode("utf-8")
205
+ if chunk and chunk.startswith("data:"):
206
+ if chunk == "data: [DONE]":
207
+ break
208
+ data = json.loads(chunk[5:].strip("\n"))
209
+ chunk_text = data["text"][pos:]
210
+ meta_info = data["meta_info"]
211
+ pos += len(chunk_text)
212
+ yield chunk_text, meta_info
213
+
214
+ def select(
215
+ self,
216
+ s: StreamExecutor,
217
+ choices: List[str],
218
+ temperature: float,
219
+ ):
220
+ assert temperature <= 1e-5
221
+
222
+ # Cache common prefix
223
+ data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
224
+ self._add_images(s, data)
225
+ res = http_request(
226
+ self.base_url + "/generate",
227
+ json=data,
228
+ auth_token=self.auth_token,
229
+ api_key=self.api_key,
230
+ verify=self.verify,
231
+ )
232
+ self._assert_success(res)
233
+ prompt_len = res.json()["meta_info"]["prompt_tokens"]
234
+
235
+ # Compute logprob
236
+ data = {
237
+ "text": [s.text_ + c for c in choices],
238
+ "sampling_params": {"max_new_tokens": 0},
239
+ "return_logprob": True,
240
+ "logprob_start_len": max(prompt_len - 2, 0),
241
+ }
242
+ self._add_images(s, data)
243
+ res = http_request(
244
+ self.base_url + "/generate",
245
+ json=data,
246
+ auth_token=self.auth_token,
247
+ api_key=self.api_key,
248
+ verify=self.verify,
249
+ )
250
+ self._assert_success(res)
251
+ obj = res.json()
252
+ normalized_prompt_logprobs = [
253
+ r["meta_info"]["normalized_prompt_logprob"] for r in obj
254
+ ]
255
+ decision = choices[np.argmax(normalized_prompt_logprobs)]
256
+ prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
257
+ decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
258
+
259
+ return (
260
+ decision,
261
+ normalized_prompt_logprobs,
262
+ prefill_token_logprobs,
263
+ decode_token_logprobs,
264
+ )
265
+
266
+ def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
267
+ res = http_request(
268
+ self.base_url + "/concate_and_append_request",
269
+ json={"src_rids": src_rids, "dst_rid": dst_rid},
270
+ auth_token=self.auth_token,
271
+ api_key=self.api_key,
272
+ verify=self.verify,
273
+ )
274
+ self._assert_success(res)
275
+
276
+ def _add_images(self, s: StreamExecutor, data):
277
+ if s.images_:
278
+ assert len(s.images_) == 1, "Only support one image."
279
+ data["image_data"] = s.images_[0][1]
280
+
281
+ def _assert_success(self, res):
282
+ if res.status_code != 200:
283
+ raise RuntimeError(res.json())
@@ -0,0 +1,149 @@
1
+ import os
2
+ import warnings
3
+ from typing import Optional
4
+
5
+ from sglang.lang.backend.base_backend import BaseBackend
6
+ from sglang.lang.chat_template import get_chat_template
7
+ from sglang.lang.interpreter import StreamExecutor
8
+ from sglang.lang.ir import SglSamplingParams
9
+
10
+ try:
11
+ import vertexai
12
+ from vertexai.preview.generative_models import (
13
+ GenerationConfig,
14
+ GenerativeModel,
15
+ Image,
16
+ )
17
+ except ImportError as e:
18
+ GenerativeModel = e
19
+
20
+
21
+ class VertexAI(BaseBackend):
22
+ def __init__(self, model_name, safety_settings=None):
23
+ super().__init__()
24
+
25
+ if isinstance(GenerativeModel, Exception):
26
+ raise GenerativeModel
27
+
28
+ project_id = os.environ["GCP_PROJECT_ID"]
29
+ location = os.environ.get("GCP_LOCATION")
30
+ vertexai.init(project=project_id, location=location)
31
+
32
+ self.model_name = model_name
33
+ self.chat_template = get_chat_template("default")
34
+ self.safety_settings = safety_settings
35
+
36
+ def get_chat_template(self):
37
+ return self.chat_template
38
+
39
+ def generate(
40
+ self,
41
+ s: StreamExecutor,
42
+ sampling_params: SglSamplingParams,
43
+ ):
44
+ if s.messages_:
45
+ prompt = self.messages_to_vertexai_input(s.messages_)
46
+ else:
47
+ # single-turn
48
+ prompt = (
49
+ self.text_to_vertexai_input(s.text_, s.cur_images)
50
+ if s.cur_images
51
+ else s.text_
52
+ )
53
+ ret = GenerativeModel(self.model_name).generate_content(
54
+ prompt,
55
+ generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
56
+ safety_settings=self.safety_settings,
57
+ )
58
+
59
+ comp = ret.text
60
+
61
+ return comp, {}
62
+
63
+ def generate_stream(
64
+ self,
65
+ s: StreamExecutor,
66
+ sampling_params: SglSamplingParams,
67
+ ):
68
+ if s.messages_:
69
+ prompt = self.messages_to_vertexai_input(s.messages_)
70
+ else:
71
+ # single-turn
72
+ prompt = (
73
+ self.text_to_vertexai_input(s.text_, s.cur_images)
74
+ if s.cur_images
75
+ else s.text_
76
+ )
77
+ generator = GenerativeModel(self.model_name).generate_content(
78
+ prompt,
79
+ stream=True,
80
+ generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
81
+ safety_settings=self.safety_settings,
82
+ )
83
+ for ret in generator:
84
+ yield ret.text, {}
85
+
86
+ def text_to_vertexai_input(self, text, images):
87
+ input = []
88
+ # split with image token
89
+ text_segs = text.split(self.chat_template.image_token)
90
+ for image_path, image_base64_data in images:
91
+ text_seg = text_segs.pop(0)
92
+ if text_seg != "":
93
+ input.append(text_seg)
94
+ input.append(Image.from_bytes(image_base64_data))
95
+ text_seg = text_segs.pop(0)
96
+ if text_seg != "":
97
+ input.append(text_seg)
98
+ return input
99
+
100
+ def messages_to_vertexai_input(self, messages):
101
+ vertexai_message = []
102
+ # from openai message format to vertexai message format
103
+ for msg in messages:
104
+ if isinstance(msg["content"], str):
105
+ text = msg["content"]
106
+ else:
107
+ text = msg["content"][0]["text"]
108
+
109
+ if msg["role"] == "system":
110
+ warnings.warn("Warning: system prompt is not supported in VertexAI.")
111
+ vertexai_message.append(
112
+ {
113
+ "role": "user",
114
+ "parts": [{"text": "System prompt: " + text}],
115
+ }
116
+ )
117
+ vertexai_message.append(
118
+ {
119
+ "role": "model",
120
+ "parts": [{"text": "Understood."}],
121
+ }
122
+ )
123
+ continue
124
+ if msg["role"] == "user":
125
+ vertexai_msg = {
126
+ "role": "user",
127
+ "parts": [{"text": text}],
128
+ }
129
+ elif msg["role"] == "assistant":
130
+ vertexai_msg = {
131
+ "role": "model",
132
+ "parts": [{"text": text}],
133
+ }
134
+
135
+ # images
136
+ if isinstance(msg["content"], list) and len(msg["content"]) > 1:
137
+ for image in msg["content"][1:]:
138
+ assert image["type"] == "image_url"
139
+ vertexai_msg["parts"].append(
140
+ {
141
+ "inline_data": {
142
+ "data": image["image_url"]["url"].split(",")[1],
143
+ "mime_type": "image/jpeg",
144
+ }
145
+ }
146
+ )
147
+
148
+ vertexai_message.append(vertexai_msg)
149
+ return vertexai_message
sglang/lang/tracer.py CHANGED
@@ -3,8 +3,8 @@
3
3
  import uuid
4
4
  from typing import Any, Callable, Dict, List, Optional, Union
5
5
 
6
- from sglang.backend.base_backend import BaseBackend
7
6
  from sglang.global_config import global_config
7
+ from sglang.lang.backend.base_backend import BaseBackend
8
8
  from sglang.lang.interpreter import ProgramState, ProgramStateGroup
9
9
  from sglang.lang.ir import (
10
10
  SglArgument,
sglang/launch_server.py CHANGED
@@ -11,4 +11,4 @@ if __name__ == "__main__":
11
11
  args = parser.parse_args()
12
12
  server_args = ServerArgs.from_cli_args(args)
13
13
 
14
- launch_server(server_args, None)
14
+ launch_server(server_args)
@@ -1,7 +1,6 @@
1
1
  """Launch the inference server for Llava-video model."""
2
2
 
3
3
  import argparse
4
- import multiprocessing as mp
5
4
 
6
5
  from sglang.srt.server import ServerArgs, launch_server
7
6
 
@@ -27,6 +26,4 @@ if __name__ == "__main__":
27
26
 
28
27
  server_args = ServerArgs.from_cli_args(args)
29
28
 
30
- pipe_reader, pipe_writer = mp.Pipe(duplex=False)
31
-
32
- launch_server(server_args, pipe_writer, model_overide_args)
29
+ launch_server(server_args, model_overide_args, None)
@@ -6,7 +6,7 @@ import dataclasses
6
6
  from enum import IntEnum, auto
7
7
  from typing import Dict, List, Optional, Tuple, Union
8
8
 
9
- from sglang.srt.openai_protocol import ChatCompletionRequest
9
+ from sglang.srt.openai_api.protocol import ChatCompletionRequest
10
10
 
11
11
 
12
12
  class SeparatorStyle(IntEnum):
@@ -4,8 +4,6 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from sglang.srt.utils import wrap_kernel_launcher
8
-
9
7
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
10
8
 
11
9
 
@@ -119,9 +117,6 @@ def _fwd_kernel(
119
117
  tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
120
118
 
121
119
 
122
- cached_kernel = None
123
-
124
-
125
120
  def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
126
121
  if CUDA_CAPABILITY[0] >= 8:
127
122
  BLOCK = 128
@@ -139,29 +134,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
139
134
  grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
140
135
  num_warps = 4 if Lk <= 64 else 8
141
136
 
142
- global cached_kernel
143
- if cached_kernel:
144
- cached_kernel(
145
- grid,
146
- num_warps,
147
- q,
148
- k,
149
- v,
150
- sm_scale,
151
- b_start_loc,
152
- b_seq_len,
153
- o,
154
- q.stride(0),
155
- q.stride(1),
156
- k.stride(0),
157
- k.stride(1),
158
- v.stride(0),
159
- v.stride(1),
160
- o.stride(0),
161
- o.stride(1),
162
- )
163
- return
164
-
165
137
  _fwd_kernel[grid](
166
138
  q,
167
139
  k,
@@ -185,4 +157,3 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
185
157
  num_warps=num_warps,
186
158
  num_stages=1,
187
159
  )
188
- cached_kernel = wrap_kernel_launcher(_fwd_kernel)
@@ -3,7 +3,6 @@ import triton
3
3
  import triton.language as tl
4
4
 
5
5
  from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
6
- from sglang.srt.utils import wrap_kernel_launcher
7
6
 
8
7
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
9
8
 
@@ -172,9 +171,6 @@ def _fwd_kernel(
172
171
  tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
173
172
 
174
173
 
175
- cached_kernel = None
176
-
177
-
178
174
  def extend_attention_fwd(
179
175
  q_extend,
180
176
  k_extend,
@@ -222,40 +218,6 @@ def extend_attention_fwd(
222
218
  num_warps = 4 if Lk <= 64 else 8
223
219
  num_stages = 1
224
220
 
225
- global cached_kernel
226
- if cached_kernel:
227
- cached_kernel(
228
- grid,
229
- num_warps,
230
- q_extend,
231
- k_extend,
232
- v_extend,
233
- o_extend,
234
- k_buffer,
235
- v_buffer,
236
- req_to_tokens,
237
- b_req_idx,
238
- b_seq_len,
239
- b_start_loc_extend,
240
- b_seq_len_extend,
241
- sm_scale,
242
- kv_group_num,
243
- q_extend.stride(0),
244
- q_extend.stride(1),
245
- k_extend.stride(0),
246
- k_extend.stride(1),
247
- v_extend.stride(0),
248
- v_extend.stride(1),
249
- o_extend.stride(0),
250
- o_extend.stride(1),
251
- k_buffer.stride(0),
252
- k_buffer.stride(1),
253
- v_buffer.stride(0),
254
- v_buffer.stride(1),
255
- req_to_tokens.stride(0),
256
- )
257
- return
258
-
259
221
  _fwd_kernel[grid](
260
222
  q_extend,
261
223
  k_extend,
@@ -290,7 +252,6 @@ def extend_attention_fwd(
290
252
  num_stages=num_stages,
291
253
  logit_cap=logit_cap,
292
254
  )
293
- cached_kernel = wrap_kernel_launcher(_fwd_kernel)
294
255
 
295
256
 
296
257
  def redundant_attention(