sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/bench_serving.py +72 -10
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/deepseekvl2.py +10 -1
  4. sglang/srt/configs/model_config.py +6 -16
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/custom_op.py +5 -0
  7. sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
  8. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  9. sglang/srt/distributed/parallel_state.py +32 -5
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/entrypoints/http_server.py +7 -1
  12. sglang/srt/entrypoints/verl_engine.py +2 -0
  13. sglang/srt/function_call_parser.py +0 -1
  14. sglang/srt/layers/attention/flashattention_backend.py +582 -125
  15. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  17. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  18. sglang/srt/layers/dp_attention.py +12 -1
  19. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  20. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  21. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  26. sglang/srt/layers/moe/topk.py +79 -6
  27. sglang/srt/layers/quantization/__init__.py +137 -165
  28. sglang/srt/layers/quantization/awq.py +200 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  30. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  31. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  32. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  33. sglang/srt/layers/quantization/gptq.py +30 -40
  34. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  35. sglang/srt/layers/quantization/utils.py +1 -1
  36. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  37. sglang/srt/lora/backend/base_backend.py +4 -4
  38. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  39. sglang/srt/lora/backend/triton_backend.py +5 -8
  40. sglang/srt/lora/layers.py +19 -33
  41. sglang/srt/lora/lora_manager.py +20 -7
  42. sglang/srt/lora/mem_pool.py +12 -6
  43. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  44. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  45. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  46. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  47. sglang/srt/lora/utils.py +6 -0
  48. sglang/srt/managers/cache_controller.py +34 -11
  49. sglang/srt/managers/io_struct.py +4 -2
  50. sglang/srt/managers/mm_utils.py +202 -156
  51. sglang/srt/managers/multimodal_processor.py +0 -2
  52. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  53. sglang/srt/managers/multimodal_processors/clip.py +44 -0
  54. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  55. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  56. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  57. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  58. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  59. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  61. sglang/srt/managers/schedule_batch.py +185 -127
  62. sglang/srt/managers/scheduler.py +29 -23
  63. sglang/srt/managers/tokenizer_manager.py +1 -2
  64. sglang/srt/managers/tp_worker.py +3 -0
  65. sglang/srt/managers/utils.py +1 -6
  66. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  67. sglang/srt/mem_cache/memory_pool.py +72 -6
  68. sglang/srt/mem_cache/paged_allocator.py +39 -0
  69. sglang/srt/metrics/collector.py +23 -53
  70. sglang/srt/model_executor/cuda_graph_runner.py +16 -13
  71. sglang/srt/model_executor/forward_batch_info.py +10 -10
  72. sglang/srt/model_executor/model_runner.py +64 -59
  73. sglang/srt/model_loader/loader.py +19 -1
  74. sglang/srt/model_loader/weight_utils.py +6 -3
  75. sglang/srt/models/clip.py +568 -0
  76. sglang/srt/models/deepseek_janus_pro.py +12 -17
  77. sglang/srt/models/deepseek_v2.py +339 -123
  78. sglang/srt/models/deepseek_vl2.py +105 -104
  79. sglang/srt/models/gemma3_causal.py +12 -2
  80. sglang/srt/models/gemma3_mm.py +20 -80
  81. sglang/srt/models/llama.py +4 -1
  82. sglang/srt/models/llava.py +31 -19
  83. sglang/srt/models/llavavid.py +16 -7
  84. sglang/srt/models/minicpmo.py +63 -147
  85. sglang/srt/models/minicpmv.py +17 -27
  86. sglang/srt/models/mllama.py +29 -14
  87. sglang/srt/models/qwen2.py +9 -6
  88. sglang/srt/models/qwen2_5_vl.py +21 -31
  89. sglang/srt/models/qwen2_vl.py +20 -21
  90. sglang/srt/openai_api/adapter.py +106 -93
  91. sglang/srt/openai_api/protocol.py +10 -5
  92. sglang/srt/patch_torch.py +71 -0
  93. sglang/srt/platforms/interface.py +371 -0
  94. sglang/srt/server_args.py +120 -25
  95. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  96. sglang/srt/speculative/eagle_utils.py +140 -28
  97. sglang/srt/speculative/eagle_worker.py +94 -25
  98. sglang/srt/utils.py +137 -51
  99. sglang/test/runners.py +27 -2
  100. sglang/test/test_custom_ops.py +55 -0
  101. sglang/test/test_utils.py +14 -27
  102. sglang/utils.py +2 -2
  103. sglang/version.py +1 -1
  104. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
  105. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
  106. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  107. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  108. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
