sglang 0.4.2.post2__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. sglang/check_env.py +1 -0
  2. sglang/srt/constrained/outlines_backend.py +4 -1
  3. sglang/srt/function_call_parser.py +96 -69
  4. sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
  5. sglang/srt/layers/attention/flashinfer_backend.py +34 -41
  6. sglang/srt/layers/attention/triton_backend.py +64 -16
  7. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
  8. sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
  9. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -5
  10. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  11. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  12. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  13. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  15. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  20. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  21. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  22. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/quantization/fp8_kernel.py +43 -10
  24. sglang/srt/lora/backend/__init__.py +25 -5
  25. sglang/srt/lora/backend/base_backend.py +31 -9
  26. sglang/srt/lora/backend/flashinfer_backend.py +41 -4
  27. sglang/srt/lora/backend/triton_backend.py +34 -4
  28. sglang/srt/lora/layers.py +293 -0
  29. sglang/srt/lora/lora.py +101 -326
  30. sglang/srt/lora/lora_manager.py +101 -269
  31. sglang/srt/lora/mem_pool.py +174 -0
  32. sglang/srt/lora/triton_ops/__init__.py +7 -1
  33. sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
  34. sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
  35. sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
  36. sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
  37. sglang/srt/lora/utils.py +141 -0
  38. sglang/srt/model_executor/cuda_graph_runner.py +4 -0
  39. sglang/srt/models/llama.py +8 -3
  40. sglang/srt/speculative/build_eagle_tree.py +482 -102
  41. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  42. sglang/srt/speculative/eagle_utils.py +134 -61
  43. sglang/srt/speculative/eagle_worker.py +1 -0
  44. sglang/version.py +1 -1
  45. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
  46. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +49 -32
  47. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
  48. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
  49. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
@@ -16,307 +16,115 @@
16
16
  # and "Punica: Multi-Tenant LoRA Serving"
17
17
 
18
18
  import logging
19
- import re
19
+ from typing import Dict, List, Set, Tuple
20
20
 
21
21
  import torch
22
22
 
23
- from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend
24
- from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer
23
+ from sglang.srt.configs.load_config import LoadConfig
24
+ from sglang.srt.hf_transformers_utils import AutoConfig
25
+ from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
26
+ from sglang.srt.lora.layers import get_lora_layer
27
+ from sglang.srt.lora.lora import LoRAAdapter
25
28
  from sglang.srt.lora.lora_config import LoRAConfig
29
+ from sglang.srt.lora.mem_pool import LoRAMemoryPool
30
+ from sglang.srt.lora.utils import (
31
+ LoRABatchInfo,
32
+ LoRAType,
33
+ get_customized_names_from_hf_names,
34
+ get_layer_id,
35
+ get_stacked_name,
36
+ get_weight_name,
37
+ )
26
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
- from sglang.srt.utils import is_flashinfer_available, replace_submodule
39
+ from sglang.srt.utils import replace_submodule
28
40
 
29
41
  logger = logging.getLogger(__name__)
30
42
 
31
43
 
32
- def get_module_name(name):
33
- # Fallback solution of mapping from config module name to module name in model class.
34
- # Please check if it aligns with your base model.
35
- # Please implement the function in the model class if it is not.
36
- # You can reference this function in llama.py.
37
- params_mapping = {
38
- "q_proj": "qkv_proj",
39
- "k_proj": "qkv_proj",
40
- "v_proj": "qkv_proj",
41
- "gate_proj": "gate_up_proj",
42
- "up_proj": "gate_up_proj",
43
- }
44
- return params_mapping.get(name, name)
45
-
46
-
47
- def get_hidden_dim(module_name, config):
48
- # Fallback solution of get_hidden_dim for different modules
49
- # Please check if it aligns with your base model.
50
- # Please implement the function in the model class if it is not.
51
- # You can reference this function in llama.py.
52
- if module_name in ["q_proj", "o_proj", "qkv_proj"]:
53
- return config.hidden_size, config.hidden_size
54
- elif module_name in ["kv_proj"]:
55
- return config.hidden_size, config.hidden_size // (
56
- config.num_attention_heads // config.num_key_value_heads
57
- )
58
- elif module_name == "gate_up_proj":
59
- return config.hidden_size, config.intermediate_size
60
- elif module_name == "down_proj":
61
- return config.intermediate_size, config.hidden_size
62
- else:
63
- raise NotImplementedError()
64
-
65
-
66
- def get_stacked_name(name):
67
- # origin name -> (name for A, name for B)
68
- params_mapping = {
69
- "q_proj": ("qkv_proj", "q_proj"),
70
- "k_proj": ("qkv_proj", "kv_proj"),
71
- "v_proj": ("qkv_proj", "kv_proj"),
72
- "gate_proj": ("gate_up_proj", "gate_up_proj"),
73
- "up_proj": ("gate_up_proj", "gate_up_proj"),
74
- }
75
- return params_mapping.get(name, (name, name))
76
-
77
-
78
- def get_backend_from_name(name):
79
- backend_mapping = {
80
- "triton": TritonLoraBackend,
81
- "flashinfer": FlashInferLoraBackend,
82
- }
83
-
84
- if name in backend_mapping:
85
- return backend_mapping[name]
86
-
87
- raise Exception(
88
- f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
89
- )
90
-
91
-
92
- def get_layer_id(name):
93
- match = re.search(r"layers\.(\d+)\.", name)
94
- if match is None:
95
- return None
96
- return int(match.group(1))
97
-
98
-
99
44
  class LoRAManager:
