sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +3 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +8 -1
  8. sglang/lang/interpreter.py +114 -67
  9. sglang/lang/ir.py +17 -2
  10. sglang/srt/constrained/fsm_cache.py +3 -0
  11. sglang/srt/flush_cache.py +1 -1
  12. sglang/srt/hf_transformers_utils.py +75 -1
  13. sglang/srt/layers/extend_attention.py +17 -0
  14. sglang/srt/layers/fused_moe.py +485 -0
  15. sglang/srt/layers/logits_processor.py +12 -7
  16. sglang/srt/layers/radix_attention.py +10 -3
  17. sglang/srt/layers/token_attention.py +16 -1
  18. sglang/srt/managers/controller/dp_worker.py +110 -0
  19. sglang/srt/managers/controller/infer_batch.py +619 -0
  20. sglang/srt/managers/controller/manager_multi.py +191 -0
  21. sglang/srt/managers/controller/manager_single.py +97 -0
  22. sglang/srt/managers/controller/model_runner.py +462 -0
  23. sglang/srt/managers/controller/radix_cache.py +267 -0
  24. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  25. sglang/srt/managers/controller/tp_worker.py +791 -0
  26. sglang/srt/managers/detokenizer_manager.py +45 -45
  27. sglang/srt/managers/io_struct.py +15 -11
  28. sglang/srt/managers/router/infer_batch.py +103 -59
  29. sglang/srt/managers/router/manager.py +1 -1
  30. sglang/srt/managers/router/model_rpc.py +175 -122
  31. sglang/srt/managers/router/model_runner.py +91 -104
  32. sglang/srt/managers/router/radix_cache.py +7 -1
  33. sglang/srt/managers/router/scheduler.py +6 -6
  34. sglang/srt/managers/tokenizer_manager.py +152 -89
  35. sglang/srt/model_config.py +4 -5
  36. sglang/srt/models/commandr.py +10 -13
  37. sglang/srt/models/dbrx.py +9 -15
  38. sglang/srt/models/gemma.py +8 -15
  39. sglang/srt/models/grok.py +671 -0
  40. sglang/srt/models/llama2.py +19 -15
  41. sglang/srt/models/llava.py +84 -20
  42. sglang/srt/models/llavavid.py +11 -20
  43. sglang/srt/models/mixtral.py +248 -118
  44. sglang/srt/models/mixtral_quant.py +373 -0
  45. sglang/srt/models/qwen.py +9 -13
  46. sglang/srt/models/qwen2.py +11 -13
  47. sglang/srt/models/stablelm.py +9 -15
  48. sglang/srt/models/yivl.py +17 -22
  49. sglang/srt/openai_api_adapter.py +140 -95
  50. sglang/srt/openai_protocol.py +10 -1
  51. sglang/srt/server.py +77 -42
  52. sglang/srt/server_args.py +51 -6
  53. sglang/srt/utils.py +124 -66
  54. sglang/test/test_programs.py +44 -0
  55. sglang/test/test_utils.py +32 -1
  56. sglang/utils.py +22 -4
  57. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
  58. sglang-0.1.17.dist-info/RECORD +81 -0
  59. sglang/srt/backend_config.py +0 -13
  60. sglang/srt/models/dbrx_config.py +0 -281
  61. sglang/srt/weight_utils.py +0 -417
  62. sglang-0.1.16.dist-info/RECORD +0 -72
  63. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  64. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  65. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import multiprocessing
6
6
  import queue
7
7
  import threading
8
8
  import uuid
9
+ import warnings
9
10
  from concurrent.futures import ThreadPoolExecutor
10
11
  from contextlib import contextmanager
11
12
  from typing import Any, Callable, Dict, List, Optional, Union
@@ -30,7 +31,11 @@ from sglang.lang.ir import (
30
31
  SglVarScopeEnd,
31
32
  SglVideo,
32
33
  )
33
- from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback
34
+ from sglang.utils import (
35
+ encode_image_base64,
36
+ encode_video_base64,
37
+ get_exception_traceback,
38
+ )
34
39
 
35
40
 
36
41
  def run_internal(state, program, func_args, func_kwargs, sync):
@@ -61,7 +66,7 @@ def run_program(
61
66
  default_sampling_para,
62
67
  chat_template=None,
63
68
  stream=stream,
64
- api_num_spec_tokens=program.api_num_spec_tokens,
69
+ num_api_spec_tokens=program.num_api_spec_tokens,
65
70
  )
66
71
  state = ProgramState(stream_executor)
67
72
 
@@ -173,7 +178,7 @@ class StreamExecutor:
173
178
  default_sampling_para,
