sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +7 -7
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +158 -11
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/bench_latency.py +299 -0
  8. sglang/global_config.py +12 -2
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +114 -67
  11. sglang/lang/ir.py +28 -3
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +13 -6
  15. sglang/srt/constrained/fsm_cache.py +8 -2
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +3 -1
  19. sglang/srt/hf_transformers_utils.py +130 -1
  20. sglang/srt/layers/extend_attention.py +17 -0
  21. sglang/srt/layers/fused_moe.py +582 -0
  22. sglang/srt/layers/logits_processor.py +65 -32
  23. sglang/srt/layers/radix_attention.py +41 -7
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/controller/dp_worker.py +113 -0
  26. sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
  27. sglang/srt/managers/controller/manager_multi.py +191 -0
  28. sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
  29. sglang/srt/managers/{router → controller}/model_runner.py +262 -158
  30. sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
  31. sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
  32. sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
  33. sglang/srt/managers/detokenizer_manager.py +42 -46
  34. sglang/srt/managers/io_struct.py +22 -12
  35. sglang/srt/managers/tokenizer_manager.py +151 -87
  36. sglang/srt/model_config.py +83 -5
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +10 -13
  39. sglang/srt/models/dbrx.py +9 -15
  40. sglang/srt/models/gemma.py +12 -15
  41. sglang/srt/models/grok.py +738 -0
  42. sglang/srt/models/llama2.py +26 -15
  43. sglang/srt/models/llama_classification.py +104 -0
  44. sglang/srt/models/llava.py +86 -19
  45. sglang/srt/models/llavavid.py +11 -20
  46. sglang/srt/models/mixtral.py +282 -103
  47. sglang/srt/models/mixtral_quant.py +372 -0
  48. sglang/srt/models/qwen.py +9 -13
  49. sglang/srt/models/qwen2.py +11 -13
  50. sglang/srt/models/stablelm.py +9 -15
  51. sglang/srt/models/yivl.py +17 -22
  52. sglang/srt/openai_api_adapter.py +150 -95
  53. sglang/srt/openai_protocol.py +11 -2
  54. sglang/srt/server.py +124 -48
  55. sglang/srt/server_args.py +128 -48
  56. sglang/srt/utils.py +234 -67
  57. sglang/test/test_programs.py +65 -3
  58. sglang/test/test_utils.py +32 -1
  59. sglang/utils.py +23 -4
  60. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
  61. sglang-0.1.18.dist-info/RECORD +78 -0
  62. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -417
  66. sglang-0.1.16.dist-info/RECORD +0 -72
  67. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
sglang/lang/ir.py CHANGED
@@ -82,6 +82,19 @@ class SglSamplingParams:
82
82
  "top_k": self.top_k,
83
83
  }
84
84
 
85
+ def to_litellm_kwargs(self):
86
+ if self.regex is not None:
87
+ warnings.warn("Regular expression is not supported in the LiteLLM backend.")
88
+ return {
89
+ "max_tokens": self.max_new_tokens,
90
+ "stop": self.stop or None,
91
+ "temperature": self.temperature,
92
+ "top_p": self.top_p,
93
+ "top_k": self.top_k,
94
+ "frequency_penalty": self.frequency_penalty,
95
+ "presence_penalty": self.presence_penalty,
96
+ }
97
+
85
98
  def to_srt_kwargs(self):
