sglang 0.4.2.post3__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sglang/check_env.py +1 -0
  2. sglang/srt/constrained/outlines_backend.py +4 -1
  3. sglang/srt/layers/attention/flashinfer_backend.py +34 -41
  4. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -3
  5. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  6. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  7. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  8. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  9. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  10. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  11. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  12. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  13. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  15. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/lora/backend/__init__.py +25 -5
  19. sglang/srt/lora/backend/base_backend.py +31 -9
  20. sglang/srt/lora/backend/flashinfer_backend.py +41 -4
  21. sglang/srt/lora/backend/triton_backend.py +34 -4
  22. sglang/srt/lora/layers.py +293 -0
  23. sglang/srt/lora/lora.py +101 -326
  24. sglang/srt/lora/lora_manager.py +101 -269
  25. sglang/srt/lora/mem_pool.py +174 -0
  26. sglang/srt/lora/triton_ops/__init__.py +7 -1
  27. sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
  28. sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
  29. sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
  30. sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
  31. sglang/srt/lora/utils.py +141 -0
  32. sglang/srt/model_executor/cuda_graph_runner.py +4 -0
  33. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  34. sglang/srt/speculative/eagle_utils.py +64 -21
  35. sglang/srt/speculative/eagle_worker.py +1 -0
  36. sglang/version.py +1 -1
  37. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
  38. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +41 -24
  39. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
  40. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
  41. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,170 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from sglang.srt.lora.utils import LoRABatchInfo
