sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -24,6 +24,7 @@ from sglang.api import (
24
24
  user_end,
25
25
  video,
26
26
  )
27
+ from sglang.global_config import global_config
27
28
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
28
29
  from sglang.lang.choices import (
29
30
  greedy_token_selection,
@@ -31,6 +32,7 @@ from sglang.lang.choices import (
31
32
  unconditional_likelihood_normalized,
32
33
  )
33
34
  from sglang.utils import LazyImport
35
+ from sglang.version import __version__
34
36
 
35
37
  ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
36
38
  Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
@@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
38
40
  OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
39
41
  VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
40
42
 
41
- # Other configs
42
- from sglang.global_config import global_config
43
- from sglang.version import __version__
44
-
45
43
  __all__ = [
46
44
  "Engine",
47
45
  "Runtime",
sglang/bench_one_batch.py CHANGED
@@ -207,7 +207,7 @@ def prepare_extend_inputs_for_correctness_test(
207
207
 
208
208
 
209
209
  def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
210
- input_ids = np.ones((batch_size, input_len), dtype=np.int32)
210
+ input_ids = np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
211
211
  sampling_params = SamplingParams(
212
212
  temperature=0,
213
213
  max_new_tokens=BenchArgs.output_len,
@@ -396,7 +396,7 @@ def latency_test_run_once(
396
396
  decode_latencies.append(latency)
397
397
  if i < 5:
398
398
  rank_print(
399
- f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
399
+ f"Decode. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
400
400
  )
401
401
 
402
402
  if profile:
sglang/bench_serving.py CHANGED
@@ -690,7 +690,6 @@ def sample_random_requests(
690
690
  dataset_path: str,
691
691
  random_sample: bool = True,
692
692
  ) -> List[Tuple[str, int, int]]:
693
-
694
693
  input_lens = np.random.randint(
695
694
  max(int(input_len * range_ratio), 1),
696
695
  input_len + 1,
@@ -707,10 +706,6 @@ def sample_random_requests(
707
706
 
708
707
  # Download sharegpt if necessary
709
708
  if not os.path.isfile(dataset_path):
710
- print(
711
- "If you do not want to randomly sample from a dataset,"
712
- " please use --dataset-name random-ids."
713
- )
714
709
  dataset_path = download_and_cache_file(SHAREGPT_URL)
715
710
 
716
711
  # Load the dataset.
@@ -1029,7 +1024,9 @@ async def benchmark(
1029
1024
  warmup_outputs = await asyncio.gather(*warmup_tasks)
1030
1025
 
1031
1026
  # Check if at least one warmup request succeeded
1032
- if not any(output.success for output in warmup_outputs):
1027
+ if args.warmup_requests > 0 and not any(
1028
+ output.success for output in warmup_outputs
1029
+ ):
1033
1030
  raise ValueError(
1034
1031
  "Warmup failed - Please make sure benchmark arguments "
1035
1032
  f"are correctly specified. Error: {warmup_outputs[0].error}"
@@ -0,0 +1,136 @@
1
+ """
2
+ Compile DeepGEMM Kernels for a model with specify server arguments
3
+
4
+ This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
5
+ It accepts server arguments (the same as launch_server.py).
6
+
7
+ Usage:
8
+ python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
9
+
10
+ """
11
+
12
+ import argparse
13
+ import dataclasses
14
+ import multiprocessing
15
+ import os
16
+ import time
17
+
18
+ import requests
19
+
20
+ from sglang.srt.entrypoints.http_server import launch_server
21
+ from sglang.srt.managers.io_struct import GenerateReqInput
22
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
23
+ from sglang.srt.server_args import ServerArgs
24
+ from sglang.srt.utils import kill_process_tree
25
+ from sglang.srt.warmup import warmup
26
+
27
+ multiprocessing.set_start_method("spawn", force=True)
28
+
29
+ # Reduce warning
30
+ os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1"
31
+
32
+
33
+ @dataclasses.dataclass
34
+ class CompileArgs:
35
+ timeout: int = 3600
36
+
37
+ @staticmethod
38
+ def add_cli_args(parser: argparse.ArgumentParser):
39
+ parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)
40
+
41
+ @classmethod
42
+ def from_cli_args(cls, args: argparse.Namespace):
43
+ # use the default value's type to cast the args into correct types.
44
+ attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
45
+ return cls(
46
+ **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
47
+ )
48
+
49
+
50
+ @warmup("compile-deep-gemm")
51
+ async def warm_up_compile(tokenizer_manager: TokenizerManager):
52
+ print("\nGenerate warm up request for compiling DeepGEMM...\n")
53
+ generate_req_input = GenerateReqInput(
54
+ input_ids=[0, 1, 2, 3],
55
+ sampling_params={
56
+ "temperature": 0.0,
57
+ "max_new_tokens": 8,
58
+ "ignore_eos": True,
59
+ },
60
+ )
61
+ await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
62
+
63
+
64
+ def launch_server_internal(server_args):
65
+ try:
66
+ launch_server(server_args)
67
+ except Exception as e:
68
+ raise e
69
+ finally:
70
+ kill_process_tree(os.getpid(), include_parent=False)
71
+
72
+
73
+ def launch_server_process_and_send_one_request(
74
+ server_args: ServerArgs, compile_args: CompileArgs
75
+ ):
76
+ proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
77
+ proc.start()
78
+ base_url = f"http://{server_args.host}:{server_args.port}"
79
+ timeout = compile_args.timeout
80
+
81
+ start_time = time.time()
82
+ while time.time() - start_time < timeout:
83
+ try:
84
+ headers = {
85
+ "Content-Type": "application/json; charset=utf-8",
86
+ }
87
+ response = requests.get(f"{base_url}/v1/models", headers=headers)
88
+ if response.status_code == 200:
89
+ return proc
90
+ except requests.RequestException:
91
+ pass
92
+ time.sleep(10)
93
+ raise TimeoutError(
94
+ "DeepGEMM Kernels compilation timeout."
95
+ "\n\nFeel free and please restart the command."
96
+ )
97
+
98
+
99
+ def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
100
+ # Disbale cuda graph and torch compile to save time
101
+ server_args.disable_cuda_graph = True
102
+ server_args.enable_torch_compile = False
103
+ print(f"Disable CUDA Graph and Torch Compile to save time...")
104
+
105
+ # Set watchdog timeout to compile_args.timeout because compilation will take a long time
106
+ server_args.watchdog_timeout = compile_args.timeout
107
+ server_args.warmups = "compile-deep-gemm"
108
+
109
+
110
+ def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
111
+ print(
112
+ "Begin DeepGEMM Kernels compilation...\n"
113
+ "It may take a long time and timeout maybe raised "
114
+ "while the compilation is still in progress.\n"
115
+ "Just feel free to restart the command "
116
+ "until the compilation is fully finished.\n"
117
+ )
118
+
119
+ proc = launch_server_process_and_send_one_request(server_args, compile_args)
120
+
121
+ kill_process_tree(proc.pid)
122
+
123
+ print("\nDeepGEMM Kernels compilation finished successfully.")
124
+
125
+
126
+ if __name__ == "__main__":
127
+ parser = argparse.ArgumentParser()
128
+ ServerArgs.add_cli_args(parser)
129
+ CompileArgs.add_cli_args(parser)
130
+ args = parser.parse_args()
131
+ server_args = ServerArgs.from_cli_args(args)
132
+ compile_args = CompileArgs.from_cli_args(args)
133
+
134
+ refine_server_args(server_args, compile_args)
135
+
136
+ run_compile(server_args, compile_args)
@@ -1,7 +1,3 @@
1
- from typing import List, Optional, Union
2
-
3
- import numpy as np
4
-
5
1
  from sglang.lang.backend.base_backend import BaseBackend
6
2
  from sglang.lang.chat_template import get_chat_template
7
3
  from sglang.lang.interpreter import StreamExecutor
@@ -1,4 +1,4 @@
1
- from typing import Callable, List, Optional, Union
1
+ from typing import List, Optional, Union
2
2
 
3
3
  from sglang.lang.chat_template import get_chat_template
4
4
  from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
@@ -2,7 +2,7 @@ import dataclasses
2
2
  import logging
3
3
  import time
4
4
  import warnings
5
- from typing import Callable, List, Optional, Union
5
+ from typing import List, Optional, Union
6
6
 
7
7
  import numpy as np
8
8
 
@@ -161,7 +161,11 @@ class OpenAI(BaseBackend):
161
161
  prompt = s.text_
162
162
 
163
163
  kwargs = sampling_params.to_openai_kwargs()
164
- if self.model_name.startswith("o1") or self.model_name.startswith("o3"):
164
+ if (
165
+ self.model_name.startswith("o1")
166
+ or self.model_name.startswith("o3")
167
+ or "o1" in self.model_name
168
+ ):
165
169
  kwargs.pop("max_tokens", None)
166
170
  else:
167
171
  kwargs.pop("max_completion_tokens", None)
@@ -324,7 +324,11 @@ class RuntimeEndpoint(BaseBackend):
324
324
 
325
325
  def _assert_success(self, res):
326
326
  if res.status_code != 200:
327
- raise RuntimeError(res.json())
327
+ try:
328
+ content = res.json()
329
+ except json.JSONDecodeError:
330
+ content = res.text
331
+ raise RuntimeError(content)
328
332
 
329
333
 
330
334
  def compute_normalized_prompt_logprobs(input_logprobs):
@@ -1,6 +1,5 @@
1
1
  import os
2
2
  import warnings
3
- from typing import Optional
4
3
 
5
4
  from sglang.lang.backend.base_backend import BaseBackend
6
5
  from sglang.lang.chat_template import get_chat_template
sglang/lang/compiler.py CHANGED
@@ -5,13 +5,7 @@ from typing import List, Union
5
5
 
6
6
  from sglang.global_config import global_config
7
7
  from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
8
- from sglang.lang.ir import (
9
- SglArgument,
10
- SglConstantText,
11
- SglExpr,
12
- SglSamplingParams,
13
- SglVariable,
14
- )
8
+ from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
15
9
 
16
10
 
17
11
  def compile_func(function, backend):
sglang/lang/tracer.py CHANGED
@@ -1,20 +1,16 @@
1
1
  """Tracing a program."""
2
2
 
3
3
  import uuid
4
- from typing import Any, Callable, Dict, List, Optional, Union
4
+ from typing import Any, Dict, List, Optional
5
5
 
6
- from sglang.global_config import global_config
7
6
  from sglang.lang.backend.base_backend import BaseBackend
8
7
  from sglang.lang.interpreter import ProgramState, ProgramStateGroup
9
8
  from sglang.lang.ir import (
10
9
  SglArgument,
11
- SglCommitLazy,
12
- SglConcateAndAppend,
13
10
  SglConstantText,
14
11
  SglExpr,
15
12
  SglExprList,
16
13
  SglFork,
17
- SglFunction,
18
14
  SglGen,
19
15
  SglGetForkItem,
20
16
  SglRoleBegin,
@@ -230,8 +226,8 @@ class TracerProgramState(ProgramState):
230
226
  self.cur_role = None
231
227
 
232
228
  def _execute_var_scope_end(self, expr: SglVarScopeEnd):
233
- new_node = SglVariable(name, source=self.last_node)
234
- self.variables[name] = new_node
229
+ new_node = SglVariable(expr.name, source=self.last_node)
230
+ self.variables[expr.name] = new_node
235
231
 
236
232
  def get_var(self, name):
237
233
  ret = self.arguments.get(name, None)
sglang/srt/_custom_ops.py CHANGED
@@ -1,10 +1,8 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
2
2
  import logging
3
- import os
4
3
  from typing import List, Tuple
5
4
 
6
5
  import torch
7
- import torch.library
8
6
 
9
7
  from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
10
8
 
@@ -73,8 +73,11 @@ class ModelConfig:
73
73
  )
74
74
 
75
75
  if enable_multimodal is None:
76
- if self.hf_config.architectures == "Llama4ForConditionalGeneration":
76
+ if self.hf_config.architectures[0] == "Llama4ForConditionalGeneration":
77
77
  enable_multimodal = False
78
+ logger.info(
79
+ "Multimodal is disabled for Llama4. To enable it, set --enable-llama4-multimodal."
80
+ )
78
81
  else:
79
82
  enable_multimodal = True
80
83
 
@@ -19,10 +19,13 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
19
19
  import dataclasses
20
20
  import logging
21
21
  from collections import defaultdict
22
+ from typing import Optional
22
23
 
23
24
  import interegular
24
25
  from interegular import InvalidSyntax
25
- from outlines.caching import cache as disk_cache
26
+ from outlines.caching import cache
27
+
28
+ from sglang.srt.utils import get_bool_env_var
26
29
 
27
30
  try:
28
31
  # outlines >= 0.1.0
@@ -34,6 +37,9 @@ except ImportError:
34
37
 
35
38
  IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
36
39
 
40
+ # Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
41
+ DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
42
+
37
43
  logger = logging.getLogger(__name__)
38
44
 
39
45
 
@@ -45,6 +51,13 @@ class JumpEdge:
45
51
  byte_next_state: int = None
46
52
 
47
53
 
54
+ def disk_cache(expire: Optional[float] = None, typed=False, ignore=()):
55
+ if not DISABLE_DISK_CACHE:
56
+ return cache(expire, typed, ignore)
57
+ else:
58
+ return lambda fn: None
59
+
60
+
48
61
  @disk_cache()
49
62
  def init_state_to_jump_forward(regex_string):
50
63
  try:
@@ -0,0 +1,141 @@
1
+ # Adapt from
2
+ # https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
3
+
4
+ from typing import List, Optional, Union
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from sglang.srt.utils import get_device_core_count
11
+
12
+
13
+ @triton.jit
14
+ def apply_token_bitmask_inplace_kernel(
15
+ logits_ptr,
16
+ bitmask_ptr,
17
+ indices_ptr,
18
+ num_rows,
19
+ vocab_size,
20
+ logits_strides,
21
+ bitmask_strides,
22
+ NUM_SMS: tl.constexpr,
23
+ BLOCK_SIZE: tl.constexpr,
24
+ ):
25
+ """Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
26
+ where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
27
+ the masked logits will be set to -inf.
28
+
29
+ Parameters
30
+ ----------
31
+ logits_ptr : tl.tensor
32
+ Pointer to the logits tensor to apply the bitmask to.
33
+
34
+ bitmask_ptr : tl.tensor
35
+ Pointer to the bitmask tensor to apply.
36
+
37
+ indices_ptr : Optional[tl.tensor]
38
+ Optional pointer to indices tensor specifying which rows to apply the mask to.
39
+
40
+ num_rows : int
41
+ Number of rows to process. If indices_ptr is provided, this is the number of unique indices.
42
+
43
+ vocab_size : int
44
+ Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
45
+ same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.
46
+
47
+ logits_strides : int
48
+ Stride between rows in the logits tensor.
49
+
50
+ bitmask_strides : int
51
+ Stride between rows in the bitmask tensor.
52
+
53
+ NUM_SMS : int
54
+ Number of streaming multiprocessors to use.
55
+
56
+ BLOCK_SIZE : int
57
+ Size of processing blocks.
58
+ """
59
+
60
+ pid = tl.program_id(0)
61
+ num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
62
+ for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
63
+ row_id = work_id // num_blocks
64
+ block_offset = (work_id % num_blocks) * BLOCK_SIZE
65
+ batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
66
+ offsets = block_offset + tl.arange(0, BLOCK_SIZE)
67
+ bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
68
+ vocab_mask = offsets < vocab_size
69
+ packed_bitmask_mask = bitmask_offsets < bitmask_strides
70
+ packed_bitmask = tl.load(
71
+ bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets,
72
+ packed_bitmask_mask,
73
+ )
74
+ bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
75
+ bitmask = bitmask.reshape(BLOCK_SIZE)
76
+
77
+ tl.store(
78
+ logits_ptr + batch_id * logits_strides + offsets,
79
+ -float("inf"),
80
+ vocab_mask & bitmask,
81
+ )
82
+
83
+
84
+ def apply_token_bitmask_inplace_triton(
85
+ logits: torch.Tensor,
86
+ bitmask: torch.Tensor,
87
+ indices: Optional[Union[List[int], torch.Tensor]] = None,
88
+ ):
89
+ NUM_SMS = get_device_core_count()
90
+ BLOCK_SIZE = 4096
91
+ BITS_PER_BLOCK = 32
92
+
93
+ # Check input dtype
94
+ assert bitmask.dtype == torch.int32, "bitmask must be of type int32"
95
+
96
+ # Check input tensor shapes.
97
+ logits_shape = logits.shape
98
+ bitmask_shape = bitmask.shape
99
+ if logits.ndim == 1:
100
+ logits_shape = (1, logits_shape[0])
101
+ if bitmask.ndim == 1:
102
+ bitmask_shape = (1, bitmask_shape[0])
103
+
104
+ required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK
105
+ assert required_bitmask_width >= bitmask_shape[1], (
106
+ f"Bitmask width too large: allow at most {required_bitmask_width} int32s for "
107
+ f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}"
108
+ )
109
+
110
+ vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK)
111
+
112
+ num_rows = None
113
+ if isinstance(indices, list) or isinstance(indices, torch.Tensor):
114
+ indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
115
+ num_rows = indices.shape[0]
116
+ else:
117
+ assert (
118
+ logits_shape[0] == bitmask_shape[0]
119
+ ), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}"
120
+ num_rows = logits_shape[0]
121
+
122
+ if NUM_SMS > 0:
123
+ grid = (NUM_SMS,)
124
+ else:
125
+ num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
126
+ grid = (num_rows * num_blocks,)
127
+ NUM_SMS = triton.next_power_of_2(grid[0])
128
+
129
+ apply_token_bitmask_inplace_kernel[grid](
130
+ logits,
131
+ bitmask,
132
+ indices,
133
+ num_rows,
134
+ vocab_size,
135
+ logits_shape[1],
136
+ bitmask_shape[1],
137
+ NUM_SMS,
138
+ BLOCK_SIZE,
139
+ num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
140
+ num_stages=3,
141
+ )
@@ -25,13 +25,16 @@ from xgrammar import (
25
25
  StructuralTagItem,
26
26
  TokenizerInfo,
27
27
  allocate_token_bitmask,
28
- apply_token_bitmask_inplace,
29
28
  )
30
29
 
31
30
  from sglang.srt.constrained.base_grammar_backend import (
32
31
  BaseGrammarBackend,
33
32
  BaseGrammarObject,
34
33
  )
34
+ from sglang.srt.constrained.triton_ops.bitmask_ops import (
35
+ apply_token_bitmask_inplace_triton,
36
+ )
37
+ from sglang.srt.utils import get_bool_env_var
35
38
 
36
39
  logger = logging.getLogger(__name__)
37
40
 
@@ -55,6 +58,18 @@ class XGrammarGrammar(BaseGrammarObject):
55
58
  self.override_stop_tokens = override_stop_tokens
56
59
  self.finished = False
57
60
 
61
+ # Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the
62
+ # class init site to avoid re-initializing CUDA in forked subprocess.
63
+ from xgrammar.kernels import apply_token_bitmask_inplace_kernels
64
+
65
+ self.use_token_bitmask_triton = get_bool_env_var(
66
+ "SGLANG_TOKEN_BITMASK_TRITON", "false"
67
+ )
68
+ self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get(
69
+ "cuda", None
70
+ )
71
+ self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None)
72
+
58
73
  def accept_token(self, token: int):
59
74
  assert self.matcher.accept_token(token)
60
75
 
@@ -97,9 +112,16 @@ class XGrammarGrammar(BaseGrammarObject):
97
112
  def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
98
113
  return vocab_mask.to(device, non_blocking=True)
99
114
 
100
- @staticmethod
101
- def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
102
- apply_token_bitmask_inplace(logits, vocab_mask)
115
+ def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
116
+ if (
117
+ not self.use_token_bitmask_triton
118
+ and logits.device.type == "cuda"
119
+ and self.apply_vocab_mask_cuda
120
+ ):
121
+ return self.apply_vocab_mask_cuda(logits, vocab_mask)
122
+ if logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
123
+ return self.apply_vocab_mask_cpu(logits, vocab_mask)
124
+ apply_token_bitmask_inplace_triton(logits, vocab_mask)
103
125
 
104
126
  def copy(self):
105
127
  matcher = GrammarMatcher(
@@ -136,6 +158,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
136
158
  def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
137
159
  try:
138
160
  if key_string == "$$ANY$$":
161
+ # Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root)
139
162
  ctx = self.grammar_compiler.compile_builtin_json_grammar()
140
163
  else:
141
164
  ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
sglang/srt/custom_op.py CHANGED
@@ -42,65 +42,3 @@ class CustomOp(nn.Module):
42
42
  return self.forward_hip
43
43
  else:
44
44
  return self.forward_native
45
-
46
-
47
- if _is_cuda:
48
- from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
49
-
50
- def scaled_fp8_quant(
51
- input: torch.Tensor,
52
- scale: Optional[torch.Tensor] = None,
53
- num_token_padding: Optional[int] = None,
54
- use_per_token_if_dynamic: bool = False,
55
- ) -> tuple[torch.Tensor, torch.Tensor]:
56
- """
57
- Quantize input tensor to FP8 (8-bit floating point) format.
58
-
59
- Args:
60
- input (torch.Tensor): Input tensor to be quantized
61
- scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
62
- If None, scales will be computed dynamically.
63
- num_token_padding (Optional[int]): If specified, pad the first dimension
64
- of the output to at least this value.
65
- use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
66
- determines the quantization granularity:
67
- - True: compute scale per token
68
- - False: compute single scale per tensor
69
-
70
- Returns:
71
- Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
72
- - quantized_tensor: The FP8 quantized version of input
73
- - scale_tensor: The scaling factors used for quantization
74
-
75
- Raises:
76
- AssertionError: If input is not 2D or if static scale's numel != 1
77
- """
78
- assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
79
- shape = input.shape
80
- out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
81
- if num_token_padding:
82
- shape = (max(num_token_padding, input.shape[0]), shape[1])
83
- output = torch.empty(shape, device=input.device, dtype=out_dtype)
84
-
85
- if scale is None:
86
- # Dynamic scaling
87
- if use_per_token_if_dynamic:
88
- scale = torch.empty(
89
- (shape[0], 1), device=input.device, dtype=torch.float32
90
- )
91
- sgl_per_token_quant_fp8(input, output, scale)
92
- else:
93
- scale = torch.zeros(1, device=input.device, dtype=torch.float32)
94
- sgl_per_tensor_quant_fp8(
95
- input, output, scale, is_static=False
96
- ) # False for dynamic
97
- else:
98
- # Static scaling
99
- assert (
100
- scale.numel() == 1
101
- ), f"Expected scalar scale, got numel={scale.numel()}"
102
- sgl_per_tensor_quant_fp8(
103
- input, output, scale, is_static=True
104
- ) # True for static
105
-
106
- return output, scale