sglang/srt/lora/layers.py CHANGED
@@ -23,14 +23,10 @@ class BaseLayerWithLoRA(nn.Module):
23
23
  def __init__(
24
24
  self,
25
25
  base_layer: nn.Module,
26
- lora_rank: int,
27
- scaling: float,
28
26
  lora_backend: BaseLoRABackend,
29
27
  ):
30
28
  super().__init__()
31
29
  self.base_layer: nn.Module = base_layer
32
- self.lora_rank: int = lora_rank
33
- self.scaling: float = scaling
34
30
  self.set_lora: bool = False
35
31
  self.lora_backend: BaseLoRABackend = lora_backend
36
32
 
@@ -59,11 +55,9 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
59
55
  def __init__(
60
56
  self,
61
57
  base_layer: VocabParallelEmbedding,
62
- lora_rank: int,
63
- scaling: float,
64
58
  lora_backend: BaseLoRABackend,
65
59
  ) -> None:
66
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
60
+ super().__init__(base_layer, lora_backend)
67
61
  self.weight = base_layer.weight
68
62
 
69
63
 
@@ -71,11 +65,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
71
65
  def __init__(
72
66
  self,
73
67
  base_layer: ColumnParallelLinear,
74
- lora_rank: int,
75
- scaling: float,
76
68
  lora_backend: BaseLoRABackend,
77
69
  ) -> None:
78
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
70
+ super().__init__(base_layer, lora_backend)
79
71
 
80
72
  def set_lora_info(
81
73
  self,
@@ -87,7 +79,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
87
79
  self.B_buffer = B_buffer
88
80
 
89
81
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
90
- backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
82
+ backend_kwargs = {"base_output": base_output}
91
83
  lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
92
84
  lora_output = self.lora_backend.run_lora_b_sgemm(
93
85
  lora_a_output,
@@ -96,8 +88,8 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
96
88
  )
97
89
  return (
98
90
  lora_output
99
- if self.lora_backend.fuse_output_scaling_add
100
- else base_output + lora_output * self.scaling
91
+ if self.lora_backend.fuse_output_add
92
+ else base_output + lora_output
101
93
  )
102
94
 
103
95
  def forward(self, input_: torch.Tensor):
@@ -132,11 +124,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
132
124
  def __init__(
133
125
  self,
134
126
  base_layer: MergedColumnParallelLinear,
135
- lora_rank: int,
136
- scaling: float,
137
127
  lora_backend: BaseLoRABackend,
138
128
  ) -> None:
139
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
129
+ super().__init__(base_layer, lora_backend)
140
130
 
141
131
  def set_lora_info(
142
132
  self,
@@ -155,7 +145,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
155
145
  self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
156
146
 
157
147
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
158
- backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
148
+ backend_kwargs = {"base_output": base_output}
159
149
 
160
150
  lora_output = self.lora_backend.run_gate_up_lora(
161
151
  x,
@@ -165,8 +155,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
165
155
  )
166
156
  return (
167
157
  lora_output
168
- if self.lora_backend.fuse_output_scaling_add
169
- else base_output + lora_output * self.scaling
158
+ if self.lora_backend.fuse_output_add
159
+ else base_output + lora_output
170
160
  )
171
161
 
172
162
  def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
@@ -184,11 +174,9 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
184
174
  def init__(
185
175
  self,
186
176
  base_layer: QKVParallelLinear,
187
- lora_rank: int,
188
- scaling: float,
189
177
  lora_backend: BaseLoRABackend,
190
178
  ) -> None:
191
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
179
+ super().__init__(base_layer, lora_backend)
192
180
 
193
181
  def set_lora_info(
194
182
  self,
@@ -230,7 +218,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
230
218
  )
231
219
 
232
220
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
233
- backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
221
+ backend_kwargs = {"base_output": base_output}
234
222
  if self.lora_backend.fuse_stacked_lora_b:
235
223
  backend_kwargs["output_offset"] = self.output_offset
236
224
  backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
@@ -243,8 +231,8 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
243
231
  )
244
232
  return (
245
233
  lora_output
246
- if self.lora_backend.fuse_output_scaling_add
247
- else base_output + lora_output * self.scaling
234
+ if self.lora_backend.fuse_output_add
235
+ else base_output + lora_output
248
236
  )
249
237
 
250
238
  def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
@@ -273,11 +261,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
273
261
  def __init__(
274
262
  self,
275
263
  base_layer: RowParallelLinear,
276
- lora_rank: int,
277
- scaling: float,
278
264
  lora_backend: BaseLoRABackend,
279
265
  ) -> None:
280
- super().__init__(base_layer, lora_rank, scaling, lora_backend)
266
+ super().__init__(base_layer, lora_backend)
281
267
 
282
268
  def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor):