6
+
7
+
8
+ @triton.jit
9
+ def _gate_up_lora_b_kernel(
10
+ # Pointers to matrices
11
+ x,
12
+ weights,
13
+ output,
14
+ # Parameters of size
15
+ K, # K = R
16
+ output_dim,
17
+ # Strides
18
+ x_stride_0,
19
+ x_stride_1,
20
+ w_stride_0,
21
+ w_stride_1,
22
+ w_stride_2,
23
+ output_stride_0,
24
+ output_stride_1,
25
+ # Information on sequence lengths and weight id
26
+ seg_lens,
27
+ seg_indptr,
28
+ weight_indices,
29
+ # Meta parameters
30
+ BLOCK_S: tl.constexpr,
31
+ BLOCK_N: tl.constexpr,
32
+ BLOCK_K: tl.constexpr,
33
+ # For fused output scaling and adding
34
+ fuse_scaling_add,
35
+ scaling,
36
+ ):
37
+ # This kernel packs 2 sgemms (gate/up) into a single kernel.
38
+
39
+ # x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
40
+ # weights: (num_lora, 2 * output_dim, K)
41
+ # output: (s, 2 * output_dim)
42
+ # output_dim >> K
43
+
44
+ # Current block computes sequence with batch_id,
45
+ # which starts from row seg_start of x with length seg_len.
46
+ # gate_up_id decides which of gate or up (0: gate, 1: up)
47
+ batch_id = tl.program_id(axis=2)
48
+ gate_up_id = tl.program_id(axis=1)
49
+ pid = tl.program_id(axis=0)
50
+ seg_len = tl.load(seg_lens + batch_id)
51
+ w_index = tl.load(weight_indices + batch_id)
52
+ seg_start = tl.load(seg_indptr + batch_id)
53
+ n_start = gate_up_id * output_dim # offset on output dim
54
+
55
+ # The tile in output matrix will have (pid_s, pid_n) as id
56
+ num_pid_n = tl.cdiv(output_dim, BLOCK_N)
57
+ pid_s = pid // num_pid_n
58
+ pid_n = pid % num_pid_n
59
+
60
+ # Create pointers for the first block of x and weights
61
+ # The pointers will be advanced as we move in the K direction
62
+ # and accumulate
63
+ s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
64
+ n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
65
+ k_offset = tl.arange(0, BLOCK_K)
66
+
67
+ x_ptrs = (x + seg_start * x_stride_0 + (gate_up_id * K) * x_stride_1) + (
68
+ s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
69
+ )
70
+ w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
71
+ k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
72
+ )
73
+
74
+ # Iteate to compute the block in output matrix
75
+ partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
76
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
77
+ x_tile = tl.load(
78
+ x_ptrs,
79
+ mask=(s_offset[:, None] < seg_len)
80
+ and (k_offset[None, :] < K - k * BLOCK_K),
81
+ other=0.0,
82
+ )
83
+ w_tile = tl.load(
84
+ w_ptrs,
85
+ mask=(k_offset[:, None] < K - k * BLOCK_K)
86
+ and (n_offset[None, :] < output_dim),
87
+ other=0.0,
88
+ )
89
+ partial_sum += tl.dot(x_tile, w_tile)
90
+
91
+ x_ptrs += BLOCK_K * x_stride_1
92
+ w_ptrs += BLOCK_K * w_stride_2
93
+
94
+ # Store result to output matrix
95
+ partial_sum *= scaling
96
+ partial_sum = partial_sum.to(x.dtype.element_ty)
97
+ output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
98
+ s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
99
+ )
100
+ output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim)
101
+ if fuse_scaling_add:
102
+ partial_sum += tl.load(output_ptr, mask=output_mask)
103
+ tl.store(output_ptr, partial_sum, mask=output_mask)
104
+
105
+
106
+ def gate_up_lora_b_fwd(
107
+ x: torch.Tensor,
108
+ gate_up_lora_b: torch.Tensor,
109
+ batch_info: LoRABatchInfo,
110
+ output_dim: int,
111
+ base_output: torch.Tensor = None,
112
+ scaling: float = 1.0,
113
+ ) -> torch.Tensor:
114
+
115
+ # x: (s, 2 * r)
116
+ # gate_up_lora_b: (num_lora, 2 * output_dim, r)
117
+ # output: (s, 2 * output_dim)
118
+
119
+ # Compute lora_output with shape (s, output_dim) as follows:
120
+ # lora_output[:, :output_dim] = sgemm(x[:, :r], gate_up_lora_b[:, :output_dim, :])
121
+ # lora_output[:, output_dim:]
122
+ # = sgemm(x[:, r:], gate_up_lora_b[:, output_dim:, :])
123
+
124
+ # Get dims
125
+ s = x.shape[0]
126
+ input_dim = x.shape[1]
127
+ r = gate_up_lora_b.shape[-1]
128
+ assert input_dim == 2 * r
129
+
130
+ BLOCK_S = 16
131
+ BLOCK_R = 16
132
+ BLOCK_OUT = 64
133
+
134
+ grid_b = (
135
+ triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(output_dim, BLOCK_OUT),
136
+ 2, # this dimension decides current block computes on gate or up proj
137
+ batch_info.bs,
138
+ )
139
+
140
+ if base_output is None:
141
+ output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype)
142
+ fuse_scaling_add = False
143
+ else:
144
+ output = base_output
145
+ fuse_scaling_add = True
146
+
147
+ _gate_up_lora_b_kernel[grid_b](
148
+ x,
149
+ gate_up_lora_b,
150
+ output,
151
+ r,
152
+ output_dim,
153
+ x.stride(0),
154
+ x.stride(1),
155
+ gate_up_lora_b.stride(0),
156
+ gate_up_lora_b.stride(1),
157
+ gate_up_lora_b.stride(2),
158
+ output.stride(0),
159
+ output.stride(1),
160
+ batch_info.seg_lens,
161
+ batch_info.seg_indptr,
162
+ batch_info.weight_indices,
163
+ BLOCK_S,
164
+ BLOCK_OUT,
165
+ BLOCK_R,
166
+ fuse_scaling_add,
167
+ scaling,
168
+ )
169
+
170
+ return output
@@ -2,7 +2,7 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from sglang.srt.lora.lora import LoraBatchInfo
5
+ from sglang.srt.lora.utils import LoRABatchInfo
6
6
 
7
7
 
8
8
  @triton.jit
