sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +3 -6
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +6 -2
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +105 -6
- sglang/srt/disaggregation/mini_lb.py +74 -9
- sglang/srt/disaggregation/mooncake/conn.py +33 -63
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +137 -17
- sglang/srt/disaggregation/utils.py +32 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +883 -209
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +20 -5
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +17 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -88
- sglang/srt/layers/quantization/fp8_utils.py +189 -264
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +9 -8
- sglang/srt/layers/sampler.py +7 -12
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +15 -4
- sglang/srt/managers/scheduler.py +28 -77
- sglang/srt/managers/tokenizer_manager.py +116 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +41 -29
- sglang/srt/mem_cache/memory_pool.py +38 -15
- sglang/srt/model_executor/cuda_graph_runner.py +15 -10
- sglang/srt/model_executor/model_runner.py +39 -31
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +292 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +31 -203
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +86 -72
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +6 -14
- sglang/srt/utils.py +62 -6
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -24,6 +24,7 @@ from sglang.api import (
|
|
24
24
|
user_end,
|
25
25
|
video,
|
26
26
|
)
|
27
|
+
from sglang.global_config import global_config
|
27
28
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
28
29
|
from sglang.lang.choices import (
|
29
30
|
greedy_token_selection,
|
@@ -31,6 +32,7 @@ from sglang.lang.choices import (
|
|
31
32
|
unconditional_likelihood_normalized,
|
32
33
|
)
|
33
34
|
from sglang.utils import LazyImport
|
35
|
+
from sglang.version import __version__
|
34
36
|
|
35
37
|
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
|
36
38
|
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
@@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
|
|
38
40
|
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
39
41
|
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
|
40
42
|
|
41
|
-
# Other configs
|
42
|
-
from sglang.global_config import global_config
|
43
|
-
from sglang.version import __version__
|
44
|
-
|
45
43
|
__all__ = [
|
46
44
|
"Engine",
|
47
45
|
"Runtime",
|
sglang/bench_one_batch.py
CHANGED
@@ -207,7 +207,7 @@ def prepare_extend_inputs_for_correctness_test(
|
|
207
207
|
|
208
208
|
|
209
209
|
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
210
|
-
input_ids = np.
|
210
|
+
input_ids = np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
|
211
211
|
sampling_params = SamplingParams(
|
212
212
|
temperature=0,
|
213
213
|
max_new_tokens=BenchArgs.output_len,
|
@@ -396,7 +396,7 @@ def latency_test_run_once(
|
|
396
396
|
decode_latencies.append(latency)
|
397
397
|
if i < 5:
|
398
398
|
rank_print(
|
399
|
-
f"Decode.
|
399
|
+
f"Decode. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
400
400
|
)
|
401
401
|
|
402
402
|
if profile:
|
sglang/bench_serving.py
CHANGED
@@ -690,7 +690,6 @@ def sample_random_requests(
|
|
690
690
|
dataset_path: str,
|
691
691
|
random_sample: bool = True,
|
692
692
|
) -> List[Tuple[str, int, int]]:
|
693
|
-
|
694
693
|
input_lens = np.random.randint(
|
695
694
|
max(int(input_len * range_ratio), 1),
|
696
695
|
input_len + 1,
|
@@ -707,10 +706,6 @@ def sample_random_requests(
|
|
707
706
|
|
708
707
|
# Download sharegpt if necessary
|
709
708
|
if not os.path.isfile(dataset_path):
|
710
|
-
print(
|
711
|
-
"If you do not want to randomly sample from a dataset,"
|
712
|
-
" please use --dataset-name random-ids."
|
713
|
-
)
|
714
709
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
715
710
|
|
716
711
|
# Load the dataset.
|
@@ -1029,7 +1024,9 @@ async def benchmark(
|
|
1029
1024
|
warmup_outputs = await asyncio.gather(*warmup_tasks)
|
1030
1025
|
|
1031
1026
|
# Check if at least one warmup request succeeded
|
1032
|
-
if
|
1027
|
+
if args.warmup_requests > 0 and not any(
|
1028
|
+
output.success for output in warmup_outputs
|
1029
|
+
):
|
1033
1030
|
raise ValueError(
|
1034
1031
|
"Warmup failed - Please make sure benchmark arguments "
|
1035
1032
|
f"are correctly specified. Error: {warmup_outputs[0].error}"
|
@@ -0,0 +1,136 @@
|
|
1
|
+
"""
|
2
|
+
Compile DeepGEMM Kernels for a model with specify server arguments
|
3
|
+
|
4
|
+
This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
|
5
|
+
It accepts server arguments (the same as launch_server.py).
|
6
|
+
|
7
|
+
Usage:
|
8
|
+
python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
|
9
|
+
|
10
|
+
"""
|
11
|
+
|
12
|
+
import argparse
|
13
|
+
import dataclasses
|
14
|
+
import multiprocessing
|
15
|
+
import os
|
16
|
+
import time
|
17
|
+
|
18
|
+
import requests
|
19
|
+
|
20
|
+
from sglang.srt.entrypoints.http_server import launch_server
|
21
|
+
from sglang.srt.managers.io_struct import GenerateReqInput
|
22
|
+
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
23
|
+
from sglang.srt.server_args import ServerArgs
|
24
|
+
from sglang.srt.utils import kill_process_tree
|
25
|
+
from sglang.srt.warmup import warmup
|
26
|
+
|
27
|
+
multiprocessing.set_start_method("spawn", force=True)
|
28
|
+
|
29
|
+
# Reduce warning
|
30
|
+
os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1"
|
31
|
+
|
32
|
+
|
33
|
+
@dataclasses.dataclass
|
34
|
+
class CompileArgs:
|
35
|
+
timeout: int = 3600
|
36
|
+
|
37
|
+
@staticmethod
|
38
|
+
def add_cli_args(parser: argparse.ArgumentParser):
|
39
|
+
parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)
|
40
|
+
|
41
|
+
@classmethod
|
42
|
+
def from_cli_args(cls, args: argparse.Namespace):
|
43
|
+
# use the default value's type to cast the args into correct types.
|
44
|
+
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
45
|
+
return cls(
|
46
|
+
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
@warmup("compile-deep-gemm")
|
51
|
+
async def warm_up_compile(tokenizer_manager: TokenizerManager):
|
52
|
+
print("\nGenerate warm up request for compiling DeepGEMM...\n")
|
53
|
+
generate_req_input = GenerateReqInput(
|
54
|
+
input_ids=[0, 1, 2, 3],
|
55
|
+
sampling_params={
|
56
|
+
"temperature": 0.0,
|
57
|
+
"max_new_tokens": 8,
|
58
|
+
"ignore_eos": True,
|
59
|
+
},
|
60
|
+
)
|
61
|
+
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
62
|
+
|
63
|
+
|
64
|
+
def launch_server_internal(server_args):
|
65
|
+
try:
|
66
|
+
launch_server(server_args)
|
67
|
+
except Exception as e:
|
68
|
+
raise e
|
69
|
+
finally:
|
70
|
+
kill_process_tree(os.getpid(), include_parent=False)
|
71
|
+
|
72
|
+
|
73
|
+
def launch_server_process_and_send_one_request(
|
74
|
+
server_args: ServerArgs, compile_args: CompileArgs
|
75
|
+
):
|
76
|
+
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
|
77
|
+
proc.start()
|
78
|
+
base_url = f"http://{server_args.host}:{server_args.port}"
|
79
|
+
timeout = compile_args.timeout
|
80
|
+
|
81
|
+
start_time = time.time()
|
82
|
+
while time.time() - start_time < timeout:
|
83
|
+
try:
|
84
|
+
headers = {
|
85
|
+
"Content-Type": "application/json; charset=utf-8",
|
86
|
+
}
|
87
|
+
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
88
|
+
if response.status_code == 200:
|
89
|
+
return proc
|
90
|
+
except requests.RequestException:
|
91
|
+
pass
|
92
|
+
time.sleep(10)
|
93
|
+
raise TimeoutError(
|
94
|
+
"DeepGEMM Kernels compilation timeout."
|
95
|
+
"\n\nFeel free and please restart the command."
|
96
|
+
)
|
97
|
+
|
98
|
+
|
99
|
+
def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
|
100
|
+
# Disbale cuda graph and torch compile to save time
|
101
|
+
server_args.disable_cuda_graph = True
|
102
|
+
server_args.enable_torch_compile = False
|
103
|
+
print(f"Disable CUDA Graph and Torch Compile to save time...")
|
104
|
+
|
105
|
+
# Set watchdog timeout to compile_args.timeout because compilation will take a long time
|
106
|
+
server_args.watchdog_timeout = compile_args.timeout
|
107
|
+
server_args.warmups = "compile-deep-gemm"
|
108
|
+
|
109
|
+
|
110
|
+
def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
|
111
|
+
print(
|
112
|
+
"Begin DeepGEMM Kernels compilation...\n"
|
113
|
+
"It may take a long time and timeout maybe raised "
|
114
|
+
"while the compilation is still in progress.\n"
|
115
|
+
"Just feel free to restart the command "
|
116
|
+
"until the compilation is fully finished.\n"
|
117
|
+
)
|
118
|
+
|
119
|
+
proc = launch_server_process_and_send_one_request(server_args, compile_args)
|
120
|
+
|
121
|
+
kill_process_tree(proc.pid)
|
122
|
+
|
123
|
+
print("\nDeepGEMM Kernels compilation finished successfully.")
|
124
|
+
|
125
|
+
|
126
|
+
if __name__ == "__main__":
|
127
|
+
parser = argparse.ArgumentParser()
|
128
|
+
ServerArgs.add_cli_args(parser)
|
129
|
+
CompileArgs.add_cli_args(parser)
|
130
|
+
args = parser.parse_args()
|
131
|
+
server_args = ServerArgs.from_cli_args(args)
|
132
|
+
compile_args = CompileArgs.from_cli_args(args)
|
133
|
+
|
134
|
+
refine_server_args(server_args, compile_args)
|
135
|
+
|
136
|
+
run_compile(server_args, compile_args)
|
sglang/lang/backend/anthropic.py
CHANGED
sglang/lang/backend/openai.py
CHANGED
@@ -2,7 +2,7 @@ import dataclasses
|
|
2
2
|
import logging
|
3
3
|
import time
|
4
4
|
import warnings
|
5
|
-
from typing import
|
5
|
+
from typing import List, Optional, Union
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
|
@@ -161,7 +161,11 @@ class OpenAI(BaseBackend):
|
|
161
161
|
prompt = s.text_
|
162
162
|
|
163
163
|
kwargs = sampling_params.to_openai_kwargs()
|
164
|
-
if
|
164
|
+
if (
|
165
|
+
self.model_name.startswith("o1")
|
166
|
+
or self.model_name.startswith("o3")
|
167
|
+
or "o1" in self.model_name
|
168
|
+
):
|
165
169
|
kwargs.pop("max_tokens", None)
|
166
170
|
else:
|
167
171
|
kwargs.pop("max_completion_tokens", None)
|
@@ -324,7 +324,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
324
324
|
|
325
325
|
def _assert_success(self, res):
|
326
326
|
if res.status_code != 200:
|
327
|
-
|
327
|
+
try:
|
328
|
+
content = res.json()
|
329
|
+
except json.JSONDecodeError:
|
330
|
+
content = res.text
|
331
|
+
raise RuntimeError(content)
|
328
332
|
|
329
333
|
|
330
334
|
def compute_normalized_prompt_logprobs(input_logprobs):
|
sglang/lang/backend/vertexai.py
CHANGED
sglang/lang/compiler.py
CHANGED
@@ -5,13 +5,7 @@ from typing import List, Union
|
|
5
5
|
|
6
6
|
from sglang.global_config import global_config
|
7
7
|
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
|
8
|
-
from sglang.lang.ir import
|
9
|
-
SglArgument,
|
10
|
-
SglConstantText,
|
11
|
-
SglExpr,
|
12
|
-
SglSamplingParams,
|
13
|
-
SglVariable,
|
14
|
-
)
|
8
|
+
from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
|
15
9
|
|
16
10
|
|
17
11
|
def compile_func(function, backend):
|
sglang/lang/tracer.py
CHANGED
@@ -1,20 +1,16 @@
|
|
1
1
|
"""Tracing a program."""
|
2
2
|
|
3
3
|
import uuid
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Dict, List, Optional
|
5
5
|
|
6
|
-
from sglang.global_config import global_config
|
7
6
|
from sglang.lang.backend.base_backend import BaseBackend
|
8
7
|
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
|
9
8
|
from sglang.lang.ir import (
|
10
9
|
SglArgument,
|
11
|
-
SglCommitLazy,
|
12
|
-
SglConcateAndAppend,
|
13
10
|
SglConstantText,
|
14
11
|
SglExpr,
|
15
12
|
SglExprList,
|
16
13
|
SglFork,
|
17
|
-
SglFunction,
|
18
14
|
SglGen,
|
19
15
|
SglGetForkItem,
|
20
16
|
SglRoleBegin,
|
@@ -230,8 +226,8 @@ class TracerProgramState(ProgramState):
|
|
230
226
|
self.cur_role = None
|
231
227
|
|
232
228
|
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
233
|
-
new_node = SglVariable(name, source=self.last_node)
|
234
|
-
self.variables[name] = new_node
|
229
|
+
new_node = SglVariable(expr.name, source=self.last_node)
|
230
|
+
self.variables[expr.name] = new_node
|
235
231
|
|
236
232
|
def get_var(self, name):
|
237
233
|
ret = self.arguments.get(name, None)
|
sglang/srt/_custom_ops.py
CHANGED
@@ -73,8 +73,11 @@ class ModelConfig:
|
|
73
73
|
)
|
74
74
|
|
75
75
|
if enable_multimodal is None:
|
76
|
-
if self.hf_config.architectures == "Llama4ForConditionalGeneration":
|
76
|
+
if self.hf_config.architectures[0] == "Llama4ForConditionalGeneration":
|
77
77
|
enable_multimodal = False
|
78
|
+
logger.info(
|
79
|
+
"Multimodal is disabled for Llama4. To enable it, set --enable-llama4-multimodal."
|
80
|
+
)
|
78
81
|
else:
|
79
82
|
enable_multimodal = True
|
80
83
|
|
@@ -19,10 +19,13 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
|
19
19
|
import dataclasses
|
20
20
|
import logging
|
21
21
|
from collections import defaultdict
|
22
|
+
from typing import Optional
|
22
23
|
|
23
24
|
import interegular
|
24
25
|
from interegular import InvalidSyntax
|
25
|
-
from outlines.caching import cache
|
26
|
+
from outlines.caching import cache
|
27
|
+
|
28
|
+
from sglang.srt.utils import get_bool_env_var
|
26
29
|
|
27
30
|
try:
|
28
31
|
# outlines >= 0.1.0
|
@@ -34,6 +37,9 @@ except ImportError:
|
|
34
37
|
|
35
38
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
36
39
|
|
40
|
+
# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
|
41
|
+
DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
|
42
|
+
|
37
43
|
logger = logging.getLogger(__name__)
|
38
44
|
|
39
45
|
|
@@ -45,6 +51,13 @@ class JumpEdge:
|
|
45
51
|
byte_next_state: int = None
|
46
52
|
|
47
53
|
|
54
|
+
def disk_cache(expire: Optional[float] = None, typed=False, ignore=()):
|
55
|
+
if not DISABLE_DISK_CACHE:
|
56
|
+
return cache(expire, typed, ignore)
|
57
|
+
else:
|
58
|
+
return lambda fn: None
|
59
|
+
|
60
|
+
|
48
61
|
@disk_cache()
|
49
62
|
def init_state_to_jump_forward(regex_string):
|
50
63
|
try:
|
@@ -0,0 +1,141 @@
|
|
1
|
+
# Adapt from
|
2
|
+
# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
|
3
|
+
|
4
|
+
from typing import List, Optional, Union
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import triton
|
8
|
+
import triton.language as tl
|
9
|
+
|
10
|
+
from sglang.srt.utils import get_device_core_count
|
11
|
+
|
12
|
+
|
13
|
+
@triton.jit
|
14
|
+
def apply_token_bitmask_inplace_kernel(
|
15
|
+
logits_ptr,
|
16
|
+
bitmask_ptr,
|
17
|
+
indices_ptr,
|
18
|
+
num_rows,
|
19
|
+
vocab_size,
|
20
|
+
logits_strides,
|
21
|
+
bitmask_strides,
|
22
|
+
NUM_SMS: tl.constexpr,
|
23
|
+
BLOCK_SIZE: tl.constexpr,
|
24
|
+
):
|
25
|
+
"""Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
|
26
|
+
where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
|
27
|
+
the masked logits will be set to -inf.
|
28
|
+
|
29
|
+
Parameters
|
30
|
+
----------
|
31
|
+
logits_ptr : tl.tensor
|
32
|
+
Pointer to the logits tensor to apply the bitmask to.
|
33
|
+
|
34
|
+
bitmask_ptr : tl.tensor
|
35
|
+
Pointer to the bitmask tensor to apply.
|
36
|
+
|
37
|
+
indices_ptr : Optional[tl.tensor]
|
38
|
+
Optional pointer to indices tensor specifying which rows to apply the mask to.
|
39
|
+
|
40
|
+
num_rows : int
|
41
|
+
Number of rows to process. If indices_ptr is provided, this is the number of unique indices.
|
42
|
+
|
43
|
+
vocab_size : int
|
44
|
+
Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
|
45
|
+
same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.
|
46
|
+
|
47
|
+
logits_strides : int
|
48
|
+
Stride between rows in the logits tensor.
|
49
|
+
|
50
|
+
bitmask_strides : int
|
51
|
+
Stride between rows in the bitmask tensor.
|
52
|
+
|
53
|
+
NUM_SMS : int
|
54
|
+
Number of streaming multiprocessors to use.
|
55
|
+
|
56
|
+
BLOCK_SIZE : int
|
57
|
+
Size of processing blocks.
|
58
|
+
"""
|
59
|
+
|
60
|
+
pid = tl.program_id(0)
|
61
|
+
num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
|
62
|
+
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
|
63
|
+
row_id = work_id // num_blocks
|
64
|
+
block_offset = (work_id % num_blocks) * BLOCK_SIZE
|
65
|
+
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
|
66
|
+
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
|
67
|
+
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
|
68
|
+
vocab_mask = offsets < vocab_size
|
69
|
+
packed_bitmask_mask = bitmask_offsets < bitmask_strides
|
70
|
+
packed_bitmask = tl.load(
|
71
|
+
bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets,
|
72
|
+
packed_bitmask_mask,
|
73
|
+
)
|
74
|
+
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
|
75
|
+
bitmask = bitmask.reshape(BLOCK_SIZE)
|
76
|
+
|
77
|
+
tl.store(
|
78
|
+
logits_ptr + batch_id * logits_strides + offsets,
|
79
|
+
-float("inf"),
|
80
|
+
vocab_mask & bitmask,
|
81
|
+
)
|
82
|
+
|
83
|
+
|
84
|
+
def apply_token_bitmask_inplace_triton(
|
85
|
+
logits: torch.Tensor,
|
86
|
+
bitmask: torch.Tensor,
|
87
|
+
indices: Optional[Union[List[int], torch.Tensor]] = None,
|
88
|
+
):
|
89
|
+
NUM_SMS = get_device_core_count()
|
90
|
+
BLOCK_SIZE = 4096
|
91
|
+
BITS_PER_BLOCK = 32
|
92
|
+
|
93
|
+
# Check input dtype
|
94
|
+
assert bitmask.dtype == torch.int32, "bitmask must be of type int32"
|
95
|
+
|
96
|
+
# Check input tensor shapes.
|
97
|
+
logits_shape = logits.shape
|
98
|
+
bitmask_shape = bitmask.shape
|
99
|
+
if logits.ndim == 1:
|
100
|
+
logits_shape = (1, logits_shape[0])
|
101
|
+
if bitmask.ndim == 1:
|
102
|
+
bitmask_shape = (1, bitmask_shape[0])
|
103
|
+
|
104
|
+
required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK
|
105
|
+
assert required_bitmask_width >= bitmask_shape[1], (
|
106
|
+
f"Bitmask width too large: allow at most {required_bitmask_width} int32s for "
|
107
|
+
f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}"
|
108
|
+
)
|
109
|
+
|
110
|
+
vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK)
|
111
|
+
|
112
|
+
num_rows = None
|
113
|
+
if isinstance(indices, list) or isinstance(indices, torch.Tensor):
|
114
|
+
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
|
115
|
+
num_rows = indices.shape[0]
|
116
|
+
else:
|
117
|
+
assert (
|
118
|
+
logits_shape[0] == bitmask_shape[0]
|
119
|
+
), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}"
|
120
|
+
num_rows = logits_shape[0]
|
121
|
+
|
122
|
+
if NUM_SMS > 0:
|
123
|
+
grid = (NUM_SMS,)
|
124
|
+
else:
|
125
|
+
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
126
|
+
grid = (num_rows * num_blocks,)
|
127
|
+
NUM_SMS = triton.next_power_of_2(grid[0])
|
128
|
+
|
129
|
+
apply_token_bitmask_inplace_kernel[grid](
|
130
|
+
logits,
|
131
|
+
bitmask,
|
132
|
+
indices,
|
133
|
+
num_rows,
|
134
|
+
vocab_size,
|
135
|
+
logits_shape[1],
|
136
|
+
bitmask_shape[1],
|
137
|
+
NUM_SMS,
|
138
|
+
BLOCK_SIZE,
|
139
|
+
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
|
140
|
+
num_stages=3,
|
141
|
+
)
|
@@ -25,13 +25,16 @@ from xgrammar import (
|
|
25
25
|
StructuralTagItem,
|
26
26
|
TokenizerInfo,
|
27
27
|
allocate_token_bitmask,
|
28
|
-
apply_token_bitmask_inplace,
|
29
28
|
)
|
30
29
|
|
31
30
|
from sglang.srt.constrained.base_grammar_backend import (
|
32
31
|
BaseGrammarBackend,
|
33
32
|
BaseGrammarObject,
|
34
33
|
)
|
34
|
+
from sglang.srt.constrained.triton_ops.bitmask_ops import (
|
35
|
+
apply_token_bitmask_inplace_triton,
|
36
|
+
)
|
37
|
+
from sglang.srt.utils import get_bool_env_var
|
35
38
|
|
36
39
|
logger = logging.getLogger(__name__)
|
37
40
|
|
@@ -55,6 +58,18 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
55
58
|
self.override_stop_tokens = override_stop_tokens
|
56
59
|
self.finished = False
|
57
60
|
|
61
|
+
# Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the
|
62
|
+
# class init site to avoid re-initializing CUDA in forked subprocess.
|
63
|
+
from xgrammar.kernels import apply_token_bitmask_inplace_kernels
|
64
|
+
|
65
|
+
self.use_token_bitmask_triton = get_bool_env_var(
|
66
|
+
"SGLANG_TOKEN_BITMASK_TRITON", "false"
|
67
|
+
)
|
68
|
+
self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get(
|
69
|
+
"cuda", None
|
70
|
+
)
|
71
|
+
self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None)
|
72
|
+
|
58
73
|
def accept_token(self, token: int):
|
59
74
|
assert self.matcher.accept_token(token)
|
60
75
|
|
@@ -97,9 +112,16 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
97
112
|
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
98
113
|
return vocab_mask.to(device, non_blocking=True)
|
99
114
|
|
100
|
-
|
101
|
-
|
102
|
-
|
115
|
+
def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
116
|
+
if (
|
117
|
+
not self.use_token_bitmask_triton
|
118
|
+
and logits.device.type == "cuda"
|
119
|
+
and self.apply_vocab_mask_cuda
|
120
|
+
):
|
121
|
+
return self.apply_vocab_mask_cuda(logits, vocab_mask)
|
122
|
+
if logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
|
123
|
+
return self.apply_vocab_mask_cpu(logits, vocab_mask)
|
124
|
+
apply_token_bitmask_inplace_triton(logits, vocab_mask)
|
103
125
|
|
104
126
|
def copy(self):
|
105
127
|
matcher = GrammarMatcher(
|
@@ -136,6 +158,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
136
158
|
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
137
159
|
try:
|
138
160
|
if key_string == "$$ANY$$":
|
161
|
+
# Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root)
|
139
162
|
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
140
163
|
else:
|
141
164
|
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
sglang/srt/custom_op.py
CHANGED
@@ -42,65 +42,3 @@ class CustomOp(nn.Module):
|
|
42
42
|
return self.forward_hip
|
43
43
|
else:
|
44
44
|
return self.forward_native
|
45
|
-
|
46
|
-
|
47
|
-
if _is_cuda:
|
48
|
-
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
|
49
|
-
|
50
|
-
def scaled_fp8_quant(
|
51
|
-
input: torch.Tensor,
|
52
|
-
scale: Optional[torch.Tensor] = None,
|
53
|
-
num_token_padding: Optional[int] = None,
|
54
|
-
use_per_token_if_dynamic: bool = False,
|
55
|
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
56
|
-
"""
|
57
|
-
Quantize input tensor to FP8 (8-bit floating point) format.
|
58
|
-
|
59
|
-
Args:
|
60
|
-
input (torch.Tensor): Input tensor to be quantized
|
61
|
-
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
62
|
-
If None, scales will be computed dynamically.
|
63
|
-
num_token_padding (Optional[int]): If specified, pad the first dimension
|
64
|
-
of the output to at least this value.
|
65
|
-
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
66
|
-
determines the quantization granularity:
|
67
|
-
- True: compute scale per token
|
68
|
-
- False: compute single scale per tensor
|
69
|
-
|
70
|
-
Returns:
|
71
|
-
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
72
|
-
- quantized_tensor: The FP8 quantized version of input
|
73
|
-
- scale_tensor: The scaling factors used for quantization
|
74
|
-
|
75
|
-
Raises:
|
76
|
-
AssertionError: If input is not 2D or if static scale's numel != 1
|
77
|
-
"""
|
78
|
-
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
79
|
-
shape = input.shape
|
80
|
-
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
81
|
-
if num_token_padding:
|
82
|
-
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
83
|
-
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
84
|
-
|
85
|
-
if scale is None:
|
86
|
-
# Dynamic scaling
|
87
|
-
if use_per_token_if_dynamic:
|
88
|
-
scale = torch.empty(
|
89
|
-
(shape[0], 1), device=input.device, dtype=torch.float32
|
90
|
-
)
|
91
|
-
sgl_per_token_quant_fp8(input, output, scale)
|
92
|
-
else:
|
93
|
-
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
94
|
-
sgl_per_tensor_quant_fp8(
|
95
|
-
input, output, scale, is_static=False
|
96
|
-
) # False for dynamic
|
97
|
-
else:
|
98
|
-
# Static scaling
|
99
|
-
assert (
|
100
|
-
scale.numel() == 1
|
101
|
-
), f"Expected scalar scale, got numel={scale.numel()}"
|
102
|
-
sgl_per_tensor_quant_fp8(
|
103
|
-
input, output, scale, is_static=True
|
104
|
-
) # True for static
|
105
|
-
|
106
|
-
return output, scale
|