283
269
  self.set_lora = True
@@ -285,7 +271,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
285
271
  self.B_buffer = B_buffer
286
272
 
287
273
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
288
- backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
274
+ backend_kwargs = {"base_output": base_output}
289
275
  lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
290
276
  lora_output = self.lora_backend.run_lora_b_sgemm(
291
277
  lora_a_output,
@@ -294,8 +280,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
294
280
  )
295
281
  return (
296
282
  lora_output
297
- if self.lora_backend.fuse_output_scaling_add
298
- else base_output + lora_output * self.scaling
283
+ if self.lora_backend.fuse_output_add
284
+ else base_output + lora_output
299
285
  )
300
286
 
301
287
  def forward(self, input_: torch.Tensor):
@@ -344,7 +330,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
344
330
 
345
331
 
346
332
  def get_lora_layer(
347
- layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend
333
+ layer: nn.Module, lora_backend: BaseLoRABackend
348
334
  ) -> BaseLayerWithLoRA:
349
335
  supported_layer_types = {
350
336
  # the order matters
@@ -356,6 +342,6 @@ def get_lora_layer(
356
342
  }
357
343
  for src_layer_type, lora_layer_type in supported_layer_types.items():
358
344
  if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
359
- ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
345
+ ret = lora_layer_type(layer, lora_backend)
360
346
  return ret
361
347
  raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
@@ -103,11 +103,14 @@ class LoRAManager:
103
103
  self.loras[name] = lora_adapter
104
104
 
105
105
  # misc lora configs
106
- # FIXME remove the restrictions after implementing unified paging
107
106
  self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
108
- self.scaling: float = list(self.loras.values())[0].scaling
109
- assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
110
- assert all(x.scaling == self.scaling for x in self.loras.values())
107
+
108
+ if self.lora_backend == "flashinfer":
109
+ # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
110
+ max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
111
+ scaling = list(self.loras.values())[0].scaling
112
+ assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
113
+ assert all(x.scaling == scaling for x in self.loras.values())
111
114
 
112
115
  # Convert original model layers to layers with LoRA
113
116
  self.convert_to_lora_layers()
@@ -148,8 +151,18 @@ class LoRAManager:
148
151
  seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
149
152
  max_len = int(torch.max(seg_lens))
150
153
  weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
154
+
155
+ lora_ranks = torch.empty(
156
+ (self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
157
+ )
158
+ scalings = torch.empty(
159
+ (self.max_loras_per_batch,), dtype=torch.float, device="cuda"
160
+ )
151
161
  for i, lora_path in enumerate(forward_batch.lora_paths):
152
162
  weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
163
+ lora = self.loras[lora_path]
164
+ lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
165
+ scalings[weight_indices[i]] = lora.scaling
153
166
 
154
167
  batch_info = LoRABatchInfo(
155
168
  bs=bs,
@@ -157,6 +170,8 @@ class LoRAManager:
157
170
  seg_indptr=seg_indptr,
158
171
  max_len=max_len,
159
172
  weight_indices=weight_indices,
173
+ lora_ranks=lora_ranks,
174
+ scalings=scalings,
160
175
  )
161
176
  self.lora_backend.set_batch_info(batch_info)
162
177
 
@@ -189,9 +204,7 @@ class LoRAManager:
189
204
  )
190
205
 
191
206
  def set_lora_module(self, module_name, module):
192
- lora_module = get_lora_layer(
193
- module, self.max_lora_dim, self.scaling, self.lora_backend
194
- )
207
+ lora_module = get_lora_layer(module, self.lora_backend)
195
208
  replace_submodule(self.base_model, module_name, lora_module)
196
209
  return lora_module
197
210
 
@@ -163,10 +163,11 @@ class LoRAMemoryPool:
163
163
  if uid is None:
164
164
  for i in range(self.num_layer):
165
165
  for k in self.A_buffer.keys():
166
- self.A_buffer[k][i][buffer_id] *= 0
166
+ self.A_buffer[k][i][buffer_id] = 0
167
167
  return
168
168
 
169
169
  assert lora_adapter is not None
170
+ lora_rank = lora_adapter.config.hf_config["r"]
170
171
  for layer_id in range(self.num_layer):
171
172
  layer_weights = lora_adapter.layers[layer_id].weights
172
173
  temp_A_buffer: Dict[str, torch.Tensor] = {}
@@ -208,17 +209,22 @@ class LoRAMemoryPool:
208
209
  )
