sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +13 -6
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +2 -4
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +40 -35
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +110 -74
  31. sglang/srt/managers/tokenizer_manager.py +24 -15
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +60 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +118 -141
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +6 -8
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +8 -43
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/{llama2.py → llama.py} +48 -26
  49. sglang/srt/models/llama_classification.py +14 -40
  50. sglang/srt/models/llama_embedding.py +7 -6
  51. sglang/srt/models/llava.py +38 -16
  52. sglang/srt/models/llavavid.py +7 -8
  53. sglang/srt/models/minicpm.py +1 -5
  54. sglang/srt/models/minicpm3.py +665 -0
  55. sglang/srt/models/mistral.py +2 -3
  56. sglang/srt/models/mixtral.py +6 -5
  57. sglang/srt/models/mixtral_quant.py +1 -5
  58. sglang/srt/models/qwen.py +1 -5
  59. sglang/srt/models/qwen2.py +1 -5
  60. sglang/srt/models/qwen2_moe.py +6 -5
  61. sglang/srt/models/stablelm.py +1 -5
  62. sglang/srt/models/xverse.py +375 -0
  63. sglang/srt/models/xverse_moe.py +445 -0
  64. sglang/srt/openai_api/adapter.py +65 -46
  65. sglang/srt/openai_api/protocol.py +11 -3
  66. sglang/srt/sampling/sampling_batch_info.py +67 -58
  67. sglang/srt/server.py +24 -14
  68. sglang/srt/server_args.py +130 -28
  69. sglang/srt/utils.py +12 -0
  70. sglang/test/few_shot_gsm8k.py +132 -0
  71. sglang/test/runners.py +114 -22
  72. sglang/test/test_programs.py +70 -0
  73. sglang/test/test_utils.py +89 -1
  74. sglang/utils.py +38 -4
  75. sglang/version.py +1 -1
  76. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
  77. sglang-0.3.1.dist-info/RECORD +129 -0
  78. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  79. sglang-0.2.15.dist-info/RECORD +0 -118
  80. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,403 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
