sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
sglang/srt/lora/layers.py CHANGED
@@ -1,3 +1,5 @@
1
+ from typing import List, Tuple
2
+
1
3
  import torch
2
4
  from torch import nn
3
5
 
@@ -21,14 +23,10 @@ class BaseLayerWithLoRA(nn.Module):
21
23
  def __init__(
22
24
  self,
23
25
  base_layer: nn.Module,
24
- lora_rank: int,
25
- scaling: float,
26
26
  lora_backend: BaseLoRABackend,
27
27
  ):
28
28
  super().__init__()
29
29
  self.base_layer: nn.Module = base_layer
30
- self.lora_rank: int = lora_rank
31
- self.scaling: float = scaling
32
30
  self.set_lora: bool = False
33
31
  self.lora_backend: BaseLoRABackend = lora_backend
34
32
 
@@ -38,16 +36,28 @@ class BaseLayerWithLoRA(nn.Module):
38
36
  def set_lora_info(self, *args):
39
37
  pass
40
38
 
39
+ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
40
+ pass
41
+
42
+ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
43
+ pass
44
+
41
45
 
42
46
  class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
47
+ """
48
+ Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation).
49
+
50
+ Note: The current version does not yet implement the LoRA functionality.
51
+ This class behaves exactly the same as the base VocabParallelEmbedding.
52
+ Future versions will integrate LoRA functionality to support efficient parameter fine-tuning.
53
+ """
54
+
43
55
  def __init__(
44
56
  self,
45
57
  base_layer: VocabParallelEmbedding,
46
- lora_rank: int,
47
- scaling: float,
48
58
  lora_backend: BaseLoRABackend,
49
59
  ) -> None:
50
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
60
+ super().__init__(base_layer, lora_backend)
51
61
  self.weight = base_layer.weight
52
62
 
53
63
 
@@ -55,11 +65,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
55
65
  def __init__(
56
66
  self,
57
67
  base_layer: ColumnParallelLinear,
58
- lora_rank: int,
59
- scaling: float,
60
68
  lora_backend: BaseLoRABackend,
61
69
  ) -> None:
62
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
70
+ super().__init__(base_layer, lora_backend)
63
71
 
