sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. sglang/bench_serving.py +23 -3
  2. sglang/srt/configs/deepseekvl2.py +10 -1
  3. sglang/srt/configs/model_config.py +5 -16
  4. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  5. sglang/srt/distributed/parallel_state.py +32 -5
  6. sglang/srt/entrypoints/http_server.py +7 -1
  7. sglang/srt/entrypoints/verl_engine.py +2 -0
  8. sglang/srt/function_call_parser.py +0 -1
  9. sglang/srt/layers/attention/flashattention_backend.py +218 -79
  10. sglang/srt/layers/dp_attention.py +12 -1
  11. sglang/srt/layers/moe/topk.py +30 -3
  12. sglang/srt/layers/quantization/__init__.py +134 -165
  13. sglang/srt/layers/quantization/awq.py +200 -0
  14. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  15. sglang/srt/layers/quantization/gptq.py +30 -40
  16. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  17. sglang/srt/layers/rotary_embedding.py +12 -0
  18. sglang/srt/lora/backend/base_backend.py +4 -4
  19. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  20. sglang/srt/lora/backend/triton_backend.py +5 -8
  21. sglang/srt/lora/layers.py +19 -33
  22. sglang/srt/lora/lora_manager.py +20 -7
  23. sglang/srt/lora/mem_pool.py +12 -6
  24. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  25. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  26. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  27. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  28. sglang/srt/lora/utils.py +6 -0
  29. sglang/srt/managers/io_struct.py +4 -2
  30. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  31. sglang/srt/managers/schedule_batch.py +1 -0
  32. sglang/srt/managers/scheduler.py +25 -19
  33. sglang/srt/managers/tokenizer_manager.py +0 -1
  34. sglang/srt/managers/tp_worker.py +3 -0
  35. sglang/srt/model_executor/cuda_graph_runner.py +9 -8
  36. sglang/srt/model_executor/model_runner.py +9 -6
  37. sglang/srt/model_loader/loader.py +11 -1
  38. sglang/srt/model_loader/weight_utils.py +6 -3
  39. sglang/srt/models/clip.py +563 -0
  40. sglang/srt/models/deepseek_janus_pro.py +2 -2
  41. sglang/srt/models/deepseek_v2.py +151 -26
  42. sglang/srt/models/gemma3_causal.py +12 -2
  43. sglang/srt/models/gemma3_mm.py +6 -0
  44. sglang/srt/openai_api/adapter.py +88 -87
  45. sglang/srt/openai_api/protocol.py +10 -5
  46. sglang/srt/patch_torch.py +71 -0
  47. sglang/srt/server_args.py +21 -11
  48. sglang/srt/speculative/eagle_worker.py +1 -1
  49. sglang/srt/utils.py +33 -0
  50. sglang/test/runners.py +27 -2
  51. sglang/test/test_utils.py +1 -1
  52. sglang/version.py +1 -1
  53. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +8 -4
  54. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +57 -53
  55. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +0 -0
  56. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/licenses/LICENSE +0 -0
  57. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,7 @@ import triton.language as tl
24
24
 
25
25
  from sglang.srt.utils import (
26
26
  direct_register_custom_op,
27
+ get_bool_env_var,
27
28
  get_device_core_count,
28
29
  get_device_name,
29
30
  get_device_sm,
@@ -43,7 +44,7 @@ if _is_cuda:
43
44
  from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
44
45
 
45
46
  sm_version = get_device_sm()
46
- if sm_version >= 90 and int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "1")):
47
+ if sm_version >= 90 and get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
47
48
  _enable_jit_deepgemm = True
48
49
 
49
50
 
@@ -11,12 +11,29 @@ from sglang.srt.utils import is_cuda
11
11
  _is_cuda = is_cuda()
12
12
 
13
13
  try:
14
- import vllm
14
+ from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
15
+ from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
16
+ from vllm.model_executor.layers.quantization.gptq_marlin import (
17
+ GPTQMarlinLinearMethod,
18
+ GPTQMarlinMoEMethod,
19
+ )
20
+ from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
21
+ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
22
+ check_marlin_supported,
23
+ )
24
+ from vllm.scalar_type import scalar_types
15
25
 
16
26
  VLLM_AVAILABLE = True
17
27
  except ImportError:
18
28
  VLLM_AVAILABLE = False
19
29
 
