sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_latency.py +17 -8
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +5 -17
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -4
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +33 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/activation.py +12 -0
  15. sglang/srt/layers/attention_backend.py +480 -0
  16. sglang/srt/layers/flashinfer_utils.py +235 -0
  17. sglang/srt/layers/fused_moe/layer.py +27 -7
  18. sglang/srt/layers/layernorm.py +12 -0
  19. sglang/srt/layers/logits_processor.py +64 -77
  20. sglang/srt/layers/radix_attention.py +11 -161
  21. sglang/srt/layers/sampler.py +38 -122
  22. sglang/srt/layers/torchao_utils.py +75 -0
  23. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  24. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  25. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  26. sglang/srt/lora/lora.py +403 -0
  27. sglang/srt/lora/lora_config.py +43 -0
  28. sglang/srt/lora/lora_manager.py +259 -0
  29. sglang/srt/managers/controller_multi.py +1 -5
  30. sglang/srt/managers/controller_single.py +0 -5
  31. sglang/srt/managers/io_struct.py +16 -1
  32. sglang/srt/managers/policy_scheduler.py +122 -5
  33. sglang/srt/managers/schedule_batch.py +105 -71
  34. sglang/srt/managers/tokenizer_manager.py +17 -8
  35. sglang/srt/managers/tp_worker.py +188 -121
  36. sglang/srt/model_executor/cuda_graph_runner.py +69 -133
  37. sglang/srt/model_executor/forward_batch_info.py +35 -312
  38. sglang/srt/model_executor/model_runner.py +123 -154
  39. sglang/srt/models/baichuan.py +416 -0
  40. sglang/srt/models/chatglm.py +1 -5
  41. sglang/srt/models/commandr.py +1 -5
  42. sglang/srt/models/dbrx.py +1 -5
  43. sglang/srt/models/deepseek.py +1 -5
  44. sglang/srt/models/deepseek_v2.py +7 -6
  45. sglang/srt/models/exaone.py +1 -5
  46. sglang/srt/models/gemma.py +1 -5
  47. sglang/srt/models/gemma2.py +1 -5
  48. sglang/srt/models/gpt_bigcode.py +1 -5
  49. sglang/srt/models/grok.py +1 -5
  50. sglang/srt/models/internlm2.py +1 -5
  51. sglang/srt/models/llama.py +51 -5
  52. sglang/srt/models/llama_classification.py +1 -20
  53. sglang/srt/models/llava.py +30 -5
  54. sglang/srt/models/llavavid.py +2 -2
  55. sglang/srt/models/minicpm.py +1 -5
  56. sglang/srt/models/minicpm3.py +669 -0
  57. sglang/srt/models/mixtral.py +6 -5
  58. sglang/srt/models/mixtral_quant.py +1 -5
  59. sglang/srt/models/olmoe.py +415 -0
  60. sglang/srt/models/qwen.py +1 -5
  61. sglang/srt/models/qwen2.py +1 -5
  62. sglang/srt/models/qwen2_moe.py +6 -5
  63. sglang/srt/models/stablelm.py +1 -5
  64. sglang/srt/models/xverse.py +375 -0
  65. sglang/srt/models/xverse_moe.py +445 -0
  66. sglang/srt/openai_api/adapter.py +65 -46
  67. sglang/srt/openai_api/protocol.py +11 -3
  68. sglang/srt/sampling/sampling_batch_info.py +46 -80
  69. sglang/srt/server.py +30 -15
  70. sglang/srt/server_args.py +163 -28
  71. sglang/srt/utils.py +19 -51
  72. sglang/test/few_shot_gsm8k.py +132 -0
  73. sglang/test/runners.py +114 -22
  74. sglang/test/test_programs.py +7 -5
  75. sglang/test/test_utils.py +85 -2
  76. sglang/utils.py +32 -37
  77. sglang/version.py +1 -1
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
  79. sglang-0.3.1.post1.dist-info/RECORD +130 -0
  80. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
  81. sglang-0.3.0.dist-info/RECORD +0 -118
  82. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
  83. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,403 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