100
45
  def __init__(
101
46
  self,
102
- base_model,
103
- lora_paths,
104
- base_hf_config,
105
- max_loras_per_batch,
106
- load_config,
107
- dtype,
108
- lora_backend,
47
+ base_model: torch.nn.Module,
48
+ lora_paths: Dict[str, str],
49
+ base_hf_config: AutoConfig,
50
+ max_loras_per_batch: int,
51
+ load_config: LoadConfig,
52
+ dtype: torch.dtype,
53
+ lora_backend: str = "triton",
109
54
  ):
110
- self.base_model = base_model
111
- self.lora_paths = lora_paths
112
- self.base_hf_config = base_hf_config
113
- self.max_loras_per_batch = max_loras_per_batch
114
- self.load_config = load_config
115
- self.dtype = dtype
116
-
117
- logger.info(f"Using {lora_backend} as backend of Lora kernels.")
55
+ self.base_model: torch.nn.Module = base_model
56
+ self.lora_paths: Dict[str, str] = lora_paths
57
+ self.base_hf_config: AutoConfig = base_hf_config
58
+ self.max_loras_per_batch: int = max_loras_per_batch
59
+ self.load_config: LoadConfig = load_config
60
+ self.dtype: torch.dtype = dtype
61
+
62
+ # LoRA backend for running sgemm kernels
63
+ logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
118
64
  backend_type = get_backend_from_name(lora_backend)
119
- self.lora_backend = backend_type(lora_backend)
65
+ self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
120
66
 
121
67
  self.init_loras()
122
68
  self.init_lora_memory_pool()
123
- self.init_lora_batch()
124
-
125
- def match_target_modules(self, module_name):
126
- for target_module in self.target_modules:
127
- if module_name.split(".")[-1] == target_module:
128
- return True
129
- return False
130
-
131
- def get_target_modules(self):
132
- modules = []
133
- for module_name, module in self.base_model.named_modules():
134
- if self.match_target_modules(module_name):
135
- modules.append((module_name, module))
136
- return modules
137
-
138
- def set_lora_module(self, module_name, module):
139
- lora_module = get_lora_layer(
140
- module, self.max_lora_dim, self.scaling, self.lora_backend
141
- )
142
- replace_submodule(self.base_model, module_name, lora_module)
143
- return lora_module
144
69
 
145
70
  def init_loras(self):
146
- # get configs and target modules
147
- self.configs = {}
148
- self.origin_target_modules = set()
71
+ # Config of each LoRA adapter
72
+ self.configs: Dict[str, LoRAConfig] = {}
73
+
74
+ # Target module names in huggingface lora configs.
75
+ # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
76
+ self.hf_target_names: Set[str] = set()
149
77
  for name, path in self.lora_paths.items():
150
78
  self.configs[name] = LoRAConfig(path)