30
+ GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
31
+
32
+ class scalar_types:
33
+ uint4b8 = "uint4b8"
34
+ uint8b128 = "uint8b128"
35
+
36
+
20
37
  logger = logging.getLogger(__name__)
21
38
 
22
39
 
@@ -117,12 +134,8 @@ class GPTQConfig(QuantizationConfig):
117
134
 
118
135
  def get_quant_method(
119
136
  self, layer: torch.nn.Module, prefix: str
120
- ) -> Optional["GPTQLinearMethod"]:
121
- if not VLLM_AVAILABLE:
122
- raise ImportError("vllm is not installed")
123
-
124
- from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
125
-
137
+ ) -> Optional[GPTQLinearMethod]:
138
+ # Delay the import to avoid circular dependency
126
139
  from sglang.srt.layers.quantization import get_linear_quant_method
127
140
 
128
141
  return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
@@ -131,16 +144,11 @@ class GPTQConfig(QuantizationConfig):
131
144
  class GPTQMarlinConfig(QuantizationConfig):
132
145
  """Config class for GPTQ Marlin"""
133
146
 
134
- if VLLM_AVAILABLE:
135
- from vllm.scalar_type import scalar_types
136
-
137
- # (num_bits, is_sym) -> quant_type
138
- TYPE_MAP = {
139
- (4, True): scalar_types.uint4b8,
140
- (8, True): scalar_types.uint8b128,
141
- }
142
- else:
143
- raise ImportError("vllm is not installed")
147
+ # (num_bits, is_sym) -> quant_type
148
+ TYPE_MAP = {
149
+ (4, True): scalar_types.uint4b8,
150
+ (8, True): scalar_types.uint8b128,
151
+ }
144
152
 
145
153
  def __init__(
146
154
  self,
@@ -197,6 +205,7 @@ class GPTQMarlinConfig(QuantizationConfig):
197
205
  "Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
198
206
  )
199
207
 
208
+ # (num_bits, is_sym) -> quant_type
200
209
  self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
201
210
 
202
211
  def __repr__(self) -> str:
@@ -278,15 +287,8 @@ class GPTQMarlinConfig(QuantizationConfig):
278
287
 
279
288
  def get_quant_method(
280
289
  self, layer: torch.nn.Module, prefix: str
281
- ) -> Optional["QuantizeMethodBase"]:
282
- if not VLLM_AVAILABLE:
283
- raise ImportError("vllm is not installed")
284
-
285
- from vllm.model_executor.layers.quantization.gptq_marlin import (
286
- GPTQMarlinLinearMethod,
287
- GPTQMarlinMoEMethod,
288
- )
289
-
290
+ ) -> Optional[QuantizeMethodBase]:
291
+ # Delay the import to avoid circular dependency
290
292
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
291
293
  from sglang.srt.layers.quantization import get_linear_quant_method
292
294
 
@@ -304,19 +306,12 @@ class GPTQMarlinConfig(QuantizationConfig):
304
306
 
305
307
  @classmethod
306
308
  def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
307
- if not VLLM_AVAILABLE:
308
- return False
309
-
310
309
  quant_method = quant_config.get("quant_method", "").lower()
311
310
  num_bits = quant_config.get("bits")
312
311
  group_size = quant_config.get("group_size")
313
312
  sym = quant_config.get("sym")
314
313
  desc_act = quant_config.get("desc_act")
315
314
 
316
- from vllm.model_executor.layers.quantization.utils.marlin_utils import (
317
- check_marlin_supported,
318
- )
319
-
320
315
  if not _is_cuda:
321
316
  return False
322
317
 
@@ -427,13 +422,8 @@ class MarlinConfig(QuantizationConfig):
427
422
 
428
423
  def get_quant_method(
429
424
  self, layer: torch.nn.Module, prefix: str
430
- ) -> Optional["MarlinLinearMethod"]:
431
- if not VLLM_AVAILABLE:
432
- raise ImportError("vllm is not installed")
433
-
434
- from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
435
-
436
- # Delay import to avoid circular dependency
425
+ ) -> Optional[MarlinLinearMethod]:
426
+ # Delay the import to avoid circular dependency
437
427
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
438
428
 
