sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import importlib
|
4
|
-
from typing import TYPE_CHECKING,
|
3
|
+
import importlib.util
|
4
|
+
from typing import TYPE_CHECKING, List, Optional
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as F
|
@@ -24,7 +24,7 @@ from sglang.srt.utils import (
|
|
24
24
|
)
|
25
25
|
|
26
26
|
if TYPE_CHECKING:
|
27
|
-
from sglang.srt.layers.moe.
|
27
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
28
28
|
from sglang.srt.layers.moe.topk import TopKOutput
|
29
29
|
|
30
30
|
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
@@ -116,9 +116,15 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|
116
116
|
) -> torch.Tensor:
|
117
117
|
|
118
118
|
if use_intel_amx_backend(layer):
|
119
|
-
|
119
|
+
x_shapes = x.shape
|
120
|
+
if len(x_shapes) == 3:
|
121
|
+
x = x.view(-1, x.shape[-1])
|
122
|
+
output = torch.ops.sgl_kernel.weight_packed_linear(
|
120
123
|
x, layer.weight, bias, True # is_vnni
|
121
124
|
)
|
125
|
+
if len(x_shapes) == 3:
|
126
|
+
output = output.view(x_shapes[0], x_shapes[1], -1)
|
127
|
+
return output
|
122
128
|
|
123
129
|
return F.linear(x, layer.weight, bias)
|
124
130
|
|
@@ -221,31 +227,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
221
227
|
layer: torch.nn.Module,
|
222
228
|
x: torch.Tensor,
|
223
229
|
topk_output: TopKOutput,
|
224
|
-
|
225
|
-
activation: str = "silu",
|
226
|
-
apply_router_weight_on_input: bool = False,
|
227
|
-
inplace: bool = True,
|
228
|
-
no_combine: bool = False,
|
229
|
-
routed_scaling_factor: Optional[float] = None,
|
230
|
-
activation_alpha: Optional[float] = None,
|
231
|
-
swiglu_limit: Optional[float] = None,
|
230
|
+
moe_runner_config: MoeRunnerConfig,
|
232
231
|
) -> torch.Tensor:
|
233
|
-
kwargs = {}
|
234
|
-
if activation_alpha is not None:
|
235
|
-
kwargs["activation_alpha"] = activation_alpha
|
236
|
-
if swiglu_limit is not None:
|
237
|
-
kwargs["swiglu_limit"] = swiglu_limit
|
238
232
|
|
239
233
|
return self.forward(
|
240
234
|
x=x,
|
241
235
|
layer=layer,
|
242
236
|
topk_output=topk_output,
|
243
|
-
|
244
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
245
|
-
inplace=inplace,
|
246
|
-
no_combine=no_combine,
|
247
|
-
routed_scaling_factor=routed_scaling_factor,
|
248
|
-
**kwargs,
|
237
|
+
moe_runner_config=moe_runner_config,
|
249
238
|
)
|
250
239
|
|
251
240
|
def forward_cuda(
|
@@ -253,18 +242,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
253
242
|
layer: torch.nn.Module,
|
254
243
|
x: torch.Tensor,
|
255
244
|
topk_output: TopKOutput,
|
256
|
-
|
257
|
-
activation: str = "silu",
|
258
|
-
apply_router_weight_on_input: bool = False,
|
259
|
-
inplace: bool = True,
|
260
|
-
no_combine: bool = False,
|
261
|
-
routed_scaling_factor: Optional[float] = None,
|
262
|
-
activation_alpha: Optional[float] = None,
|
263
|
-
swiglu_limit: Optional[float] = None,
|
245
|
+
moe_runner_config: MoeRunnerConfig,
|
264
246
|
) -> torch.Tensor:
|
265
247
|
|
266
248
|
if self.use_triton_kernels:
|
267
249
|
if self.with_bias:
|
250
|
+
assert self.triton_kernel_moe_with_bias_forward is not None
|
268
251
|
return self.triton_kernel_moe_with_bias_forward(
|
269
252
|
hidden_states=x,
|
270
253
|
w1=layer.w13_weight,
|
@@ -272,24 +255,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
272
255
|
b1=layer.w13_weight_bias,
|
273
256
|
b2=layer.w2_weight_bias,
|
274
257
|
topk_output=topk_output,
|
275
|
-
|
276
|
-
activation_alpha=activation_alpha,
|
277
|
-
swiglu_limit=swiglu_limit,
|
258
|
+
moe_runner_config=moe_runner_config,
|
278
259
|
w1_pcg=None,
|
279
260
|
w2_pcg=None,
|
280
261
|
)
|
281
262
|
else:
|
263
|
+
assert self.triton_kernel_moe_forward is not None
|
282
264
|
return self.triton_kernel_moe_forward(
|
283
265
|
hidden_states=x,
|
284
266
|
w1=layer.w13_weight,
|
285
267
|
w2=layer.w2_weight,
|
286
268
|
topk_output=topk_output,
|
269
|
+
moe_runner_config=moe_runner_config,
|
287
270
|
)
|
288
271
|
else:
|
289
272
|
if _use_aiter:
|
290
|
-
assert not no_combine, "unsupported"
|
273
|
+
assert not moe_runner_config.no_combine, "unsupported"
|
291
274
|
topk_weights, topk_ids, _ = topk_output
|
292
|
-
if apply_router_weight_on_input:
|
275
|
+
if moe_runner_config.apply_router_weight_on_input:
|
293
276
|
assert (
|
294
277
|
topk_weights.dim() == 2
|
295
278
|
), "`topk_weights` should be in shape (num_tokens, topk)"
|
@@ -309,7 +292,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
309
292
|
topk_ids,
|
310
293
|
activation=(
|
311
294
|
ActivationType.Silu
|
312
|
-
if activation == "silu"
|
295
|
+
if moe_runner_config.activation == "silu"
|
313
296
|
else ActivationType.Gelu
|
314
297
|
),
|
315
298
|
)
|
@@ -325,13 +308,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
325
308
|
b1=getattr(layer, "w13_weight_bias", None),
|
326
309
|
b2=getattr(layer, "w2_weight_bias", None),
|
327
310
|
topk_output=topk_output,
|
328
|
-
|
329
|
-
activation=activation,
|
330
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
331
|
-
no_combine=no_combine,
|
332
|
-
routed_scaling_factor=routed_scaling_factor,
|
333
|
-
activation_alpha=activation_alpha,
|
334
|
-
swiglu_limit=swiglu_limit,
|
311
|
+
moe_runner_config=moe_runner_config,
|
335
312
|
)
|
336
313
|
|
337
314
|
def forward_cpu(
|
@@ -339,21 +316,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
339
316
|
layer: torch.nn.Module,
|
340
317
|
x: torch.Tensor,
|
341
318
|
topk_output: TopKOutput,
|
342
|
-
|
343
|
-
activation: str = "silu",
|
344
|
-
apply_router_weight_on_input: bool = False,
|
345
|
-
inplace: bool = True,
|
346
|
-
no_combine: bool = False,
|
347
|
-
routed_scaling_factor: Optional[float] = None,
|
319
|
+
moe_runner_config: MoeRunnerConfig,
|
348
320
|
) -> torch.Tensor:
|
349
|
-
assert
|
350
|
-
|
351
|
-
|
321
|
+
assert (
|
322
|
+
moe_runner_config.activation == "silu"
|
323
|
+
), f"activation = {moe_runner_config.activation} is not supported."
|
324
|
+
|
325
|
+
if (
|
326
|
+
use_intel_amx_backend(layer)
|
327
|
+
and not moe_runner_config.apply_router_weight_on_input
|
328
|
+
):
|
352
329
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
353
330
|
|
354
331
|
topk_weights, topk_ids, _ = topk_output
|
355
332
|
x, topk_weights = apply_topk_weights_cpu(
|
356
|
-
apply_router_weight_on_input, topk_weights, x
|
333
|
+
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
357
334
|
)
|
358
335
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
359
336
|
x,
|
@@ -378,11 +355,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
378
355
|
layer,
|
379
356
|
x,
|
380
357
|
topk_output,
|
381
|
-
|
382
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
383
|
-
inplace=inplace,
|
384
|
-
no_combine=no_combine,
|
385
|
-
routed_scaling_factor=routed_scaling_factor,
|
358
|
+
moe_runner_config,
|
386
359
|
)
|
387
360
|
|
388
361
|
def forward_npu(
|
@@ -390,12 +363,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
390
363
|
layer: torch.nn.Module,
|
391
364
|
x: torch.Tensor,
|
392
365
|
topk_output: TopKOutput,
|
393
|
-
|
394
|
-
activation: str = "silu",
|
395
|
-
apply_router_weight_on_input: bool = False,
|
396
|
-
inplace: bool = True,
|
397
|
-
no_combine: bool = False,
|
398
|
-
routed_scaling_factor: Optional[float] = None,
|
366
|
+
moe_runner_config: MoeRunnerConfig,
|
399
367
|
) -> torch.Tensor:
|
400
368
|
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
401
369
|
|
@@ -403,11 +371,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
403
371
|
layer,
|
404
372
|
x,
|
405
373
|
topk_output,
|
406
|
-
|
407
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
408
|
-
inplace=inplace,
|
409
|
-
no_combine=no_combine,
|
410
|
-
routed_scaling_factor=routed_scaling_factor,
|
374
|
+
moe_runner_config,
|
411
375
|
)
|
412
376
|
|
413
377
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
@@ -11,13 +11,39 @@ import numpy
|
|
11
11
|
import torch
|
12
12
|
|
13
13
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
14
|
-
from sglang.srt.
|
15
|
-
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
14
|
+
from sglang.srt.utils import is_cuda
|
16
15
|
|
17
16
|
if TYPE_CHECKING:
|
18
17
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
19
18
|
|
20
19
|
|
20
|
+
def get_scalar_types():
|
21
|
+
"""
|
22
|
+
Returns:
|
23
|
+
tuple: (ScalarType, scalar_types)
|
24
|
+
"""
|
25
|
+
try:
|
26
|
+
from sgl_kernel.scalar_type import ScalarType, scalar_types
|
27
|
+
|
28
|
+
return ScalarType, scalar_types
|
29
|
+
except ImportError:
|
30
|
+
|
31
|
+
class MockScalarType:
|
32
|
+
pass
|
33
|
+
|
34
|
+
class MockScalarTypes:
|
35
|
+
uint4b8 = "uint4b8"
|
36
|
+
uint8b128 = "uint8b128"
|
37
|
+
|
38
|
+
def __getattr__(self, name):
|
39
|
+
return f"mock_{name}"
|
40
|
+
|
41
|
+
return MockScalarType, MockScalarTypes()
|
42
|
+
|
43
|
+
|
44
|
+
ScalarType, scalar_types = get_scalar_types()
|
45
|
+
|
46
|
+
|
21
47
|
def is_layer_skipped(
|
22
48
|
prefix: str,
|
23
49
|
ignored_layers: List[str],
|
@@ -120,6 +146,10 @@ def requantize_with_max_scale(
|
|
120
146
|
return max_w_scale, weight
|
121
147
|
|
122
148
|
|
149
|
+
def update_tensor_inplace(old: torch.Tensor, new: torch.Tensor) -> None:
|
150
|
+
old.copy_(new)
|
151
|
+
|
152
|
+
|
123
153
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/layer_utils.py
|
124
154
|
# Newly generated tensors need to replace existing tensors that are
|
125
155
|
# already registered as parameters by vLLM (and won't be freed)
|
@@ -146,6 +176,27 @@ def replace_parameter(
|
|
146
176
|
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
147
177
|
|
148
178
|
|
179
|
+
def assert_fp8_all_close(a: torch.Tensor, b: torch.Tensor):
|
180
|
+
assert a.shape == b.shape
|
181
|
+
assert a.dtype == b.dtype == torch.float8_e4m3fn
|
182
|
+
|
183
|
+
a_u8 = a.view(torch.uint8)
|
184
|
+
b_u8 = b.view(torch.uint8)
|
185
|
+
diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs()
|
186
|
+
|
187
|
+
numel = a.numel()
|
188
|
+
|
189
|
+
count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item()
|
190
|
+
count_tiny_diff = (diff_u8 >= 1).sum().item()
|
191
|
+
count_large_diff = (diff_u8 >= 2).sum().item()
|
192
|
+
|
193
|
+
assert (
|
194
|
+
(count_diff_sign == 0)
|
195
|
+
and (count_tiny_diff / numel < 0.005)
|
196
|
+
and (count_large_diff == 0)
|
197
|
+
), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=}"
|
198
|
+
|
199
|
+
|
149
200
|
# Match dynamic rules with module name (prefix) and override quantize
|
150
201
|
# config if module (prefix) matches a rule
|
151
202
|
def override_config(config: QuantizationConfig, prefix: str):
|
@@ -295,6 +346,30 @@ def pack_cols(
|
|
295
346
|
return q_res
|
296
347
|
|
297
348
|
|
349
|
+
def pack_rows(
|
350
|
+
q_w: torch.Tensor,
|
351
|
+
num_bits: int,
|
352
|
+
size_k: int,
|
353
|
+
size_n: int,
|
354
|
+
):
|
355
|
+
assert q_w.shape == (size_k, size_n)
|
356
|
+
|
357
|
+
pack_factor = get_pack_factor(num_bits)
|
358
|
+
assert size_k % pack_factor == 0
|
359
|
+
|
360
|
+
orig_device = q_w.device
|
361
|
+
|
362
|
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
363
|
+
|
364
|
+
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
|
365
|
+
|
366
|
+
for i in range(pack_factor):
|
367
|
+
q_res |= q_w[i::pack_factor, :] << num_bits * i
|
368
|
+
|
369
|
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
370
|
+
return q_res
|
371
|
+
|
372
|
+
|
298
373
|
def unpack_cols(
|
299
374
|
packed_q_w: torch.Tensor,
|
300
375
|
num_bits: int,
|
@@ -18,7 +18,9 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
|
|
18
18
|
from sglang.srt.utils import set_weight_attrs
|
19
19
|
|
20
20
|
if TYPE_CHECKING:
|
21
|
-
from sglang.srt.layers.moe
|
21
|
+
from sglang.srt.layers.moe import MoeRunnerConfig
|
22
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
23
|
+
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
22
24
|
|
23
25
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
24
26
|
|
@@ -280,11 +282,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
280
282
|
self,
|
281
283
|
layer: EPMoE,
|
282
284
|
x: torch.Tensor,
|
283
|
-
topk_output:
|
284
|
-
|
285
|
-
apply_router_weight_on_input: bool = False,
|
286
|
-
routed_scaling_factor: Optional[float] = None,
|
287
|
-
**kwargs,
|
285
|
+
topk_output: StandardTopKOutput,
|
286
|
+
moe_runner_config: MoeRunnerConfig,
|
288
287
|
) -> torch.Tensor:
|
289
288
|
|
290
289
|
# TODO(ch-wan): move it out of this class
|
@@ -324,6 +323,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
324
323
|
layer.w13_input_scale,
|
325
324
|
layer.w2_input_scale,
|
326
325
|
)
|
327
|
-
if routed_scaling_factor is not None:
|
328
|
-
output *= routed_scaling_factor
|
326
|
+
if moe_runner_config.routed_scaling_factor is not None:
|
327
|
+
output *= moe_runner_config.routed_scaling_factor
|
329
328
|
return output
|
@@ -26,7 +26,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
26
26
|
from sglang.srt.utils import set_weight_attrs
|
27
27
|
|
28
28
|
if TYPE_CHECKING:
|
29
|
-
from sglang.srt.layers.moe.
|
29
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
30
|
+
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
30
31
|
|
31
32
|
_is_fp8_fnuz = is_fp8_fnuz()
|
32
33
|
|
@@ -269,13 +270,8 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
|
269
270
|
self,
|
270
271
|
layer: torch.nn.Module,
|
271
272
|
x: torch.Tensor,
|
272
|
-
topk_output:
|
273
|
-
|
274
|
-
activation: str = "silu",
|
275
|
-
apply_router_weight_on_input: bool = False,
|
276
|
-
inplace: bool = True,
|
277
|
-
no_combine: bool = False,
|
278
|
-
routed_scaling_factor: Optional[float] = None,
|
273
|
+
topk_output: StandardTopKOutput,
|
274
|
+
moe_runner_config: MoeRunnerConfig,
|
279
275
|
) -> torch.Tensor:
|
280
276
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
281
277
|
|
@@ -284,15 +280,11 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
|
284
280
|
layer.w13_weight,
|
285
281
|
layer.w2_weight,
|
286
282
|
topk_output=topk_output,
|
287
|
-
|
288
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
289
|
-
activation=activation,
|
283
|
+
moe_runner_config=moe_runner_config,
|
290
284
|
use_fp8_w8a8=True,
|
291
285
|
per_channel_quant=True,
|
292
286
|
w1_scale=(layer.w13_weight_scale),
|
293
287
|
w2_scale=(layer.w2_weight_scale),
|
294
288
|
a1_scale=layer.w13_input_scale,
|
295
289
|
a2_scale=layer.w2_input_scale,
|
296
|
-
no_combine=no_combine,
|
297
|
-
routed_scaling_factor=routed_scaling_factor,
|
298
290
|
)
|
@@ -49,6 +49,7 @@ from sglang.srt.utils import (
|
|
49
49
|
)
|
50
50
|
|
51
51
|
if TYPE_CHECKING:
|
52
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
52
53
|
from sglang.srt.layers.moe.topk import TopKOutput
|
53
54
|
|
54
55
|
_is_cuda = is_cuda()
|
@@ -487,12 +488,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
487
488
|
layer: torch.nn.Module,
|
488
489
|
x: torch.Tensor,
|
489
490
|
topk_output: TopKOutput,
|
490
|
-
|
491
|
-
activation: str = "silu",
|
492
|
-
apply_router_weight_on_input: bool = False,
|
493
|
-
inplace: bool = True,
|
494
|
-
no_combine: bool = False,
|
495
|
-
routed_scaling_factor: Optional[float] = None,
|
491
|
+
moe_runner_config: MoeRunnerConfig,
|
496
492
|
) -> torch.Tensor:
|
497
493
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
498
494
|
|
@@ -501,7 +497,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
501
497
|
|
502
498
|
topk_weights, topk_ids, _ = topk_output
|
503
499
|
x, topk_weights = apply_topk_weights_cpu(
|
504
|
-
apply_router_weight_on_input, topk_weights, x
|
500
|
+
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
505
501
|
)
|
506
502
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
507
503
|
x,
|
@@ -525,17 +521,13 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
525
521
|
layer.w13_weight,
|
526
522
|
layer.w2_weight,
|
527
523
|
topk_output=topk_output,
|
528
|
-
|
529
|
-
activation=activation,
|
530
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
524
|
+
moe_runner_config=moe_runner_config,
|
531
525
|
use_int8_w8a8=True,
|
532
526
|
per_channel_quant=True,
|
533
527
|
w1_scale=(layer.w13_weight_scale),
|
534
528
|
w2_scale=(layer.w2_weight_scale),
|
535
529
|
a1_scale=layer.w13_input_scale,
|
536
530
|
a2_scale=layer.w2_input_scale,
|
537
|
-
no_combine=no_combine,
|
538
|
-
routed_scaling_factor=routed_scaling_factor,
|
539
531
|
)
|
540
532
|
|
541
533
|
|
@@ -982,7 +974,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
982
974
|
layer,
|
983
975
|
x,
|
984
976
|
topk_output: TopKOutput,
|
985
|
-
|
977
|
+
moe_runner_config: MoeRunnerConfig,
|
986
978
|
) -> torch.Tensor:
|
987
979
|
|
988
980
|
topk_weights, topk_ids, _ = topk_output
|
@@ -52,6 +52,8 @@ class RadixAttention(nn.Module):
|
|
52
52
|
v_head_dim: int = -1,
|
53
53
|
sliding_window_size: int = -1,
|
54
54
|
is_cross_attention: bool = False,
|
55
|
+
pos_encoding_mode: str = "NONE",
|
56
|
+
logit_capping_method: str = "tanh",
|
55
57
|
quant_config: Optional[QuantizationConfig] = None,
|
56
58
|
attn_type: AttentionType = AttentionType.DECODER,
|
57
59
|
use_irope: bool = False,
|
@@ -81,6 +83,10 @@ class RadixAttention(nn.Module):
|
|
81
83
|
self.quant_method.create_weights(self)
|
82
84
|
self.attn_type = attn_type
|
83
85
|
|
86
|
+
self.pos_encoding_mode = pos_encoding_mode
|
87
|
+
self.logit_capping_method = logit_capping_method
|
88
|
+
self.xai_temperature_len = -1
|
89
|
+
|
84
90
|
def forward(
|
85
91
|
self,
|
86
92
|
q,
|
sglang/srt/layers/sampler.py
CHANGED
@@ -6,7 +6,10 @@ import torch.distributed as dist
|
|
6
6
|
from torch import nn
|
7
7
|
|
8
8
|
from sglang.srt.distributed import get_tp_group
|
9
|
-
from sglang.srt.layers.dp_attention import
|
9
|
+
from sglang.srt.layers.dp_attention import (
|
10
|
+
get_attention_tp_group,
|
11
|
+
is_dp_attention_enabled,
|
12
|
+
)
|
10
13
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
11
14
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
12
15
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
@@ -32,7 +35,7 @@ class Sampler(nn.Module):
|
|
32
35
|
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
|
33
36
|
self.tp_sync_group = get_tp_group().device_group
|
34
37
|
|
35
|
-
if
|
38
|
+
if is_dp_attention_enabled():
|
36
39
|
self.tp_sync_group = get_attention_tp_group().device_group
|
37
40
|
|
38
41
|
def forward(
|
sglang/srt/lora/layers.py
CHANGED
@@ -253,7 +253,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
253
253
|
)
|
254
254
|
return lora_output
|
255
255
|
|
256
|
-
def forward(self, input_: torch.Tensor):
|
256
|
+
def forward(self, input_: torch.Tensor, skip_all_reduce=False):
|
257
257
|
# duplicate the logic in RowParallelLinear
|
258
258
|
if self.base_layer.input_is_parallel:
|
259
259
|
input_parallel = input_
|
@@ -270,7 +270,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
270
270
|
if self.set_lora:
|
271
271
|
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
272
272
|
|
273
|
-
if
|
273
|
+
if (
|
274
|
+
self.base_layer.reduce_results
|
275
|
+
and self.base_layer.tp_size > 1
|
276
|
+
and not skip_all_reduce
|
277
|
+
):
|
274
278
|
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
275
279
|
else:
|
276
280
|
output_ = output_parallel
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -32,8 +32,8 @@ from sglang.srt.lora.utils import (
|
|
32
32
|
LoRABatchInfo,
|
33
33
|
LoRAType,
|
34
34
|
get_layer_id,
|
35
|
-
|
36
|
-
|
35
|
+
get_normalized_target_modules,
|
36
|
+
get_target_module_name,
|
37
37
|
)
|
38
38
|
from sglang.srt.managers.io_struct import LoRAUpdateResult
|
39
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -55,7 +55,7 @@ class LoRAManager:
|
|
55
55
|
tp_rank: int = 0,
|
56
56
|
max_lora_rank: Optional[int] = None,
|
57
57
|
target_modules: Optional[Iterable[str]] = None,
|
58
|
-
lora_paths: Optional[
|
58
|
+
lora_paths: Optional[List[LoRARef]] = None,
|
59
59
|
):
|
60
60
|
self.base_model: torch.nn.Module = base_model
|
61
61
|
self.base_hf_config: AutoConfig = base_hf_config
|
@@ -350,19 +350,27 @@ class LoRAManager:
|
|
350
350
|
"""
|
351
351
|
for layer_id, layer_modules in enumerate(self.lora_modules):
|
352
352
|
for module_name, module in layer_modules.items():
|
353
|
-
|
354
|
-
module_name, self.memory_pool.
|
353
|
+
target_module = get_target_module_name(
|
354
|
+
module_name, self.memory_pool.target_modules
|
355
355
|
)
|
356
356
|
module.set_lora_info(
|
357
|
-
self.memory_pool.get_tensor(
|
358
|
-
|
357
|
+
self.memory_pool.get_tensor(
|
358
|
+
target_module=target_module,
|
359
|
+
layer_id=layer_id,
|
360
|
+
lora_type=LoRAType.LORA_A,
|
361
|
+
),
|
362
|
+
self.memory_pool.get_tensor(
|
363
|
+
target_module=target_module,
|
364
|
+
layer_id=layer_id,
|
365
|
+
lora_type=LoRAType.LORA_B,
|
366
|
+
),
|
359
367
|
)
|
360
368
|
|
361
369
|
def init_state(
|
362
370
|
self,
|
363
371
|
max_lora_rank: Optional[int] = None,
|
364
372
|
target_modules: Optional[Iterable[str]] = None,
|
365
|
-
lora_paths: Optional[
|
373
|
+
lora_paths: Optional[List[LoRARef]] = None,
|
366
374
|
):
|
367
375
|
"""
|
368
376
|
Initialize the internal (mutable) state of the LoRAManager.
|
@@ -380,12 +388,11 @@ class LoRAManager:
|
|
380
388
|
max_lora_rank=max_lora_rank,
|
381
389
|
target_modules=target_modules,
|
382
390
|
)
|
383
|
-
self.init_lora_weight_names()
|
384
391
|
self.init_lora_modules()
|
385
392
|
self.init_memory_pool()
|
386
393
|
self.update_lora_info()
|
387
394
|
|
388
|
-
def init_lora_adapters(self, lora_paths: Optional[
|
395
|
+
def init_lora_adapters(self, lora_paths: Optional[List[LoRARef]] = None):
|
389
396
|
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
390
397
|
self.configs: Dict[str, LoRAConfig] = {}
|
391
398
|
|
@@ -399,7 +406,7 @@ class LoRAManager:
|
|
399
406
|
self.num_pinned_loras: int = 0
|
400
407
|
|
401
408
|
if lora_paths:
|
402
|
-
for lora_ref in lora_paths
|
409
|
+
for lora_ref in lora_paths:
|
403
410
|
result = self.load_lora_adapter(lora_ref)
|
404
411
|
if not result.success:
|
405
412
|
raise RuntimeError(
|
@@ -426,6 +433,7 @@ class LoRAManager:
|
|
426
433
|
"enable all support modules types. "
|
427
434
|
)
|
428
435
|
self.target_modules.update(config.target_modules)
|
436
|
+
self.target_modules = get_normalized_target_modules(self.target_modules)
|
429
437
|
|
430
438
|
if max_lora_rank is not None:
|
431
439
|
self.max_lora_rank = max_lora_rank
|
@@ -435,15 +443,6 @@ class LoRAManager:
|
|
435
443
|
default=0,
|
436
444
|
)
|
437
445
|
|
438
|
-
def init_lora_weight_names(self):
|
439
|
-
"""
|
440
|
-
Add new LoRA weight names if needed based on the current `self.configs`.
|
441
|
-
"""
|
442
|
-
|
443
|
-
self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
|
444
|
-
self.target_modules
|
445
|
-
)
|
446
|
-
|
447
446
|
def load_lora_weights(self, lora_ref: LoRARef):
|
448
447
|
"""
|
449
448
|
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
|
@@ -467,7 +466,7 @@ class LoRAManager:
|
|
467
466
|
tp_size=self.tp_size,
|
468
467
|
tp_rank=self.tp_rank,
|
469
468
|
max_lora_rank=self.max_lora_rank,
|
470
|
-
|
469
|
+
target_modules=self.target_modules,
|
471
470
|
base_model=self.base_model,
|
472
471
|
)
|
473
472
|
|
@@ -494,7 +493,7 @@ class LoRAManager:
|
|
494
493
|
continue
|
495
494
|
|
496
495
|
# The module should be converted if it is included in target_names
|
497
|
-
if module_name.split(".")[-1] in self.
|
496
|
+
if module_name.split(".")[-1] in self.target_modules:
|
498
497
|
layer_id = get_layer_id(module_name)
|
499
498
|
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
500
499
|
module_name, module
|
sglang/srt/lora/lora_registry.py
CHANGED
@@ -59,9 +59,9 @@ class LoRARegistry:
|
|
59
59
|
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
|
60
60
|
"""
|
61
61
|
|
62
|
-
def __init__(self, lora_paths: Optional[
|
62
|
+
def __init__(self, lora_paths: Optional[List[LoRARef]] = None):
|
63
63
|
assert lora_paths is None or all(
|
64
|
-
isinstance(lora, LoRARef) for lora in lora_paths
|
64
|
+
isinstance(lora, LoRARef) for lora in lora_paths
|
65
65
|
), (
|
66
66
|
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
|
67
67
|
"Please file an issue if you see this error."
|
@@ -78,7 +78,7 @@ class LoRARegistry:
|
|
78
78
|
|
79
79
|
# Initialize the registry with provided LoRA paths, if present.
|
80
80
|
if lora_paths:
|
81
|
-
for lora_ref in lora_paths
|
81
|
+
for lora_ref in lora_paths:
|
82
82
|
self._register_adapter(lora_ref)
|
83
83
|
|
84
84
|
async def register(self, lora_ref: LoRARef):
|