sglang 0.3.5__py3-none-any.whl → 0.3.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sglang/bench_serving.py +113 -3
  2. sglang/srt/configs/model_config.py +5 -2
  3. sglang/srt/constrained/__init__.py +2 -66
  4. sglang/srt/constrained/base_grammar_backend.py +72 -0
  5. sglang/srt/constrained/outlines_backend.py +165 -0
  6. sglang/srt/constrained/outlines_jump_forward.py +182 -0
  7. sglang/srt/constrained/xgrammar_backend.py +114 -0
  8. sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
  9. sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
  10. sglang/srt/layers/fused_moe/fused_moe.py +23 -7
  11. sglang/srt/layers/quantization/base_config.py +4 -6
  12. sglang/srt/layers/vocab_parallel_embedding.py +216 -150
  13. sglang/srt/managers/io_struct.py +5 -3
  14. sglang/srt/managers/schedule_batch.py +14 -20
  15. sglang/srt/managers/scheduler.py +153 -94
  16. sglang/srt/managers/tokenizer_manager.py +81 -17
  17. sglang/srt/metrics/collector.py +211 -0
  18. sglang/srt/metrics/func_timer.py +108 -0
  19. sglang/srt/mm_utils.py +1 -1
  20. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  21. sglang/srt/model_executor/forward_batch_info.py +7 -3
  22. sglang/srt/model_executor/model_runner.py +2 -1
  23. sglang/srt/models/gemma2_reward.py +69 -0
  24. sglang/srt/models/gpt2.py +31 -37
  25. sglang/srt/models/internlm2_reward.py +62 -0
  26. sglang/srt/models/llama.py +11 -6
  27. sglang/srt/models/llama_reward.py +5 -26
  28. sglang/srt/models/qwen2_vl.py +5 -7
  29. sglang/srt/openai_api/adapter.py +6 -2
  30. sglang/srt/sampling/sampling_batch_info.py +2 -3
  31. sglang/srt/sampling/sampling_params.py +0 -14
  32. sglang/srt/server.py +58 -16
  33. sglang/srt/server_args.py +42 -22
  34. sglang/srt/utils.py +87 -0
  35. sglang/test/simple_eval_common.py +1 -1
  36. sglang/test/simple_eval_humaneval.py +2 -2
  37. sglang/test/simple_eval_mgsm.py +2 -2
  38. sglang/test/test_utils.py +18 -4
  39. sglang/utils.py +1 -0
  40. sglang/version.py +1 -1
  41. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/METADATA +11 -7
  42. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/RECORD +45 -42
  43. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/WHEEL +1 -1
  44. sglang/srt/constrained/base_tool_cache.py +0 -65
  45. sglang/srt/constrained/bnf_cache.py +0 -61
  46. sglang/srt/constrained/fsm_cache.py +0 -95
  47. sglang/srt/constrained/grammar.py +0 -190
  48. sglang/srt/constrained/jump_forward.py +0 -203
  49. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/LICENSE +0 -0
  50. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,114 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """Constrained decoding with xgrammar backend."""