64
72
  def set_lora_info(
65
73
  self,
@@ -71,7 +79,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
71
79
  self.B_buffer = B_buffer
72
80
 
73
81
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
74
- backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
82
+ backend_kwargs = {"base_output": base_output}
75
83
  lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
76
84
  lora_output = self.lora_backend.run_lora_b_sgemm(
77
85
  lora_a_output,
@@ -80,8 +88,8 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
80
88
  )
81
89
  return (
82
90
  lora_output
83
- if self.lora_backend.fuse_output_scaling_add
84
- else base_output + lora_output * self.scaling
91
+ if self.lora_backend.fuse_output_add
92
+ else base_output + lora_output
85
93
  )
86
94
 
87
95
  def forward(self, input_: torch.Tensor):
@@ -101,16 +109,24 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
101
109
  output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
102
110
  return output, output_bias
103
111
 
112
+ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
113
+ return A
114
+
115
+ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
116
+ shard_size = self.base_layer.output_partition_sizes[0]
117
+ start_idx = tp_rank * shard_size
118
+ end_idx = (tp_rank + 1) * shard_size
119
+ B = B[start_idx:end_idx, :]
120
+ return B
121
+
104
122
 
105
123
  class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
106
124
  def __init__(
107
125
  self,
108
126
  base_layer: MergedColumnParallelLinear,
109
- lora_rank: int,
110
- scaling: float,
111
127
  lora_backend: BaseLoRABackend,
112
128
  ) -> None:
113
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
129
+ super().__init__(base_layer, lora_backend)
114
130
 
115
131
  def set_lora_info(
116
132
  self,
@@ -120,6 +136,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
120
136
  self.set_lora = True
121
137
  self.A_buffer_gate_up = A_buffer
122
138
  if self.lora_backend.fuse_stacked_lora_b:
139
+ # TODO: avoid using contiguous() in GPU.
123
140
  # B_buffer_gate_up: (num_lora, 2 * output_dim, r)
124
141
  self.B_buffer_gate_up = torch.cat(
125
142
  (B_buffer[0], B_buffer[1]), dim=-2
@@ -128,7 +145,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
128
145
  self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
129
146
 
130
147
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
131
- backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
148
+ backend_kwargs = {"base_output": base_output}
132
149
 
133
150
  lora_output = self.lora_backend.run_gate_up_lora(
134
151
  x,
@@ -138,20 +155,28 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
138
155
  )
139
156
  return (
140
157
  lora_output
141
- if self.lora_backend.fuse_output_scaling_add
142
- else base_output + lora_output * self.scaling
158
+ if self.lora_backend.fuse_output_add
159
+ else base_output + lora_output
143
160
  )
144
161
 
162
+ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
163
+ return A
164
+
165
+ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
166
+ # Since the outputs for both gate and up are identical, we use a random one.
167
+ shard_size = self.base_layer.output_partition_sizes[0]
168
+ start_idx = tp_rank * shard_size
169
+ end_idx = (tp_rank + 1) * shard_size
170
+ return B[:, start_idx:end_idx, :]
171
+
145
172
 
146
173
  class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
147
174
  def init__(
148
175
  self,
149
176
  base_layer: QKVParallelLinear,
150
- lora_rank: int,
151
- scaling: float,
152
177
  lora_backend: BaseLoRABackend,
153
178
  ) -> None:
154
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
179
+ super().__init__(base_layer, lora_backend)
155
180
 
156
181
  def set_lora_info(
157
182
  self,
@@ -193,7 +218,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
193
218
  )
194
219
 
195
220
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
196
- backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
221
+ backend_kwargs = {"base_output": base_output}
197
222
  if self.lora_backend.fuse_stacked_lora_b:
198
223
  backend_kwargs["output_offset"] = self.output_offset
199
224
  backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
@@ -206,20 +231,39 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
206
231
  )
207
232
  return (
208
233
  lora_output
209
- if self.lora_backend.fuse_output_scaling_add
210
- else base_output + lora_output * self.scaling
234
+ if self.lora_backend.fuse_output_add
235
+ else base_output + lora_output
211
236
  )
212
237
 
238
+ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
239
+ return A
240
+
241
+ def slice_lora_b_weights(
242
+ self, B: List[torch.Tensor], tp_rank: int
243
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
244
+ B_q, B_kv = B
245
+ base_layer = self.base_layer
246
+ q_proj_shard_size = base_layer.q_proj_shard_size
247
+ kv_proj_shard_size = base_layer.kv_proj_shard_size
248
+ num_kv_head_replicas = base_layer.num_kv_head_replicas
249
+
250
+ q_start_idx = q_proj_shard_size * tp_rank
251
+ q_end_idx = q_start_idx + q_proj_shard_size
252
+
253
+ kv_shard_id = tp_rank // num_kv_head_replicas
254
+ kv_start_idx = kv_proj_shard_size * kv_shard_id
255
+ kv_end_idx = kv_start_idx + kv_proj_shard_size
256
+
257
+ return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]
258
+
213
259
 
214
260
  class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
215
261
  def __init__(
216
262
  self,
217
263
  base_layer: RowParallelLinear,
218
- lora_rank: int,
219
- scaling: float,
220
264
  lora_backend: BaseLoRABackend,
221
265
  ) -> None:
222
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
266
+ super().__init__(base_layer, lora_backend)
223
267
 
224
268
  def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor):
225
269
  self.set_lora = True
@@ -227,7 +271,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
227
271
  self.B_buffer = B_buffer
228
272
 
229
273
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
230
- backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
274
+ backend_kwargs = {"base_output": base_output}
231
275
  lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
232
276
  lora_output = self.lora_backend.run_lora_b_sgemm(
233
277
  lora_a_output,
@@ -236,8 +280,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
236
280
  )
237
281
  return (
238
282
  lora_output
239
- if self.lora_backend.fuse_output_scaling_add
240
- else base_output + lora_output * self.scaling
283
+ if self.lora_backend.fuse_output_add
284
+ else base_output + lora_output
241
285
  )
242
286
 
243
287
  def forward(self, input_: torch.Tensor):
@@ -274,9 +318,19 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
274
318
  output_bias = self.base_layer.bias
275
319
  return output, output_bias
276
320
 
321
+ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
322
+ shard_size = self.base_layer.input_size_per_partition
323
+ start_idx = tp_rank * shard_size
324
+ end_idx = (tp_rank + 1) * shard_size
325
+ A = A[:, start_idx:end_idx].contiguous()
326
+ return A
327
+
328
+ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
329
+ return B
330
+
277
331
 
278
332
  def get_lora_layer(
279
- layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend
333
+ layer: nn.Module, lora_backend: BaseLoRABackend
280
334
  ) -> BaseLayerWithLoRA:
281
335
  supported_layer_types = {
282
336
  # the order matters
@@ -288,6 +342,6 @@ def get_lora_layer(
288
342
  }
289
343
  for src_layer_type, lora_layer_type in supported_layer_types.items():
290
344
  if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
291
- ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
345
+ ret = lora_layer_type(layer, lora_backend)
292
346
  return ret
293
347
  raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
sglang/srt/lora/lora.py CHANGED
@@ -39,16 +39,9 @@ class LoRALayer(nn.Module):
39
39
  super().__init__()
40
40
  self.config: LoRAConfig = config
41
41
  self.base_hf_config: AutoConfig = base_hf_config
42
- self.weights: Dict[str, torch.Tensor] = {}
43
- self.weight_gpu: Dict[str, torch.Tensor] = {}
44
-
45
- def load_to_gpu(self):
46
- for name, weight in self.weights.items():
47
- self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
48
42
 
49
- def offload_from_gpu(self):
50
- for name, weight in self.weights.items():
51
- self.weight_gpu[name] = None
43
+ # lora weights in cpu. The weights are loaded from checkpoint.
44
+ self.weights: Dict[str, torch.Tensor] = {}
52
45
 
53
46
 
54
47
  class LoRAAdapter(nn.Module):
@@ -77,19 +70,6 @@ class LoRAAdapter(nn.Module):
77
70
  )
78
71
 
79
72
  self.weights: Dict[str, torch.Tensor] = {}
80
- self.weights_gpu: Dict[str, torch.Tensor] = {}
81
-
82
- def load_to_gpu(self):
83
- for name, weight in self.weights.items():
84
- self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
85
- for layer in self.layers:
86
- layer.load_to_gpu()
87
-
88
- def offload_from_gpu(self):
89
- for name, weight in self.weights.items():
90
- self.weights_gpu[name] = None
91
- for layer in self.layers:
92
- layer.offload_from_gpu()
93
73
 
94
74
  # initialize the LoRA weights to cpu
95
75
  def initialize_weights(self):
@@ -23,7 +23,7 @@ import torch
23
23
  from sglang.srt.configs.load_config import LoadConfig
24
24
  from sglang.srt.hf_transformers_utils import AutoConfig
25
25
  from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
26
- from sglang.srt.lora.layers import get_lora_layer
26
+ from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
27
27
  from sglang.srt.lora.lora import LoRAAdapter
28
28
  from sglang.srt.lora.lora_config import LoRAConfig
29
29
  from sglang.srt.lora.mem_pool import LoRAMemoryPool
@@ -51,6 +51,8 @@ class LoRAManager:
51
51
  load_config: LoadConfig,
52
52
  dtype: torch.dtype,
53
53
  lora_backend: str = "triton",
54
+ tp_size: int = 1,
55
+ tp_rank: int = 0,
54
56
  ):
