sglang 0.4.2.post3__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sglang/check_env.py +1 -0
  2. sglang/srt/constrained/outlines_backend.py +4 -1
  3. sglang/srt/layers/attention/flashinfer_backend.py +34 -41
  4. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -3
  5. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  6. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  7. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  8. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  9. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  10. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  11. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  12. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  13. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  15. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/lora/backend/__init__.py +25 -5
  19. sglang/srt/lora/backend/base_backend.py +31 -9
  20. sglang/srt/lora/backend/flashinfer_backend.py +41 -4
  21. sglang/srt/lora/backend/triton_backend.py +34 -4
  22. sglang/srt/lora/layers.py +293 -0
  23. sglang/srt/lora/lora.py +101 -326
  24. sglang/srt/lora/lora_manager.py +101 -269
  25. sglang/srt/lora/mem_pool.py +174 -0
  26. sglang/srt/lora/triton_ops/__init__.py +7 -1
  27. sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
  28. sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
  29. sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
  30. sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
  31. sglang/srt/lora/utils.py +141 -0
  32. sglang/srt/model_executor/cuda_graph_runner.py +4 -0
  33. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  34. sglang/srt/speculative/eagle_utils.py +64 -21
  35. sglang/srt/speculative/eagle_worker.py +1 -0
  36. sglang/version.py +1 -1
  37. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
  38. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +41 -24
  39. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
  40. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
  41. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora.py CHANGED
@@ -19,282 +19,25 @@
19
19
  # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
20
20
 
21
21
  import re
22
- from dataclasses import dataclass
22
+ from typing import Dict, List
23
23
 
24
24
  import torch
25
25
  from torch import nn
26
26
 
27
- from sglang.srt.layers.linear import (
28
- ColumnParallelLinear,
29
- MergedColumnParallelLinear,
30
- QKVParallelLinear,
31
- RowParallelLinear,
32
- )
33
- from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
27
+ from sglang.srt.configs.load_config import LoadConfig
28
+ from sglang.srt.hf_transformers_utils import AutoConfig
29
+ from sglang.srt.lora.backend import BaseLoRABackend
30
+ from sglang.srt.lora.lora_config import LoRAConfig
34
31
  from sglang.srt.model_loader.loader import DefaultModelLoader
35
32
 
36
33
 