174
179
  chat_template,
175
180
  stream,
176
- api_num_spec_tokens=None,
181
+ num_api_spec_tokens=None,
177
182
  use_thread=True,
178
183
  ):
179
184
  self.sid = uuid.uuid4().hex
@@ -181,20 +186,16 @@ class StreamExecutor:
181
186
  self.arguments: Dict[str, Any] = arguments
182
187
  self.default_sampling_para = default_sampling_para
183
188
  self.stream = stream
184
- self.api_num_spec_tokens = api_num_spec_tokens
185
189
 
186
190
  self.variables = {} # Dict[name: str -> value: str]
187
191
  self.variable_event = {} # Dict[name: str -> event: threading.Event]
188
192
  self.meta_info = {} # Dict[name: str -> info: str]
189
193
  self.is_finished = False
190
- self.error = None
194
+ self.error_ = None
191
195
 
192
196
  # For completion
193
197
  self.text_ = "" # The full text
194
198
 
195
- # For speculative execution
196
- self.speculated_text = ""
197
-
198
199
  # For chat
199
200
  self.messages_ = [] # The messages in the OpenAI API format
200
201
  self.chat_template = chat_template or self.backend.get_chat_template()
@@ -208,6 +209,10 @@ class StreamExecutor:
208
209
  # For fork/join
209
210
  self.fork_start_text_pos = None
210
211
 
212
+ # For speculative execution
213
+ self.num_api_spec_tokens = num_api_spec_tokens
214
+ self.speculated_text = ""
215
+
211
216
  # Worker thread
212
217
  self.use_thread = use_thread
213
218
  if self.use_thread:
@@ -286,6 +291,8 @@ class StreamExecutor:
286
291
  exes[i].fork_start_text_pos = len(self.text_)
287
292
  exes[i].images_ = list(self.images_)
288
293
 
294
+ # TODO(ying): handle API speculative execution
295
+
289
296
  return exes
290
297
 
291
298
  def text(self):
@@ -296,6 +303,10 @@ class StreamExecutor:
296
303
  self.sync()
297
304
  return self.messages_
298
305
 
306
+ def error(self):
307
+ self.sync()
308
+ return self.error_
309
+
299
310
  def end(self):
300
311
  if self.use_thread:
301
312
  if self.worker.is_alive():
@@ -314,7 +325,7 @@ class StreamExecutor:
314
325
  try:
315
326
  self._execute(expr)
316
327
  except Exception as e:
317
- # print(f"Error in stream_executor: {get_exception_traceback()}")
328
+ warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
318
329
  error = e
319
330
  break
320
331
  self.queue.task_done()
@@ -334,7 +345,7 @@ class StreamExecutor:
334
345
  if self.stream_var_event:
335
346
  for name in self.stream_var_event:
336
347
  self.stream_var_event[name].set()
337
- self.error = error
348
+ self.error_ = error
338
349
 
339
350
  if self.stream_text_event:
340
351
  self.stream_text_event.set()
@@ -383,12 +394,23 @@ class StreamExecutor:
383
394
  else:
384
395
  raise ValueError(f"Unknown type: {type(other)}")
385
396
 
386
- def _execute_fill(self, value: str):
397
+ def _execute_fill(self, value: str, prefix=False):
387
398
  value = str(value)
399
+
400
+ if (
401
+ self.cur_role == "assistant"
402
+ and self.num_api_spec_tokens is not None
403
+ and self.backend.is_chat_model
404
+ and not prefix
405
+ ):
406
+ self.backend.spec_fill(value)
407
+ return
408
+
388
409
  if self.speculated_text.startswith(value):
389
410
  self.speculated_text = self.speculated_text[len(value) :]
390
411
  else:
391
412
  self.speculated_text = ""
413
+
392
414
  self.text_ += value
393
415
 
394
416
  def _execute_image(self, expr: SglImage):
@@ -413,65 +435,80 @@ class StreamExecutor:
413
435
  # if global_config.eager_fill_image:
414
436
  # self.backend.fill_image(self)
415
437
 