@@ -108,7 +108,7 @@ def _qkv_lora_b_kernel(
108
108
  def qkv_lora_b_fwd(
109
109
  x: torch.Tensor,
110
110
  qkv_lora_b: torch.Tensor,
111
- batch_info: LoraBatchInfo,
111
+ batch_info: LoRABatchInfo,
112
112
  output_offset: torch.Tensor,
113
113
  max_qkv_out_dim: int,
114
114
  base_output: torch.Tensor = None,
@@ -123,11 +123,11 @@ def qkv_lora_b_fwd(
123
123
  # output: (s, output_dim_q + 2 * output_dim_kv)
124
124
 
125
125
  # Compute lora_output with shape (s, output_dim) as follows:
126
- # lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], )
126
+ # lora_output[:, :output_dim_q] = sgemm(x[:, :r], qkv_lora_b[:, :outptu_dim_q, :])
127
127
  # lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
128
- # = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0])
128
+ # = sgemm(x[:, r: 2 * r], qkv_lora_b[:, outptu_dim_q: output_dim_q + output_dim_kv, :])
129
129
  # lora_output[:, output_dim_q + output_dim_kv: ]
130
- # = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1])
130
+ # = sgemm(x[:, 2 * r: , qkv_lora_b[:, output_dim_q + output_dim_kv: , :])
131
131
 
132
132
  # Get dims
133
133
  s = x.shape[0]
@@ -2,7 +2,7 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from sglang.srt.lora.lora import LoraBatchInfo
5
+ from sglang.srt.lora.utils import LoRABatchInfo
6
6
 
7
7
 
8
8
  @triton.jit
@@ -91,7 +91,7 @@ def _sgemm_lora_a_kernel(
91
91
 
92
92
 
93
93
  def sgemm_lora_a_fwd(
94
- x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo
94
+ x: torch.Tensor, weights: torch.Tensor, batch_info: LoRABatchInfo
95
95
  ) -> torch.Tensor:
96
96
  # x: (s, input_dim)
97
97
  # weights: (num_lora, r, input_dim)
@@ -2,7 +2,7 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from sglang.srt.lora.lora import LoraBatchInfo
5
+ from sglang.srt.lora.utils import LoRABatchInfo
6
6
 
7
7
 
8
8
  @triton.jit
@@ -98,7 +98,7 @@ def _sgemm_lora_b_kernel(
98
98
  def sgemm_lora_b_fwd(
99
99
  x: torch.Tensor,
100
100
  weights: torch.Tensor,
101
- batch_info: LoraBatchInfo,
101
+ batch_info: LoRABatchInfo,
102
102
  base_output: torch.Tensor = None,
103
103
  scaling: float = 1.0,
104
104
  ) -> torch.Tensor:
@@ -0,0 +1,141 @@
1
+ import re
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ from typing import Optional, Set, Tuple
5
+
6
+ import torch
7
+
8
+ from sglang.srt.hf_transformers_utils import AutoConfig
9
+
10
+
11
+ @dataclass
12
+ class LoRABatchInfo:
13
+ # Batch size
14
+ bs: int
15
+
16
+ # Lengths of each sequence in shape (bs,)
17
+ seg_lens: torch.Tensor
18
+
19
+ # Indice pointers of each sequence in shape (bs + 1, )
20
+ seg_indptr: torch.Tensor
21
+
22
+ # Maximum sequence length of current batch
23
+ max_len: int
24
+
25
+ # The index of lora adapter used by each sequence, in shape (bs,)
26
+ weight_indices: torch.Tensor
27
+
28
+
29
+ class LoRAType(Enum):
30
+ LORA_A = 0
31
+ LORA_B = 1
32
+
33
+
34
+ def get_layer_id(name: str) -> int:
35
+ """
36
+ Extract integer id of layer from its name in string.
37
+ """
38
+ match = re.search(r"layers\.(\d+)\.", name)
39
+ if match is None:
40
+ return None
41
+ return int(match.group(1))
42
+
43
+
44
+ def get_customized_names_from_hf_names(
45
+ hf_module_names: Set[str], base_model: torch.nn.Module
46
+ ) -> Set[str]:
47
+ """
48
+ This function takes in a set of huggingface style module names:
49
+ e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
50
+ and outputs a set of module names of customized sglang layers:
51
+ e.g., {"qkv_proj", "o_proj"}
52
+ """
53
+ if hasattr(base_model, "get_module_name"):
54
+ return {base_model.get_module_name(name) for name in hf_module_names}
55
+ else:
56
+ """
57
+ Fallback solution of mapping from config module name to module name in model class.
58
+ Please check if it aligns with your base model.
59
+ Please implement the function in the model class if it is not.
60
+ You can reference this function in llama.py.
61
+ """
62
+ params_mapping = {
63
+ "q_proj": "qkv_proj",
64
+ "k_proj": "qkv_proj",
65
+ "v_proj": "qkv_proj",
66
+ "gate_proj": "gate_up_proj",
67
+ "up_proj": "gate_up_proj",
68
+ }
69
+ return {params_mapping.get(name, name) for name in hf_module_names}
70
+
71
+
72
+ def get_hidden_dim(
73
+ module_name: str, config: AutoConfig, base_model: torch.nn.Module
74
+ ) -> Tuple[int]:
75
+ """
76
+ Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
77
+ """
78
+
79
+ if hasattr(base_model, "get_hidden_dim"):
80
+ return base_model.get_hidden_dim(module_name)
81
+ else:
82
+ """
83
+ WARNING: get_hidden_dim() is not defined,
84
+ which is used to get the hidden dim for different lora modules
85
+ Use the default one, but please check if it is correct for your model.
86
+ Please implement the function in the model class if it is not.
87
+ You can reference this function in llama.py.
88
+ """
89
+ if module_name in ["q_proj", "o_proj", "qkv_proj"]:
90
+ return config.hidden_size, config.hidden_size
91
+ elif module_name in ["kv_proj"]:
92
+ return config.hidden_size, config.hidden_size // (
93
+ config.num_attention_heads // config.num_key_value_heads
94
+ )
95
+ elif module_name == "gate_up_proj":
96
+ return config.hidden_size, config.intermediate_size
97
+ elif module_name == "down_proj":
98
+ return config.intermediate_size, config.hidden_size
99
+ else:
100
+ raise NotImplementedError()
101
+
102
+
103
+ def get_stacked_name(name: str) -> Tuple[str]:
104
+ """
105
+ Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
106
+ """
107
+ params_mapping = {
108
+ "q_proj": ("qkv_proj", "q_proj"),
109
+ "k_proj": ("qkv_proj", "kv_proj"),
110
+ "v_proj": ("qkv_proj", "kv_proj"),
111
+ "gate_proj": ("gate_up_proj", "gate_up_proj"),
112
+ "up_proj": ("gate_up_proj", "gate_up_proj"),
113
+ }
114
+ return params_mapping.get(name, (name, name))
115
+
116
+
117
+ def get_stacked_multiply(module_name: str) -> int:
118
+ """
119
+ Mapping a lora module name to its magnification at output dimension
120
+ """
121
+ stacked_rank = {
122
+ "qkv_proj": 3,
123
+ "kv_proj": 2,
124
+ "gate_up_proj": 2,
125
+ }
126
+ return stacked_rank[module_name] if module_name in stacked_rank else 1
127
+
128
+
129
+ def get_weight_name(
130
+ target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType
131
+ ) -> Optional[str]:
132
+ """
133
+ target_name is name of a given module,
134
+ lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
135
+ If there is a weight name in lora_weight_names that can match target_name, return this name
136
+ Else return None
137
+ """
138
+ idx = 0 if lora_type == LoRAType.LORA_A else 1
139
+ for weight_name_pair in lora_weight_names:
140
+ if weight_name_pair[idx] in target_name:
141
+ return weight_name_pair[idx]
@@ -237,6 +237,7 @@ class CudaGraphRunner:
237
237
  "1. disable cuda graph by --disable-cuda-graph\n"
238
238
  "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
239
239
  "3. disable torch compile by not using --enable-torch-compile\n"
240
+ "4. set --cuda-graph-max-bs to a smaller value (e.g., 32)\n"
240
241
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
241
242
  )
242
243
 
@@ -462,8 +463,11 @@ class CudaGraphRunner:
462
463
  ),
463
464
  positions=None,
464
465
  retrive_index=None,
466
+ retrive_next_token=None,
467
+ retrive_next_sibling=None,
465
468
  retrive_cum_len=None,
466
469
  draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
470
+ spec_steps=self.model_runner.server_args.speculative_num_steps,
467
471
  capture_hidden_mode=CaptureHiddenMode.FULL,
468
472
  )
469
473
 
@@ -85,6 +85,7 @@ class EAGLEDraftCudaGraphRunner:
85
85
  "1. disable cuda graph by --disable-cuda-graph\n"
86
86
  "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
87
87
  "3. disable torch compile by not using --enable-torch-compile\n"
88
+ "4. specify --dtype to the same dtype (e.g. bfloat16)\n"
88
89
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
89
90
  )