439
429
  if isinstance(layer, LinearBase) or (
@@ -37,7 +37,7 @@ class W8A8Fp8Config(QuantizationConfig):
37
37
  Note:
38
38
  - For models without offline quantization, weights will be quantized during model loading
39
39
  - If CUTLASS is supported: Per-channel weight quantization is used
40
- - If CUTLASS is not supported: Falls back to per-token weight quantization
40
+ - If CUTLASS is not supported: Falls back to per-tensor weight quantization
41
41
  """
42
42
 
43
43
  def __init__(self, is_checkpoint_fp8_serialized: bool = False):
@@ -651,6 +651,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
651
651
  query: torch.Tensor,
652
652
  key: torch.Tensor,
653
653
  offsets: Optional[torch.Tensor] = None,
654
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
655
+ if _is_cuda_available:
656
+ return self.forward_cuda(positions, query, key, offsets)
657
+ else:
658
+ return self.forward_native(positions, query, key, offsets)
659
+
660
+ def forward_native(
661
+ self,
662
+ positions: torch.Tensor,
663
+ query: torch.Tensor,
664
+ key: torch.Tensor,
665
+ offsets: Optional[torch.Tensor] = None,
654
666
  ) -> Tuple[torch.Tensor, torch.Tensor]:
655
667
  """PyTorch-native implementation equivalent to forward()."""
656
668
  query_rot = query[..., : self.rotary_dim]
@@ -5,7 +5,7 @@ import torch
5
5
  from sglang.srt.lora.utils import LoRABatchInfo
6
6
 
7
7
 
8
- def get_fuse_output_scaling_add_from_name(name: str) -> bool:
8
+ def get_fuse_output_add_from_name(name: str) -> bool:
9
9
  mapping = {
10
10
  "triton": True,
11
11
  "flashinfer": False,
@@ -28,14 +28,14 @@ class BaseLoRABackend:
28
28
  Args:
29
29
  name: name of backend
30
30
  batch_info: information of current batch for use
31
- fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
32
- and the operation of scaling and adding will be fused into kernel
31
+ fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
32
+ and the operation of adding will be fused into kernel
33
33
  """
34
34
 
35
35
  def __init__(self, name: str, batch_info: LoRABatchInfo = None):
36
36
  self.name = name
37
37
  self.batch_info = batch_info
38
- self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
38
+ self.fuse_output_add = get_fuse_output_add_from_name(name)
39
39
  self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
40
40
 
41
41
  def run_lora_a_sgemm(
@@ -37,13 +37,16 @@ class FlashInferLoRABackend(BaseLoRABackend):
37
37
  self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
38
38
  ) -> torch.Tensor:
39
39
 
40
- return self.segment_gemm.run(
41
- x=x,
42
- weights=weights,
43
- batch_size=self.batch_info.bs,
44
- weight_column_major=True,
45
- seg_indptr=self.batch_info.seg_indptr,
46
- weight_indices=self.batch_info.weight_indices,
40
+ return (
41
+ self.segment_gemm.run(
42
+ x=x,
43
+ weights=weights,
44
+ batch_size=self.batch_info.bs,
45
+ weight_column_major=True,
46
+ seg_indptr=self.batch_info.seg_indptr,
47
+ weight_indices=self.batch_info.weight_indices,
48
+ )
49
+ * self.batch_info.scalings[0]
47
50
  )
48
51
 
49
52
  def run_qkv_lora(
@@ -90,7 +93,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
90
93
  weights=kv_lora_b[1],
91
94
  )
92
95
 
93
- return lora_output
96
+ return lora_output * self.batch_info.scalings[0]
94
97
 
95
98
  def run_gate_up_lora(
96
99
  self,
@@ -125,4 +128,4 @@ class FlashInferLoRABackend(BaseLoRABackend):
125
128
  weights=gate_up_lora_b[1],
126
129
  )
127
130
 
128
- return lora_output
131
+ return lora_output * self.batch_info.scalings[0]
@@ -25,11 +25,10 @@ class TritonLoRABackend(BaseLoRABackend):
25
25
  x: torch.Tensor,
26
26
  weights: torch.Tensor,
27
27
  base_output: torch.Tensor = None,
28
- scaling: float = 1.0,
29
28
  *args,
30
29
  **kwargs
31
30
  ) -> torch.Tensor:
32
- return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling)
31
+ return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output)
33
32
 
