sglang 0.4.2.post2__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. sglang/check_env.py +1 -0
  2. sglang/srt/constrained/outlines_backend.py +4 -1
  3. sglang/srt/function_call_parser.py +96 -69
  4. sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
  5. sglang/srt/layers/attention/flashinfer_backend.py +34 -41
  6. sglang/srt/layers/attention/triton_backend.py +64 -16
  7. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
  8. sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
  9. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -5
  10. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  11. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  12. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  13. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  15. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  20. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  21. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  22. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/quantization/fp8_kernel.py +43 -10
  24. sglang/srt/lora/backend/__init__.py +25 -5
  25. sglang/srt/lora/backend/base_backend.py +31 -9
  26. sglang/srt/lora/backend/flashinfer_backend.py +41 -4
  27. sglang/srt/lora/backend/triton_backend.py +34 -4
  28. sglang/srt/lora/layers.py +293 -0
  29. sglang/srt/lora/lora.py +101 -326
  30. sglang/srt/lora/lora_manager.py +101 -269
  31. sglang/srt/lora/mem_pool.py +174 -0
  32. sglang/srt/lora/triton_ops/__init__.py +7 -1
  33. sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
  34. sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
  35. sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
  36. sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
  37. sglang/srt/lora/utils.py +141 -0
  38. sglang/srt/model_executor/cuda_graph_runner.py +4 -0
  39. sglang/srt/models/llama.py +8 -3
  40. sglang/srt/speculative/build_eagle_tree.py +482 -102
  41. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  42. sglang/srt/speculative/eagle_utils.py +134 -61
  43. sglang/srt/speculative/eagle_worker.py +1 -0
  44. sglang/version.py +1 -1
  45. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
  46. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +49 -32
  47. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
  48. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
  49. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,293 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from sglang.srt.distributed import (
5
+ get_tensor_model_parallel_rank,
6
+ split_tensor_along_last_dim,
7
+ tensor_model_parallel_all_gather,
8
+ tensor_model_parallel_all_reduce,
9
+ )
10
+ from sglang.srt.layers.linear import (
11
+ ColumnParallelLinear,
12
+ MergedColumnParallelLinear,
13
+ QKVParallelLinear,
14
+ RowParallelLinear,
15
+ )
16
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
17
+ from sglang.srt.lora.backend import BaseLoRABackend
18
+
19
+
20
+ class BaseLayerWithLoRA(nn.Module):
21
+ def __init__(
22
+ self,
23
+ base_layer: nn.Module,
24
+ lora_rank: int,
25
+ scaling: float,
26
+ lora_backend: BaseLoRABackend,
27
+ ):
28
+ super().__init__()
29
+ self.base_layer: nn.Module = base_layer
30
+ self.lora_rank: int = lora_rank
31
+ self.scaling: float = scaling
32
+ self.set_lora: bool = False
33
+ self.lora_backend: BaseLoRABackend = lora_backend
34
+
35
+ def forward(self, x: torch.Tensor):
36
+ return self.base_layer.forward(x)
37
+
38
+ def set_lora_info(self, *args):
39
+ pass
40
+
41
+
42
+ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
43
+ def __init__(
44
+ self,
45
+ base_layer: VocabParallelEmbedding,
46
+ lora_rank: int,
47
+ scaling: float,
48
+ lora_backend: BaseLoRABackend,
49
+ ) -> None:
50
+ super().__init__(base_layer, lora_rank, scaling, lora_backend)
51
+ self.weight = base_layer.weight
52
+
53
+
54
+ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
55
+ def __init__(
56
+ self,
57
+ base_layer: ColumnParallelLinear,
58
+ lora_rank: int,
59
+ scaling: float,
60
+ lora_backend: BaseLoRABackend,
61
+ ) -> None:
62
+ super().__init__(base_layer, lora_rank, scaling, lora_backend)
63
+
64
+ def set_lora_info(
65
+ self,
66
+ A_buffer: torch.Tensor,
67
+ B_buffer: torch.Tensor,
68
+ ):
69
+ self.set_lora = True
70
+ self.A_buffer = A_buffer
71
+ self.B_buffer = B_buffer
72
+
73
+ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
74
+ backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
75
+ lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
76
+ lora_output = self.lora_backend.run_lora_b_sgemm(
77
+ lora_a_output,
78
+ self.B_buffer[0],
79
+ **backend_kwargs,
80
+ )
81
+ return (
82
+ lora_output
83
+ if self.lora_backend.fuse_output_scaling_add
84
+ else base_output + lora_output * self.scaling
85
+ )
86
+
87
+ def forward(self, input_: torch.Tensor):
88
+ # duplicate the logic in ColumnParallelLinear
89
+ bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
90
+ output_parallel = self.base_layer.quant_method.apply(
91
+ self.base_layer, input_, bias
92
+ )
93
+
94
+ if self.set_lora:
95
+ output_parallel = self.apply_lora(output_parallel, input_)
96
+
97
+ if self.base_layer.gather_output:
98
+ output = tensor_model_parallel_all_gather(output_parallel)
99
+ else:
100
+ output = output_parallel
101
+ output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
102
+ return output, output_bias
103
+
104
+
105
+ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
106
+ def __init__(
107
+ self,
108
+ base_layer: MergedColumnParallelLinear,
109
+ lora_rank: int,
110
+ scaling: float,
111
+ lora_backend: BaseLoRABackend,
112
+ ) -> None:
113
+ super().__init__(base_layer, lora_rank, scaling, lora_backend)
114
+
115
+ def set_lora_info(
116
+ self,
117
+ A_buffer: torch.Tensor,
118
+ B_buffer: torch.Tensor,
119
+ ):
120
+ self.set_lora = True
121
+ self.A_buffer_gate_up = A_buffer
122
+ if self.lora_backend.fuse_stacked_lora_b:
123
+ # B_buffer_gate_up: (num_lora, 2 * output_dim, r)
124
+ self.B_buffer_gate_up = torch.cat(
125
+ (B_buffer[0], B_buffer[1]), dim=-2
126
+ ).contiguous()
127
+ else:
128
+ self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
129
+
130
+ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
131
+ backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
132
+
133
+ lora_output = self.lora_backend.run_gate_up_lora(
134
+ x,
135
+ self.A_buffer_gate_up,
136
+ self.B_buffer_gate_up,
137
+ **backend_kwargs,
138
+ )
139
+ return (
140
+ lora_output
141
+ if self.lora_backend.fuse_output_scaling_add
142
+ else base_output + lora_output * self.scaling
143
+ )
144
+
145
+
146
+ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
147
+ def init__(
148
+ self,
149
+ base_layer: QKVParallelLinear,
150
+ lora_rank: int,
151
+ scaling: float,
152
+ lora_backend: BaseLoRABackend,
153
+ ) -> None:
154
+ super().__init__(base_layer, lora_rank, scaling, lora_backend)
155
+
156
+ def set_lora_info(
157
+ self,
158
+ A_buffer_qkv: torch.Tensor,
159
+ B_buffer_q: torch.Tensor,
160
+ B_buffer_kv: torch.Tensor,
161
+ ):
162
+ self.set_lora = True
163
+ self.A_buffer_qkv = A_buffer_qkv
164
+
165
+ if self.lora_backend.fuse_stacked_lora_b:
166
+ assert (
167
+ B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
168
+ ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
169
+ output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
170
+
171
+ # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
172
+ self.B_buffer_qkv = torch.cat(
173
+ (B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
174
+ ).contiguous()
175
+
176
+ # Offsets of q/k/v in output dimension
177
+ self.output_offset = torch.tensor(
178
+ [
179
+ 0,
180
+ output_dim_q,
181
+ output_dim_q + output_dim_kv,
182
+ output_dim_q + 2 * output_dim_kv,
183
+ ],
184
+ dtype=torch.int32,
185
+ device=B_buffer_q.device,
186
+ )
187
+ # For computing number of launched blocks
188
+ self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
189
+ else:
190
+ self.B_buffer_qkv = (
191
+ B_buffer_q,
192
+ B_buffer_kv,
193
+ )
194
+
195
+ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
196
+ backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
197
+ if self.lora_backend.fuse_stacked_lora_b:
198
+ backend_kwargs["output_offset"] = self.output_offset
199
+ backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
200
+
201
+ lora_output = self.lora_backend.run_qkv_lora(
202
+ x,
203
+ self.A_buffer_qkv,
204
+ self.B_buffer_qkv,
205
+ **backend_kwargs,
206
+ )
207
+ return (
208
+ lora_output
209
+ if self.lora_backend.fuse_output_scaling_add
210
+ else base_output + lora_output * self.scaling
211
+ )
212
+
213
+
214
+ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
215
+ def __init__(
216
+ self,
217
+ base_layer: RowParallelLinear,
218
+ lora_rank: int,
219
+ scaling: float,
220
+ lora_backend: BaseLoRABackend,
221
+ ) -> None:
222
+ super().__init__(base_layer, lora_rank, scaling, lora_backend)
223
+
224
+ def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor):
225
+ self.set_lora = True
226
+ self.A_buffer = A_buffer
227
+ self.B_buffer = B_buffer
228
+
229
+ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
230
+ backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
231
+ lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
232
+ lora_output = self.lora_backend.run_lora_b_sgemm(
233
+ lora_a_output,
234
+ self.B_buffer[0],
235
+ **backend_kwargs,
236
+ )
237
+ return (
238
+ lora_output
239
+ if self.lora_backend.fuse_output_scaling_add
240
+ else base_output + lora_output * self.scaling
241
+ )
242
+
243
+ def forward(self, input_: torch.Tensor):
244
+ # duplicate the logic in RowParallelLinear
245
+ if self.base_layer.input_is_parallel:
246
+ input_parallel = input_
247
+ else:
248
+ tp_rank = get_tensor_model_parallel_rank()
249
+ splitted_input = split_tensor_along_last_dim(
250
+ input_, num_partitions=self.base_layer.tp_size
251
+ )
252
+ input_parallel = splitted_input[tp_rank].contiguous()
253
+ output_parallel = self.base_layer.quant_method.apply(
254
+ self.base_layer, input_parallel
255
+ )
256
+
257
+ if self.set_lora:
258
+ output_parallel = self.apply_lora(output_parallel, input_parallel)
259
+
260
+ if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
261
+ output_ = tensor_model_parallel_all_reduce(output_parallel)
262
+ else:
263
+ output_ = output_parallel
264
+
265
+ if not self.base_layer.skip_bias_add:
266
+ output = (
267
+ output_ + self.base_layer.bias
268
+ if self.base_layer.bias is not None
269
+ else output_
270
+ )
271
+ output_bias = None
272
+ else:
273
+ output = output_
274
+ output_bias = self.base_layer.bias
275
+ return output, output_bias
276
+
277
+
278
+ def get_lora_layer(
279
+ layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend
280
+ ) -> BaseLayerWithLoRA:
281
+ supported_layer_types = {
282
+ # the order matters
283
+ VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
284
+ QKVParallelLinear: QKVParallelLinearWithLoRA,
285
+ MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
286
+ ColumnParallelLinear: ColumnParallelLinearWithLoRA,
287
+ RowParallelLinear: RowParallelLinearWithLoRA,
288
+ }
289
+ for src_layer_type, lora_layer_type in supported_layer_types.items():
290
+ if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
291
+ ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
292
+ return ret
293
+ raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")