90
91
 
@@ -4,6 +4,7 @@ import dataclasses
4
4
  from typing import TYPE_CHECKING, List
5
5
 
6
6
  import torch
7
+ import torch.nn.functional as F
7
8
  import triton
8
9
  import triton.language as tl
9
10
 
@@ -11,7 +12,14 @@ from sglang.srt.layers.attention.flashinfer_backend import (
11
12
  create_flashinfer_kv_indices_triton,
12
13
  )
13
14
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
14
- from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
15
+ from sglang.srt.speculative.build_eagle_tree import (
16
+ build_tree_kernel,
17
+ build_tree_kernel_efficient,
18
+ )
19
+ from sglang.srt.utils import is_cuda_available
20
+
21
+ if is_cuda_available():
22
+ from sgl_kernel import tree_speculative_sampling_target_only
15
23
 
16
24
  if TYPE_CHECKING:
17
25
  from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -160,8 +168,11 @@ class EagleVerifyInput:
160
168
  custom_mask: torch.Tensor
161
169
  positions: torch.Tensor
162
170
  retrive_index: torch.Tensor
171
+ retrive_next_token: torch.Tensor
172
+ retrive_next_sibling: torch.Tensor
163
173
  retrive_cum_len: torch.Tensor
164
174
  draft_token_num: int
175
+ spec_steps: int
165
176
  capture_hidden_mode: CaptureHiddenMode
166
177
 