34
33
  def run_qkv_lora(
35
34
  self,
@@ -39,7 +38,6 @@ class TritonLoRABackend(BaseLoRABackend):
39
38
  output_offset: torch.Tensor,
40
39
  max_qkv_out_dim: int,
41
40
  base_output: torch.Tensor = None,
42
- scaling: float = 1.0,
43
41
  *args,
44
42
  **kwargs
45
43
  ) -> torch.Tensor:
@@ -49,7 +47,7 @@ class TritonLoRABackend(BaseLoRABackend):
49
47
  # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
50
48
  assert isinstance(qkv_lora_b, torch.Tensor)
51
49
 
52
- lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
50
+ lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info, stack_num=3)
53
51
  lora_output = qkv_lora_b_fwd(
54
52
  lora_a_output,
55
53
  qkv_lora_b,
@@ -57,7 +55,6 @@ class TritonLoRABackend(BaseLoRABackend):
57
55
  output_offset,
58
56
  max_qkv_out_dim,
59
57
  base_output,
60
- scaling,
61
58
  )
62
59
  return lora_output
63
60
 
@@ -67,7 +64,6 @@ class TritonLoRABackend(BaseLoRABackend):
67
64
  gate_up_lora_a: torch.Tensor,
68
65
  gate_up_lora_b: torch.Tensor,
69
66
  base_output: torch.Tensor = None,
70
- scaling: float = 1.0,
71
67
  *args,
72
68
  **kwargs
73
69
  ) -> torch.Tensor:
@@ -79,13 +75,14 @@ class TritonLoRABackend(BaseLoRABackend):
79
75
  output_dim = gate_up_lora_b.shape[-2] // 2
80
76
 
81
77
  # lora_a_output: (s, 2 * r)
82
- lora_a_output = sgemm_lora_a_fwd(x, gate_up_lora_a, self.batch_info)
78
+ lora_a_output = sgemm_lora_a_fwd(
79
+ x, gate_up_lora_a, self.batch_info, stack_num=2
80
+ )
83
81
  lora_output = gate_up_lora_b_fwd(
84
82
  lora_a_output,
85
83
  gate_up_lora_b,
86
84
  self.batch_info,
87
85
  output_dim,
88
86
  base_output,
89
- scaling,
90
87
  )
91
88
  return lora_output
sglang/srt/lora/layers.py CHANGED
@@ -23,14 +23,10 @@ class BaseLayerWithLoRA(nn.Module):
23
23
  def __init__(
24
24
  self,
25
25
  base_layer: nn.Module,
26
- lora_rank: int,
27
- scaling: float,
28
26
  lora_backend: BaseLoRABackend,
29
27
  ):
30
28
  super().__init__()
31
29
  self.base_layer: nn.Module = base_layer
32
- self.lora_rank: int = lora_rank
33
- self.scaling: float = scaling
34
30
  self.set_lora: bool = False
35
31
  self.lora_backend: BaseLoRABackend = lora_backend
36
32
 
@@ -59,11 +55,9 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
59
55
  def __init__(
60
56
  self,
61
57
  base_layer: VocabParallelEmbedding,
62
- lora_rank: int,
63
- scaling: float,
64
58
  lora_backend: BaseLoRABackend,
65
59
  ) -> None:
66
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
60
+ super().__init__(base_layer, lora_backend)
67
61
  self.weight = base_layer.weight
68
62
 
69
63
 
@@ -71,11 +65,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
71
65
  def __init__(
72
66
  self,
73
67
  base_layer: ColumnParallelLinear,
74
- lora_rank: int,
75
- scaling: float,
76
68
  lora_backend: BaseLoRABackend,
77
69
  ) -> None:
78
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
70
+ super().__init__(base_layer, lora_backend)
79
71
 