17
+ # and "Punica: Multi-Tenant LoRA Serving"
18
+
19
+ # LoRA layers class inheritance adapted from:
20
+ # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
21
+
22
+
23
+ import json
24
+ import os
25
+ import re
26
+ from typing import Any, Dict, List, Optional, Tuple
27
+
28
+ import safetensors.torch
29
+ import torch
30
+ from torch import nn
31
+ from vllm.model_executor.layers.linear import (
32
+ ColumnParallelLinear,
33
+ MergedColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear,
36
+ )
37
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
38
+ ParallelLMHead,
39
+ VocabParallelEmbedding,
40
+ )
41
+ from vllm.model_executor.model_loader.loader import DefaultModelLoader
42
+
43
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
44
+
45
+
46
+ class BaseLayerWithLoRA(nn.Module):
47
+ def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
48
+ super().__init__()
49
+ self.base_layer = base_layer
50
+ self.segment_gemm = segment_gemm
51
+ self.lora_rank = lora_rank
52
+ self.scaling = scaling
53
+ self.set_lora = False
54
+
55
+ def forward(self, x: torch.Tensor):
56
+ return self.base_layer.forward(x)
57
+
58
+ def set_lora_info(self, *args):
59
+ pass
60
+
61
+
62
+ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
63
+ def __init__(
64
+ self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
65
+ ) -> None:
66
+ super().__init__(base_layer, segment_gemm, lora_rank, scaling)
67
+ self.weight = base_layer.weight
68
+
69
+
70
+ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
71
+ def __init__(
72
+ self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
73
+ ) -> None:
74
+ super().__init__(base_layer, segment_gemm, lora_rank, scaling)
75
+
76
+ def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
77
+ # TODO
78
+ return output
79
+
80
+ def forward(self, input_: torch.Tensor):
81
+ # duplicate the logic in ColumnParallelLinear
82
+ bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
83
+ output_parallel = self.base_layer.quant_method.apply(
84
+ self.base_layer, input_, bias
85
+ )
86
+
87
+ if self.set_lora:
88
+ output_parallel = self.apply_lora(output_parallel, input_)
89
+
90
+ if self.base_layer.gather_output:
91
+ output = tensor_model_parallel_all_gather(output_parallel)
92
+ else:
93
+ output = output_parallel
94
+ output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
95
+ return output, output_bias
96
+
97
+
98
+ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
99
+ def __init__(
100
+ self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
101
+ ) -> None:
102
+ super().__init__(base_layer, segment_gemm, lora_rank, scaling)
103
+
104
+ def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
105
+ self.set_lora = True
106
+ self.A_buffer = A_buffer
107
+ self.B_buffer = B_buffer
108
+ self.bs = bs
109
+ self.seq_lens = seq_lens
110
+ self.weight_indices = weight_indices
111
+
112
+ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
113
+ lora_a_output = self.segment_gemm.run(
114
+ x=x,
115
+ weights=self.A_buffer,
116
+ batch_size=self.bs,
117
+ weight_column_major=True,
118
+ seg_lens=self.seq_lens,
119
+ weight_indices=self.weight_indices,
120
+ )
121
+ # FIXME
122
+ assert lora_a_output.shape[-1] == self.lora_rank * 2
123
+ lora_output = torch.empty_like(base_output)
124
+ output_dim = lora_output.shape[-1] // 2
125
+ for i in range(2):
126
+ left = output_dim * i
127
+ right = left + output_dim
128
+ lora_output[:, left:right] = self.segment_gemm.run(
129
+ x=lora_a_output[
130
+ :, self.lora_rank * i : self.lora_rank * (i + 1)
131
+ ].contiguous(),
132
+ weights=self.B_buffer[:, left:right, :].contiguous(),
133
+ batch_size=self.bs,
134
+ weight_column_major=True,
135
+ seg_lens=self.seq_lens,
136
+ weight_indices=self.weight_indices,
137
+ )
138
+ return base_output + lora_output * self.scaling
139
+
140
+
141
+ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
142
+ def __init__(
143
+ self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
144
+ ) -> None:
145
+ super().__init__(base_layer, segment_gemm, lora_rank, scaling)
146
+
147
+ def set_lora_info(
148
+ self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices
149
+ ):
150
+ self.set_lora = True
151
+ self.A_buffer_qkv = A_buffer_qkv
152
+ self.B_buffer_q = B_buffer_q
153
+ self.B_buffer_kv = B_buffer_kv
154
+ self.bs = bs
155
+ self.seq_lens = seq_lens
156
+ self.weight_indices = weight_indices
157
+
158
+ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
159
+ lora_a_output = self.segment_gemm.run(
160
+ x=x,
161
+ weights=self.A_buffer_qkv,
162
+ batch_size=self.bs,
163
+ weight_column_major=True,
164
+ seg_lens=self.seq_lens,
165
+ weight_indices=self.weight_indices,
166
+ )
167
+ # FIXME parallelize qkv
168
+ lora_output = torch.empty_like(base_output)
169
+ # q
170
+ output_dim_q = self.B_buffer_q.shape[-2]
171
+ lora_output[:, :output_dim_q] = self.segment_gemm.run(
172
+ x=lora_a_output[:, : self.lora_rank].contiguous(),
173
+ weights=self.B_buffer_q,
174
+ batch_size=self.bs,
175
+ weight_column_major=True,
176
+ seg_lens=self.seq_lens,
177
+ weight_indices=self.weight_indices,
178
+ )
179
+ # kv
180
+ output_dim_kv = self.B_buffer_kv.shape[-2] // 2
181
+ for i in range(2):
182
+ left = output_dim_kv * i
183
+ right = left + output_dim_kv
184
+ lora_output[:, output_dim_q + left : output_dim_q + right] = (
185
+ self.segment_gemm.run(
186
+ x=lora_a_output[
187
+ :, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
188
+ ].contiguous(),
189
+ weights=self.B_buffer_kv[:, left:right, :].contiguous(),
190
+ batch_size=self.bs,
191
+ weight_column_major=True,
192
+ seg_lens=self.seq_lens,
193
+ weight_indices=self.weight_indices,
194
+ )
195
+ )
196
+ return base_output + lora_output * self.scaling
197
+
198
+
199
+ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
200
+ def __init__(
201
+ self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
202
+ ) -> None:
203
+ super().__init__(base_layer, segment_gemm, lora_rank, scaling)
204
+
205
+ def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
206
+ self.set_lora = True
207
+ self.A_buffer = A_buffer
208
+ self.B_buffer = B_buffer
209
+ self.bs = bs
210
+ self.seq_lens = seq_lens
211
+ self.weight_indices = weight_indices
212
+
213
+ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
214
+ lora_output = self.segment_gemm.run(
215
+ x=x,
216
+ weights=self.A_buffer,
217
+ batch_size=self.bs,
218
+ weight_column_major=True,
219
+ seg_lens=self.seq_lens,
220
+ weight_indices=self.weight_indices,
221
+ )
222
+ lora_output = self.segment_gemm.run(
223
+ x=lora_output,
224
+ weights=self.B_buffer,
225
+ batch_size=self.bs,
226
+ weight_column_major=True,
227
+ seg_lens=self.seq_lens,
228
+ weight_indices=self.weight_indices,
229
+ )
230
+ return base_output + lora_output * self.scaling
231
+
232
+ def forward(self, input_):
233
+ # duplicate the logic in RowParallelLinear
234
+ if self.base_layer.input_is_parallel:
235
+ input_parallel = input_
236
+ else:
237
+ tp_rank = get_tensor_model_parallel_rank()
238
+ splitted_input = split_tensor_along_last_dim(
239
+ input_, num_partitions=self.base_layer.tp_size
240
+ )
241
+ input_parallel = splitted_input[tp_rank].contiguous()
242
+ output_parallel = self.base_layer.quant_method.apply(
243
+ self.base_layer, input_parallel
244
+ )
245
+
246
+ if self.set_lora:
247
+ output_parallel = self.apply_lora(output_parallel, input_parallel)
248
+
249
+ if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
250
+ output_ = tensor_model_parallel_all_reduce(output_parallel)
251
+ else:
252
+ output_ = output_parallel
253
+
254
+ if not self.base_layer.skip_bias_add:
255
+ output = (
256
+ output_ + self.base_layer.bias
257
+ if self.base_layer.bias is not None
258
+ else output_
259
+ )
260
+ output_bias = None
261
+ else:
262
+ output = output_
263
+ output_bias = self.base_layer.bias
264
+ return output, output_bias
265
+
266
+
267
+ def get_lora_layer(
268
+ layer: nn.Module, segment_gemm, lora_rank, scaling
269
+ ) -> BaseLayerWithLoRA:
270
+ supported_layer_types = {
271
+ # the order matters
272
+ VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
273
+ QKVParallelLinear: QKVParallelLinearWithLoRA,
274
+ MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
275
+ ColumnParallelLinear: ColumnParallelLinearWithLoRA,
276
+ RowParallelLinear: RowParallelLinearWithLoRA,
277
+ }
278
+ for src_layer_type, lora_layer_type in supported_layer_types.items():
279
+ if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
280
+ ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
281
+ return ret
282
+ raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
283
+
284
+
285
+ def get_mapped_params(module_names):
286
+ ret = set()
287
+ for module_name in module_names:
288
+ ret.add(params_mapping(module_name))
289
+ return list(ret)
290
+
291
+
292
+ class LoRALayer(nn.Module):
293
+ def __init__(self, config, base_hf_config):
294
+ super().__init__()
295
+ self.config = config
296
+ self.base_hf_config = base_hf_config
297
+ self.weights = {}
298
+ self.weight_gpu = {}
299
+
300
+ def load_to_gpu(self):
301
+ for name, weight in self.weights.items():
302
+ self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
303
+
304
+ def offload_from_gpu(self):
305
+ for name, weight in self.weights.items():
306
+ self.weight_gpu[name] = None
307
+
308
+
309
+ class LoRAAdapter(nn.Module):
310
+ def __init__(self, uid, config, base_hf_config, load_config):
311
+ super().__init__()
312
+ self.uid = uid
313
+ self.config = config
314
+ assert self.config.hf_config["peft_type"].lower() == "lora"
315
+ self.base_hf_config = base_hf_config
316
+ self.load_config = load_config
317
+ self.scaling = self.config.lora_alpha / self.config.r
318
+
319
+ self.layers = nn.ModuleList(
320
+ [
321
+ LoRALayer(config, base_hf_config)
322
+ for i in range(base_hf_config.num_hidden_layers)
323
+ ]
324
+ )
325
+
326
+ self.weights = {}
327
+ self.weights_gpu = {}
328
+
329
+ def get_stacked_multiply(self, module_name):
330
+ stacked_rank = {
331
+ "qkv_proj": 3,
332
+ "kv_proj": 2,
333
+ "gate_up_proj": 2,
334
+ }
335
+ return stacked_rank[module_name] if module_name in stacked_rank else 1
336
+
337
+ def load_to_gpu(self):
338
+ for name, weight in self.weights.items():
339
+ self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
340
+ for layer in self.layers:
341
+ layer.load_to_gpu()
342
+
343
+ def offload_from_gpu(self):
344
+ for name, weight in self.weights.items():
345
+ self.weights_gpu[name] = None
346
+ for layer in self.layers:
347
+ layer.offload_from_gpu()
348
+
349
+ # initialize the LoRA weights to cpu
350
+ def initialize_weights(self):
351
+ model_path = self.config.path
352
+ loader = DefaultModelLoader(self.load_config)
353
+ revision = getattr(self.config.hf_config, "revision", None)
354
+ for name, loaded_weight in loader._get_weights_iterator(
355
+ model_path, revision=revision, fall_back_to_pt=True
356
+ ):
357
+ match = re.search(r"layers\.(\d+)\.", name)
358
+ if match is not None:
359
+ layer_id = int(match.group(1))
360
+ self.layers[layer_id].weights[name] = loaded_weight.cpu()
361
+ else:
362
+ self.weights[name] = loaded_weight.cpu()
363
+
364
+ # stack kv_proj and gate_up_proj
365
+ for i in range(self.base_hf_config.num_hidden_layers):
366
+ layer = self.layers[i]
367
+ weight_names = [name for name, _ in layer.weights.items()]
368
+ for weight_name in weight_names:
369
+ if "k_proj" in weight_name:
370
+ q_name = weight_name.replace("k_proj", "q_proj")
371
+ v_name = weight_name.replace("k_proj", "v_proj")
372
+ kv_name = weight_name.replace("k_proj", "kv_proj")
373
+ qkv_name = weight_name.replace("k_proj", "qkv_proj")
374
+ if "lora_A" in weight_name:
375
+ layer.weights[qkv_name] = torch.cat(
376
+ (
377
+ layer.weights[q_name],
378
+ layer.weights[weight_name],
379
+ layer.weights[v_name],
380
+ ),
381
+ 0,
382
+ )
383
+ layer.weights.pop(q_name)
384
+ layer.weights.pop(weight_name)
385
+ layer.weights.pop(v_name)
386
+ else:
387
+ layer.weights[kv_name] = torch.cat(
388
+ (
389
+ layer.weights[weight_name],
390
+ layer.weights[v_name],
391
+ ),
392
+ 0,
393
+ )
394
+ layer.weights.pop(weight_name)
395
+ layer.weights.pop(v_name)
396
+ elif "gate_proj" in weight_name:
397
+ up_name = weight_name.replace("gate_proj", "up_proj")
398
+ gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
399
+ layer.weights[gate_up_name] = torch.cat(
400
+ (layer.weights[weight_name], layer.weights[up_name]), 0
401
+ )
402
+ layer.weights.pop(weight_name)
403
+ layer.weights.pop(up_name)
@@ -0,0 +1,43 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ import json
17
+ import os
18
+
19
+ from huggingface_hub import snapshot_download
20
+
21
+
22
+ class LoRAConfig:
23
+ def __init__(
24
+ self,
25
+ path: str,
26
+ ) -> None:
27
+ self.path = path
28
+ self.hf_config = self.get_lora_config()
29
+ self.target_modules = self.hf_config["target_modules"]
30
+ self.r = self.hf_config["r"]
31
+ self.lora_alpha = self.hf_config["lora_alpha"]
32
+
33
+ def get_lora_config(self, dummy=False):
34
+ if dummy:
35
+ raise NotImplementedError()
36
+ else:
37
+ if not os.path.isdir(self.path):
38
+ weights_dir = snapshot_download(self.path, allow_patterns=["*.json"])
39
+ else:
40
+ weights_dir = self.path
41
+ config_name = "adapter_config.json"
42
+ with open(os.path.join(weights_dir, config_name), "r") as f:
43
+ return json.load(f)
@@ -0,0 +1,259 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
17
+ # and "Punica: Multi-Tenant LoRA Serving"
18
+
19
+
20
+ import re
21
+ from dataclasses import dataclass
22
+
23
+ import torch
24
+
25
+ from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
26
+ from sglang.srt.lora.lora_config import LoRAConfig
27
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
28
+ from sglang.srt.utils import is_hip, replace_submodule
29
+
30
+ # ROCm: flashinfer available later
31
+ if not is_hip():
32
+ from flashinfer import SegmentGEMMWrapper
33
+
34
+
35
+ def get_stacked_name(name):
36
+ # origin name -> (name for A, name for B)
37
+ params_mapping = {
38
+ "q_proj": ("qkv_proj", "q_proj"),
39
+ "k_proj": ("qkv_proj", "kv_proj"),
40
+ "v_proj": ("qkv_proj", "kv_proj"),
41
+ "gate_proj": ("gate_up_proj", "gate_up_proj"),
42
+ "up_proj": ("gate_up_proj", "gate_up_proj"),
43
+ }
44
+ return params_mapping.get(name, (name, name))
45
+
46
+
47
+ def get_layer_id(name):
48
+ match = re.search(r"layers\.(\d+)\.", name)
49
+ if match is None:
50
+ return None
51
+ return int(match.group(1))
52
+
53
+
54
+ class LoRAManager:
55
+ def __init__(
56
+ self,
57
+ base_model,
58
+ lora_paths,
59
+ base_hf_config,
60
+ max_loras_per_batch,
61
+ load_config,
62
+ dtype,
63
+ ):
64
+ self.base_model = base_model
65
+ self.lora_paths = lora_paths
66
+ self.base_hf_config = base_hf_config
67
+ self.max_loras_per_batch = max_loras_per_batch
68
+ self.load_config = load_config
69
+ self.dtype = dtype
70
+
71
+ workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
72
+ self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
73
+
74
+ self.init_loras()
75
+ self.init_lora_memory_pool()
76
+ self.init_lora_batch()
77
+
78
+ def match_target_modules(self, module_name):
79
+ for target_module in self.target_modules:
80
+ if module_name.split(".")[-1] == target_module:
81
+ return True
82
+ return False
83
+
84
+ def get_target_modules(self):
85
+ modules = []
86
+ for module_name, module in self.base_model.named_modules():
87
+ if self.match_target_modules(module_name):
88
+ modules.append((module_name, module))
89
+ return modules
90
+
91
+ def set_lora_module(self, module_name, module):
92
+ lora_module = get_lora_layer(
93
+ module, self.segment_gemm, self.max_lora_dim, self.scaling
94
+ )
95
+ replace_submodule(self.base_model, module_name, lora_module)
96
+ return lora_module
97
+
98
+ def init_loras(self):
99
+ # get configs and target modules
100
+ self.configs = {}
101
+ self.origin_target_modules = set()
102
+ for name, path in self.lora_paths.items():
103
+ self.configs[name] = LoRAConfig(path)
104
+ self.origin_target_modules = set(self.origin_target_modules) | set(
105
+ self.configs[name].target_modules
106
+ )
107
+ self.target_modules = set(
108
+ [
109
+ self.base_model.get_module_name(module)
110
+ for module in self.origin_target_modules
111
+ ]
112
+ )
113
+ self.target_weights = set(
114
+ [get_stacked_name(module) for module in self.origin_target_modules]
115
+ )
116
+
117
+ # load all weights to cpu
118
+ self.loras = []
119
+ self.lora_id = {}
120
+ for name in self.lora_paths.keys():
121
+ self.lora_id[name] = len(self.loras)
122
+ self.loras.append(
123
+ LoRAAdapter(
124
+ name, self.configs[name], self.base_hf_config, self.load_config
125
+ )
126
+ )
127
+ self.loras[-1].initialize_weights()
128
+
129
+ # misc lora configs
130
+ self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
131
+ self.scaling = self.loras[0].scaling
132
+ # FIXME remove the restrictions
133
+ assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
134
+ assert all(x.scaling == self.scaling for x in self.loras)
135
+
136
+ # monkey patch to use the LoRA version
137
+ self.lora_modules = []
138
+ for module_name, module in self.get_target_modules():
139
+ self.lora_modules.append(
140
+ (module_name, self.set_lora_module(module_name, module))
141
+ )
142
+
143
+ def init_lora_memory_pool(self):
144
+ # preallocate lora memory pool
145
+ self.A_buffer = {}
146
+ self.B_buffer = {}
147
+ num_layer = self.base_hf_config.num_hidden_layers
148
+ for module_A, module_B in self.target_weights:
149
+ # init A tensor, column_major=True
150
+ hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
151
+ c = self.loras[-1].get_stacked_multiply(module_A)
152
+ if module_A not in self.A_buffer:
153
+ self.A_buffer[module_A] = [
154
+ torch.empty(
155
+ (
156
+ self.max_loras_per_batch,
157
+ self.max_lora_dim * c,
158
+ hidden_dim_A,
159
+ ),
160
+ dtype=self.dtype,
161
+ device="cuda",
162
+ )
163
+ for i in range(num_layer)
164
+ ]
165
+ # init B tensor, column_major=True
166
+ _, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
167
+ c = self.loras[-1].get_stacked_multiply(module_B)
168
+ if module_B not in self.B_buffer:
169
+ self.B_buffer[module_B] = [
170
+ torch.empty(
171
+ (
172
+ self.max_loras_per_batch,
173
+ hidden_dim_B * c,
174
+ self.max_lora_dim,
175
+ ),
176
+ dtype=self.dtype,
177
+ device="cuda",
178
+ )
179
+ for i in range(num_layer)
180
+ ]
181
+
182
+ def init_lora_batch(self):
183
+ self.active_uids = set() # set of active loras
184
+ self.buffer_id = {} # lora uid -> idx in memory pool
185
+
186
+ def get_weight_name(self, name, idx):
187
+ for target_weight_name in self.target_weights:
188
+ if target_weight_name[idx] in name:
189
+ return target_weight_name[idx]
190
+
191
+ def load_lora(self, uid, buffer_id):
192
+ num_layer = self.base_hf_config.num_hidden_layers
193
+ if uid is None:
194
+ for i in range(num_layer):
195
+ for k in self.A_buffer.keys():
196
+ self.A_buffer[k][i][buffer_id] *= 0
197
+ return
198
+
199
+ for i in range(num_layer):
200
+ layer_weights = self.loras[self.lora_id[uid]].layers[i].weights
201
+ for name, weights in layer_weights.items():
202
+ if "lora_A" in name:
203
+ lora_weight_name = self.get_weight_name(name, 0)
204
+ if lora_weight_name:
205
+ self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
206
+ else:
207
+ lora_weight_name = self.get_weight_name(name, 1)
208
+ if lora_weight_name:
209
+ self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
210
+
211
+ def prepare_lora_batch(self, batch, extend_seq_lens=None):
212
+ # load active loras into lora memory pool
213
+ cur_uids = set([req.lora_path for req in batch.reqs])
214
+ assert len(cur_uids) <= self.max_loras_per_batch
215
+ i = 0
216
+ evictable_uids = list(self.active_uids)
217
+ for uid in cur_uids:
218
+ if uid not in self.active_uids:
219
+ while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
220
+ i += 1
221
+ if i < len(evictable_uids):
222
+ self.active_uids.remove(evictable_uids[i])
223
+ self.buffer_id.pop(evictable_uids[i])
224
+ self.load_lora(uid, i)
225
+ self.active_uids.add(uid)
226
+ self.buffer_id[uid] = i
227
+ i += 1
228
+
229
+ if cur_uids == set([None]):
230
+ return
231
+
232
+ # setup lora in forward modules
233
+ bs = len(batch.reqs)
234
+ seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs)
235
+ weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
236
+ for i, req in enumerate(batch.reqs):
237
+ weight_indices[i] = self.buffer_id[req.lora_path]
238
+
239
+ for module_name, module in self.lora_modules:
240
+ layer_id = get_layer_id(module_name)
241
+
242
+ if "qkv_proj" not in module_name:
243
+ weight_name = self.get_weight_name(module_name, 0)
244
+ module.set_lora_info(
245
+ self.A_buffer[weight_name][layer_id],
246
+ self.B_buffer[weight_name][layer_id],
247
+ bs,
248
+ seg_lens,
249
+ weight_indices,
250
+ )
251
+ else:
252
+ module.set_lora_info(
253
+ self.A_buffer["qkv_proj"][layer_id],
254
+ self.B_buffer["q_proj"][layer_id],
255
+ self.B_buffer["kv_proj"][layer_id],
256
+ bs,
257
+ seg_lens,
258
+ weight_indices,
259
+ )