55
57
  self.base_model: torch.nn.Module = base_model
56
58
  self.lora_paths: Dict[str, str] = lora_paths
@@ -58,6 +60,9 @@ class LoRAManager:
58
60
  self.max_loras_per_batch: int = max_loras_per_batch
59
61
  self.load_config: LoadConfig = load_config
60
62
  self.dtype: torch.dtype = dtype
63
+ self.device: torch.device = next(self.base_model.parameters()).device
64
+ self.tp_size: int = tp_size
65
+ self.tp_rank: int = tp_rank
61
66
 
62
67
  # LoRA backend for running sgemm kernels
63
68
  logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
@@ -98,11 +103,14 @@ class LoRAManager:
98
103
  self.loras[name] = lora_adapter
99
104
 
100
105
  # misc lora configs
101
- # FIXME remove the restrictions after implementing unified paging
102
106
  self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
103
- self.scaling: float = list(self.loras.values())[0].scaling
104
- assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
105
- assert all(x.scaling == self.scaling for x in self.loras.values())
107
+
108
+ if self.lora_backend == "flashinfer":
109
+ # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
110
+ max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
111
+ scaling = list(self.loras.values())[0].scaling
112
+ assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
113
+ assert all(x.scaling == scaling for x in self.loras.values())
106
114
 
107
115
  # Convert original model layers to layers with LoRA
108
116
  self.convert_to_lora_layers()
@@ -110,7 +118,13 @@ class LoRAManager:
110
118
  def init_lora_memory_pool(self):
111
119
  # Initialize memory pool
112
120
  self.memory_pool = LoRAMemoryPool(
113
- self.base_hf_config, self.max_loras_per_batch, self.max_lora_dim, self.dtype
121
+ self.base_hf_config,
122
+ self.max_loras_per_batch,
123
+ self.max_lora_dim,
124
+ self.dtype,
125
+ self.tp_size,
126
+ self.tp_rank,
127
+ self.lora_modules,
114
128
  )