80
72
  def set_lora_info(
81
73
  self,
@@ -87,7 +79,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
87
79
  self.B_buffer = B_buffer
88
80
 
89
81
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
90
- backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
82
+ backend_kwargs = {"base_output": base_output}
91
83
  lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
92
84
  lora_output = self.lora_backend.run_lora_b_sgemm(
93
85
  lora_a_output,
@@ -96,8 +88,8 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
96
88
  )
97
89
  return (
98
90
  lora_output
99
- if self.lora_backend.fuse_output_scaling_add
100
- else base_output + lora_output * self.scaling
91
+ if self.lora_backend.fuse_output_add
92
+ else base_output + lora_output
101
93
  )
102
94
 
103
95
  def forward(self, input_: torch.Tensor):
@@ -132,11 +124,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
132
124
  def __init__(
133
125
  self,
134
126
  base_layer: MergedColumnParallelLinear,
135
- lora_rank: int,
136
- scaling: float,
137
127
  lora_backend: BaseLoRABackend,
138
128
  ) -> None:
139
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
129
+ super().__init__(base_layer, lora_backend)
140
130
 
141
131
  def set_lora_info(
142
132
  self,
@@ -155,7 +145,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
155
145
  self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
156
146
 
157
147
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
158
- backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
148
+ backend_kwargs = {"base_output": base_output}
159
149
 
160
150
  lora_output = self.lora_backend.run_gate_up_lora(
161
151
  x,
@@ -165,8 +155,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
165
155
  )
166
156
  return (
167
157
  lora_output
168
- if self.lora_backend.fuse_output_scaling_add
169
- else base_output + lora_output * self.scaling
158
+ if self.lora_backend.fuse_output_add
159
+ else base_output + lora_output
170
160
  )
171
161
 
172
162
  def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
@@ -184,11 +174,9 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
184
174
  def init__(
185
175
  self,
186
176
  base_layer: QKVParallelLinear,
187
- lora_rank: int,
188
- scaling: float,
189
177
  lora_backend: BaseLoRABackend,
190
178
  ) -> None:
191
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
179
+ super().__init__(base_layer, lora_backend)
192
180
 
193
181
  def set_lora_info(
194
182
  self,
@@ -230,7 +218,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
230
218
  )
231
219
 
232
220
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
233
- backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
221
+ backend_kwargs = {"base_output": base_output}
234
222
  if self.lora_backend.fuse_stacked_lora_b:
235
223
  backend_kwargs["output_offset"] = self.output_offset
236
224
  backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
@@ -243,8 +231,8 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
243
231
  )
244
232
  return (
245
233
  lora_output
246
- if self.lora_backend.fuse_output_scaling_add
247
- else base_output + lora_output * self.scaling
234
+ if self.lora_backend.fuse_output_add
235
+ else base_output + lora_output
248
236
  )
249
237
 
250
238
  def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
@@ -273,11 +261,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
273
261
  def __init__(
274
262
  self,
275
263
  base_layer: RowParallelLinear,
276
- lora_rank: int,
277
- scaling: float,
278
264
  lora_backend: BaseLoRABackend,
279
265
  ) -> None:
280
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
266
+ super().__init__(base_layer, lora_backend)
281
267
 
282
268
  def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor):
283
269
  self.set_lora = True
@@ -285,7 +271,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
285
271
  self.B_buffer = B_buffer
286
272
 
287
273
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
288
- backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
274
+ backend_kwargs = {"base_output": base_output}
289
275
  lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
290
276
  lora_output = self.lora_backend.run_lora_b_sgemm(
291
277
  lora_a_output,
@@ -294,8 +280,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
294
280
  )
295
281
  return (
296
282
  lora_output
297
- if self.lora_backend.fuse_output_scaling_add
298
- else base_output + lora_output * self.scaling
283
+ if self.lora_backend.fuse_output_add
284
+ else base_output + lora_output
299
285
  )
300
286
 
301
287
  def forward(self, input_: torch.Tensor):
@@ -344,7 +330,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
344
330
 
345
331
 
346
332
  def get_lora_layer(
347
- layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend
333
+ layer: nn.Module, lora_backend: BaseLoRABackend
348
334
  ) -> BaseLayerWithLoRA:
349
335
  supported_layer_types = {
350
336
  # the order matters
@@ -356,6 +342,6 @@ def get_lora_layer(
356
342
  }
357
343
  for src_layer_type, lora_layer_type in supported_layer_types.items():
358
344
  if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
359
- ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
345
+ ret = lora_layer_type(layer, lora_backend)
360
346
  return ret
361
347
  raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
@@ -103,11 +103,14 @@ class LoRAManager:
103
103
  self.loras[name] = lora_adapter
104
104
 
105
105
  # misc lora configs
106
- # FIXME remove the restrictions after implementing unified paging
107
106
  self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
108
- self.scaling: float = list(self.loras.values())[0].scaling
109
- assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
110
- assert all(x.scaling == self.scaling for x in self.loras.values())
107
+
108
+ if self.lora_backend == "flashinfer":
109
+ # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
110
+ max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
111
+ scaling = list(self.loras.values())[0].scaling
112
+ assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
113
+ assert all(x.scaling == scaling for x in self.loras.values())
111
114
 
112
115
  # Convert original model layers to layers with LoRA
113
116
  self.convert_to_lora_layers()
@@ -148,8 +151,18 @@ class LoRAManager:
148
151
  seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
149
152
  max_len = int(torch.max(seg_lens))
150
153
  weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
154
+
155
+ lora_ranks = torch.empty(
156
+ (self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
157
+ )
158
+ scalings = torch.empty(
159
+ (self.max_loras_per_batch,), dtype=torch.float, device="cuda"
160
+ )
151
161
  for i, lora_path in enumerate(forward_batch.lora_paths):
152
162
  weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
163
+ lora = self.loras[lora_path]
164
+ lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
165
+ scalings[weight_indices[i]] = lora.scaling
153
166
 
154
167
  batch_info = LoRABatchInfo(
155
168
  bs=bs,
@@ -157,6 +170,8 @@ class LoRAManager:
157
170
  seg_indptr=seg_indptr,
158
171
  max_len=max_len,
159
172
  weight_indices=weight_indices,
173
+ lora_ranks=lora_ranks,
174
+ scalings=scalings,
160
175
  )
161
176
  self.lora_backend.set_batch_info(batch_info)
162
177
 
@@ -189,9 +204,7 @@ class LoRAManager:
189
204
  )
190
205
 
191
206
  def set_lora_module(self, module_name, module):
192
- lora_module = get_lora_layer(
193
- module, self.max_lora_dim, self.scaling, self.lora_backend
194
- )
207
+ lora_module = get_lora_layer(module, self.lora_backend)
195
208
  replace_submodule(self.base_model, module_name, lora_module)
196
209
  return lora_module
197
210
 
@@ -163,10 +163,11 @@ class LoRAMemoryPool:
163
163
  if uid is None:
164
164
  for i in range(self.num_layer):
165
165
  for k in self.A_buffer.keys():
166
- self.A_buffer[k][i][buffer_id] *= 0
166
+ self.A_buffer[k][i][buffer_id] = 0
167
167
  return
168
168
 
169
169
  assert lora_adapter is not None
170
+ lora_rank = lora_adapter.config.hf_config["r"]
170
171
  for layer_id in range(self.num_layer):
171
172
  layer_weights = lora_adapter.layers[layer_id].weights
172
173
  temp_A_buffer: Dict[str, torch.Tensor] = {}
@@ -208,17 +209,22 @@ class LoRAMemoryPool:
208
209
  )
209
210
 
210
211
  for name, weights in temp_A_buffer.items():
211
- self.A_buffer[name][layer_id][buffer_id].copy_(weights)
212
+ c = get_stacked_multiply(name)
213
+ self.A_buffer[name][layer_id][buffer_id][: lora_rank * c, :].copy_(
214
+ weights
215
+ )
212
216
 
213
217
  for name, weights in temp_B_buffer.items():
214
218
  c = get_stacked_multiply(name)
215
219
  if c > 1:
216
220
  for stacked_id in range(c):
217
- self.B_buffer[name][layer_id][stacked_id][buffer_id].copy_(
218
- weights[stacked_id]
219
- )
221
+ self.B_buffer[name][layer_id][stacked_id][buffer_id][
222
+ :, :lora_rank
223
+ ].copy_(weights[stacked_id])
220
224
  else:
221
- self.B_buffer[name][layer_id][0][buffer_id].copy_(weights)
225
+ self.B_buffer[name][layer_id][0][buffer_id][:, :lora_rank].copy_(
226
+ weights
227
+ )
222
228
 