37
- @dataclass
38
- class LoraBatchInfo:
39
- # Batch size
40
- bs: int
41
-
42
- # Lengths of each sequence in shape (bs,)
43
- seg_lens: torch.Tensor
44
-
45
- # Indice pointers of each sequence in shape (bs + 1, )
46
- seg_indptr: torch.Tensor
47
-
48
- # Maximum sequence length of current batch
49
- max_len: int
50
-
51
- # The index of lora adapter used by each sequence, in shape (bs,)
52
- weight_indices: torch.Tensor
53
-
54
-
55
- class BaseLayerWithLoRA(nn.Module):
56
- def __init__(self, base_layer, lora_rank, scaling, lora_backend):
57
- super().__init__()
58
- self.base_layer = base_layer
59
- self.lora_rank = lora_rank
60
- self.scaling = scaling
61
- self.set_lora = False
62
- self.lora_backend = lora_backend
63
-
64
- def forward(self, x: torch.Tensor):
65
- return self.base_layer.forward(x)
66
-
67
- def set_lora_info(self, *args):
68
- pass
69
-
70
-
71
- class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
72
- def __init__(
73
- self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
74
- ) -> None:
75
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
76
- self.weight = base_layer.weight
77
-
78
-
79
- class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
80
- def __init__(
81
- self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
82
- ) -> None:
83
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
84
-
85
- def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
86
- # TODO
87
- return output
88
-
89
- def forward(self, input_: torch.Tensor):
90
- # duplicate the logic in ColumnParallelLinear
91
- bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
92
- output_parallel = self.base_layer.quant_method.apply(
93
- self.base_layer, input_, bias
94
- )
95
-
96
- if self.set_lora:
97
- output_parallel = self.apply_lora(output_parallel, input_)
98
-
99
- if self.base_layer.gather_output:
100
- output = tensor_model_parallel_all_gather(output_parallel)
101
- else:
102
- output = output_parallel
103
- output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
104
- return output, output_bias
105
-
106
-
107
- class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
108
- def __init__(
109
- self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
110
- ) -> None:
111
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
112
-
113
- def set_lora_info(
114
- self,
115
- A_buffer,
116
- B_buffer,
117
- ):
118
- self.set_lora = True
119
- self.A_buffer = A_buffer
120
- self.B_buffer = B_buffer
121
-
122
- def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
123
- lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
124
-
125
- output_dim = base_output.shape[-1]
126
- lora_output = torch.empty_like(base_output)
127
- lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
128
- x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
129
- weights=self.B_buffer[0],
130
- )
131
-
132
- lora_output[:, output_dim : 2 * output_dim] = (
133
- self.lora_backend.run_lora_b_sgemm(
134
- x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
135
- weights=self.B_buffer[1],
136
- )
137
- )
138
-
139
- return base_output + lora_output * self.scaling
140
-
141
-
142
- class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
143
- def init__(
144
- self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
145
- ) -> None:
146
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
147
-
148
- def set_lora_info(
149
- self,
150
- A_buffer_qkv,
151
- B_buffer_q,
152
- B_buffer_kv,
153
- ):
154
- self.set_lora = True
155
- self.A_buffer_qkv = A_buffer_qkv
156
-
157
- if self.lora_backend.fuse_qkv_lora_b:
158
- assert (
159
- B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
160
- ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
161
- output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
162
-
163
- # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
164
- self.B_buffer_qkv = torch.cat(
165
- (B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
166
- ).contiguous()
167
-
168
- # Offsets of q/k/v in output dimension
169
- self.output_offset = torch.tensor(
170
- [
171
- 0,
172
- output_dim_q,
173
- output_dim_q + output_dim_kv,
174
- output_dim_q + 2 * output_dim_kv,
175
- ],
176
- dtype=torch.int32,
177
- device=B_buffer_q.device,
178
- )
179
- # For computing number of launched blocks
180
- self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
181
- else:
182
- self.B_buffer_qkv = (
183
- B_buffer_q,
184
- B_buffer_kv,
185
- )
186
- self.output_offset = None
187
- self.max_qkv_out_dim = None
188
-
189
- def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
190
- lora_output = self.lora_backend.run_qkv_lora(
191
- x,
192
- self.A_buffer_qkv,
193
- self.B_buffer_qkv,
194
- output_offset=self.output_offset,
195
- max_qkv_out_dim=self.max_qkv_out_dim,
196
- base_output=base_output,
197
- scaling=self.scaling,
198
- )
199
- return (
200
- lora_output
201
- if self.lora_backend.fuse_output_scaling_add
202
- else base_output + lora_output * self.scaling
203
- )
204
-
205
-
206
- class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
207
- def __init__(
208
- self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
209
- ) -> None:
210
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
211
-
212
- def set_lora_info(self, A_buffer, B_buffer):
213
- self.set_lora = True
214
- self.A_buffer = A_buffer
215
- self.B_buffer = B_buffer
216
-
217
- def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
218
- lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
219
- lora_output = self.lora_backend.run_lora_b_sgemm(
220
- lora_a_output,
221
- self.B_buffer[0],
222
- base_output=base_output,
223
- scaling=self.scaling,
224
- )
225
- return (
226
- lora_output
227
- if self.lora_backend.fuse_output_scaling_add
228
- else base_output + lora_output * self.scaling
229
- )
230
-
231
- def forward(self, input_):
232
- # duplicate the logic in RowParallelLinear
233
- if self.base_layer.input_is_parallel:
234
- input_parallel = input_
235
- else:
236
- tp_rank = get_tensor_model_parallel_rank()
237
- splitted_input = split_tensor_along_last_dim(
238
- input_, num_partitions=self.base_layer.tp_size
239
- )
240
- input_parallel = splitted_input[tp_rank].contiguous()
241
- output_parallel = self.base_layer.quant_method.apply(
242
- self.base_layer, input_parallel
243
- )
244
-
245
- if self.set_lora:
246
- output_parallel = self.apply_lora(output_parallel, input_parallel)
247
-
248
- if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
249
- output_ = tensor_model_parallel_all_reduce(output_parallel)
250
- else:
251
- output_ = output_parallel
252
-
253
- if not self.base_layer.skip_bias_add:
254
- output = (
255
- output_ + self.base_layer.bias
256
- if self.base_layer.bias is not None
257
- else output_
258
- )
259
- output_bias = None
260
- else:
261
- output = output_
262
- output_bias = self.base_layer.bias
263
- return output, output_bias
264
-
265
-
266
- def get_lora_layer(
267
- layer: nn.Module, lora_rank, scaling, lora_backend
268
- ) -> BaseLayerWithLoRA:
269
- supported_layer_types = {
270
- # the order matters
271
- VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
272
- QKVParallelLinear: QKVParallelLinearWithLoRA,
273
- MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
274
- ColumnParallelLinear: ColumnParallelLinearWithLoRA,
275
- RowParallelLinear: RowParallelLinearWithLoRA,
276
- }
277
- for src_layer_type, lora_layer_type in supported_layer_types.items():
278
- if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
279
- ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
280
- return ret
281
- raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
282
-
283
-
284
- def get_mapped_params(module_names):
285
- ret = set()
286
- for module_name in module_names:
287
- ret.add(params_mapping(module_name))
288
- return list(ret)
289
-
290
-
291
34
  class LoRALayer(nn.Module):
292
- def __init__(self, config, base_hf_config):
35
+ def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
293
36
  super().__init__()
294
- self.config = config
295
- self.base_hf_config = base_hf_config
296
- self.weights = {}
297
- self.weight_gpu = {}
37
+ self.config: LoRAConfig = config
38
+ self.base_hf_config: AutoConfig = base_hf_config
39
+ self.weights: Dict[str, torch.Tensor] = {}
40
+ self.weight_gpu: Dict[str, torch.Tensor] = {}
298
41
 
299
42
  def load_to_gpu(self):
300
43
  for name, weight in self.weights.items():
@@ -306,33 +49,32 @@ class LoRALayer(nn.Module):
306
49
 
307
50
 
308
51
  class LoRAAdapter(nn.Module):
309
- def __init__(self, uid, config, base_hf_config, load_config, lora_backend):
52
+ def __init__(
53
+ self,
54
+ uid: str,
55
+ config: LoRAConfig,
56
+ base_hf_config: AutoConfig,
57
+ load_config: LoadConfig,
58
+ lora_backend: BaseLoRABackend,
59
+ ):
310
60
  super().__init__()
311
- self.uid = uid
312
- self.config = config
61
+ self.uid: str = uid
62
+ self.config: LoRAConfig = config
313
63
  assert self.config.hf_config["peft_type"].lower() == "lora"
314
- self.base_hf_config = base_hf_config
315
- self.load_config = load_config
316
- self.lora_backend = lora_backend
317
- self.scaling = self.config.lora_alpha / self.config.r
64
+ self.base_hf_config: AutoConfig = base_hf_config
65
+ self.load_config: LoadConfig = load_config
66
+ self.lora_backend: BaseLoRABackend = lora_backend
67
+ self.scaling: float = self.config.lora_alpha / self.config.r
318
68
 
319
- self.layers = nn.ModuleList(
69
+ self.layers: List[LoRALayer] = nn.ModuleList(
320
70
  [
321
71
  LoRALayer(config, base_hf_config)
322
72
  for i in range(base_hf_config.num_hidden_layers)
323
73
  ]
324
74
  )
325
75
 
326
- self.weights = {}
327
- self.weights_gpu = {}
328
-
329
- def get_stacked_multiply(self, module_name):
330
- stacked_rank = {
331
- "qkv_proj": 3,
332
- "kv_proj": 2,
333
- "gate_up_proj": 2,
334
- }
335
- return stacked_rank[module_name] if module_name in stacked_rank else 1
76
+ self.weights: Dict[str, torch.Tensor] = {}
77
+ self.weights_gpu: Dict[str, torch.Tensor] = {}
336
78
 
337
79
  def load_to_gpu(self):
338
80
  for name, weight in self.weights.items():
@@ -367,44 +109,77 @@ class LoRAAdapter(nn.Module):
367
109
  for i in range(self.base_hf_config.num_hidden_layers):
368
110
  layer = self.layers[i]
369
111
  weight_names = [name for name, _ in layer.weights.items()]
370
- for weight_name in weight_names:
371
- if "k_proj" in weight_name:
372
- q_name = weight_name.replace("k_proj", "q_proj")
373
- v_name = weight_name.replace("k_proj", "v_proj")
374
- kv_name = weight_name.replace("k_proj", "kv_proj")
375
- qkv_name = weight_name.replace("k_proj", "qkv_proj")
376
- if "lora_A" in weight_name:
377
- layer.weights[qkv_name] = torch.cat(
378
- (
379
- layer.weights[q_name],
380
- layer.weights[weight_name],
381
- layer.weights[v_name],
382
- ),
383
- 0,
384
- )
385
- layer.weights.pop(q_name)
386
- layer.weights.pop(weight_name)
387
- layer.weights.pop(v_name)
388
- else:
389
- layer.weights[kv_name] = torch.stack(
390
- [
391
- layer.weights[weight_name],
392
- layer.weights[v_name],
393
- ],
394
- dim=0,
395
- )
396
- layer.weights.pop(weight_name)
397
- layer.weights.pop(v_name)
398
- elif "gate_proj" in weight_name:
399
- up_name = weight_name.replace("gate_proj", "up_proj")
400
- gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
401
- if "lora_A" in weight_name:
402
- layer.weights[gate_up_name] = torch.cat(
403
- (layer.weights[weight_name], layer.weights[up_name]), 0
404
- )
405
- else:
406
- layer.weights[gate_up_name] = torch.stack(
407
- [layer.weights[weight_name], layer.weights[up_name]], dim=0
408
- )
409
- layer.weights.pop(weight_name)
410
- layer.weights.pop(up_name)
112
+ self.stack_qkv_proj(weight_names, layer.weights)
113
+ self.stack_gate_up_proj(weight_names, layer.weights)
114
+
115
+ def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
116
+
117
+ # Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
118
+ target_module = set()
119
+ for weight_name in weight_names:
120
+ if "k_proj" in weight_name:
121
+ target_module.add("k_proj")
122
+ if "q_proj" in weight_name:
123
+ target_module.add("q_proj")
124
+ if "v_proj" in weight_name:
125
+ target_module.add("v_proj")
126
+ if len(target_module) == 0:
127
+ return
128
+
129
+ for weight_name in weight_names:
130
+ # We assume every lora adaptor should contain lora modules for q_proj
131
+ if "q_proj" in weight_name:
132
+ q_name = weight_name
133
+ k_name = weight_name.replace("q_proj", "k_proj")
134
+ v_name = weight_name.replace("q_proj", "v_proj")
135
+ kv_name = weight_name.replace("q_proj", "kv_proj")
136
+ qkv_name = weight_name.replace("q_proj", "qkv_proj")
137
+
138
+ # If k_proj doesn't have lora, initialize it to zero
139
+ k_proj_weight = (
140
+ weights[k_name]
141
+ if "k_proj" in target_module
142
+ else torch.zeros_like(weights[v_name])
143
+ )
144
+ if "lora_A" in weight_name:
145
+ weights[qkv_name] = torch.cat(
146
+ (
147
+ weights[q_name],
148
+ k_proj_weight,
149
+ weights[v_name],
150
+ ),
151
+ 0,
152
+ )
153
+ weights.pop(q_name)
154
+ if "k_proj" in target_module:
155
+ weights.pop(k_name)
156
+ weights.pop(v_name)
157
+ else:
158
+ weights[kv_name] = torch.stack(
159
+ [
160
+ k_proj_weight,
161
+ weights[v_name],
162
+ ],
163
+ dim=0,
164
+ )
165
+ if "k_proj" in target_module:
166
+ weights.pop(k_name)
167
+ weights.pop(v_name)
168
+
169
+ def stack_gate_up_proj(
170
+ self, weight_names: List[str], weights: Dict[str, torch.Tensor]
171
+ ):
172
+ for weight_name in weight_names:
173
+ if "gate_proj" in weight_name:
174
+ up_name = weight_name.replace("gate_proj", "up_proj")
175
+ gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
176
+ if "lora_A" in weight_name:
177
+ weights[gate_up_name] = torch.cat(
178
+ (weights[weight_name], weights[up_name]), 0
179
+ )
180
+ else:
181
+ weights[gate_up_name] = torch.stack(
182
+ [weights[weight_name], weights[up_name]], dim=0
183
+ )
184
+ weights.pop(weight_name)
185
+ weights.pop(up_name)