438
+ def _spec_gen(self, sampling_params):
439
+ stop = sampling_params.stop
440
+ max_new_tokens = sampling_params.max_new_tokens
441
+ meta_info = {}
442
+
443
+ def regen():
444
+ nonlocal meta_info
445
+
446
+ sampling_params.max_new_tokens = max(
447
+ sampling_params.max_new_tokens, self.num_api_spec_tokens
448
+ )
449
+ sampling_params.stop = None
450
+ self.speculated_text, meta_info = self.backend.generate(
451
+ self, sampling_params=sampling_params
452
+ )
453
+
454
+ def find_stop():
455
+ if isinstance(stop, str):
456
+ return self.speculated_text.find(stop)
457
+ elif isinstance(stop, (tuple, list)):
458
+ pos = -1
459
+ for stop_str in stop:
460
+ stop_pos = self.speculated_text.find(stop_str)
461
+ if stop_pos != -1 and (pos == -1 or stop_pos < pos):
462
+ pos = stop_pos
463
+ return pos
464
+ else:
465
+ raise Exception("Wrong type of stop in sampling parameters.")
466
+
467
+ if stop is None:
468
+ if len(self.speculated_text) < max_new_tokens:
469
+ regen()
470
+ comp = self.speculated_text[:max_new_tokens]
471
+ self.speculated_text = self.speculated_text[max_new_tokens:]
472
+ elif isinstance(stop, (str, list, tuple)):
473
+ if self.speculated_text == "":
474
+ regen()
475
+ stop_pos = find_stop()
476
+ if stop_pos == -1:
477
+ stop_pos = min(
478
+ sampling_params.max_new_tokens,
479
+ len(self.speculated_text),
480
+ )
481
+ comp = self.speculated_text[:stop_pos]
482
+ self.speculated_text = self.speculated_text[stop_pos:]
483
+ else:
484
+ raise ValueError("Wrong type of stop in sampling parameters.")
485
+
486
+ return comp, meta_info
487
+
416
488
  def _execute_gen(self, expr: SglGen):
417
489
  sampling_params = self._resolve_sampling_params(expr.sampling_params)
418
490
  name = expr.name
419
491
 
420
492
  if not self.stream:
421
- if self.api_num_spec_tokens is not None:
422
- stop = sampling_params.stop
423
- max_new_tokens = sampling_params.max_new_tokens
424
- meta_info = {}
425
-
426
- def regen():
427
- sampling_params.max_new_tokens = max(
428
- sampling_params.max_new_tokens, self.api_num_spec_tokens
429
- )
430
- sampling_params.stop = None
431
- self.speculated_text, meta_info = self.backend.generate(
432
- self, sampling_params=sampling_params
433
- )
434
-
435
- def find_stop():
436
- if isinstance(stop, str):
437
- return self.speculated_text.find(stop), len(stop)
438
- elif isinstance(stop, (tuple, list)):
439
- pos = -1
440
- stop_len = 0
441
- for stop_str in stop:
442
- stop_pos = self.speculated_text.find(stop_str)
443
- if stop_pos != -1 and (pos == -1 or stop_pos < pos):
444
- pos = stop_pos
445
- stop_len = len(stop_str)
446
- return pos, stop_len
447
- else:
448
- raise Exception("Wrong type of stop in sampling parameters.")
449
-
450
- if stop is None:
451
- if len(self.speculated_text) < max_new_tokens:
452
- regen()
453
- comp = self.speculated_text[:max_new_tokens]
454
- self.speculated_text = self.speculated_text[max_new_tokens:]
455
- elif isinstance(stop, (str, list, tuple)):
456
- if self.speculated_text == "":
457
- regen()
458
- stop_pos, stop_len = find_stop()
459
- if stop_pos == -1:
460
- stop_pos, stop_len = (
461
- min(
462
- sampling_params.max_new_tokens,
463
- len(self.speculated_text),
464
- ),
465
- 0,
466
- )
467
- comp = self.speculated_text[:stop_pos]
468
- self.speculated_text = self.speculated_text[stop_pos:]
469
- else:
470
- raise ValueError("Wrong type of stop in sampling parameters.")
471
- else:
493
+ if self.num_api_spec_tokens is None:
472
494
  comp, meta_info = self.backend.generate(
473
- self, sampling_params=sampling_params
495
+ self,
496
+ sampling_params=sampling_params,
474
497
  )
498
+ else:
499
+ if self.backend.is_chat_model:
500
+ # Speculative execution on models with only chat interface.
501
+ # Store the calls into a temporary list.
502
+ # They will be lazily executed later.
503
+ comp, meta_info = self.backend.generate(
504
+ self,
505
+ sampling_params=sampling_params,
506
+ spec_var_name=name,
507
+ )
508
+ return
509
+
510
+ else: # Speculative execution on models with completion interface
511
+ comp, meta_info = self._spec_gen(sampling_params)
475
512
 
476
513
  self.text_ += comp
477
514
 
@@ -479,6 +516,9 @@ class StreamExecutor:
479
516
  self.meta_info[name] = meta_info
480
517
  self.variable_event[name].set()
481
518
  else:
519
+ assert (
520
+ self.num_api_spec_tokens is None
521
+ ), "stream is not supported with api speculative execution"
482
522
  generator = self.backend.generate_stream(
483
523
  self, sampling_params=sampling_params
484
524
  )
@@ -534,10 +574,19 @@ class StreamExecutor:
534
574
 
535
575
  prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
536
576
 
537
- self._execute_fill(prefix)
577
+ self._execute_fill(prefix, prefix=True)
538
578
  self.cur_role_begin_pos = len(self.text_)
539
579
 
540
580
  def _execute_role_end(self, expr: SglRoleEnd):
581
+ if (
582
+ self.cur_role == "assistant"
583
+ and self.num_api_spec_tokens is not None
584
+ and self.backend.is_chat_model
585
+ ):
586
+ # Execute the stored lazy generation calls
587
+ self.backend.role_end_generate(self)
588
+ self.cur_role = None
589
+
541
590
  new_text = self.text_[self.cur_role_begin_pos :].lstrip()
542
591
 
543
592
  _, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
@@ -564,8 +613,6 @@ class StreamExecutor:
564
613
  # OpenAI chat API format
565
614
  self.messages_.append({"role": expr.role, "content": new_text})
566
615
 
567
- self.cur_role = None
568
-
569
616
  def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
570
617
  self.variables[expr.name] = int(len(self.text_))
571
618
 
@@ -709,7 +756,7 @@ class ProgramState:
709
756
  return self.stream_executor.sync()
710
757
 
711
758
  def error(self):
712
- return self.stream_executor.error
759
+ return self.stream_executor.error()
713
760
 
714
761
  def text_iter(self, var_name: Optional[str] = None):
715
762
  if self.stream_executor.stream:
sglang/lang/ir.py CHANGED
@@ -81,6 +81,21 @@ class SglSamplingParams:
81
81
  "top_p": self.top_p,
82
82
  "top_k": self.top_k,
83
83
  }
84
+
85
+ def to_litellm_kwargs(self):
86
+ if self.regex is not None:
87
+ warnings.warn(
88
+ "Regular expression is not supported in the LiteLLM backend."
89
+ )
90
+ return {
91
+ "max_tokens": self.max_new_tokens,
92
+ "stop": self.stop or None,
93
+ "temperature": self.temperature,
94
+ "top_p": self.top_p,
95
+ "top_k": self.top_k,
96
+ "frequency_penalty": self.frequency_penalty,
97
+ "presence_penalty": self.presence_penalty,
98
+ }
84
99
 
85
100
  def to_srt_kwargs(self):