223
229
  def get_tensor(
224
230
  self, weight_name: str, layer_id: int, lora_type: LoRAType
@@ -22,17 +22,18 @@ def _gate_up_lora_b_kernel(
22
22
  w_stride_2,
23
23
  output_stride_0,
24
24
  output_stride_1,
25
- # Information on sequence lengths and weight id
25
+ # Information on sequence lengths,ranks and weight id
26
26
  seg_lens,
27
27
  seg_indptr,
28
28
  weight_indices,
29
+ lora_ranks,
29
30
  # Meta parameters
30
31
  BLOCK_S: tl.constexpr,
31
32
  BLOCK_N: tl.constexpr,
32
33
  BLOCK_K: tl.constexpr,
33
34
  # For fused output scaling and adding
34
35
  fuse_scaling_add,
35
- scaling,
36
+ scalings,
36
37
  ):
37
38
  # This kernel packs 2 sgemms (gate/up) into a single kernel.
38
39
 
@@ -51,6 +52,11 @@ def _gate_up_lora_b_kernel(
51
52
  w_index = tl.load(weight_indices + batch_id)
52
53
  seg_start = tl.load(seg_indptr + batch_id)
53
54
  n_start = gate_up_id * output_dim # offset on output dim
55
+ rank = tl.load(lora_ranks + w_index)
56
+ scaling = tl.load(scalings + w_index)
57
+
58
+ # Adjust K (rank) according to the specific LoRA adapter
59
+ K = tl.minimum(K, rank)
54
60
 
55
61
  # The tile in output matrix will have (pid_s, pid_n) as id
56
62
  num_pid_n = tl.cdiv(output_dim, BLOCK_N)
@@ -109,7 +115,6 @@ def gate_up_lora_b_fwd(
109
115
  batch_info: LoRABatchInfo,
110
116
  output_dim: int,
111
117
  base_output: torch.Tensor = None,
112
- scaling: float = 1.0,
113
118
  ) -> torch.Tensor:
114
119
 
115
120
  # x: (s, 2 * r)
@@ -160,11 +165,12 @@ def gate_up_lora_b_fwd(
160
165
  batch_info.seg_lens,
161
166
  batch_info.seg_indptr,
162
167
  batch_info.weight_indices,
168
+ batch_info.lora_ranks,
163
169
  BLOCK_S,
164
170
  BLOCK_OUT,
165
171
  BLOCK_R,
166
172
  fuse_scaling_add,
167
- scaling,
173
+ batch_info.scalings,
168
174
  )
169
175
 
170
176
  return output
@@ -26,6 +26,7 @@ def _qkv_lora_b_kernel(
26
26
  seg_lens,
27
27
  seg_indptr,
28
28
  weight_indices,
29
+ lora_ranks,
29
30
  # Offsets of q/k/v slice on output dimension
30
31
  n_offs,
31
32
  # Meta parameters
@@ -34,7 +35,7 @@ def _qkv_lora_b_kernel(
34
35
  BLOCK_K: tl.constexpr,
35
36
  # For fused output scaling and adding
36
37
  fuse_scaling_add,
37
- scaling,
38
+ scalings,
38
39
  ):
39
40
  # This kernel packs 3 sgemms (q/k/v) into a single kernel.
40
41
 
@@ -54,6 +55,10 @@ def _qkv_lora_b_kernel(
54
55
  seg_start = tl.load(seg_indptr + batch_id)
55
56
  n_start = tl.load(n_offs + qkv_id)
56
57
  n_size = tl.load(n_offs + qkv_id + 1) - n_start
58
+ rank = tl.load(lora_ranks + w_index)
59
+ scaling = tl.load(scalings + w_index)
60
+ # Adjust K (rank) according to the specific LoRA adapter
61
+ K = tl.minimum(K, rank)
57
62
 
58
63
  # The tile in output matrix will have (pid_s, pid_n) as id
59
64
  num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
@@ -112,7 +117,6 @@ def qkv_lora_b_fwd(
112
117
  output_offset: torch.Tensor,
113
118
  max_qkv_out_dim: int,
114
119
  base_output: torch.Tensor = None,
115
- scaling: float = 1.0,
116
120
  ) -> torch.Tensor:
117
121
 
118
122
  # x: (s, 3 * r)
@@ -171,12 +175,13 @@ def qkv_lora_b_fwd(
171
175
  batch_info.seg_lens,
172
176
  batch_info.seg_indptr,
173
177
  batch_info.weight_indices,
178
+ batch_info.lora_ranks,
174
179
  output_offset,
175
180
  BLOCK_S,
176
181
  BLOCK_OUT,
177
182
  BLOCK_R,
178
183
  fuse_scaling_add,
179
- scaling,
184
+ batch_info.scalings,
180
185
  )
181
186
 
182
187
  return output