167
178
  @classmethod
@@ -175,10 +186,45 @@ class EagleVerifyInput:
175
186
  seq_lens_sum: int,
176
187
  topk: int,
177
188
  spec_steps: int,
178
- num_verify_token: int,
189
+ num_verify_tokens: int,
190
+ is_all_greedy: bool,
179
191
  ):
180
- tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
181
- build_tree_kernel(
192
+ if is_all_greedy:
193
+ tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
194
+ build_tree_kernel(
195
+ verified_id,
196
+ score_list, # b, n, topk; n= 1 + (num_steps-1) * self.topk
197
+ token_list,
198
+ parents_list,
199
+ seq_lens,
200
+ seq_lens_sum,
201
+ topk,
202
+ spec_steps,
203
+ num_verify_tokens,
204
+ )
205
+ )
206
+
207
+ return cls(
208
+ draft_tokens,
209
+ tree_mask,
210
+ position,
211
+ retrive_index,
212
+ None,
213
+ None,
214
+ retrive_cum_len,
215
+ num_verify_tokens,
216
+ spec_steps,
217
+ CaptureHiddenMode.FULL,
218
+ )
219
+ else:
220
+ (
221
+ tree_mask,
222
+ position,
223
+ retrive_index,
224
+ retrive_next_token,
225
+ retrive_next_sibling,
226
+ draft_tokens,
227
+ ) = build_tree_kernel_efficient(
182
228
  verified_id,
183
229
  score_list,
184
230
  token_list,
@@ -187,18 +233,21 @@ class EagleVerifyInput:
187
233
  seq_lens_sum,
188
234
  topk,
189
235
  spec_steps,
190
- num_verify_token,
236
+ num_verify_tokens,
237
+ )
238
+
239
+ return cls(
240
+ draft_tokens,
241
+ tree_mask,
242
+ position,
243
+ retrive_index,
244
+ retrive_next_token,
245
+ retrive_next_sibling,
246
+ None,
247
+ num_verify_tokens,
248
+ spec_steps,
249
+ CaptureHiddenMode.FULL,
191
250
  )
192
- )
193
- return cls(
194
- draft_tokens,
195
- tree_mask,
196
- position,
197
- retrive_index,
198
- retrive_cum_len,
199
- num_verify_token,
200
- CaptureHiddenMode.FULL,
201
- )
202
251
 
203
252
  def prepare_for_verify(self, batch: ScheduleBatch):
204
253
  batch.input_ids = self.draft_token
@@ -313,12 +362,6 @@ class EagleVerifyInput:
313
362
  uniform_samples=coins,
314
363
  target_probs=target_probs,
315
364
  draft_probs=draft_probs,
316
- threshold_single=global_server_args_dict[
317
- "speculative_accept_threshold_single"
318
- ],
319
- threshold_acc=global_server_args_dict[
320
- "speculative_accept_threshold_acc"
321
- ],
322
365
  deterministic=True,
323
366
  )
324
367
 
@@ -185,6 +185,7 @@ class EAGLEWorker(TpModelWorker):
185
185
  self.topk,
186
186
  self.speculative_num_steps,
187
187
  self.server_args.speculative_num_draft_tokens,
188
+ batch.sampling_info.is_all_greedy,
188
189
  )
189
190
 
190
191
  # Free cache locations
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.2.post3"
1
+ __version__ = "0.4.2.post4"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: sglang
3
- Version: 0.4.2.post3
3
+ Version: 0.4.2.post4
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -239,11 +239,11 @@ Requires-Dist: xgrammar>=0.1.10; extra == "runtime-common"
239
239
  Provides-Extra: srt
240
240
  Requires-Dist: sglang[runtime_common]; extra == "srt"
241
241
  Requires-Dist: cuda-python; extra == "srt"
242
- Requires-Dist: sgl-kernel>=0.0.3.post2; extra == "srt"
242
+ Requires-Dist: sgl-kernel>=0.0.3.post3; extra == "srt"
243
243
  Requires-Dist: torch; extra == "srt"
244
- Requires-Dist: vllm==0.6.4.post1; extra == "srt"
244
+ Requires-Dist: vllm<=0.7.2,>=0.6.4.post1; extra == "srt"
245
245
  Requires-Dist: flashinfer_python>=0.2.0.post2; extra == "srt"
246
- Requires-Dist: outlines<0.1.0,>=0.0.44; extra == "srt"
246
+ Requires-Dist: outlines<=0.1.11,>=0.0.44; extra == "srt"
247
247
  Provides-Extra: srt-hip
248
248
  Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
249
249
  Requires-Dist: torch; extra == "srt-hip"