151
- self.origin_target_modules = set(self.origin_target_modules) | set(
79
+ self.hf_target_names = set(self.hf_target_names) | set(
152
80
  self.configs[name].target_modules
153
81
  )
154
- if hasattr(self.base_model, "get_module_name"):
155
- self.target_modules = {
156
- self.base_model.get_module_name(module)
157
- for module in self.origin_target_modules
158
- }
159
- else:
160
- logger.warning(
161
- "WARNING: get_module_name() is not defined, "
162
- "which is used to map config module name to model implementation module name."
163
- "Use the default one, but please check if it is correct for your model."
164
- )
165
- self.target_modules = {
166
- get_module_name(module) for module in self.origin_target_modules
167
- }
168
- self.target_weights = set(
169
- [get_stacked_name(module) for module in self.origin_target_modules]
82
+
83
+ # Target lora weight names for lora_a and lora_b modules repectively.
84
+ # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
85
+ self.lora_weight_names: Set[Tuple[str]] = set(
86
+ [get_stacked_name(module) for module in self.hf_target_names]
170
87
  )
171
88
 
172
89
  # load all weights to cpu
173
- self.loras = []
174
- self.lora_id = {}
90
+ self.loras: Dict[str, LoRAAdapter] = {}
175
91
  for name in self.lora_paths.keys():
176
- self.lora_id[name] = len(self.loras)
177
- self.loras.append(
178
- LoRAAdapter(
179
- name,
180
- self.configs[name],
181
- self.base_hf_config,
182
- self.load_config,
183
- self.lora_backend,
184
- )
92
+ lora_adapter = LoRAAdapter(
93
+ name,
94
+ self.configs[name],
95
+ self.base_hf_config,
96
+ self.load_config,
97
+ self.lora_backend,
185
98
  )
186
- self.loras[-1].initialize_weights()
99
+ lora_adapter.initialize_weights()
100
+ self.loras[name] = lora_adapter
187
101
 
188
102
  # misc lora configs
189
- self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
190
- self.scaling = self.loras[0].scaling
191
- # FIXME remove the restrictions
103
+ # FIXME remove the restrictions after implementing unified paging
104
+ self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
105
+ self.scaling: float = list(self.loras.values())[0].scaling
192
106
  assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
193
- assert all(x.scaling == self.scaling for x in self.loras)
107
+ assert all(x.scaling == self.scaling for x in self.loras.values())
194
108
 
195
- # monkey patch to use the LoRA version
196
- self.lora_modules = []
197
- for module_name, module in self.get_target_modules():
198
- self.lora_modules.append(
199
- (module_name, self.set_lora_module(module_name, module))
200
- )
109
+ # Convert original model layers to layers with LoRA
110
+ self.convert_to_lora_layers()
201
111
 
202
112
  def init_lora_memory_pool(self):
203
- # preallocate lora memory pool
204
- self.A_buffer = {}
205
- self.B_buffer = {}
206
- num_layer = self.base_hf_config.num_hidden_layers
207
- for module_A, module_B in self.target_weights:
208
- # init A tensor, column_major=True
209
- if hasattr(self.base_model, "get_hidden_dim"):
210
- hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
211
- else:
212
- logger.warning(
213
- "WARNING: get_hidden_dim() is not defined, "
214
- "which is used to get the hidden dim for different lora modules"
215
- "Use the default one, but please check if it is correct for your model."
216
- )
217
- hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
218
- c = self.loras[-1].get_stacked_multiply(module_A)
219
- if module_A not in self.A_buffer:
220
- self.A_buffer[module_A] = [
221
- torch.empty(
222
- (
223
- self.max_loras_per_batch,
224
- self.max_lora_dim * c,
225
- hidden_dim_A,
226
- ),
227
- dtype=self.dtype,
228
- device="cuda",
229
- )
230
- for i in range(num_layer)
231
- ]
232
- # init B tensor, column_major=True
233
- if hasattr(self.base_model, "get_hidden_dim"):
234
- _, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
235
- else:
236
- logger.warning(
237
- "WARNING: get_hidden_dim() is not defined, "
238
- "which is used to get the hidden dim for different lora modules"
239
- "Use the default one, but please check if it is correct for your model."
240
- )
241
- _, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
242
- c = self.loras[-1].get_stacked_multiply(module_B)
243
- if module_B not in self.B_buffer:
244
- self.B_buffer[module_B] = [
245
- torch.empty(
246
- (
247
- c,
248
- self.max_loras_per_batch,
249
- hidden_dim_B,
250
- self.max_lora_dim,
251
- ),
252
- dtype=self.dtype,
253
- device="cuda",
254
- )
255
- for i in range(num_layer)
256
- ]
257
-
258
- def init_lora_batch(self):
259
- self.active_uids = set() # set of active loras
260
- self.buffer_id = {} # lora uid -> idx in memory pool
261
-
262
- def get_weight_name(self, name, idx):
263
- for target_weight_name in self.target_weights:
264
- if target_weight_name[idx] in name:
265
- return target_weight_name[idx]
266
-
267
- def load_lora(self, uid, buffer_id):
268
- num_layer = self.base_hf_config.num_hidden_layers
269
- if uid is None:
270
- for i in range(num_layer):
271
- for k in self.A_buffer.keys():
272
- self.A_buffer[k][i][buffer_id] *= 0
273
- return
113
+ # Initialize memory pool
114
+ self.memory_pool = LoRAMemoryPool(
115
+ self.base_hf_config, self.max_loras_per_batch, self.max_lora_dim, self.dtype
116
+ )
274
117
 
275
- for i in range(num_layer):
276
- layer_weights = self.loras[self.lora_id[uid]].layers[i].weights
277
- for name, weights in layer_weights.items():
278
- if "lora_A" in name:
279
- lora_weight_name = self.get_weight_name(name, 0)
280
- if lora_weight_name:
281
- self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
282
- else:
283
- lora_weight_name = self.get_weight_name(name, 1)
284
- if lora_weight_name:
285
- c = self.loras[-1].get_stacked_multiply(lora_weight_name)
286
- if c > 1:
287
- for j in range(c):
288
- self.B_buffer[lora_weight_name][i][j][buffer_id].copy_(
289
- weights[j]
290
- )
291
- else:
292
- self.B_buffer[lora_weight_name][i][0][buffer_id].copy_(
293
- weights
294
- )
118
+ # Initialize target lora modules in memory pool
119
+ self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
295
120
 
296
121
  def prepare_lora_batch(self, forward_batch: ForwardBatch):
297
122
  # load active loras into lora memory pool
298
123
  cur_uids = set(forward_batch.lora_paths)
299
124
  assert len(cur_uids) <= self.max_loras_per_batch
300
- i = 0
301
- j = len(self.active_uids)
302
- evictable_uids = list(self.active_uids)
303
- for uid in cur_uids:
304
- if uid not in self.active_uids:
305
- if j < self.max_loras_per_batch:
306
- index = j
307
- j += 1
308
- else:
309
- while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
310
- i += 1
311
- assert i < len(evictable_uids)
312
- self.active_uids.remove(evictable_uids[i])
313
- self.buffer_id.pop(evictable_uids[i])
314
- index = i
315
- i += 1
316
- self.load_lora(uid, index)
317
- self.active_uids.add(uid)
318
- self.buffer_id[uid] = index
125
+ self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
319
126
 
127
+ # FIXME: Handle lora uid with None more safely
320
128
  if cur_uids == set([None]):
321
129
  return
322
130
 
@@ -332,9 +140,9 @@ class LoRAManager:
332
140
  max_len = int(torch.max(seg_lens))
333
141
  weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
334
142
  for i, lora_path in enumerate(forward_batch.lora_paths):
335
- weight_indices[i] = self.buffer_id[lora_path]
143
+ weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
336
144
 
337
- batch_info = LoraBatchInfo(
145
+ batch_info = LoRABatchInfo(
338
146
  bs=bs,
339
147
  seg_lens=seg_lens,
340
148
  seg_indptr=seg_indptr,
@@ -346,16 +154,40 @@ class LoRAManager:
346
154
  # call set_lora_info for each lora modules
347
155
  for module_name, module in self.lora_modules:
348
156
  layer_id = get_layer_id(module_name)
349
-
350
157
  if "qkv_proj" not in module_name:
351
- weight_name = self.get_weight_name(module_name, 0)
158
+ weight_name = get_weight_name(
159
+ module_name, self.lora_weight_names, LoRAType.LORA_A
160
+ )
352
161
  module.set_lora_info(
353
- self.A_buffer[weight_name][layer_id],
354
- self.B_buffer[weight_name][layer_id],
162
+ self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
163
+ self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
355
164
  )
356
165
  else:
357
166
  module.set_lora_info(
358
- self.A_buffer["qkv_proj"][layer_id],
359
- self.B_buffer["q_proj"][layer_id],
360
- self.B_buffer["kv_proj"][layer_id],
167
+ self.memory_pool.get_tensor("qkv_proj", layer_id, LoRAType.LORA_A),
168
+ self.memory_pool.get_tensor("q_proj", layer_id, LoRAType.LORA_B),
169
+ self.memory_pool.get_tensor("kv_proj", layer_id, LoRAType.LORA_B),
170
+ )
171
+
172
+ def set_lora_module(self, module_name, module):
173
+ lora_module = get_lora_layer(
174
+ module, self.max_lora_dim, self.scaling, self.lora_backend
175
+ )
176
+ replace_submodule(self.base_model, module_name, lora_module)
177
+ return lora_module
178
+
179
+ def convert_to_lora_layers(self):
180
+ # Target module names of customized layers defined in python/sglang/srt/layers
181
+ # e.g., {"qkv_proj", "o_proj"}
182
+ customized_target_names = get_customized_names_from_hf_names(
183
+ self.hf_target_names, self.base_model
184
+ )
185
+
186
+ # Monkey patch to use the LoRA version layers
187
+ self.lora_modules: List[Tuple[str, torch.nn.Module]] = []
188
+ for module_name, module in self.base_model.named_modules():
189
+ # The module should be converted if it is included in target_names
190
+ if module_name.split(".")[-1] in customized_target_names:
191
+ self.lora_modules.append(
192
+ (module_name, self.set_lora_module(module_name, module))
361
193
  )
@@ -0,0 +1,174 @@
1
+ from typing import Dict, List, Optional, Set, Tuple
2
+
3
+ import torch
4
+
5
+ from sglang.srt.hf_transformers_utils import AutoConfig
6
+ from sglang.srt.lora.lora import LoRAAdapter
7
+ from sglang.srt.lora.utils import (
8
+ LoRAType,
9
+ get_hidden_dim,
10
+ get_stacked_multiply,
11
+ get_weight_name,
12
+ )
13
+
14
+
15
+ class LoRAMemoryPool:
16
+ """Class for memory pool management of lora modules"""
17
+
18
+ def __init__(
19
+ self,
20
+ base_hf_config: AutoConfig,
21
+ max_loras_per_batch: int,
22
+ max_lora_dim: int,
23
+ dtype: torch.dtype,
24
+ ):
25
+
26
+ self.base_hf_config: AutoConfig = base_hf_config
27
+ self.num_layer: int = base_hf_config.num_hidden_layers
28
+ self.max_loras_per_batch: int = max_loras_per_batch
29
+ self.max_lora_dim: int = max_lora_dim
30
+ self.dtype: torch.dtype = dtype
31
+
32
+ # Both A_buffer and B_buffer maps lora weight names to its buffer space.
33
+ # A_buffer contains num_layer number of row-major tensors with shape
34
+ # (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
35
+ # B_buffer contains num_layer number of column-major tensors with shape
36
+ # (stacked_num, max_loras_per_batch, output_dim, max_lora_dim)
37
+ self.A_buffer: Dict[str, List[torch.Tensor]] = {}
38
+ self.B_buffer: Dict[str, List[torch.Tensor]] = {}
39
+
40
+ # Lora uid -> buffer idx in memory pool
41
+ self.uid_to_buffer_id: Dict[Optional[str], int] = {}
42
+
43
+ # Buffer idx -> lora uid in memory pool
44
+ # All uids are initalized as empty strings for empty buffer slots
45
+ # Here we don't initalize to None since None is a valid uid
46
+ self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
47
+
48
+ def init_buffers(
49
+ self,
50
+ lora_weight_names: Set[Tuple[str]],
51
+ base_model: torch.nn.Module,
52
+ ):
53
+
54
+ # lora_weight_names is a set of name pairs indicating each pair of lora modules to load
55
+ # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
56
+ self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
57
+
58
+ for module_A, module_B in lora_weight_names:
59
+ # Init A tensor, column_major=False
60
+ input_dim, _ = get_hidden_dim(module_A, self.base_hf_config, base_model)
61
+ c = get_stacked_multiply(module_A)
62
+ if module_A not in self.A_buffer:
63
+ self.A_buffer[module_A] = [
64
+ torch.empty(
65
+ (
66
+ self.max_loras_per_batch,
67
+ self.max_lora_dim * c,
68
+ input_dim,
69
+ ),
70
+ dtype=self.dtype,
71
+ device="cuda",
72
+ )
73
+ for i in range(self.num_layer)
74
+ ]
75
+
76
+ # Init B tensor, column_major=True
77
+ _, output_dim = get_hidden_dim(module_B, self.base_hf_config, base_model)
78
+ c = get_stacked_multiply(module_B)
79
+ if module_B not in self.B_buffer:
80
+ self.B_buffer[module_B] = [
81
+ torch.empty(
82
+ (
83
+ c, # stacked lora_b modules might need separation
84
+ self.max_loras_per_batch,
85
+ output_dim,
86
+ self.max_lora_dim,
87
+ ),
88
+ dtype=self.dtype,
89
+ device="cuda",
90
+ )
91
+ for i in range(self.num_layer)
92
+ ]
93
+
94
+ def prepare_lora_batch(
95
+ self,
96
+ cur_uids: Set[Optional[str]],
97
+ lora_adapters: Dict[str, LoRAAdapter],
98
+ ):
99
+
100
+ def get_available_buffer_slot():
101
+ for buffer_id in range(self.max_loras_per_batch):
102
+ # Prioritize empty slots
103
+ if self.buffer_id_to_uid[buffer_id] == "":
104
+ return buffer_id, ""
105
+
106
+ for buffer_id in range(self.max_loras_per_batch):
107
+ # Evict unneeded lora
108
+ if self.buffer_id_to_uid[buffer_id] not in cur_uids:
109
+ return buffer_id, self.buffer_id_to_uid[buffer_id]
110
+
111
+ raise ValueError(
112
+ "No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
113
+ )
114
+
115
+ for uid in cur_uids:
116
+ if uid not in self.uid_to_buffer_id:
117
+ buffer_id, evicted_lora_uid = get_available_buffer_slot()
118
+ if evicted_lora_uid != "":
119
+ self.uid_to_buffer_id.pop(evicted_lora_uid)
120
+ self.load_lora_weight_to_buffer(
121
+ uid, buffer_id, lora_adapters.get(uid, None)
122
+ )
123
+ self.uid_to_buffer_id[uid] = buffer_id
124
+ self.buffer_id_to_uid[buffer_id] = uid
125
+
126
+ def load_lora_weight_to_buffer(
127
+ self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
128
+ ):
129
+
130
+ if uid is None:
131
+ for i in range(self.num_layer):
132
+ for k in self.A_buffer.keys():
133
+ self.A_buffer[k][i][buffer_id] *= 0
134
+ return
135
+
136
+ assert lora_adapter is not None
137
+ for layer_id in range(self.num_layer):
138
+ layer_weights = lora_adapter.layers[layer_id].weights
139
+ for name, weights in layer_weights.items():
140
+ if "lora_A" in name:
141
+ lora_weight_name = get_weight_name(
142
+ name, self.lora_weight_names, LoRAType.LORA_A
143
+ )
144
+ if lora_weight_name:
145
+ self.A_buffer[lora_weight_name][layer_id][buffer_id].copy_(
146
+ weights
147
+ )
148
+ else:
149
+ lora_weight_name = get_weight_name(
150
+ name, self.lora_weight_names, LoRAType.LORA_B
151
+ )
152
+ if lora_weight_name:
153
+ c = get_stacked_multiply(lora_weight_name)
154
+ if c > 1:
155
+ for stacked_id in range(c):
156
+ self.B_buffer[lora_weight_name][layer_id][stacked_id][
157
+ buffer_id
158
+ ].copy_(weights[stacked_id])
159
+ else:
160
+ self.B_buffer[lora_weight_name][layer_id][0][
161
+ buffer_id
162
+ ].copy_(weights)
163
+
164
+ def get_tensor(
165
+ self, weight_name: str, layer_id: int, lora_type: LoRAType
166
+ ) -> torch.Tensor:
167
+
168
+ if lora_type == LoRAType.LORA_A:
169
+ return self.A_buffer[weight_name][layer_id]
170
+
171
+ return self.B_buffer[weight_name][layer_id]
172
+
173
+ def get_buffer_id(self, lora_uid: str):
174
+ return self.uid_to_buffer_id[lora_uid]
@@ -1,5 +1,11 @@
1
+ from .gate_up_lora_b import gate_up_lora_b_fwd
1
2
  from .qkv_lora_b import qkv_lora_b_fwd
2
3
  from .sgemm_lora_a import sgemm_lora_a_fwd
3
4
  from .sgemm_lora_b import sgemm_lora_b_fwd
4
5
 
5
- __all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"]
6
+ __all__ = [
7
+ "gate_up_lora_b_fwd",
8
+ "qkv_lora_b_fwd",
9
+ "sgemm_lora_a_fwd",
10
+ "sgemm_lora_b_fwd",
11
+ ]