86
99
  return {
87
100
  "max_new_tokens": self.max_new_tokens,
@@ -97,9 +110,9 @@ class SglSamplingParams:
97
110
 
98
111
 
99
112
  class SglFunction:
100
- def __init__(self, func, api_num_spec_tokens=None, bind_arguments=None):
113
+ def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
101
114
  self.func = func
102
- self.api_num_spec_tokens = api_num_spec_tokens
115
+ self.num_api_spec_tokens = num_api_spec_tokens
103
116
  self.bind_arguments = bind_arguments or {}
104
117
  self.pin_prefix_rid = None
105
118
 
@@ -107,6 +120,7 @@ class SglFunction:
107
120
  argspec = inspect.getfullargspec(func)
108
121
  assert argspec.args[0] == "s", 'The first argument must be "s"'
109
122
  self.arg_names = argspec.args[1:]
123
+ self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
110
124
 
111
125
  def bind(self, **kwargs):
112
126
  assert all(key in self.arg_names for key in kwargs)
@@ -165,7 +179,18 @@ class SglFunction:
165
179
  assert isinstance(batch_kwargs, (list, tuple))
166
180
  if len(batch_kwargs) == 0:
167
181
  return []
168
- assert isinstance(batch_kwargs[0], dict)
182
+ if not isinstance(batch_kwargs[0], dict):
183
+ num_programs = len(batch_kwargs)
184
+ # change the list of argument values to dict of arg_name -> arg_value
185
+ batch_kwargs = [
186
+ {self.arg_names[i]: v for i, v in enumerate(arg_values)}
187
+ for arg_values in batch_kwargs
188
+ if isinstance(arg_values, (list, tuple)) and
189
+ len(self.arg_names) - len(self.arg_defaults) <= len(arg_values) <= len(self.arg_names)
190
+ ]
191
+ # Ensure to raise an exception if the number of arguments mismatch
192
+ if len(batch_kwargs) != num_programs:
193
+ raise Exception("Given arguments mismatch the SGL function signature")
169
194
 
170
195
  default_sampling_para = SglSamplingParams(
171
196
  max_new_tokens=max_new_tokens,
sglang/launch_server.py CHANGED
@@ -1,6 +1,9 @@
1
+ """Launch the inference server."""
2
+
1
3
  import argparse
2
4
 
3
- from sglang.srt.server import ServerArgs, launch_server
5
+ from sglang.srt.server import launch_server
6
+ from sglang.srt.server_args import ServerArgs
4
7
 
5
8
  if __name__ == "__main__":
6
9
  parser = argparse.ArgumentParser()
@@ -1,10 +1,11 @@
1
+ """Launch the inference server for Llava-video model."""
2
+
1
3
  import argparse
2
4
  import multiprocessing as mp
3
5
 
4
6
  from sglang.srt.server import ServerArgs, launch_server
5
7
 
6
8
  if __name__ == "__main__":
7
-
8
9
  model_overide_args = {}
9
10
 
10
11
  model_overide_args["mm_spatial_pool_stride"] = 2
@@ -1,13 +1,19 @@
1
1
  import json
2
2
  from typing import Dict, Optional, Union
3
3
 
4
- from outlines.caching import cache as disk_cache
5
- from outlines.caching import disable_cache
6
- from outlines.fsm.fsm import RegexFSM
7
- from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
8
- from outlines.models.transformers import TransformerTokenizer
9
4
  from pydantic import BaseModel
10
5
 
6
+ try:
7
+ from outlines.caching import cache as disk_cache
8
+ from outlines.fsm.guide import RegexGuide
9
+ from outlines.caching import disable_cache
10
+ from outlines.fsm.guide import RegexGuide
11
+ from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
12
+ from outlines.models.transformers import TransformerTokenizer
13
+ except ImportError as e:
14
+ print(f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n')
15
+ raise
16
+
11
17
  try:
12
18
  from outlines.fsm.json_schema import build_regex_from_object
13
19
  except ImportError:
@@ -28,11 +34,12 @@ except ImportError:
28
34
 
29
35
 
30
36
  __all__ = [
31
- "RegexFSM",
37
+ "RegexGuide",
32
38
  "FSMInfo",
33
39
  "make_deterministic_fsm",
34
40
  "build_regex_from_object",
35
41
  "TransformerTokenizer",
36
42
  "disk_cache",
37
43
  "disable_cache",
44
+ "make_byte_level_fsm",
38
45
  ]
@@ -1,4 +1,6 @@
1
- from sglang.srt.constrained import RegexFSM, TransformerTokenizer
1
+ """Cache for the compressed finite state machine."""
2
+
3
+ from sglang.srt.constrained import RegexGuide, TransformerTokenizer
2
4
  from sglang.srt.constrained.base_cache import BaseCache
3
5
 
4
6
 
@@ -6,6 +8,10 @@ class FSMCache(BaseCache):
6
8
  def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
7
9
  super().__init__(enable=enable)
8
10
 
11
+ if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"):
12
+ # Do not support TiktokenTokenizer or SentencePieceTokenizer
13
+ return
14
+
9
15
  from importlib.metadata import version
10
16
 
11
17
  if version("outlines") >= "0.0.35":
@@ -22,4 +28,4 @@ class FSMCache(BaseCache):
22
28
  )
23
29
 
24
30
  def init_value(self, regex):
25
- return RegexFSM(regex, self.outlines_tokenizer)
31
+ return RegexGuide(regex, self.outlines_tokenizer)
@@ -1,17 +1,43 @@
1
- import interegular
1
+ """
2
+ Faster constrained decoding.
3
+ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
4
+ """
5
+
6
+ import dataclasses
7
+ from collections import defaultdict
2
8
 
3
- from sglang.srt.constrained import FSMInfo, disk_cache, make_deterministic_fsm
9
+ import interegular
10
+ import outlines.caching
11
+
12
+ from sglang.srt.constrained import (
13
+ FSMInfo,
14
+ disk_cache,
15
+ make_byte_level_fsm,
16
+ make_deterministic_fsm,
17
+ )
4
18
  from sglang.srt.constrained.base_cache import BaseCache
5
19
 
6
20
  IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
7
21
 
8
22
 
23
+ @dataclasses.dataclass
24
+ class JumpEdge:
25
+ symbol: str = None
26
+ symbol_next_state: int = None
27
+ byte: int = None
28
+ byte_next_state: int = None
29
+
30
+
9
31
  class JumpForwardMap:
10
32
  def __init__(self, regex_string):
11
33
  @disk_cache()
12
34
  def _init_state_to_jump_forward(regex_string):
13
35
  regex_pattern = interegular.parse_pattern(regex_string)
14
- regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
36
+
37
+ byte_fsm = make_byte_level_fsm(
38
+ regex_pattern.to_fsm().reduce(), keep_utf8=True
39
+ )
40
+ regex_fsm, _ = make_deterministic_fsm(byte_fsm)
15
41
 
16
42
  fsm_info: FSMInfo = regex_fsm.fsm_info
17
43
 
@@ -21,40 +47,93 @@ class JumpForwardMap:
21
47
  id_to_symbol.setdefault(id_, []).append(symbol)
22
48
 
23
49
  transitions = fsm_info.transitions
24
- dirty_states = set()
50
+ outgoings_ct = defaultdict(int)
25
51
  state_to_jump_forward = {}
26
52
 
27
53
  for (state, id_), next_state in transitions.items():
28
- if state in dirty_states:
29
- continue
30
- if state in state_to_jump_forward:
31
- dirty_states.add(state)
32
- del state_to_jump_forward[state]
54
+ if id_ == fsm_info.alphabet_anything_value:
33
55
  continue
34
- if len(id_to_symbol[id_]) > 1:
35
- dirty_states.add(state)
56
+ symbols = id_to_symbol[id_]
57
+ for c in symbols:
58
+ if len(c) > 1:
59
+ # Skip byte level transitions
60
+ continue
61
+
62
+ outgoings_ct[state] += 1
63
+ if outgoings_ct[state] > 1:
64
+ if state in state_to_jump_forward:
65
+ del state_to_jump_forward[state]
66
+ break
67
+
68
+ state_to_jump_forward[state] = JumpEdge(
69
+ symbol=c,
70
+ symbol_next_state=next_state,
71
+ )
72
+
73
+ # Process the byte level jump forward
74
+ outgoings_ct = defaultdict(int)
75
+ for (state, id_), next_state in transitions.items():
76
+ if id_ == fsm_info.alphabet_anything_value:
36
77
  continue
37
-
38
- state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state)
78
+ symbols = id_to_symbol[id_]
79
+ for c in symbols:
80
+ byte_ = None
81
+ if len(c) == 1 and ord(c) < 0x80:
82
+ # ASCII character
83
+ byte_ = ord(c)
84
+ elif len(c) > 1:
85
+ # FIXME: This logic is due to the leading \x00
86
+ # https://github.com/outlines-dev/outlines/pull/930
87
+ byte_ = int(symbols[0][1:], 16)
88
+
89
+ if byte_ is not None:
90
+ outgoings_ct[state] += 1
91
+ if outgoings_ct[state] > 1:
92
+ if state in state_to_jump_forward:
93
+ del state_to_jump_forward[state]
94
+ break
95
+ e = state_to_jump_forward.get(state, JumpEdge())
96
+ e.byte = byte_
97
+ e.byte_next_state = next_state
98
+ state_to_jump_forward[state] = e
39
99
 
40
100
  return state_to_jump_forward
41
101
 
42
102
  self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
43
103
 
44
- def valid_states(self):
45
- return self.state_to_jump_forward.keys()
104
+ def jump_forward_symbol(self, state):
105
+ jump_forward_str = ""
106
+ next_state = state
107
+ while state in self.state_to_jump_forward:
108
+ e = self.state_to_jump_forward[state]
109
+ if e.symbol is None:
110
+ break
111
+ jump_forward_str += e.symbol
112
+ next_state = e.symbol_next_state
113
+ state = next_state
46
114
 
47
- def jump_forward(self, state):
115
+ return jump_forward_str, next_state
116
+
117
+ def jump_forward_byte(self, state):
48
118
  if state not in self.state_to_jump_forward:
49
119
  return None
50
120
 
51
- jump_forward_str = ""
121
+ jump_forward_bytes = []
52
122
  next_state = None
53
123
  while state in self.state_to_jump_forward:
54
- symbol, next_state = self.state_to_jump_forward[state]
55
- jump_forward_str += symbol
124
+ e = self.state_to_jump_forward[state]
125
+ assert e.byte is not None and e.byte_next_state is not None
126
+ jump_forward_bytes.append((e.byte, e.byte_next_state))
127
+ next_state = e.byte_next_state
56
128
  state = next_state
57
- return jump_forward_str, next_state
129
+
130
+ return jump_forward_bytes
131
+
132
+ def is_jump_forward_symbol_state(self, state):
133
+ return (
134
+ state in self.state_to_jump_forward
135
+ and self.state_to_jump_forward[state].symbol is not None
136
+ )
58
137
 
59
138
 
60
139
  class JumpForwardCache(BaseCache):
@@ -65,12 +144,21 @@ class JumpForwardCache(BaseCache):
65
144
  return JumpForwardMap(regex)
66
145
 
67
146
 
68
- def test_main():
69
- regex_string = r"The google's DNS sever address is " + IP_REGEX
147
+ def test_main(regex_string):
70
148
  jump_forward_map = JumpForwardMap(regex_string)
71
- for state in jump_forward_map.valid_states():
72
- print(state, f'"{jump_forward_map.jump_forward(state)}"')
149
+ for state, e in jump_forward_map.state_to_jump_forward.items():
150
+ if e.symbol is not None:
151
+ jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
152
+ print(f"{state} -> {next_state}", jump_forward_str)
153
+ bytes_ = jump_forward_map.jump_forward_byte(state)
154
+ print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
73
155
 
74
156
 
75
157
  if __name__ == "__main__":
76
- test_main()
158
+ import outlines
159
+
160
+ outlines.caching.clear_cache()
161
+ test_main(r"The google's DNS sever address is " + IP_REGEX)
162
+ test_main(r"霍格沃茨特快列车|霍比特人比尔博")
163
+ # 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
164
+ # 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
@@ -1,3 +1,5 @@
1
+ """Conversation templates."""
2
+
1
3
  # Adapted from
2
4
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
3
5
  import dataclasses
sglang/srt/flush_cache.py CHANGED
@@ -1,4 +1,6 @@
1
1
  """
2
+ Flush the KV cache.
3
+
2
4
  Usage:
3
5
  python3 -m sglang.srt.flush_cache --url http://localhost:30000
4
6
  """
@@ -13,4 +15,4 @@ if __name__ == "__main__":
13
15
  args = parser.parse_args()
14
16
 
15
17
  response = requests.get(args.url + "/flush_cache")
16
- assert response.status_code == 200
18
+ assert response.status_code == 200
@@ -1,9 +1,10 @@
1
1
  """Utilities for Huggingface Transformers."""
2
2
 
3
+ import functools
3
4
  import json
4
5
  import os
5
6
  import warnings
6
- from typing import List, Optional, Tuple, Union
7
+ from typing import AbstractSet, Collection, Literal, Optional, Union
7
8
 
8
9
  from huggingface_hub import snapshot_download
9
10
  from transformers import (
@@ -84,6 +85,12 @@ def get_tokenizer(
84
85
  tokenizer_revision: Optional[str] = None,
85
86
  **kwargs,
86
87
  ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
88
+ if tokenizer_name.endswith(".json"):
89
+ return TiktokenTokenizer(tokenizer_name)
90
+
91
+ if tokenizer_name.endswith(".model"):
92
+ return SentencePieceTokenizer(tokenizer_name)
93
+
87
94
  """Gets a tokenizer for the given model name via Huggingface."""
88
95
  if is_multimodal_model(tokenizer_name):
89
96
  processor = get_processor(
@@ -170,3 +177,125 @@ def get_processor(
170
177
  **kwargs,
171
178
  )
172
179
  return processor
180
+
181
+
182
+ class TiktokenTokenizer:
183
+ def __init__(self, tokenizer_path):
184
+ import tiktoken
185
+ from jinja2 import Template
186
+
187
+ PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
188
+
189
+ # Read JSON
190
+ name = "tmp-json"
191
+ with open(tokenizer_path, "rb") as fin:
192
+ tok_dict = json.load(fin)
193
+
194
+ mergeable_ranks = {
195
+ bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
196
+ }
197
+ special_tokens = {
198
+ bytes(item["bytes"]).decode(): item["token"]
199
+ for item in tok_dict["special_tokens"]
200
+ }
201
+ assert tok_dict["word_split"] == "V1"
202
+
203
+ kwargs = {
204
+ "name": name,
205
+ "pat_str": tok_dict.get("pat_str", PAT_STR_B),
206
+ "mergeable_ranks": mergeable_ranks,
207
+ "special_tokens": special_tokens,
208
+ }
209
+ if "default_allowed_special" in tok_dict:
210
+ default_allowed_special = set(
211
+ [
212
+ bytes(bytes_list).decode()
213
+ for bytes_list in tok_dict["default_allowed_special"]
214
+ ]
215
+ )
216
+ else:
217
+ default_allowed_special = None
218
+ if "vocab_size" in tok_dict:
219
+ kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
220
+
221
+ tokenizer = tiktoken.Encoding(**kwargs)
222
+ tokenizer._default_allowed_special = default_allowed_special or set()
223
+ tokenizer._default_allowed_special |= {"<|separator|>"}
224
+
225
+ def encode_patched(
226
+ self,
227
+ text: str,
228
+ *,
229
+ allowed_special: Union[
230
+ Literal["all"], AbstractSet[str]
231
+ ] = set(), # noqa: B006
232
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all",
233
+ ) -> list[int]:
234
+ if isinstance(allowed_special, set):
235
+ allowed_special |= self._default_allowed_special
236
+ return tiktoken.Encoding.encode(
237
+ self,
238
+ text,
239
+ allowed_special=allowed_special,
240
+ disallowed_special=disallowed_special,
241
+ )
242
+
243
+ tokenizer.encode = functools.partial(encode_patched, tokenizer)
244
+
245
+ # Convert to HF interface
246
+ self.tokenizer = tokenizer
247
+ self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
248
+ self.vocab_size = tokenizer.n_vocab
249
+ self.chat_template = Template(
250
+ "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
251
+ )
252
+
253
+ def encode(self, x, add_special_tokens=False):
254
+ return self.tokenizer.encode(x)
255
+
256
+ def decode(self, x):
257
+ return self.tokenizer.decode(x)
258
+
259
+ def batch_decode(
260
+ self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
261
+ ):
262
+ if isinstance(batch[0], int):
263
+ batch = [[x] for x in batch]
264
+ return self.tokenizer.decode_batch(batch)
265
+
266
+ def apply_chat_template(self, messages, tokenize, add_generation_prompt):
267
+ ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
268
+ return self.encode(ret) if tokenize else ret
269
+
270
+
271
+ class SentencePieceTokenizer:
272
+ def __init__(self, tokenizer_path):
273
+ import sentencepiece as spm
274
+ from jinja2 import Template
275
+
276
+ tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
277
+
278
+ # Convert to HF interface
279
+ self.tokenizer = tokenizer
280
+ self.eos_token_id = tokenizer.eos_id()
281
+ self.vocab_size = tokenizer.vocab_size()
282
+ self.chat_template = Template(
283
+ "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
284
+ )
285
+
286
+ def encode(self, x, add_special_tokens=False):
287
+ return self.tokenizer.encode(x)
288
+
289
+ def decode(self, x):
290
+ return self.tokenizer.decode(x)
291
+
292
+ def batch_decode(
293
+ self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
294
+ ):
295
+ if isinstance(batch[0], int):
296
+ batch = [[x] for x in batch]
297
+ return self.tokenizer.decode(batch)
298
+
299
+ def apply_chat_template(self, messages, tokenize, add_generation_prompt):
300
+ ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
301
+ return self.encode(ret) if tokenize else ret
@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher
8
8
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
9
9
 
10
10
 
11
+ @triton.jit
12
+ def tanh(x):
13
+ # Tanh is just a scaled sigmoid
14
+ return 2 * tl.sigmoid(2 * x) - 1
15
+
16
+
11
17
  @triton.jit
12
18
  def _fwd_kernel(
13
19
  Q_Extend,
@@ -39,6 +45,7 @@ def _fwd_kernel(
39
45
  BLOCK_DMODEL: tl.constexpr,
40
46
  BLOCK_M: tl.constexpr,
41
47
  BLOCK_N: tl.constexpr,
48
+ logit_cap: tl.constexpr,
42
49
  ):
43
50
  cur_seq = tl.program_id(0)
44
51
  cur_head = tl.program_id(1)
@@ -90,6 +97,10 @@ def _fwd_kernel(
90
97
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
91
98
  qk += tl.dot(q, k)
92
99
  qk *= sm_scale
100
+
101
+ if logit_cap > 0:
102
+ qk = logit_cap * tanh(qk / logit_cap)
103
+
93
104
  qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
94
105
 
95
106
  n_e_max = tl.maximum(tl.max(qk, 1), e_max)
@@ -126,6 +137,10 @@ def _fwd_kernel(
126
137
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
127
138
  qk += tl.dot(q, k)
128
139
  qk *= sm_scale
140
+
141
+ if logit_cap > 0:
142
+ qk = logit_cap * tanh(qk / logit_cap)
143
+
129
144
  mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
130
145
  start_n + offs_n[None, :]
131
146
  )
@@ -176,6 +191,7 @@ def extend_attention_fwd(
176
191
  b_seq_len_extend,
177
192
  max_len_in_batch,
178
193
  max_len_extend,
194
+ logit_cap=-1,
179
195
  ):
180
196
  """
181
197
  q_extend, k_extend, v_extend, o_extend: contiguous tensors
@@ -271,6 +287,7 @@ def extend_attention_fwd(
271
287
  BLOCK_N=BLOCK_N,
272
288
  num_warps=num_warps,
273
289
  num_stages=num_stages,
290
+ logit_cap=logit_cap,
274
291
  )
275
292
  cached_kernel = wrap_kernel_launcher(_fwd_kernel)
276
293