sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +7 -0
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +16 -1
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mooncake/conn.py +16 -0
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +13 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +132 -79
- sglang/srt/function_call/ebnf_composer.py +10 -3
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/qwen3_coder_detector.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +14 -3
- sglang/srt/layers/moe/ep_moe/layer.py +323 -242
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +90 -24
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +27 -10
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/lora/lora_registry.py +93 -29
- sglang/srt/managers/cache_controller.py +9 -7
- sglang/srt/managers/data_parallel_controller.py +4 -0
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +14 -8
- sglang/srt/managers/scheduler.py +64 -1
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/tokenizer_manager.py +80 -15
- sglang/srt/managers/tp_worker.py +8 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +83 -27
- sglang/srt/models/deepseek_v2.py +75 -84
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/qwen3_moe.py +17 -71
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +65 -6
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/srt/utils.py +96 -1
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +118 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,7 @@ from sglang.srt.utils import (
|
|
24
24
|
)
|
25
25
|
|
26
26
|
if TYPE_CHECKING:
|
27
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
27
28
|
from sglang.srt.layers.moe.topk import TopKOutput
|
28
29
|
|
29
30
|
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
@@ -129,6 +130,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
129
130
|
super().__init__()
|
130
131
|
self.use_triton_kernels = use_triton_kernels
|
131
132
|
|
133
|
+
self.triton_kernel_moe_forward = None
|
134
|
+
if torch.cuda.is_available() and has_triton_kernels:
|
135
|
+
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
136
|
+
triton_kernel_moe_forward as _tk_forward,
|
137
|
+
)
|
138
|
+
|
139
|
+
self.triton_kernel_moe_forward = _tk_forward
|
140
|
+
|
132
141
|
def create_weights(
|
133
142
|
self,
|
134
143
|
layer: torch.nn.Module,
|
@@ -194,6 +203,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
194
203
|
no_combine: bool = False,
|
195
204
|
routed_scaling_factor: Optional[float] = None,
|
196
205
|
) -> torch.Tensor:
|
206
|
+
|
207
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
208
|
+
|
209
|
+
if isinstance(layer, EPMoE):
|
210
|
+
return layer.run_moe(
|
211
|
+
hidden_states=x,
|
212
|
+
topk_output=topk_output,
|
213
|
+
)
|
214
|
+
|
197
215
|
return self.forward(
|
198
216
|
x=x,
|
199
217
|
layer=layer,
|
@@ -219,16 +237,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
219
237
|
) -> torch.Tensor:
|
220
238
|
|
221
239
|
if self.use_triton_kernels:
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
# gating_output=router_logits,
|
229
|
-
# topk=top_k,
|
230
|
-
# renormalize=renormalize,
|
231
|
-
# )
|
240
|
+
return self.triton_kernel_moe_forward(
|
241
|
+
hidden_states=x,
|
242
|
+
w1=layer.w13_weight,
|
243
|
+
w2=layer.w2_weight,
|
244
|
+
topk_output=topk_output,
|
245
|
+
)
|
232
246
|
else:
|
233
247
|
if _use_aiter:
|
234
248
|
assert not no_combine, "unsupported"
|
@@ -354,69 +368,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
354
368
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
355
369
|
|
356
370
|
forward_native = forward_cpu
|
357
|
-
|
358
|
-
|
359
|
-
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
360
|
-
|
361
|
-
def create_weights(
|
362
|
-
self,
|
363
|
-
layer: torch.nn.Module,
|
364
|
-
num_experts_per_partition: int,
|
365
|
-
hidden_size: int,
|
366
|
-
intermediate_size: int,
|
367
|
-
params_dtype: torch.dtype,
|
368
|
-
**extra_weight_attrs,
|
369
|
-
):
|
370
|
-
# Fused gate_up_proj (column parallel)
|
371
|
-
w13_weight = torch.nn.Parameter(
|
372
|
-
torch.empty(
|
373
|
-
num_experts_per_partition,
|
374
|
-
2 * intermediate_size,
|
375
|
-
hidden_size,
|
376
|
-
dtype=params_dtype,
|
377
|
-
),
|
378
|
-
requires_grad=False,
|
379
|
-
)
|
380
|
-
layer.register_parameter("w13_weight", w13_weight)
|
381
|
-
set_weight_attrs(w13_weight, extra_weight_attrs)
|
382
|
-
|
383
|
-
# down_proj (row parallel)
|
384
|
-
w2_weight = torch.nn.Parameter(
|
385
|
-
torch.empty(
|
386
|
-
num_experts_per_partition,
|
387
|
-
hidden_size,
|
388
|
-
intermediate_size,
|
389
|
-
dtype=params_dtype,
|
390
|
-
),
|
391
|
-
requires_grad=False,
|
392
|
-
)
|
393
|
-
layer.register_parameter("w2_weight", w2_weight)
|
394
|
-
set_weight_attrs(w2_weight, extra_weight_attrs)
|
395
|
-
|
396
|
-
# scale
|
397
|
-
layer.register_parameter("w13_input_scale", None)
|
398
|
-
layer.register_parameter("w13_weight_scale", None)
|
399
|
-
|
400
|
-
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
401
|
-
|
402
|
-
w2_input_scale = torch.nn.Parameter(
|
403
|
-
ones_tensor,
|
404
|
-
requires_grad=False,
|
405
|
-
)
|
406
|
-
layer.register_parameter("w2_input_scale", w2_input_scale)
|
407
|
-
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
408
|
-
|
409
|
-
w2_weight_scale = torch.nn.Parameter(
|
410
|
-
ones_tensor,
|
411
|
-
requires_grad=False,
|
412
|
-
)
|
413
|
-
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
414
|
-
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
415
|
-
|
416
|
-
def apply(
|
417
|
-
self,
|
418
|
-
layer: torch.nn.Module,
|
419
|
-
hidden_states: torch.Tensor,
|
420
|
-
topk_output: TopKOutput,
|
421
|
-
) -> torch.Tensor:
|
422
|
-
raise NotImplementedError
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import Any, Dict, List, Optional
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch.nn import Module
|
@@ -17,6 +17,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
|
17
17
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
18
18
|
from sglang.srt.utils import set_weight_attrs
|
19
19
|
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput
|
22
|
+
|
20
23
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
21
24
|
|
22
25
|
logger = logging.getLogger(__name__)
|
@@ -84,13 +87,14 @@ class W4AFp8Config(QuantizationConfig):
|
|
84
87
|
self, layer: torch.nn.Module, prefix: str
|
85
88
|
) -> Optional[QuantizeMethodBase]:
|
86
89
|
from sglang.srt.layers.linear import LinearBase
|
90
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
87
91
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
88
92
|
|
89
93
|
if isinstance(layer, LinearBase):
|
90
94
|
if is_layer_skipped(prefix, self.ignored_layers):
|
91
95
|
return UnquantizedLinearMethod()
|
92
96
|
return Fp8LinearMethod(self)
|
93
|
-
elif isinstance(layer,
|
97
|
+
elif isinstance(layer, EPMoE):
|
94
98
|
return W4AFp8MoEMethod(self)
|
95
99
|
return None
|
96
100
|
|
@@ -105,8 +109,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
105
109
|
|
106
110
|
def create_weights(
|
107
111
|
self,
|
108
|
-
layer:
|
109
|
-
|
112
|
+
layer: EPMoE,
|
113
|
+
num_experts: int,
|
110
114
|
hidden_size: int,
|
111
115
|
intermediate_size: int,
|
112
116
|
params_dtype: torch.dtype,
|
@@ -117,7 +121,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
117
121
|
# Fused gate_up_proj (column parallel)
|
118
122
|
w13_weight = torch.nn.Parameter(
|
119
123
|
torch.empty(
|
120
|
-
|
124
|
+
num_experts,
|
121
125
|
intermediate_size * 2,
|
122
126
|
hidden_size // 2,
|
123
127
|
dtype=torch.int8,
|
@@ -130,7 +134,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
130
134
|
# down_proj (row parallel)
|
131
135
|
w2_weight = torch.nn.Parameter(
|
132
136
|
torch.empty(
|
133
|
-
|
137
|
+
num_experts,
|
134
138
|
hidden_size,
|
135
139
|
intermediate_size // 2,
|
136
140
|
dtype=torch.int8,
|
@@ -142,7 +146,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
142
146
|
|
143
147
|
w13_weight_scale = torch.nn.Parameter(
|
144
148
|
torch.zeros(
|
145
|
-
|
149
|
+
num_experts,
|
146
150
|
2 * intermediate_size,
|
147
151
|
hidden_size // self.quant_config.group_size,
|
148
152
|
dtype=torch.float32,
|
@@ -154,7 +158,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
154
158
|
|
155
159
|
w2_weight_scale = torch.nn.Parameter(
|
156
160
|
torch.zeros(
|
157
|
-
|
161
|
+
num_experts,
|
158
162
|
hidden_size,
|
159
163
|
intermediate_size // self.quant_config.group_size,
|
160
164
|
dtype=torch.float32,
|
@@ -166,14 +170,14 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
166
170
|
|
167
171
|
# Input scales
|
168
172
|
w13_input_scale = torch.nn.Parameter(
|
169
|
-
torch.ones((
|
173
|
+
torch.ones((num_experts, 2), dtype=torch.bfloat16),
|
170
174
|
requires_grad=False,
|
171
175
|
)
|
172
176
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
173
177
|
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
174
178
|
|
175
179
|
w2_input_scale = torch.nn.Parameter(
|
176
|
-
torch.ones(
|
180
|
+
torch.ones(num_experts, dtype=torch.bfloat16),
|
177
181
|
requires_grad=False,
|
178
182
|
)
|
179
183
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
@@ -183,25 +187,25 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
183
187
|
device = layer.w13_weight.device
|
184
188
|
|
185
189
|
self.a_strides1 = torch.full(
|
186
|
-
(
|
190
|
+
(num_experts, 3),
|
187
191
|
hidden_size,
|
188
192
|
device=device,
|
189
193
|
dtype=torch.int64,
|
190
194
|
)
|
191
195
|
self.c_strides1 = torch.full(
|
192
|
-
(
|
196
|
+
(num_experts, 3),
|
193
197
|
2 * intermediate_size,
|
194
198
|
device=device,
|
195
199
|
dtype=torch.int64,
|
196
200
|
)
|
197
201
|
self.a_strides2 = torch.full(
|
198
|
-
(
|
202
|
+
(num_experts, 3),
|
199
203
|
intermediate_size,
|
200
204
|
device=device,
|
201
205
|
dtype=torch.int64,
|
202
206
|
)
|
203
207
|
self.c_strides2 = torch.full(
|
204
|
-
(
|
208
|
+
(num_experts, 3),
|
205
209
|
hidden_size,
|
206
210
|
device=device,
|
207
211
|
dtype=torch.int64,
|
@@ -212,13 +216,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
212
216
|
self.s_strides2 = self.c_strides2
|
213
217
|
|
214
218
|
self.expert_offsets = torch.empty(
|
215
|
-
(
|
219
|
+
(num_experts + 1), dtype=torch.int32, device=device
|
216
220
|
)
|
217
221
|
self.problem_sizes1 = torch.empty(
|
218
|
-
(
|
222
|
+
(num_experts, 3), dtype=torch.int32, device=device
|
219
223
|
)
|
220
224
|
self.problem_sizes2 = torch.empty(
|
221
|
-
(
|
225
|
+
(num_experts, 3), dtype=torch.int32, device=device
|
222
226
|
)
|
223
227
|
|
224
228
|
return
|
@@ -266,3 +270,50 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
266
270
|
[w2_input_scale_max], dtype=dtype, device=device
|
267
271
|
)
|
268
272
|
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
|
273
|
+
|
274
|
+
def apply(
|
275
|
+
self,
|
276
|
+
layer: EPMoE,
|
277
|
+
hidden_states: torch.Tensor,
|
278
|
+
topk_output: TopKOutput,
|
279
|
+
) -> torch.Tensor:
|
280
|
+
|
281
|
+
# TODO(ch-wan): move it out of this class
|
282
|
+
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
283
|
+
|
284
|
+
topk_ids, topk_weights, _ = topk_output
|
285
|
+
local_topk_ids = topk_ids
|
286
|
+
if layer.expert_map is not None:
|
287
|
+
"Translate info from expert_map to topk_ids"
|
288
|
+
local_topk_ids = torch.where(
|
289
|
+
layer.expert_map[topk_ids] != layer.num_experts,
|
290
|
+
layer.expert_map[topk_ids],
|
291
|
+
layer.num_experts,
|
292
|
+
)
|
293
|
+
|
294
|
+
return cutlass_w4a8_moe(
|
295
|
+
layer.start_expert_id,
|
296
|
+
layer.end_expert_id,
|
297
|
+
layer.num_experts,
|
298
|
+
hidden_states,
|
299
|
+
layer.w13_weight,
|
300
|
+
layer.w2_weight,
|
301
|
+
layer.w13_weight_scale_inv,
|
302
|
+
layer.w2_weight_scale_inv,
|
303
|
+
topk_weights,
|
304
|
+
topk_ids,
|
305
|
+
local_topk_ids,
|
306
|
+
self.a_strides1,
|
307
|
+
self.b_strides1,
|
308
|
+
self.c_strides1,
|
309
|
+
self.a_strides2,
|
310
|
+
self.b_strides2,
|
311
|
+
self.c_strides2,
|
312
|
+
self.s_strides13,
|
313
|
+
self.s_strides2,
|
314
|
+
self.expert_offsets,
|
315
|
+
self.problem_sizes1,
|
316
|
+
self.problem_sizes2,
|
317
|
+
layer.w13_input_scale,
|
318
|
+
layer.w2_input_scale,
|
319
|
+
)
|
sglang/srt/lora/lora_registry.py
CHANGED
@@ -14,12 +14,16 @@
|
|
14
14
|
|
15
15
|
|
16
16
|
import asyncio
|
17
|
+
from collections import defaultdict
|
17
18
|
from dataclasses import dataclass, field, fields
|
18
19
|
from typing import Dict, List, Optional, Union
|
19
20
|
from uuid import uuid4
|
20
21
|
|
22
|
+
from sglang.srt.aio_rwlock import RWLock
|
23
|
+
from sglang.srt.utils import ConcurrentCounter
|
21
24
|
|
22
|
-
|
25
|
+
|
26
|
+
@dataclass(frozen=True)
|
23
27
|
class LoRARef:
|
24
28
|
"""
|
25
29
|
Reference record for a LoRA model.
|
@@ -48,10 +52,11 @@ class LoRARef:
|
|
48
52
|
|
49
53
|
class LoRARegistry:
|
50
54
|
"""
|
51
|
-
The central registry to keep track of available LoRA adapters.
|
55
|
+
The central registry to keep track of available LoRA adapters and ongoing LoRA requests.
|
52
56
|
|
53
|
-
|
54
|
-
|
57
|
+
The `LoRARegistry` resides in the tokenizer manager process and acts as the single source of truth for all
|
58
|
+
available LoRA adapters. It supports concurrent inference and dynamic adapter updates through a two-phase
|
59
|
+
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
|
55
60
|
"""
|
56
61
|
|
57
62
|
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
@@ -62,8 +67,19 @@ class LoRARegistry:
|
|
62
67
|
"Please file an issue if you see this error."
|
63
68
|
)
|
64
69
|
|
70
|
+
# A read-write lock to ensure adapters loading / unloading operations are exclusive.
|
71
|
+
# Please note that the counter increment/decrement operations are not synchronized through this
|
72
|
+
# lock, as they are designed to be non-blocking and can be performed concurrently.
|
73
|
+
self._registry_lock = RWLock()
|
65
74
|
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
|
66
|
-
self._registry: Dict[str, LoRARef] =
|
75
|
+
self._registry: Dict[str, LoRARef] = {}
|
76
|
+
# Counters for ongoing requests, mapping from LoRA ID to ConcurrentCounter.
|
77
|
+
self._counters: Dict[str, ConcurrentCounter] = {}
|
78
|
+
|
79
|
+
# Initialize the registry with provided LoRA paths, if present.
|
80
|
+
if lora_paths:
|
81
|
+
for lora_ref in lora_paths.values():
|
82
|
+
self._register_adapter(lora_ref)
|
67
83
|
|
68
84
|
async def register(self, lora_ref: LoRARef):
|
69
85
|
"""
|
@@ -72,11 +88,8 @@ class LoRARegistry:
|
|
72
88
|
Args:
|
73
89
|
lora_ref (LoRARef): The LoRARef object to register.
|
74
90
|
"""
|
75
|
-
|
76
|
-
|
77
|
-
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
|
78
|
-
)
|
79
|
-
self._registry[lora_ref.lora_name] = lora_ref
|
91
|
+
async with self._registry_lock.writer_lock:
|
92
|
+
self._register_adapter(lora_ref)
|
80
93
|
|
81
94
|
async def unregister(self, lora_name: str) -> str:
|
82
95
|
"""
|
@@ -85,12 +98,14 @@ class LoRARegistry:
|
|
85
98
|
Args:
|
86
99
|
lora_name (str): The name of the LoRA model to unregister.
|
87
100
|
"""
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
101
|
+
async with self._registry_lock.writer_lock:
|
102
|
+
lora_ref = self._registry.get(lora_name, None)
|
103
|
+
if lora_ref is None:
|
104
|
+
raise ValueError(
|
105
|
+
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
106
|
+
)
|
107
|
+
del self._registry[lora_name]
|
108
|
+
del self._counters[lora_ref.lora_id]
|
94
109
|
|
95
110
|
return lora_ref.lora_id
|
96
111
|
|
@@ -98,27 +113,76 @@ class LoRARegistry:
|
|
98
113
|
"""
|
99
114
|
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
|
100
115
|
by incrementing its counter.
|
101
|
-
|
102
|
-
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
|
103
116
|
"""
|
104
117
|
|
105
|
-
|
118
|
+
def _lookup(name: str) -> str:
|
106
119
|
lora_ref = self._registry.get(name, None)
|
107
120
|
if lora_ref is None:
|
108
121
|
raise ValueError(
|
109
122
|
f"The following requested LoRA adapters are not loaded: {name}\n"
|
110
123
|
f"Loaded adapters: {self._registry.keys()}."
|
111
124
|
)
|
112
|
-
# await self._counters[lora_ref.lora_id].increment()
|
113
125
|
return lora_ref.lora_id
|
114
126
|
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
127
|
+
async with self._registry_lock.reader_lock:
|
128
|
+
if isinstance(lora_name, str):
|
129
|
+
lora_id = _lookup(lora_name)
|
130
|
+
await self._counters[lora_id].increment(notify_all=False)
|
131
|
+
return lora_id
|
132
|
+
elif isinstance(lora_name, list):
|
133
|
+
lora_ids = [_lookup(name) for name in lora_name]
|
134
|
+
|
135
|
+
# Increment the counters only after all IDs are looked up.
|
136
|
+
await asyncio.gather(
|
137
|
+
*[self._counters[id].increment(notify_all=False) for id in lora_ids]
|
138
|
+
)
|
139
|
+
return lora_ids
|
140
|
+
else:
|
141
|
+
raise TypeError(
|
142
|
+
"lora_name must be either a string or a list of strings."
|
143
|
+
)
|
144
|
+
|
145
|
+
async def release(self, lora_id: Union[str, List[str]]):
|
146
|
+
"""
|
147
|
+
Decrements the usage counter for a LoRA adapter, indicating that it is no longer in use.
|
148
|
+
"""
|
149
|
+
|
150
|
+
async with self._registry_lock.reader_lock:
|
151
|
+
if isinstance(lora_id, str):
|
152
|
+
await self._counters[lora_id].decrement()
|
153
|
+
elif isinstance(lora_id, list):
|
154
|
+
await asyncio.gather(
|
155
|
+
*[self._counters[id].decrement() for id in lora_id]
|
156
|
+
)
|
157
|
+
else:
|
158
|
+
raise TypeError("lora_id must be either a string or a list of strings.")
|
159
|
+
|
160
|
+
async def wait_for_unload(self, lora_id: str):
|
161
|
+
"""
|
162
|
+
Waits until the usage counter for a LoRA adapter reaches zero, indicating that it is no longer in use.
|
163
|
+
This is useful for ensuring that a LoRA adapter can be safely unloaded.
|
164
|
+
|
165
|
+
This method itself is not synchronized, which is safe because it should only be called during LoRA unloading,
|
166
|
+
which itself is guaranteed to be sequential.
|
167
|
+
"""
|
168
|
+
assert (
|
169
|
+
lora_id not in self._registry
|
170
|
+
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
|
171
|
+
counter = self._counters.get(lora_id)
|
172
|
+
if counter:
|
173
|
+
# Wait until no requests are using this LoRA adapter.
|
174
|
+
await counter.wait_for_zero()
|
175
|
+
del self._counters[lora_id]
|
176
|
+
|
177
|
+
def _register_adapter(self, lora_ref: LoRARef):
|
178
|
+
"""
|
179
|
+
Internal helper method to register a LoRA adapter.
|
180
|
+
"""
|
181
|
+
|
182
|
+
if lora_ref.lora_name in self._registry:
|
183
|
+
raise ValueError(
|
184
|
+
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
|
121
185
|
)
|
122
|
-
|
123
|
-
|
124
|
-
|
186
|
+
self._registry[lora_ref.lora_name] = lora_ref
|
187
|
+
self._counters[lora_ref.lora_id] = ConcurrentCounter()
|
188
|
+
return lora_ref
|
@@ -201,8 +201,9 @@ class PrefetchOperation(StorageOperation):
|
|
201
201
|
def increment(self, num_tokens: int):
|
202
202
|
with self._lock:
|
203
203
|
if self._done_flag:
|
204
|
-
return
|
204
|
+
return False
|
205
205
|
self.completed_tokens += num_tokens
|
206
|
+
return True
|
206
207
|
|
207
208
|
def mark_done(self):
|
208
209
|
with self._lock:
|
@@ -528,12 +529,12 @@ class HiCacheController:
|
|
528
529
|
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
529
530
|
)
|
530
531
|
break
|
531
|
-
self.
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
532
|
+
if operation.increment(self.page_size):
|
533
|
+
self.mem_pool_host.set_from_flat_data_page(
|
534
|
+
operation.host_indices[operation.completed_tokens],
|
535
|
+
page_data,
|
536
|
+
)
|
537
|
+
else:
|
537
538
|
# operation terminated by controller, release pre-allocated memory
|
538
539
|
self.mem_pool_host.free(
|
539
540
|
operation.host_indices[operation.completed_tokens :]
|
@@ -589,6 +590,7 @@ class HiCacheController:
|
|
589
590
|
if storage_hit_count < self.prefetch_threshold:
|
590
591
|
# not to prefetch if not enough benefits
|
591
592
|
self.prefetch_revoke_queue.put(operation.request_id)
|
593
|
+
self.mem_pool_host.free(operation.host_indices)
|
592
594
|
logger.debug(
|
593
595
|
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
594
596
|
)
|
@@ -26,6 +26,7 @@ import zmq
|
|
26
26
|
|
27
27
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
28
28
|
from sglang.srt.managers.io_struct import (
|
29
|
+
BlockReqInput,
|
29
30
|
TokenizedEmbeddingReqInput,
|
30
31
|
TokenizedGenerateReqInput,
|
31
32
|
)
|
@@ -282,6 +283,9 @@ class DataParallelController:
|
|
282
283
|
),
|
283
284
|
):
|
284
285
|
self.dispatching(recv_req)
|
286
|
+
elif isinstance(recv_req, BlockReqInput):
|
287
|
+
for worker in self.workers:
|
288
|
+
worker.send_pyobj(recv_req)
|
285
289
|
else:
|
286
290
|
# Send other control messages to first worker of tp group
|
287
291
|
for worker in self.workers[:: self.control_message_step]:
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -911,6 +911,8 @@ class AbortReq:
|
|
911
911
|
rid: str = ""
|
912
912
|
# Whether to abort all requests
|
913
913
|
abort_all: bool = False
|
914
|
+
# The finished reason data
|
915
|
+
finished_reason: Optional[Dict[str, Any]] = None
|
914
916
|
|
915
917
|
|
916
918
|
@dataclass
|
@@ -1101,3 +1103,13 @@ class LoRAUpdateResult:
|
|
1101
1103
|
|
1102
1104
|
|
1103
1105
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
1106
|
+
|
1107
|
+
|
1108
|
+
class BlockReqType(Enum):
|
1109
|
+
BLOCK = 1
|
1110
|
+
UNBLOCK = 2
|
1111
|
+
|
1112
|
+
|
1113
|
+
@dataclass
|
1114
|
+
class BlockReqInput:
|
1115
|
+
type: BlockReqType
|