sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +3 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +8 -1
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +17 -2
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +1 -1
- sglang/srt/hf_transformers_utils.py +75 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +15 -11
- sglang/srt/managers/router/infer_batch.py +103 -59
- sglang/srt/managers/router/manager.py +1 -1
- sglang/srt/managers/router/model_rpc.py +175 -122
- sglang/srt/managers/router/model_runner.py +91 -104
- sglang/srt/managers/router/radix_cache.py +7 -1
- sglang/srt/managers/router/scheduler.py +6 -6
- sglang/srt/managers/tokenizer_manager.py +152 -89
- sglang/srt/model_config.py +4 -5
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +8 -15
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +19 -15
- sglang/srt/models/llava.py +84 -20
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +248 -118
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +77 -42
- sglang/srt/server_args.py +51 -6
- sglang/srt/utils.py +124 -66
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +22 -4
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
sglang/lang/interpreter.py
CHANGED
@@ -6,6 +6,7 @@ import multiprocessing
|
|
6
6
|
import queue
|
7
7
|
import threading
|
8
8
|
import uuid
|
9
|
+
import warnings
|
9
10
|
from concurrent.futures import ThreadPoolExecutor
|
10
11
|
from contextlib import contextmanager
|
11
12
|
from typing import Any, Callable, Dict, List, Optional, Union
|
@@ -30,7 +31,11 @@ from sglang.lang.ir import (
|
|
30
31
|
SglVarScopeEnd,
|
31
32
|
SglVideo,
|
32
33
|
)
|
33
|
-
from sglang.utils import
|
34
|
+
from sglang.utils import (
|
35
|
+
encode_image_base64,
|
36
|
+
encode_video_base64,
|
37
|
+
get_exception_traceback,
|
38
|
+
)
|
34
39
|
|
35
40
|
|
36
41
|
def run_internal(state, program, func_args, func_kwargs, sync):
|
@@ -61,7 +66,7 @@ def run_program(
|
|
61
66
|
default_sampling_para,
|
62
67
|
chat_template=None,
|
63
68
|
stream=stream,
|
64
|
-
|
69
|
+
num_api_spec_tokens=program.num_api_spec_tokens,
|
65
70
|
)
|
66
71
|
state = ProgramState(stream_executor)
|
67
72
|
|
@@ -173,7 +178,7 @@ class StreamExecutor:
|
|
173
178
|
default_sampling_para,
|
174
179
|
chat_template,
|
175
180
|
stream,
|
176
|
-
|
181
|
+
num_api_spec_tokens=None,
|
177
182
|
use_thread=True,
|
178
183
|
):
|
179
184
|
self.sid = uuid.uuid4().hex
|
@@ -181,20 +186,16 @@ class StreamExecutor:
|
|
181
186
|
self.arguments: Dict[str, Any] = arguments
|
182
187
|
self.default_sampling_para = default_sampling_para
|
183
188
|
self.stream = stream
|
184
|
-
self.api_num_spec_tokens = api_num_spec_tokens
|
185
189
|
|
186
190
|
self.variables = {} # Dict[name: str -> value: str]
|
187
191
|
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
188
192
|
self.meta_info = {} # Dict[name: str -> info: str]
|
189
193
|
self.is_finished = False
|
190
|
-
self.
|
194
|
+
self.error_ = None
|
191
195
|
|
192
196
|
# For completion
|
193
197
|
self.text_ = "" # The full text
|
194
198
|
|
195
|
-
# For speculative execution
|
196
|
-
self.speculated_text = ""
|
197
|
-
|
198
199
|
# For chat
|
199
200
|
self.messages_ = [] # The messages in the OpenAI API format
|
200
201
|
self.chat_template = chat_template or self.backend.get_chat_template()
|
@@ -208,6 +209,10 @@ class StreamExecutor:
|
|
208
209
|
# For fork/join
|
209
210
|
self.fork_start_text_pos = None
|
210
211
|
|
212
|
+
# For speculative execution
|
213
|
+
self.num_api_spec_tokens = num_api_spec_tokens
|
214
|
+
self.speculated_text = ""
|
215
|
+
|
211
216
|
# Worker thread
|
212
217
|
self.use_thread = use_thread
|
213
218
|
if self.use_thread:
|
@@ -286,6 +291,8 @@ class StreamExecutor:
|
|
286
291
|
exes[i].fork_start_text_pos = len(self.text_)
|
287
292
|
exes[i].images_ = list(self.images_)
|
288
293
|
|
294
|
+
# TODO(ying): handle API speculative execution
|
295
|
+
|
289
296
|
return exes
|
290
297
|
|
291
298
|
def text(self):
|
@@ -296,6 +303,10 @@ class StreamExecutor:
|
|
296
303
|
self.sync()
|
297
304
|
return self.messages_
|
298
305
|
|
306
|
+
def error(self):
|
307
|
+
self.sync()
|
308
|
+
return self.error_
|
309
|
+
|
299
310
|
def end(self):
|
300
311
|
if self.use_thread:
|
301
312
|
if self.worker.is_alive():
|
@@ -314,7 +325,7 @@ class StreamExecutor:
|
|
314
325
|
try:
|
315
326
|
self._execute(expr)
|
316
327
|
except Exception as e:
|
317
|
-
|
328
|
+
warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
|
318
329
|
error = e
|
319
330
|
break
|
320
331
|
self.queue.task_done()
|
@@ -334,7 +345,7 @@ class StreamExecutor:
|
|
334
345
|
if self.stream_var_event:
|
335
346
|
for name in self.stream_var_event:
|
336
347
|
self.stream_var_event[name].set()
|
337
|
-
self.
|
348
|
+
self.error_ = error
|
338
349
|
|
339
350
|
if self.stream_text_event:
|
340
351
|
self.stream_text_event.set()
|
@@ -383,12 +394,23 @@ class StreamExecutor:
|
|
383
394
|
else:
|
384
395
|
raise ValueError(f"Unknown type: {type(other)}")
|
385
396
|
|
386
|
-
def _execute_fill(self, value: str):
|
397
|
+
def _execute_fill(self, value: str, prefix=False):
|
387
398
|
value = str(value)
|
399
|
+
|
400
|
+
if (
|
401
|
+
self.cur_role == "assistant"
|
402
|
+
and self.num_api_spec_tokens is not None
|
403
|
+
and self.backend.is_chat_model
|
404
|
+
and not prefix
|
405
|
+
):
|
406
|
+
self.backend.spec_fill(value)
|
407
|
+
return
|
408
|
+
|
388
409
|
if self.speculated_text.startswith(value):
|
389
410
|
self.speculated_text = self.speculated_text[len(value) :]
|
390
411
|
else:
|
391
412
|
self.speculated_text = ""
|
413
|
+
|
392
414
|
self.text_ += value
|
393
415
|
|
394
416
|
def _execute_image(self, expr: SglImage):
|
@@ -413,65 +435,80 @@ class StreamExecutor:
|
|
413
435
|
# if global_config.eager_fill_image:
|
414
436
|
# self.backend.fill_image(self)
|
415
437
|
|
438
|
+
def _spec_gen(self, sampling_params):
|
439
|
+
stop = sampling_params.stop
|
440
|
+
max_new_tokens = sampling_params.max_new_tokens
|
441
|
+
meta_info = {}
|
442
|
+
|
443
|
+
def regen():
|
444
|
+
nonlocal meta_info
|
445
|
+
|
446
|
+
sampling_params.max_new_tokens = max(
|
447
|
+
sampling_params.max_new_tokens, self.num_api_spec_tokens
|
448
|
+
)
|
449
|
+
sampling_params.stop = None
|
450
|
+
self.speculated_text, meta_info = self.backend.generate(
|
451
|
+
self, sampling_params=sampling_params
|
452
|
+
)
|
453
|
+
|
454
|
+
def find_stop():
|
455
|
+
if isinstance(stop, str):
|
456
|
+
return self.speculated_text.find(stop)
|
457
|
+
elif isinstance(stop, (tuple, list)):
|
458
|
+
pos = -1
|
459
|
+
for stop_str in stop:
|
460
|
+
stop_pos = self.speculated_text.find(stop_str)
|
461
|
+
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
462
|
+
pos = stop_pos
|
463
|
+
return pos
|
464
|
+
else:
|
465
|
+
raise Exception("Wrong type of stop in sampling parameters.")
|
466
|
+
|
467
|
+
if stop is None:
|
468
|
+
if len(self.speculated_text) < max_new_tokens:
|
469
|
+
regen()
|
470
|
+
comp = self.speculated_text[:max_new_tokens]
|
471
|
+
self.speculated_text = self.speculated_text[max_new_tokens:]
|
472
|
+
elif isinstance(stop, (str, list, tuple)):
|
473
|
+
if self.speculated_text == "":
|
474
|
+
regen()
|
475
|
+
stop_pos = find_stop()
|
476
|
+
if stop_pos == -1:
|
477
|
+
stop_pos = min(
|
478
|
+
sampling_params.max_new_tokens,
|
479
|
+
len(self.speculated_text),
|
480
|
+
)
|
481
|
+
comp = self.speculated_text[:stop_pos]
|
482
|
+
self.speculated_text = self.speculated_text[stop_pos:]
|
483
|
+
else:
|
484
|
+
raise ValueError("Wrong type of stop in sampling parameters.")
|
485
|
+
|
486
|
+
return comp, meta_info
|
487
|
+
|
416
488
|
def _execute_gen(self, expr: SglGen):
|
417
489
|
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
418
490
|
name = expr.name
|
419
491
|
|
420
492
|
if not self.stream:
|
421
|
-
if self.
|
422
|
-
stop = sampling_params.stop
|
423
|
-
max_new_tokens = sampling_params.max_new_tokens
|
424
|
-
meta_info = {}
|
425
|
-
|
426
|
-
def regen():
|
427
|
-
sampling_params.max_new_tokens = max(
|
428
|
-
sampling_params.max_new_tokens, self.api_num_spec_tokens
|
429
|
-
)
|
430
|
-
sampling_params.stop = None
|
431
|
-
self.speculated_text, meta_info = self.backend.generate(
|
432
|
-
self, sampling_params=sampling_params
|
433
|
-
)
|
434
|
-
|
435
|
-
def find_stop():
|
436
|
-
if isinstance(stop, str):
|
437
|
-
return self.speculated_text.find(stop), len(stop)
|
438
|
-
elif isinstance(stop, (tuple, list)):
|
439
|
-
pos = -1
|
440
|
-
stop_len = 0
|
441
|
-
for stop_str in stop:
|
442
|
-
stop_pos = self.speculated_text.find(stop_str)
|
443
|
-
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
444
|
-
pos = stop_pos
|
445
|
-
stop_len = len(stop_str)
|
446
|
-
return pos, stop_len
|
447
|
-
else:
|
448
|
-
raise Exception("Wrong type of stop in sampling parameters.")
|
449
|
-
|
450
|
-
if stop is None:
|
451
|
-
if len(self.speculated_text) < max_new_tokens:
|
452
|
-
regen()
|
453
|
-
comp = self.speculated_text[:max_new_tokens]
|
454
|
-
self.speculated_text = self.speculated_text[max_new_tokens:]
|
455
|
-
elif isinstance(stop, (str, list, tuple)):
|
456
|
-
if self.speculated_text == "":
|
457
|
-
regen()
|
458
|
-
stop_pos, stop_len = find_stop()
|
459
|
-
if stop_pos == -1:
|
460
|
-
stop_pos, stop_len = (
|
461
|
-
min(
|
462
|
-
sampling_params.max_new_tokens,
|
463
|
-
len(self.speculated_text),
|
464
|
-
),
|
465
|
-
0,
|
466
|
-
)
|
467
|
-
comp = self.speculated_text[:stop_pos]
|
468
|
-
self.speculated_text = self.speculated_text[stop_pos:]
|
469
|
-
else:
|
470
|
-
raise ValueError("Wrong type of stop in sampling parameters.")
|
471
|
-
else:
|
493
|
+
if self.num_api_spec_tokens is None:
|
472
494
|
comp, meta_info = self.backend.generate(
|
473
|
-
self,
|
495
|
+
self,
|
496
|
+
sampling_params=sampling_params,
|
474
497
|
)
|
498
|
+
else:
|
499
|
+
if self.backend.is_chat_model:
|
500
|
+
# Speculative execution on models with only chat interface.
|
501
|
+
# Store the calls into a temporary list.
|
502
|
+
# They will be lazily executed later.
|
503
|
+
comp, meta_info = self.backend.generate(
|
504
|
+
self,
|
505
|
+
sampling_params=sampling_params,
|
506
|
+
spec_var_name=name,
|
507
|
+
)
|
508
|
+
return
|
509
|
+
|
510
|
+
else: # Speculative execution on models with completion interface
|
511
|
+
comp, meta_info = self._spec_gen(sampling_params)
|
475
512
|
|
476
513
|
self.text_ += comp
|
477
514
|
|
@@ -479,6 +516,9 @@ class StreamExecutor:
|
|
479
516
|
self.meta_info[name] = meta_info
|
480
517
|
self.variable_event[name].set()
|
481
518
|
else:
|
519
|
+
assert (
|
520
|
+
self.num_api_spec_tokens is None
|
521
|
+
), "stream is not supported with api speculative execution"
|
482
522
|
generator = self.backend.generate_stream(
|
483
523
|
self, sampling_params=sampling_params
|
484
524
|
)
|
@@ -534,10 +574,19 @@ class StreamExecutor:
|
|
534
574
|
|
535
575
|
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
536
576
|
|
537
|
-
self._execute_fill(prefix)
|
577
|
+
self._execute_fill(prefix, prefix=True)
|
538
578
|
self.cur_role_begin_pos = len(self.text_)
|
539
579
|
|
540
580
|
def _execute_role_end(self, expr: SglRoleEnd):
|
581
|
+
if (
|
582
|
+
self.cur_role == "assistant"
|
583
|
+
and self.num_api_spec_tokens is not None
|
584
|
+
and self.backend.is_chat_model
|
585
|
+
):
|
586
|
+
# Execute the stored lazy generation calls
|
587
|
+
self.backend.role_end_generate(self)
|
588
|
+
self.cur_role = None
|
589
|
+
|
541
590
|
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
|
542
591
|
|
543
592
|
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
@@ -564,8 +613,6 @@ class StreamExecutor:
|
|
564
613
|
# OpenAI chat API format
|
565
614
|
self.messages_.append({"role": expr.role, "content": new_text})
|
566
615
|
|
567
|
-
self.cur_role = None
|
568
|
-
|
569
616
|
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
|
570
617
|
self.variables[expr.name] = int(len(self.text_))
|
571
618
|
|
@@ -709,7 +756,7 @@ class ProgramState:
|
|
709
756
|
return self.stream_executor.sync()
|
710
757
|
|
711
758
|
def error(self):
|
712
|
-
return self.stream_executor.error
|
759
|
+
return self.stream_executor.error()
|
713
760
|
|
714
761
|
def text_iter(self, var_name: Optional[str] = None):
|
715
762
|
if self.stream_executor.stream:
|
sglang/lang/ir.py
CHANGED
@@ -81,6 +81,21 @@ class SglSamplingParams:
|
|
81
81
|
"top_p": self.top_p,
|
82
82
|
"top_k": self.top_k,
|
83
83
|
}
|
84
|
+
|
85
|
+
def to_litellm_kwargs(self):
|
86
|
+
if self.regex is not None:
|
87
|
+
warnings.warn(
|
88
|
+
"Regular expression is not supported in the LiteLLM backend."
|
89
|
+
)
|
90
|
+
return {
|
91
|
+
"max_tokens": self.max_new_tokens,
|
92
|
+
"stop": self.stop or None,
|
93
|
+
"temperature": self.temperature,
|
94
|
+
"top_p": self.top_p,
|
95
|
+
"top_k": self.top_k,
|
96
|
+
"frequency_penalty": self.frequency_penalty,
|
97
|
+
"presence_penalty": self.presence_penalty,
|
98
|
+
}
|
84
99
|
|
85
100
|
def to_srt_kwargs(self):
|
86
101
|
return {
|
@@ -97,9 +112,9 @@ class SglSamplingParams:
|
|
97
112
|
|
98
113
|
|
99
114
|
class SglFunction:
|
100
|
-
def __init__(self, func,
|
115
|
+
def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
|
101
116
|
self.func = func
|
102
|
-
self.
|
117
|
+
self.num_api_spec_tokens = num_api_spec_tokens
|
103
118
|
self.bind_arguments = bind_arguments or {}
|
104
119
|
self.pin_prefix_rid = None
|
105
120
|
|
@@ -6,6 +6,9 @@ class FSMCache(BaseCache):
|
|
6
6
|
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
7
7
|
super().__init__(enable=enable)
|
8
8
|
|
9
|
+
if tokenizer_path.endswith(".json"):
|
10
|
+
return
|
11
|
+
|
9
12
|
from importlib.metadata import version
|
10
13
|
|
11
14
|
if version("outlines") >= "0.0.35":
|
sglang/srt/flush_cache.py
CHANGED
@@ -3,7 +3,8 @@
|
|
3
3
|
import json
|
4
4
|
import os
|
5
5
|
import warnings
|
6
|
-
|
6
|
+
import functools
|
7
|
+
from typing import Optional, Union, AbstractSet, Collection, Literal
|
7
8
|
|
8
9
|
from huggingface_hub import snapshot_download
|
9
10
|
from transformers import (
|
@@ -84,6 +85,9 @@ def get_tokenizer(
|
|
84
85
|
tokenizer_revision: Optional[str] = None,
|
85
86
|
**kwargs,
|
86
87
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
88
|
+
if tokenizer_name.endswith(".json"):
|
89
|
+
return TiktokenTokenizer(tokenizer_name)
|
90
|
+
|
87
91
|
"""Gets a tokenizer for the given model name via Huggingface."""
|
88
92
|
if is_multimodal_model(tokenizer_name):
|
89
93
|
processor = get_processor(
|
@@ -170,3 +174,73 @@ def get_processor(
|
|
170
174
|
**kwargs,
|
171
175
|
)
|
172
176
|
return processor
|
177
|
+
|
178
|
+
|
179
|
+
class TiktokenTokenizer:
|
180
|
+
def __init__(self, tokenizer_path):
|
181
|
+
import tiktoken
|
182
|
+
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
183
|
+
|
184
|
+
# Read JSON
|
185
|
+
name = "tmp-json"
|
186
|
+
with open(tokenizer_path, "rb") as fin:
|
187
|
+
tok_dict = json.load(fin)
|
188
|
+
|
189
|
+
mergeable_ranks = {
|
190
|
+
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
|
191
|
+
}
|
192
|
+
special_tokens = {
|
193
|
+
bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
|
194
|
+
}
|
195
|
+
assert tok_dict["word_split"] == "V1"
|
196
|
+
|
197
|
+
kwargs = {
|
198
|
+
"name": name,
|
199
|
+
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
|
200
|
+
"mergeable_ranks": mergeable_ranks,
|
201
|
+
"special_tokens": special_tokens,
|
202
|
+
}
|
203
|
+
if "default_allowed_special" in tok_dict:
|
204
|
+
default_allowed_special = set(
|
205
|
+
[bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
|
206
|
+
)
|
207
|
+
else:
|
208
|
+
default_allowed_special = None
|
209
|
+
if "vocab_size" in tok_dict:
|
210
|
+
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
|
211
|
+
|
212
|
+
tokenizer = tiktoken.Encoding(**kwargs)
|
213
|
+
tokenizer._default_allowed_special = default_allowed_special or set()
|
214
|
+
|
215
|
+
def encode_patched(
|
216
|
+
self,
|
217
|
+
text: str,
|
218
|
+
*,
|
219
|
+
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
|
220
|
+
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
221
|
+
) -> list[int]:
|
222
|
+
if isinstance(allowed_special, set):
|
223
|
+
allowed_special |= self._default_allowed_special
|
224
|
+
return tiktoken.Encoding.encode(
|
225
|
+
self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
|
226
|
+
)
|
227
|
+
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
228
|
+
|
229
|
+
# Convert to HF interface
|
230
|
+
self.tokenizer = tokenizer
|
231
|
+
self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
|
232
|
+
self.vocab_size = tokenizer.n_vocab
|
233
|
+
|
234
|
+
def encode(self, x, add_special_tokens=False):
|
235
|
+
return self.tokenizer.encode(x)
|
236
|
+
|
237
|
+
def decode(self, x):
|
238
|
+
return self.tokenizer.decode(x)
|
239
|
+
|
240
|
+
def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
|
241
|
+
if isinstance(batch[0], int):
|
242
|
+
batch = [[x] for x in batch]
|
243
|
+
return self.tokenizer.decode_batch(batch)
|
244
|
+
|
245
|
+
def convert_ids_to_tokens(self, index):
|
246
|
+
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
|
@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher
|
|
8
8
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
9
9
|
|
10
10
|
|
11
|
+
@triton.jit
|
12
|
+
def tanh(x):
|
13
|
+
# Tanh is just a scaled sigmoid
|
14
|
+
return 2 * tl.sigmoid(2 * x) - 1
|
15
|
+
|
16
|
+
|
11
17
|
@triton.jit
|
12
18
|
def _fwd_kernel(
|
13
19
|
Q_Extend,
|
@@ -39,6 +45,7 @@ def _fwd_kernel(
|
|
39
45
|
BLOCK_DMODEL: tl.constexpr,
|
40
46
|
BLOCK_M: tl.constexpr,
|
41
47
|
BLOCK_N: tl.constexpr,
|
48
|
+
logit_cap: tl.constexpr,
|
42
49
|
):
|
43
50
|
cur_seq = tl.program_id(0)
|
44
51
|
cur_head = tl.program_id(1)
|
@@ -90,6 +97,10 @@ def _fwd_kernel(
|
|
90
97
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
91
98
|
qk += tl.dot(q, k)
|
92
99
|
qk *= sm_scale
|
100
|
+
|
101
|
+
if logit_cap > 0:
|
102
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
103
|
+
|
93
104
|
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
94
105
|
|
95
106
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
@@ -126,6 +137,10 @@ def _fwd_kernel(
|
|
126
137
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
127
138
|
qk += tl.dot(q, k)
|
128
139
|
qk *= sm_scale
|
140
|
+
|
141
|
+
if logit_cap > 0:
|
142
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
143
|
+
|
129
144
|
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
130
145
|
start_n + offs_n[None, :]
|
131
146
|
)
|
@@ -176,6 +191,7 @@ def extend_attention_fwd(
|
|
176
191
|
b_seq_len_extend,
|
177
192
|
max_len_in_batch,
|
178
193
|
max_len_extend,
|
194
|
+
logit_cap=-1,
|
179
195
|
):
|
180
196
|
"""
|
181
197
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
@@ -271,6 +287,7 @@ def extend_attention_fwd(
|
|
271
287
|
BLOCK_N=BLOCK_N,
|
272
288
|
num_warps=num_warps,
|
273
289
|
num_stages=num_stages,
|
290
|
+
logit_cap=logit_cap,
|
274
291
|
)
|
275
292
|
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
276
293
|
|