115
129
 
116
130
  # Initialize target lora modules in memory pool
@@ -131,14 +145,24 @@ class LoRAManager:
131
145
  seg_lens = (
132
146
  forward_batch.extend_seq_lens
133
147
  if forward_batch.forward_mode.is_extend()
134
- else torch.ones(bs, device="cuda")
148
+ else torch.ones(bs, device=self.device)
135
149
  )
136
- seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
150
+ seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
137
151
  seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
138
152
  max_len = int(torch.max(seg_lens))
139
- weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
153
+ weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
154
+
155
+ lora_ranks = torch.empty(
156
+ (self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
157
+ )
158
+ scalings = torch.empty(
159
+ (self.max_loras_per_batch,), dtype=torch.float, device="cuda"
160
+ )
140
161
  for i, lora_path in enumerate(forward_batch.lora_paths):
141
162
  weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
163
+ lora = self.loras[lora_path]
164
+ lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
165
+ scalings[weight_indices[i]] = lora.scaling
142
166
 
143
167
  batch_info = LoRABatchInfo(
144
168
  bs=bs,
@@ -146,31 +170,41 @@ class LoRAManager:
146
170
  seg_indptr=seg_indptr,
147
171
  max_len=max_len,
148
172
  weight_indices=weight_indices,
173
+ lora_ranks=lora_ranks,
174
+ scalings=scalings,
149
175
  )
150
176
  self.lora_backend.set_batch_info(batch_info)
151
177
 
152
178
  # call set_lora_info for each lora modules
153
- for module_name, module in self.lora_modules:
154
- layer_id = get_layer_id(module_name)
155
- if "qkv_proj" not in module_name:
156
- weight_name = get_weight_name(
157
- module_name, self.lora_weight_names, LoRAType.LORA_A
158
- )
159
- module.set_lora_info(
160
- self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
161
- self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
162
- )
163
- else:
164
- module.set_lora_info(
165
- self.memory_pool.get_tensor("qkv_proj", layer_id, LoRAType.LORA_A),
166
- self.memory_pool.get_tensor("q_proj", layer_id, LoRAType.LORA_B),
167
- self.memory_pool.get_tensor("kv_proj", layer_id, LoRAType.LORA_B),
168
- )
179
+ for layer_id, modules in self.lora_modules.items():
180
+ for module_name, module in modules:
181
+ if "qkv_proj" in module_name:
182
+ module.set_lora_info(
183
+ self.memory_pool.get_tensor(
184
+ "qkv_proj", layer_id, LoRAType.LORA_A
185
+ ),
186
+ self.memory_pool.get_tensor(
187
+ "q_proj", layer_id, LoRAType.LORA_B
188
+ ),
189
+ self.memory_pool.get_tensor(
190
+ "kv_proj", layer_id, LoRAType.LORA_B
191
+ ),
192
+ )
193
+ else:
194
+ weight_name = get_weight_name(
195
+ module_name, self.lora_weight_names, LoRAType.LORA_A
196
+ )
197
+ module.set_lora_info(
198
+ self.memory_pool.get_tensor(
199
+ weight_name, layer_id, LoRAType.LORA_A
200
+ ),
201
+ self.memory_pool.get_tensor(
202
+ weight_name, layer_id, LoRAType.LORA_B
203
+ ),
204
+ )
169
205
 
170
206
  def set_lora_module(self, module_name, module):
171
- lora_module = get_lora_layer(
172
- module, self.max_lora_dim, self.scaling, self.lora_backend
173
- )
207
+ lora_module = get_lora_layer(module, self.lora_backend)
174
208
  replace_submodule(self.base_model, module_name, lora_module)
175
209
  return lora_module
176
210
 
@@ -182,10 +216,13 @@ class LoRAManager:
182
216
  )
183
217
 
184
218
  # Monkey patch to use the LoRA version layers
185
- self.lora_modules: List[Tuple[str, torch.nn.Module]] = []
219
+ self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
220
+ i: [] for i in range(self.base_hf_config.num_hidden_layers)
221
+ }
186
222
  for module_name, module in self.base_model.named_modules():
187
223
  # The module should be converted if it is included in target_names
188
224
  if module_name.split(".")[-1] in customized_target_names:
189
- self.lora_modules.append(
225
+ layer_id = get_layer_id(module_name)
226
+ self.lora_modules[layer_id].append(
190
227
  (module_name, self.set_lora_module(module_name, module))
191
228
  )