17
+
18
+ from typing import List, Tuple
19
+
20
+ import torch
21
+ from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
22
+
23
+ from sglang.srt.constrained.base_grammar_backend import (
24
+ BaseGrammarBackend,
25
+ BaseGrammarObject,
26
+ )
27
+
28
+ MAX_ROLLBACK_TOKENS = 10
29
+
30
+
31
+ class XGrammarGrammar(BaseGrammarObject):
32
+
33
+ def __init__(
34
+ self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
35
+ ) -> None:
36
+ self.matcher = matcher
37
+ self.vocab_size = vocab_size
38
+ self.ctx = ctx
39
+
40
+ def accept_token(self, token: int):
41
+ assert self.matcher.accept_token(token)
42
+
43
+ def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
44
+ s = self.matcher.find_jump_forward_string()
45
+ if s:
46
+ return [], s
47
+ return None
48
+
49
+ def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
50
+ _, data = helper
51
+ return data, -1
52
+
53
+ def jump_and_retokenize(
54
+ self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
55
+ ):
56
+ k = 0
57
+ for i, old_id in enumerate(old_output_ids):
58
+ if old_id == new_output_ids[i]:
59
+ k = i + 1
60
+ else:
61
+ break
62
+
63
+ # rollback to the last token that is the same
64
+ if k < len(old_output_ids):
65
+ self.matcher.rollback(len(old_output_ids) - k)
66
+
67
+ for i in range(k, len(new_output_ids)):
68
+ assert self.matcher.accept_token(new_output_ids[i])
69
+
70
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor):
71
+ # Note that this bitmask is a bitset, not bool
72
+ bitmask = self.matcher.get_next_token_bitmask()
73
+ # Mask the tokens that are not allowed
74
+ vocab_mask[
75
+ self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
76
+ ] = 1
77
+
78
+ def copy(self):
79
+ matcher = GrammarMatcher(
80
+ self.ctx,
81
+ max_rollback_tokens=MAX_ROLLBACK_TOKENS,
82
+ mask_vocab_size=self.vocab_size,
83
+ )
84
+ return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
85
+
86
+
87
+ class XGrammarGrammarBackend(BaseGrammarBackend):
88
+ def __init__(
89
+ self,
90
+ tokenizer,
91
+ vocab_size: int,
92
+ ):
93
+ super().__init__()
94
+ self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
95
+ self.vocab_size = vocab_size
96
+
97
+ def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
98
+ key_type, key_string = key
99
+ if key_type == "json":
100
+ ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
101
+ elif key_type == "regex":
102
+ raise ValueError("regex hasn't been supported by xgrammar yet")
103
+ else:
104
+ raise ValueError(f"Invalid key_type: {key_type}")
105
+
106
+ matcher = GrammarMatcher(
107
+ ctx,
108
+ max_rollback_tokens=MAX_ROLLBACK_TOKENS,
109
+ mask_vocab_size=self.vocab_size,
110
+ )
111
+ return XGrammarGrammar(matcher, self.vocab_size, ctx)
112
+
113
+ def reset(self):
114
+ self.grammar_cache.clear()
@@ -507,6 +507,12 @@ def _decode_grouped_att_m_fwd(
507
507
 
508
508
  num_warps = 4
509
509
 
510
+ extra_kargs = {}
511
+ if is_hip():
512
+ # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
513
+ # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
514
+ extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
515
+
510
516
  _fwd_grouped_kernel_stage1[grid](
511
517
  q,
512
518
  k_buffer,
@@ -532,6 +538,7 @@ def _decode_grouped_att_m_fwd(
532
538
  num_warps=num_warps,
533
539
  num_stages=1,
534
540
  Lk=Lk,
541
+ **extra_kargs,
535
542
  )
536
543
 
537
544
 
@@ -25,6 +25,7 @@ import triton.language as tl
25
25
  from sglang.srt.layers.attention.triton_ops.prefill_attention import (
26
26
  context_attention_fwd,
27
27
  )
28
+ from sglang.srt.utils import is_hip
28
29
 
29
30
  is_cuda_available = torch.cuda.is_available()
30
31
  if is_cuda_available:
@@ -311,6 +312,10 @@ def extend_attention_fwd(
311
312
  num_warps = 4 if Lk <= 64 else 8
312
313
  num_stages = 1
313
314
 
315
+ extra_kargs = {}
316
+ if is_hip():
317
+ extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
318
+
314
319
  _fwd_kernel[grid](
315
320
  q_extend,
316
321
  k_extend,
@@ -348,6 +353,7 @@ def extend_attention_fwd(
348
353
  Lv=Lv,
349
354
  num_warps=num_warps,
350
355
  num_stages=num_stages,
356
+ **extra_kargs,
351
357
  )
352
358
 
353
359
 
@@ -54,6 +54,7 @@ def fused_moe_kernel(
54
54
  top_k: tl.constexpr,
55
55
  compute_type: tl.constexpr,
56
56
  use_fp8: tl.constexpr,
57
+ even_Ks: tl.constexpr,
57
58
  ):
58
59
  """
59
60
  Implements the fused computation for a Mixture of Experts (MOE) using
@@ -130,16 +131,24 @@ def fused_moe_kernel(
130
131
  # of fp32 values for higher accuracy.
131
132
  # `accumulator` will be converted back to fp16 after the loop.
132
133
  accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
133
-
134
134
  for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
135
135
  # Load the next block of A and B, generate a mask by checking the
136
136
  # K dimension.
137
- a = tl.load(
138
- a_ptrs,
139
- mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
140
- other=0.0,
141
- )
142
- b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
137
+ if even_Ks:
138
+ a = tl.load(
139
+ a_ptrs,
140
+ mask=token_mask[:, None],
141
+ other=0.0,
142
+ )
143
+ b = tl.load(b_ptrs)
144
+ else:
145
+ a = tl.load(
146
+ a_ptrs,
147
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
148
+ other=0.0,
149
+ )
150
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
151
+
143
152
  # We accumulate along the K dimension.
144
153
  if use_fp8:
145
154
  accumulator = tl.dot(a, b, acc=accumulator)
@@ -253,6 +262,12 @@ def invoke_fused_moe_kernel(
253
262
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
254
263
  )
255
264
 
265
+ K = B.shape[2] - padding_size
266
+ if K % config["BLOCK_SIZE_K"] == 0:
267
+ even_ks = True
268
+ else:
269
+ even_ks = False
270
+
256
271
  fused_moe_kernel[grid](
257
272
  A,
258
273
  B,
@@ -278,6 +293,7 @@ def invoke_fused_moe_kernel(
278
293
  top_k=top_k,
279
294
  compute_type=compute_type,
280
295
  use_fp8=use_fp8,
296
+ even_Ks=even_ks,
281
297
  **config,
282
298
  )
283
299
 
@@ -122,16 +122,14 @@ class QuantizationConfig(ABC):
122
122
  """
123
123
  raise NotImplementedError
124
124
 
125
- def method_has_implemented_embedding(
126
- method_class: Type[QuantizeMethodBase]) -> bool:
125
+
126
+ def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool:
127
127
  """
128
128
  Not all quant methods have embedding implemented, so we need to check that
129
129
  it exists for our given method. We check this by making sure the function
130
130
  has been changed from the base implementation.
131
131
  """
132
- base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
133
- None)
132
+ base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None)
134
133
  class_embedding = inspect.getattr_static(method_class, "embedding", None)
135
134
 
136
- return (class_embedding is not None
137
- and class_embedding is not base_embedding)
135
+ return class_embedding is not None and class_embedding is not base_embedding