209
210
 
210
211
  for name, weights in temp_A_buffer.items():
211
- self.A_buffer[name][layer_id][buffer_id].copy_(weights)
212
+ c = get_stacked_multiply(name)
213
+ self.A_buffer[name][layer_id][buffer_id][: lora_rank * c, :].copy_(
214
+ weights
215
+ )
212
216
 
213
217
  for name, weights in temp_B_buffer.items():
214
218
  c = get_stacked_multiply(name)
215
219
  if c > 1:
216
220
  for stacked_id in range(c):
217
- self.B_buffer[name][layer_id][stacked_id][buffer_id].copy_(
218
- weights[stacked_id]
219
- )
221
+ self.B_buffer[name][layer_id][stacked_id][buffer_id][
222
+ :, :lora_rank
223
+ ].copy_(weights[stacked_id])
220
224
  else:
221
- self.B_buffer[name][layer_id][0][buffer_id].copy_(weights)
225
+ self.B_buffer[name][layer_id][0][buffer_id][:, :lora_rank].copy_(
226
+ weights
227
+ )
222
228
 
223
229
  def get_tensor(
224
230
  self, weight_name: str, layer_id: int, lora_type: LoRAType
@@ -22,17 +22,18 @@ def _gate_up_lora_b_kernel(
22
22
  w_stride_2,
23
23
  output_stride_0,
24
24
  output_stride_1,
25
- # Information on sequence lengths and weight id
25
+ # Information on sequence lengths,ranks and weight id
26
26
  seg_lens,
27
27
  seg_indptr,
28
28
  weight_indices,
29
+ lora_ranks,
29
30
  # Meta parameters
30
31
  BLOCK_S: tl.constexpr,
31
32
  BLOCK_N: tl.constexpr,
32
33
  BLOCK_K: tl.constexpr,
33
34
  # For fused output scaling and adding
34
35
  fuse_scaling_add,
35
- scaling,
36
+ scalings,
36
37
  ):
37
38
  # This kernel packs 2 sgemms (gate/up) into a single kernel.
38
39
 
@@ -51,6 +52,11 @@ def _gate_up_lora_b_kernel(
51
52
  w_index = tl.load(weight_indices + batch_id)
52
53
  seg_start = tl.load(seg_indptr + batch_id)
53
54
  n_start = gate_up_id * output_dim # offset on output dim
55
+ rank = tl.load(lora_ranks + w_index)
56
+ scaling = tl.load(scalings + w_index)
57
+
58
+ # Adjust K (rank) according to the specific LoRA adapter
59
+ K = tl.minimum(K, rank)
54
60
 
55
61
  # The tile in output matrix will have (pid_s, pid_n) as id
56
62
  num_pid_n = tl.cdiv(output_dim, BLOCK_N)
@@ -109,7 +115,6 @@ def gate_up_lora_b_fwd(
109
115
  batch_info: LoRABatchInfo,
110
116
  output_dim: int,
111
117
  base_output: torch.Tensor = None,
112
- scaling: float = 1.0,
113
118
  ) -> torch.Tensor:
114
119
 
115
120
  # x: (s, 2 * r)
@@ -160,11 +165,12 @@ def gate_up_lora_b_fwd(
160
165
  batch_info.seg_lens,
161
166
  batch_info.seg_indptr,
162
167
  batch_info.weight_indices,
168
+ batch_info.lora_ranks,
163
169
  BLOCK_S,
164
170
  BLOCK_OUT,
165
171
  BLOCK_R,
166
172
  fuse_scaling_add,
167
- scaling,
173
+ batch_info.scalings,
168
174
  )
169
175
 
170
176
  return output
@@ -26,6 +26,7 @@ def _qkv_lora_b_kernel(
26
26
  seg_lens,
27
27
  seg_indptr,
28
28
  weight_indices,
29
+ lora_ranks,
29
30
  # Offsets of q/k/v slice on output dimension
30
31
  n_offs,
31
32
  # Meta parameters
@@ -34,7 +35,7 @@ def _qkv_lora_b_kernel(
34
35
  BLOCK_K: tl.constexpr,
35
36
  # For fused output scaling and adding
36
37
  fuse_scaling_add,
37
- scaling,
38
+ scalings,
38
39
  ):
39
40
  # This kernel packs 3 sgemms (q/k/v) into a single kernel.
40
41
 
@@ -54,6 +55,10 @@ def _qkv_lora_b_kernel(
54
55
  seg_start = tl.load(seg_indptr + batch_id)
55
56
  n_start = tl.load(n_offs + qkv_id)
56
57
  n_size = tl.load(n_offs + qkv_id + 1) - n_start
58
+ rank = tl.load(lora_ranks + w_index)
59
+ scaling = tl.load(scalings + w_index)
60
+ # Adjust K (rank) according to the specific LoRA adapter
61
+ K = tl.minimum(K, rank)
57
62
 
58
63
  # The tile in output matrix will have (pid_s, pid_n) as id
59
64
  num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
@@ -112,7 +117,6 @@ def qkv_lora_b_fwd(
112
117
  output_offset: torch.Tensor,
113
118
  max_qkv_out_dim: int,
114
119
  base_output: torch.Tensor = None,
115
- scaling: float = 1.0,
116
120
  ) -> torch.Tensor:
117
121
 
118
122
  # x: (s, 3 * r)
@@ -171,12 +175,13 @@ def qkv_lora_b_fwd(
171
175
  batch_info.seg_lens,
172
176
  batch_info.seg_indptr,
173
177
  batch_info.weight_indices,
178
+ batch_info.lora_ranks,
174
179
  output_offset,
175
180
  BLOCK_S,
176
181
  BLOCK_OUT,
177
182
  BLOCK_R,
178
183
  fuse_scaling_add,
179
- scaling,
184
+ batch_info.scalings,
180
185
  )
181
186
 
182
187
  return output
@@ -12,8 +12,9 @@ def _sgemm_lora_a_kernel(
12
12
  weights,
13
13
  output,
14
14
  # Matrix dimensions
15
- N, # r
15
+ N, # stack_num * r
16
16
  K, # input_dim
17
+ stack_num,
17
18
  # Strides
18
19
  x_stride_0,
19
20
  x_stride_1,
@@ -22,10 +23,11 @@ def _sgemm_lora_a_kernel(
22
23
  w_stride_2,
23
24
  output_stride_0,
24
25
  output_stride_1,
25
- # Information on sequence lengths and weight id
26
+ # Information on sequence lengths,ranks and weight id
26
27
  seg_lens,
27
28
  seg_indptr,
28
29
  weight_indices,
30
+ lora_ranks,
29
31
  # Meta parameters
30
32
  BLOCK_S: tl.constexpr,
31
33
  BLOCK_N: tl.constexpr,
@@ -43,6 +45,9 @@ def _sgemm_lora_a_kernel(
43
45
  seg_len = tl.load(seg_lens + batch_id)
44
46
  w_index = tl.load(weight_indices + batch_id)
45
47
  seg_start = tl.load(seg_indptr + batch_id)
48
+ rank = tl.load(lora_ranks + w_index)
49
+ # Adjust N (stack_num * max_rank) according to the specific LoRA adapter
50
+ N = tl.minimum(N, rank * stack_num)
46
51
 
47
52
  # The tile in output matrix will have (pid_s, pid_n) as id
48
53
  num_pid_n = tl.cdiv(N, BLOCK_N)
@@ -91,11 +96,15 @@ def _sgemm_lora_a_kernel(
91
96
 
92
97
 
93
98
  def sgemm_lora_a_fwd(
94
- x: torch.Tensor, weights: torch.Tensor, batch_info: LoRABatchInfo
99
+ x: torch.Tensor,
100
+ weights: torch.Tensor,
101
+ batch_info: LoRABatchInfo,
102
+ stack_num: int = 1,
95
103
  ) -> torch.Tensor:
96
104
  # x: (s, input_dim)
97
- # weights: (num_lora, r, input_dim)
98
- # output: (s, r)
105
+ # weights: (num_lora, stack_num * r, input_dim)
106
+ # output: (s, stack_num * r)
107
+ # stack_num: run_qkv_lora: 3, run_gate_up_lora: 2
99
108
  # when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
100
109
  # input_dim is much larger than r
101
110
 
@@ -126,6 +135,7 @@ def sgemm_lora_a_fwd(
126
135
  output,
127
136
  R,
128
137
  K,
138
+ stack_num,
129
139
  x.stride(0),
130
140
  x.stride(1),
131
141
  weights.stride(0),
@@ -136,6 +146,7 @@ def sgemm_lora_a_fwd(
136
146
  batch_info.seg_lens,
137
147
  batch_info.seg_indptr,
138
148
  batch_info.weight_indices,
149
+ batch_info.lora_ranks,
139
150
  BLOCK_S,
140
151
  BLOCK_R,
141
152
  BLOCK_K,
@@ -26,13 +26,14 @@ def _sgemm_lora_b_kernel(
26
26
  seg_lens,
27
27
  seg_indptr,
28
28
  weight_indices,
29
+ lora_ranks,
29
30
  # Meta parameters
30
31
  BLOCK_S: tl.constexpr,
31
32
  BLOCK_N: tl.constexpr,
32
33
  BLOCK_K: tl.constexpr,
33
34
  # For fused output scaling and adding
34
35
  fuse_scaling_add,
35
- scaling,
36
+ scalings,
36
37
  ):
37
38
  # x: (s, K), s is the sum of sequence lengths
38
39
  # weights: (num_lora, N, K)
@@ -45,6 +46,10 @@ def _sgemm_lora_b_kernel(
45
46
  seg_len = tl.load(seg_lens + batch_id)
46
47
  w_index = tl.load(weight_indices + batch_id)
47
48
  seg_start = tl.load(seg_indptr + batch_id)
49
+ rank = tl.load(lora_ranks + w_index)
50
+ scaling = tl.load(scalings + w_index)
51
+ # Adjust K (rank) according to the specific LoRA adapter
52
+ K = tl.minimum(K, rank)
48
53
 
49
54
  # The tile in output matrix will have (pid_s, pid_n) as id
50
55
  num_pid_n = tl.cdiv(N, BLOCK_N)
@@ -100,12 +105,11 @@ def sgemm_lora_b_fwd(
100
105
  weights: torch.Tensor,
101
106
  batch_info: LoRABatchInfo,
102
107
  base_output: torch.Tensor = None,
103
- scaling: float = 1.0,
104
108
  ) -> torch.Tensor:
105
- # x: (s, r)
106
- # weights: (num_lora, output_dim, r)
109
+ # x: (s, max_r)
110
+ # weights: (num_lora, output_dim, max_r)
107
111
  # output: (s, output_dim)
108
- # output_dim is much larger than r
112
+ # output_dim is much larger than max_r
109
113
 
110
114
  assert x.is_contiguous()
111
115
  assert weights.is_contiguous()
@@ -150,10 +154,11 @@ def sgemm_lora_b_fwd(
150
154
  batch_info.seg_lens,
151
155
  batch_info.seg_indptr,
152
156
  batch_info.weight_indices,
157
+ batch_info.lora_ranks,
153
158
  BLOCK_S,
154
159
  BLOCK_N,
155
160
  BLOCK_R,
156
161
  fuse_scaling_add,
157
- scaling,
162
+ batch_info.scalings,
158
163
  )
159
164
  return output
sglang/srt/lora/utils.py CHANGED
@@ -25,6 +25,12 @@ class LoRABatchInfo:
25
25
  # The index of lora adapter used by each sequence, in shape (bs,)
26
26
  weight_indices: torch.Tensor
27
27
 
28
+ # ranks of each lora adapter, in shape (lora_num,)
29
+ lora_ranks: torch.Tensor
30
+
31
+ # scaling of each lora adapter, in shape (lora_num,)
32
+ scalings: torch.Tensor
33
+
28
34
 
29
35
  class LoRAType(Enum):
30
36
  LORA_A = 0
@@ -149,6 +149,7 @@ class HiCacheController:
149
149
  self,
150
150
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
151
151
  mem_pool_host: HostKVCache,
152
+ page_size: int,
152
153
  load_cache_event: threading.Event = None,
153
154
  write_policy: str = "write_through_selective",
154
155
  ):
@@ -156,6 +157,7 @@ class HiCacheController:
156
157
  self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
157
158
  self.mem_pool_host = mem_pool_host
158
159
  self.write_policy = write_policy
160
+ self.page_size = page_size
159
161
 
160
162
  self.load_cache_event = load_cache_event
161
163
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -184,7 +186,12 @@ class HiCacheController:
184
186
  self.load_stream = torch.cuda.Stream()
185
187
 
186
188
  self.write_thread = threading.Thread(
187
- target=self.write_thread_func_buffer, daemon=True
189
+ target=(
190
+ self.write_thread_func_buffer
191
+ if self.page_size == 1
192
+ else self.write_thread_func_direct
193
+ ),
194
+ daemon=True,
188
195
  )
189
196
  self.load_thread = threading.Thread(
190
197
  target=self.load_thread_func_layer_by_layer, daemon=True
@@ -205,7 +212,12 @@ class HiCacheController:
205
212
  self.ack_load_queue.queue.clear()
206
213
 
207
214
  self.write_thread = threading.Thread(
208
- target=self.write_thread_func_buffer, daemon=True
215
+ target=(
216
+ self.write_thread_func_buffer
217
+ if self.page_size == 1
218
+ else self.write_thread_func_direct
219
+ ),
220
+ daemon=True,
209
221
  )
210
222
  self.load_thread = threading.Thread(
211
223
  target=self.load_thread_func_layer_by_layer, daemon=True
@@ -260,10 +272,12 @@ class HiCacheController:
260
272
  while not self.stop_event.is_set():
261
273
  try:
262
274
  operation = self.write_queue.get(block=True, timeout=1)
263
- operation.data = self.mem_pool_device.get_flat_data(
264
- operation.device_indices
275
+ self.mem_pool_host.write_page_all_layers(
276
+ operation.host_indices,
277
+ operation.device_indices,
278
+ self.mem_pool_device,
265
279
  )
266
- self.mem_pool_host.transfer(operation.host_indices, operation.data)
280
+ self.write_stream.synchronize()
267
281
  self.mem_pool_host.complete_io(operation.host_indices)
268
282
  for node_id in operation.node_ids:
269
283
  if node_id != 0:
@@ -320,12 +334,21 @@ class HiCacheController:
320
334
 
321
335
  self.layer_done_counter.reset()
322
336
  for i in range(self.mem_pool_host.layer_num):
323
- flat_data = self.mem_pool_host.get_flat_data_by_layer(
324
- batch_operation.host_indices, i
325
- )
326
- self.mem_pool_device.transfer_per_layer(
327
- batch_operation.device_indices, flat_data, i
328
- )
337
+ if self.page_size == 1:
338
+ flat_data = self.mem_pool_host.get_flat_data_by_layer(
339
+ batch_operation.host_indices, i
340
+ )
341
+ self.mem_pool_device.transfer_per_layer(
342
+ batch_operation.device_indices, flat_data, i
343
+ )
344
+ else:
345
+ self.mem_pool_host.load_page_per_layer(
346
+ batch_operation.host_indices,
347
+ batch_operation.device_indices,
348
+ self.mem_pool_device,
349
+ i,
350
+ )
351
+ self.load_stream.synchronize()
329
352
  self.layer_done_counter.increment()
330
353
 
331
354
  self.mem_pool_host.complete_io(batch_operation.host_indices)
@@ -20,7 +20,7 @@ import copy
20
20
  import uuid
21
21
  from dataclasses import dataclass, field
22
22
  from enum import Enum
23
- from typing import Any, Dict, List, Optional, Union
23
+ from typing import Any, Dict, List, Literal, Optional, Union
24
24
 
25
25
  from sglang.srt.managers.schedule_batch import BaseFinishReason
26
26
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -650,7 +650,7 @@ class ProfileReqInput:
650
650
  # If it is set, profiling is automatically stopped after this step, and
651
651
  # the caller doesn't need to run stop_profile.
652
652
  num_steps: Optional[int] = None
653
- activities: Optional[List[str]] = None
653
+ activities: Optional[List[Literal["CPU", "GPU", "MEM", "CUDA_PROFILER"]]] = None
654
654
 
655
655
 
656
656
  class ProfileReqType(Enum):
@@ -675,6 +675,8 @@ class ProfileReq:
675
675
  output_dir: Optional[str] = None
676
676
  num_steps: Optional[int] = None
677
677
  activities: Optional[List[str]] = None
678
+ with_stack: Optional[bool] = None
679
+ record_shapes: Optional[bool] = None
678
680
 
679
681
 
680
682
  @dataclass