17
+ # and "Punica: Multi-Tenant LoRA Serving"
18
+
19
+ # LoRA layers class inheritance adapted from:
20
+ # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
21
+
22
+
23
+ import json
24
+ import os
25
+ import re
26
+ from typing import Any, Dict, List, Optional, Tuple
27
+
28
+ import safetensors.torch
29
+ import torch
30
+ from torch import nn
31
+ from vllm.model_executor.layers.linear import (
32
+ ColumnParallelLinear,
33
+ MergedColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear,
36
+ )
37
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
38
+ ParallelLMHead,
39
+ VocabParallelEmbedding,
40
+ )
41
+ from vllm.model_executor.model_loader.loader import DefaultModelLoader
42
+
43
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
44
+
45
+
46
+ class BaseLayerWithLoRA(nn.Module):
47
+ def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
48
+ super().__init__()
49
+ self.base_layer = base_layer
50
+ self.segment_gemm = segment_gemm
51
+ self.lora_rank = lora_rank
52
+ self.scaling = scaling
53
+ self.set_lora = False
54
+
55
+ def forward(self, x: torch.Tensor):
56
+ return self.base_layer.forward(x)
57
+
58
+ def set_lora_info(self, *args):
59
+ pass
60
+
61
+
62
+ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
63
+ def __init__(
64
+ self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
65
+ ) -> None:
66
+ super().__init__(base_layer, segment_gemm, lora_rank, scaling)
67
+ self.weight = base_layer.weight
68
+
69
+
70
+ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
71
+ def __init__(
72
+ self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
73
+ ) -> None:
74
+ super().__init__(base_layer, segment_gemm, lora_rank, scaling)
75
+
76
+ def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
77
+ # TODO
78
+ return output
79
+
80
+ def forward(self, input_: torch.Tensor):
81
+ # duplicate the logic in ColumnParallelLinear
82
+ bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
83
+ output_parallel = self.base_layer.quant_method.apply(
84
+ self.base_layer, input_, bias
85
+ )
86
+
87
+ if self.set_lora:
88
+ output_parallel = self.apply_lora(output_parallel, input_)
89
+
90
+ if self.base_layer.gather_output:
91
+ output = tensor_model_parallel_all_gather(output_parallel)
92
+ else:
93
+ output = output_parallel
94
+ output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
95
+ return output, output_bias
96
+
97
+
98
+ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
99
+ def __init__(
100
+ self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
101
+ ) -> None:
102
+ super().__init__(base_layer, segment_gemm, lora_rank, scaling)
103
+
104
+ def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
105
+ self.set_lora = True
106
+ self.A_buffer = A_buffer
107
+ self.B_buffer = B_buffer
108
+ self.bs = bs
109
+ self.seq_lens = seq_lens
110
+ self.weight_indices = weight_indices
111
+
112
+ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
113
+ lora_a_output = self.segment_gemm.run(
114
+ x=x,
115
+ weights=self.A_buffer,
116
+ batch_size=self.bs,
117
+ weight_column_major=True,
118
+ seg_lens=self.seq_lens,
119
+ weight_indices=self.weight_indices,
120
+ )
121
+ # FIXME
122
+ assert lora_a_output.shape[-1] == self.lora_rank * 2
123
+ lora_output = torch.empty_like(base_output)
124
+ output_dim = lora_output.shape[-1] // 2
125
+ for i in range(2):
126
+ left = output_dim * i
127
+ right = left + output_dim
128
+ lora_output[:, left:right] = self.segment_gemm.run(
129
+ x=lora_a_output[
130
+ :, self.lora_rank * i : self.lora_rank * (i + 1)
131
+ ].contiguous(),
132
+ weights=self.B_buffer[:, left:right, :].contiguous(),
133
+ batch_size=self.bs,
134
+ weight_column_major=True,
135
+ seg_lens=self.seq_lens,
136
+ weight_indices=self.weight_indices,
137
+ )
138
+ return base_output + lora_output * self.scaling
139
+
140
+
141
+ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
142
+ def __init__(
143
+ self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
144
+ ) -> None:
145
+ super().__init__(base_layer, segment_gemm, lora_rank, scaling)
146
+
147
+ def set_lora_info(
148
+ self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices
149
+ ):
150
+ self.set_lora = True
151
+ self.A_buffer_qkv = A_buffer_qkv
152
+ self.B_buffer_q = B_buffer_q
153
+ self.B_buffer_kv = B_buffer_kv
154
+ self.bs = bs
155
+ self.seq_lens = seq_lens
156
+ self.weight_indices = weight_indices
157
+
158
+ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
159
+ lora_a_output = self.segment_gemm.run(
160
+ x=x,
161
+ weights=self.A_buffer_qkv,
162
+ batch_size=self.bs,
163
+ weight_column_major=True,
164
+ seg_lens=self.seq_lens,
165
+ weight_indices=self.weight_indices,
166
+ )
167
+ # FIXME parallelize qkv
168
+ lora_output = torch.empty_like(base_output)
169
+ # q
170
+ output_dim_q = self.B_buffer_q.shape[-2]
171
+ lora_output[:, :output_dim_q] = self.segment_gemm.run(
172
+ x=lora_a_output[:, : self.lora_rank].contiguous(),
173
+ weights=self.B_buffer_q,
174
+ batch_size=self.bs,
175
+ weight_column_major=True,
176
+ seg_lens=self.seq_lens,
177
+ weight_indices=self.weight_indices,
178
+ )
179
+ # kv
180
+ output_dim_kv = self.B_buffer_kv.shape[-2] // 2
181
+ for i in range(2):
182
+ left = output_dim_kv * i
183
+ right = left + output_dim_kv
184
+ lora_output[:, output_dim_q + left : output_dim_q + right] = (
185
+ self.segment_gemm.run(
186
+ x=lora_a_output[
187
+ :, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
188
+ ].contiguous(),
189
+ weights=self.B_buffer_kv[:, left:right, :].contiguous(),
190
+ batch_size=self.bs,
191
+ weight_column_major=True,
192
+ seg_lens=self.seq_lens,
193
+ weight_indices=self.weight_indices,
194
+ )
195
+ )
196
+ return base_output + lora_output * self.scaling
197
+
198
+
199
+ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
200
+ def __init__(
201
+ self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
202
+ ) -> None:
203
+ super().__init__(base_layer, segment_gemm, lora_rank, scaling)
204
+
205
+ def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
206
+ self.set_lora = True
207
+ self.A_buffer = A_buffer
208
+ self.B_buffer = B_buffer
209
+ self.bs = bs
210
+ self.seq_lens = seq_lens
211
+ self.weight_indices = weight_indices
212
+
213
+ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
214
+ lora_output = self.segment_gemm.run(
215
+ x=x,
216
+ weights=self.A_buffer,
217
+ batch_size=self.bs,
218
+ weight_column_major=True,
219
+ seg_lens=self.seq_lens,
220
+ weight_indices=self.weight_indices,
221
+ )
222
+ lora_output = self.segment_gemm.run(
223
+ x=lora_output,
224
+ weights=self.B_buffer,
225
+ batch_size=self.bs,
226
+ weight_column_major=True,
227
+ seg_lens=self.seq_lens,
228
+ weight_indices=self.weight_indices,
229
+ )
230
+ return base_output + lora_output * self.scaling
231
+
232
+ def forward(self, input_):
233
+ # duplicate the logic in RowParallelLinear
234
+ if self.base_layer.input_is_parallel:
235
+ input_parallel = input_
236
+ else:
237
+ tp_rank = get_tensor_model_parallel_rank()
238
+ splitted_input = split_tensor_along_last_dim(
239
+ input_, num_partitions=self.base_layer.tp_size
240
+ )
241
+ input_parallel = splitted_input[tp_rank].contiguous()
242
+ output_parallel = self.base_layer.quant_method.apply(
243
+ self.base_layer, input_parallel
244
+ )
245
+
246
+ if self.set_lora:
247
+ output_parallel = self.apply_lora(output_parallel, input_parallel)
248
+
249
+ if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
250
+ output_ = tensor_model_parallel_all_reduce(output_parallel)
251
+ else:
252
+ output_ = output_parallel
253
+
254
+ if not self.base_layer.skip_bias_add:
255
+ output = (
256
+ output_ + self.base_layer.bias
257
+ if self.base_layer.bias is not None
258
+ else output_
259
+ )
260
+ output_bias = None
261
+ else:
262
+ output = output_
263
+ output_bias = self.base_layer.bias
264
+ return output, output_bias
265
+
266
+
267
+ def get_lora_layer(
268
+ layer: nn.Module, segment_gemm, lora_rank, scaling
269
+ ) -> BaseLayerWithLoRA:
270
+ supported_layer_types = {
271
+ # the order matters
272
+ VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
273
+ QKVParallelLinear: QKVParallelLinearWithLoRA,
274
+ MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
275
+ ColumnParallelLinear: ColumnParallelLinearWithLoRA,
276
+ RowParallelLinear: RowParallelLinearWithLoRA,
277
+ }
278
+ for src_layer_type, lora_layer_type in supported_layer_types.items():
279
+ if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
280
+ ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
281
+ return ret
282
+ raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
283
+
284
+
285
+ def get_mapped_params(module_names):
286
+ ret = set()
287
+ for module_name in module_names:
288
+ ret.add(params_mapping(module_name))
289
+ return list(ret)
290
+
291
+
292
+ class LoRALayer(nn.Module):
293
+ def __init__(self, config, base_hf_config):
294
+ super().__init__()
295
+ self.config = config
296
+ self.base_hf_config = base_hf_config
297
+ self.weights = {}
298
+ self.weight_gpu = {}
299
+
300
+ def load_to_gpu(self):
301
+ for name, weight in self.weights.items():
302
+ self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
303
+
304
+ def offload_from_gpu(self):
305
+ for name, weight in self.weights.items():
306
+ self.weight_gpu[name] = None
307
+
308
+
309
+ class LoRAAdapter(nn.Module):
310
+ def __init__(self, uid, config, base_hf_config, load_config):
311
+ super().__init__()
312
+ self.uid = uid
313
+ self.config = config
314
+ assert self.config.hf_config["peft_type"].lower() == "lora"
315
+ self.base_hf_config = base_hf_config
316
+ self.load_config = load_config
317
+ self.scaling = self.config.lora_alpha / self.config.r
318
+
319
+ self.layers = nn.ModuleList(
320
+ [
321
+ LoRALayer(config, base_hf_config)
322
+ for i in range(base_hf_config.num_hidden_layers)
323
+ ]
324
+ )
325
+
326
+ self.weights = {}
327
+ self.weights_gpu = {}
328
+
329
+ def get_stacked_multiply(self, module_name):
330
+ stacked_rank = {
331
+ "qkv_proj": 3,
332
+ "kv_proj": 2,
333
+ "gate_up_proj": 2,
334
+ }
335
+ return stacked_rank[module_name] if module_name in stacked_rank else 1
336
+
337
+ def load_to_gpu(self):
338
+ for name, weight in self.weights.items():
339
+ self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
340
+ for layer in self.layers:
341
+ layer.load_to_gpu()
342
+
343
+ def offload_from_gpu(self):
344
+ for name, weight in self.weights.items():
345
+ self.weights_gpu[name] = None
346
+ for layer in self.layers:
347
+ layer.offload_from_gpu()
348
+
349
+ # initialize the LoRA weights to cpu
350
+ def initialize_weights(self):
351
+ model_path = self.config.path
352
+ loader = DefaultModelLoader(self.load_config)
353
+ revision = getattr(self.config.hf_config, "revision", None)
354
+ for name, loaded_weight in loader._get_weights_iterator(
355
+ model_path, revision=revision, fall_back_to_pt=True
356
+ ):
357
+ match = re.search(r"layers\.(\d+)\.", name)
358
+ if match is not None:
359
+ layer_id = int(match.group(1))
360
+ self.layers[layer_id].weights[name] = loaded_weight.cpu()
361
+ else:
362
+ self.weights[name] = loaded_weight.cpu()
363
+
364
+ # stack kv_proj and gate_up_proj
365
+ for i in range(self.base_hf_config.num_hidden_layers):
366
+ layer = self.layers[i]
367
+ weight_names = [name for name, _ in layer.weights.items()]
368
+ for weight_name in weight_names:
369
+ if "k_proj" in weight_name:
370
+ q_name = weight_name.replace("k_proj", "q_proj")
371
+ v_name = weight_name.replace("k_proj", "v_proj")
372
+ kv_name = weight_name.replace("k_proj", "kv_proj")
373
+ qkv_name = weight_name.replace("k_proj", "qkv_proj")
374
+ if "lora_A" in weight_name:
375
+ layer.weights[qkv_name] = torch.cat(
376
+ (
377
+ layer.weights[q_name],
378
+ layer.weights[weight_name],
379
+ layer.weights[v_name],
380
+ ),
381
+ 0,
382
+ )
383
+ layer.weights.pop(q_name)
384
+ layer.weights.pop(weight_name)
385
+ layer.weights.pop(v_name)
386
+ else:
387
+ layer.weights[kv_name] = torch.cat(
388
+ (
389
+ layer.weights[weight_name],
390
+ layer.weights[v_name],
391
+ ),
392
+ 0,
393
+ )
394
+ layer.weights.pop(weight_name)
395
+ layer.weights.pop(v_name)
396
+ elif "gate_proj" in weight_name:
397
+ up_name = weight_name.replace("gate_proj", "up_proj")
398
+ gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
399
+ layer.weights[gate_up_name] = torch.cat(
400
+ (layer.weights[weight_name], layer.weights[up_name]), 0
401
+ )
402
+ layer.weights.pop(weight_name)
403
+ layer.weights.pop(up_name)
@@ -0,0 +1,43 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ import json
17
+ import os
18
+
19
+ from huggingface_hub import snapshot_download
20
+
21
+
22
+ class LoRAConfig:
23
+ def __init__(
24
+ self,
25
+ path: str,
26
+ ) -> None:
27
+ self.path = path
28
+ self.hf_config = self.get_lora_config()
29
+ self.target_modules = self.hf_config["target_modules"]
30
+ self.r = self.hf_config["r"]
31
+ self.lora_alpha = self.hf_config["lora_alpha"]
32
+
33
+ def get_lora_config(self, dummy=False):
34
+ if dummy:
35
+ raise NotImplementedError()
36
+ else:
37
+ if not os.path.isdir(self.path):
38
+ weights_dir = snapshot_download(self.path, allow_patterns=["*.json"])
39
+ else:
40
+ weights_dir = self.path
41
+ config_name = "adapter_config.json"
42
+ with open(os.path.join(weights_dir, config_name), "r") as f:
43
+ return json.load(f)
@@ -0,0 +1,256 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
17
+ # and "Punica: Multi-Tenant LoRA Serving"
18
+
19
+
20
+ import re
21
+ from dataclasses import dataclass
22
+
23
+ import torch
24
+ from flashinfer import SegmentGEMMWrapper
25
+
26
+ from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
27
+ from sglang.srt.lora.lora_config import LoRAConfig
28
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
29
+ from sglang.srt.utils import replace_submodule
30
+
31
+
32
+ def get_stacked_name(name):
33
+ # origin name -> (name for A, name for B)
34
+ params_mapping = {
35
+ "q_proj": ("qkv_proj", "q_proj"),
36
+ "k_proj": ("qkv_proj", "kv_proj"),
37
+ "v_proj": ("qkv_proj", "kv_proj"),
38
+ "gate_proj": ("gate_up_proj", "gate_up_proj"),
39
+ "up_proj": ("gate_up_proj", "gate_up_proj"),
40
+ }
41
+ return params_mapping.get(name, (name, name))
42
+
43
+
44
+ def get_layer_id(name):
45
+ match = re.search(r"layers\.(\d+)\.", name)
46
+ if match is None:
47
+ return None
48
+ return int(match.group(1))
49
+
50
+
51
+ class LoRAManager:
52
+ def __init__(
53
+ self,
54
+ base_model,
55
+ lora_paths,
56
+ base_hf_config,
57
+ max_loras_per_batch,
58
+ load_config,
59
+ dtype,
60
+ ):
61
+ self.base_model = base_model
62
+ self.lora_paths = lora_paths
63
+ self.base_hf_config = base_hf_config
64
+ self.max_loras_per_batch = max_loras_per_batch
65
+ self.load_config = load_config
66
+ self.dtype = dtype
67
+
68
+ workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
69
+ self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
70
+
71
+ self.init_loras()
72
+ self.init_lora_memory_pool()
73
+ self.init_lora_batch()
74
+
75
+ def match_target_modules(self, module_name):
76
+ for target_module in self.target_modules:
77
+ if module_name.split(".")[-1] == target_module:
78
+ return True
79
+ return False
80
+
81
+ def get_target_modules(self):
82
+ modules = []
83
+ for module_name, module in self.base_model.named_modules():
84
+ if self.match_target_modules(module_name):
85
+ modules.append((module_name, module))
86
+ return modules
87
+
88
+ def set_lora_module(self, module_name, module):
89
+ lora_module = get_lora_layer(
90
+ module, self.segment_gemm, self.max_lora_dim, self.scaling
91
+ )
92
+ replace_submodule(self.base_model, module_name, lora_module)
93
+ return lora_module
94
+
95
+ def init_loras(self):
96
+ # get configs and target modules
97
+ self.configs = {}
98
+ self.origin_target_modules = set()
99
+ for path in self.lora_paths:
100
+ self.configs[path] = LoRAConfig(path)
101
+ self.origin_target_modules = set(self.origin_target_modules) | set(
102
+ self.configs[path].target_modules
103
+ )
104
+ self.target_modules = set(
105
+ [
106
+ self.base_model.get_module_name(module)
107
+ for module in self.origin_target_modules
108
+ ]
109
+ )
110
+ self.target_weights = set(
111
+ [get_stacked_name(module) for module in self.origin_target_modules]
112
+ )
113
+
114
+ # load all weights to cpu
115
+ self.loras = []
116
+ self.lora_id = {}
117
+ for path in self.lora_paths:
118
+ self.lora_id[path] = len(self.loras)
119
+ self.loras.append(
120
+ LoRAAdapter(
121
+ path, self.configs[path], self.base_hf_config, self.load_config
122
+ )
123
+ )
124
+ self.loras[-1].initialize_weights()
125
+
126
+ # misc lora configs
127
+ self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
128
+ self.scaling = self.loras[0].scaling
129
+ # FIXME remove the restrictions
130
+ assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
131
+ assert all(x.scaling == self.scaling for x in self.loras)
132
+
133
+ # monkey patch to use the LoRA version
134
+ self.lora_modules = []
135
+ for module_name, module in self.get_target_modules():
136
+ self.lora_modules.append(
137
+ (module_name, self.set_lora_module(module_name, module))
138
+ )
139
+
140
+ def init_lora_memory_pool(self):
141
+ # preallocate lora memory pool
142
+ self.A_buffer = {}
143
+ self.B_buffer = {}
144
+ num_layer = self.base_hf_config.num_hidden_layers
145
+ for module_A, module_B in self.target_weights:
146
+ # init A tensor, column_major=True
147
+ hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
148
+ c = self.loras[-1].get_stacked_multiply(module_A)
149
+ if module_A not in self.A_buffer:
150
+ self.A_buffer[module_A] = [
151
+ torch.empty(
152
+ (
153
+ self.max_loras_per_batch,
154
+ self.max_lora_dim * c,
155
+ hidden_dim_A,
156
+ ),
157
+ dtype=self.dtype,
158
+ device="cuda",
159
+ )
160
+ for i in range(num_layer)
161
+ ]
162
+ # init B tensor, column_major=True
163
+ _, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
164
+ c = self.loras[-1].get_stacked_multiply(module_B)
165
+ if module_B not in self.B_buffer:
166
+ self.B_buffer[module_B] = [
167
+ torch.empty(
168
+ (
169
+ self.max_loras_per_batch,
170
+ hidden_dim_B * c,
171
+ self.max_lora_dim,
172
+ ),
173
+ dtype=self.dtype,
174
+ device="cuda",
175
+ )
176
+ for i in range(num_layer)
177
+ ]
178
+
179
+ def init_lora_batch(self):
180
+ self.active_uids = set() # set of active loras
181
+ self.buffer_id = {} # lora uid -> idx in memory pool
182
+
183
+ def get_weight_name(self, name, idx):
184
+ for target_weight_name in self.target_weights:
185
+ if target_weight_name[idx] in name:
186
+ return target_weight_name[idx]
187
+
188
+ def load_lora(self, uid, buffer_id):
189
+ num_layer = self.base_hf_config.num_hidden_layers
190
+ if uid is None:
191
+ for i in range(num_layer):
192
+ for k in self.A_buffer.keys():
193
+ self.A_buffer[k][i][buffer_id] *= 0
194
+ return
195
+
196
+ for i in range(num_layer):
197
+ layer_weights = self.loras[self.lora_id[uid]].layers[i].weights
198
+ for name, weights in layer_weights.items():
199
+ if "lora_A" in name:
200
+ lora_weight_name = self.get_weight_name(name, 0)
201
+ if lora_weight_name:
202
+ self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
203
+ else:
204
+ lora_weight_name = self.get_weight_name(name, 1)
205
+ if lora_weight_name:
206
+ self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
207
+
208
+ def prepare_lora_batch(self, batch, extend_seq_lens=None):
209
+ # load active loras into lora memory pool
210
+ cur_uids = set([req.lora_path for req in batch.reqs])
211
+ assert len(cur_uids) <= self.max_loras_per_batch
212
+ i = 0
213
+ evictable_uids = list(self.active_uids)
214
+ for uid in cur_uids:
215
+ if uid not in self.active_uids:
216
+ while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
217
+ i += 1
218
+ if i < len(evictable_uids):
219
+ self.active_uids.remove(evictable_uids[i])
220
+ self.buffer_id.pop(evictable_uids[i])
221
+ self.load_lora(uid, i)
222
+ self.active_uids.add(uid)
223
+ self.buffer_id[uid] = i
224
+ i += 1
225
+
226
+ if cur_uids == set([None]):
227
+ return
228
+
229
+ # setup lora in forward modules
230
+ bs = len(batch.reqs)
231
+ seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs)
232
+ weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
233
+ for i, req in enumerate(batch.reqs):
234
+ weight_indices[i] = self.buffer_id[req.lora_path]
235
+
236
+ for module_name, module in self.lora_modules:
237
+ layer_id = get_layer_id(module_name)
238
+
239
+ if "qkv_proj" not in module_name:
240
+ weight_name = self.get_weight_name(module_name, 0)
241
+ module.set_lora_info(
242
+ self.A_buffer[weight_name][layer_id],
243
+ self.B_buffer[weight_name][layer_id],
244
+ bs,
245
+ seg_lens,
246
+ weight_indices,
247
+ )
248
+ else:
249
+ module.set_lora_info(
250
+ self.A_buffer["qkv_proj"][layer_id],
251
+ self.B_buffer["q_proj"][layer_id],
252
+ self.B_buffer["kv_proj"][layer_id],
253
+ bs,
254
+ seg_lens,
255
+ weight_indices,
256
+ )