86
101
  return {
@@ -97,9 +112,9 @@ class SglSamplingParams:
97
112
 
98
113
 
99
114
  class SglFunction:
100
- def __init__(self, func, api_num_spec_tokens=None, bind_arguments=None):
115
+ def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
101
116
  self.func = func
102
- self.api_num_spec_tokens = api_num_spec_tokens
117
+ self.num_api_spec_tokens = num_api_spec_tokens
103
118
  self.bind_arguments = bind_arguments or {}
104
119
  self.pin_prefix_rid = None
105
120
 
@@ -6,6 +6,9 @@ class FSMCache(BaseCache):
6
6
  def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
7
7
  super().__init__(enable=enable)
8
8
 
9
+ if tokenizer_path.endswith(".json"):
10
+ return
11
+
9
12
  from importlib.metadata import version
10
13
 
11
14
  if version("outlines") >= "0.0.35":
sglang/srt/flush_cache.py CHANGED
@@ -13,4 +13,4 @@ if __name__ == "__main__":
13
13
  args = parser.parse_args()
14
14
 
15
15
  response = requests.get(args.url + "/flush_cache")
16
- assert response.status_code == 200
16
+ assert response.status_code == 200
@@ -3,7 +3,8 @@
3
3
  import json
4
4
  import os
5
5
  import warnings
6
- from typing import List, Optional, Tuple, Union
6
+ import functools
7
+ from typing import Optional, Union, AbstractSet, Collection, Literal
7
8
 
8
9
  from huggingface_hub import snapshot_download
9
10
  from transformers import (
@@ -84,6 +85,9 @@ def get_tokenizer(
84
85
  tokenizer_revision: Optional[str] = None,
85
86
  **kwargs,
86
87
  ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
88
+ if tokenizer_name.endswith(".json"):
89
+ return TiktokenTokenizer(tokenizer_name)
90
+
87
91
  """Gets a tokenizer for the given model name via Huggingface."""
88
92
  if is_multimodal_model(tokenizer_name):
89
93
  processor = get_processor(
@@ -170,3 +174,73 @@ def get_processor(
170
174
  **kwargs,
171
175
  )
172
176
  return processor
177
+
178
+
179
+ class TiktokenTokenizer:
180
+ def __init__(self, tokenizer_path):
181
+ import tiktoken
182
+ PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
183
+
184
+ # Read JSON
185
+ name = "tmp-json"
186
+ with open(tokenizer_path, "rb") as fin:
187
+ tok_dict = json.load(fin)
188
+
189
+ mergeable_ranks = {
190
+ bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
191
+ }
192
+ special_tokens = {
193
+ bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
194
+ }
195
+ assert tok_dict["word_split"] == "V1"
196
+
197
+ kwargs = {
198
+ "name": name,
199
+ "pat_str": tok_dict.get("pat_str", PAT_STR_B),
200
+ "mergeable_ranks": mergeable_ranks,
201
+ "special_tokens": special_tokens,
202
+ }
203
+ if "default_allowed_special" in tok_dict:
204
+ default_allowed_special = set(
205
+ [bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
206
+ )
207
+ else:
208
+ default_allowed_special = None
209
+ if "vocab_size" in tok_dict:
210
+ kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
211
+
212
+ tokenizer = tiktoken.Encoding(**kwargs)
213
+ tokenizer._default_allowed_special = default_allowed_special or set()
214
+
215
+ def encode_patched(
216
+ self,
217
+ text: str,
218
+ *,
219
+ allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
220
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all",
221
+ ) -> list[int]:
222
+ if isinstance(allowed_special, set):
223
+ allowed_special |= self._default_allowed_special
224
+ return tiktoken.Encoding.encode(
225
+ self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
226
+ )
227
+ tokenizer.encode = functools.partial(encode_patched, tokenizer)
228
+
229
+ # Convert to HF interface
230
+ self.tokenizer = tokenizer
231
+ self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
232
+ self.vocab_size = tokenizer.n_vocab
233
+
234
+ def encode(self, x, add_special_tokens=False):
235
+ return self.tokenizer.encode(x)
236
+
237
+ def decode(self, x):
238
+ return self.tokenizer.decode(x)
239
+
240
+ def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
241
+ if isinstance(batch[0], int):
242
+ batch = [[x] for x in batch]
243
+ return self.tokenizer.decode_batch(batch)
244
+
245
+ def convert_ids_to_tokens(self, index):
246
+ return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher
8
8
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
9
9
 
10
10
 
11
+ @triton.jit
12
+ def tanh(x):
13
+ # Tanh is just a scaled sigmoid
14
+ return 2 * tl.sigmoid(2 * x) - 1
15
+
16
+
11
17
  @triton.jit
12
18
  def _fwd_kernel(
13
19
  Q_Extend,
@@ -39,6 +45,7 @@ def _fwd_kernel(
39
45
  BLOCK_DMODEL: tl.constexpr,
40
46
  BLOCK_M: tl.constexpr,
41
47
  BLOCK_N: tl.constexpr,
48
+ logit_cap: tl.constexpr,
42
49
  ):
43
50
  cur_seq = tl.program_id(0)
44
51
  cur_head = tl.program_id(1)
@@ -90,6 +97,10 @@ def _fwd_kernel(
90
97
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
91
98
  qk += tl.dot(q, k)
92
99
  qk *= sm_scale
100
+
101
+ if logit_cap > 0:
102
+ qk = logit_cap * tanh(qk / logit_cap)
103
+
93
104
  qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
94
105
 
95
106
  n_e_max = tl.maximum(tl.max(qk, 1), e_max)
@@ -126,6 +137,10 @@ def _fwd_kernel(
126
137
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
127
138
  qk += tl.dot(q, k)
128
139
  qk *= sm_scale
140
+
141
+ if logit_cap > 0:
142
+ qk = logit_cap * tanh(qk / logit_cap)
143
+
129
144
  mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
130
145
  start_n + offs_n[None, :]
131
146
  )
@@ -176,6 +191,7 @@ def extend_attention_fwd(
176
191
  b_seq_len_extend,
177
192
  max_len_in_batch,
178
193
  max_len_extend,
194
+ logit_cap=-1,
179
195
  ):
180
196
  """
181
197
  q_extend, k_extend, v_extend, o_extend: contiguous tensors
@@ -271,6 +287,7 @@ def extend_attention_fwd(
271
287
  BLOCK_N=BLOCK_N,
272
288
  num_warps=num_warps,
273
289
  num_stages=num_stages,
290
+ logit_cap=logit_cap,
274
291
  )
275
292
  cached_kernel = wrap_kernel_launcher(_fwd_kernel)
276
293