sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +375 -51
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import importlib
|
4
|
-
from typing import TYPE_CHECKING,
|
3
|
+
import importlib.util
|
4
|
+
from typing import TYPE_CHECKING, List, Optional
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as F
|
@@ -24,7 +24,7 @@ from sglang.srt.utils import (
|
|
24
24
|
)
|
25
25
|
|
26
26
|
if TYPE_CHECKING:
|
27
|
-
from sglang.srt.layers.moe.
|
27
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
28
28
|
from sglang.srt.layers.moe.topk import TopKOutput
|
29
29
|
|
30
30
|
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
@@ -116,9 +116,15 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|
116
116
|
) -> torch.Tensor:
|
117
117
|
|
118
118
|
if use_intel_amx_backend(layer):
|
119
|
-
|
119
|
+
x_shapes = x.shape
|
120
|
+
if len(x_shapes) == 3:
|
121
|
+
x = x.view(-1, x.shape[-1])
|
122
|
+
output = torch.ops.sgl_kernel.weight_packed_linear(
|
120
123
|
x, layer.weight, bias, True # is_vnni
|
121
124
|
)
|
125
|
+
if len(x_shapes) == 3:
|
126
|
+
output = output.view(x_shapes[0], x_shapes[1], -1)
|
127
|
+
return output
|
122
128
|
|
123
129
|
return F.linear(x, layer.weight, bias)
|
124
130
|
|
@@ -221,31 +227,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
221
227
|
layer: torch.nn.Module,
|
222
228
|
x: torch.Tensor,
|
223
229
|
topk_output: TopKOutput,
|
224
|
-
|
225
|
-
activation: str = "silu",
|
226
|
-
apply_router_weight_on_input: bool = False,
|
227
|
-
inplace: bool = True,
|
228
|
-
no_combine: bool = False,
|
229
|
-
routed_scaling_factor: Optional[float] = None,
|
230
|
-
activation_alpha: Optional[float] = None,
|
231
|
-
swiglu_limit: Optional[float] = None,
|
230
|
+
moe_runner_config: MoeRunnerConfig,
|
232
231
|
) -> torch.Tensor:
|
233
|
-
kwargs = {}
|
234
|
-
if activation_alpha is not None:
|
235
|
-
kwargs["activation_alpha"] = activation_alpha
|
236
|
-
if swiglu_limit is not None:
|
237
|
-
kwargs["swiglu_limit"] = swiglu_limit
|
238
232
|
|
239
233
|
return self.forward(
|
240
234
|
x=x,
|
241
235
|
layer=layer,
|
242
236
|
topk_output=topk_output,
|
243
|
-
|
244
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
245
|
-
inplace=inplace,
|
246
|
-
no_combine=no_combine,
|
247
|
-
routed_scaling_factor=routed_scaling_factor,
|
248
|
-
**kwargs,
|
237
|
+
moe_runner_config=moe_runner_config,
|
249
238
|
)
|
250
239
|
|
251
240
|
def forward_cuda(
|
@@ -253,18 +242,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
253
242
|
layer: torch.nn.Module,
|
254
243
|
x: torch.Tensor,
|
255
244
|
topk_output: TopKOutput,
|
256
|
-
|
257
|
-
activation: str = "silu",
|
258
|
-
apply_router_weight_on_input: bool = False,
|
259
|
-
inplace: bool = True,
|
260
|
-
no_combine: bool = False,
|
261
|
-
routed_scaling_factor: Optional[float] = None,
|
262
|
-
activation_alpha: Optional[float] = None,
|
263
|
-
swiglu_limit: Optional[float] = None,
|
245
|
+
moe_runner_config: MoeRunnerConfig,
|
264
246
|
) -> torch.Tensor:
|
265
247
|
|
266
248
|
if self.use_triton_kernels:
|
267
249
|
if self.with_bias:
|
250
|
+
assert self.triton_kernel_moe_with_bias_forward is not None
|
268
251
|
return self.triton_kernel_moe_with_bias_forward(
|
269
252
|
hidden_states=x,
|
270
253
|
w1=layer.w13_weight,
|
@@ -272,24 +255,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
272
255
|
b1=layer.w13_weight_bias,
|
273
256
|
b2=layer.w2_weight_bias,
|
274
257
|
topk_output=topk_output,
|
275
|
-
|
276
|
-
activation_alpha=activation_alpha,
|
277
|
-
swiglu_limit=swiglu_limit,
|
258
|
+
moe_runner_config=moe_runner_config,
|
278
259
|
w1_pcg=None,
|
279
260
|
w2_pcg=None,
|
280
261
|
)
|
281
262
|
else:
|
263
|
+
assert self.triton_kernel_moe_forward is not None
|
282
264
|
return self.triton_kernel_moe_forward(
|
283
265
|
hidden_states=x,
|
284
266
|
w1=layer.w13_weight,
|
285
267
|
w2=layer.w2_weight,
|
286
268
|
topk_output=topk_output,
|
269
|
+
moe_runner_config=moe_runner_config,
|
287
270
|
)
|
288
271
|
else:
|
289
272
|
if _use_aiter:
|
290
|
-
assert not no_combine, "unsupported"
|
273
|
+
assert not moe_runner_config.no_combine, "unsupported"
|
291
274
|
topk_weights, topk_ids, _ = topk_output
|
292
|
-
if apply_router_weight_on_input:
|
275
|
+
if moe_runner_config.apply_router_weight_on_input:
|
293
276
|
assert (
|
294
277
|
topk_weights.dim() == 2
|
295
278
|
), "`topk_weights` should be in shape (num_tokens, topk)"
|
@@ -309,7 +292,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
309
292
|
topk_ids,
|
310
293
|
activation=(
|
311
294
|
ActivationType.Silu
|
312
|
-
if activation == "silu"
|
295
|
+
if moe_runner_config.activation == "silu"
|
313
296
|
else ActivationType.Gelu
|
314
297
|
),
|
315
298
|
)
|
@@ -325,13 +308,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
325
308
|
b1=getattr(layer, "w13_weight_bias", None),
|
326
309
|
b2=getattr(layer, "w2_weight_bias", None),
|
327
310
|
topk_output=topk_output,
|
328
|
-
|
329
|
-
activation=activation,
|
330
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
331
|
-
no_combine=no_combine,
|
332
|
-
routed_scaling_factor=routed_scaling_factor,
|
333
|
-
activation_alpha=activation_alpha,
|
334
|
-
swiglu_limit=swiglu_limit,
|
311
|
+
moe_runner_config=moe_runner_config,
|
335
312
|
)
|
336
313
|
|
337
314
|
def forward_cpu(
|
@@ -339,21 +316,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
339
316
|
layer: torch.nn.Module,
|
340
317
|
x: torch.Tensor,
|
341
318
|
topk_output: TopKOutput,
|
342
|
-
|
343
|
-
activation: str = "silu",
|
344
|
-
apply_router_weight_on_input: bool = False,
|
345
|
-
inplace: bool = True,
|
346
|
-
no_combine: bool = False,
|
347
|
-
routed_scaling_factor: Optional[float] = None,
|
319
|
+
moe_runner_config: MoeRunnerConfig,
|
348
320
|
) -> torch.Tensor:
|
349
|
-
assert
|
350
|
-
|
351
|
-
|
321
|
+
assert (
|
322
|
+
moe_runner_config.activation == "silu"
|
323
|
+
), f"activation = {moe_runner_config.activation} is not supported."
|
324
|
+
|
325
|
+
if (
|
326
|
+
use_intel_amx_backend(layer)
|
327
|
+
and not moe_runner_config.apply_router_weight_on_input
|
328
|
+
):
|
352
329
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
353
330
|
|
354
331
|
topk_weights, topk_ids, _ = topk_output
|
355
332
|
x, topk_weights = apply_topk_weights_cpu(
|
356
|
-
apply_router_weight_on_input, topk_weights, x
|
333
|
+
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
357
334
|
)
|
358
335
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
359
336
|
x,
|
@@ -378,11 +355,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
378
355
|
layer,
|
379
356
|
x,
|
380
357
|
topk_output,
|
381
|
-
|
382
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
383
|
-
inplace=inplace,
|
384
|
-
no_combine=no_combine,
|
385
|
-
routed_scaling_factor=routed_scaling_factor,
|
358
|
+
moe_runner_config,
|
386
359
|
)
|
387
360
|
|
388
361
|
def forward_npu(
|
@@ -390,12 +363,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
390
363
|
layer: torch.nn.Module,
|
391
364
|
x: torch.Tensor,
|
392
365
|
topk_output: TopKOutput,
|
393
|
-
|
394
|
-
activation: str = "silu",
|
395
|
-
apply_router_weight_on_input: bool = False,
|
396
|
-
inplace: bool = True,
|
397
|
-
no_combine: bool = False,
|
398
|
-
routed_scaling_factor: Optional[float] = None,
|
366
|
+
moe_runner_config: MoeRunnerConfig,
|
399
367
|
) -> torch.Tensor:
|
400
368
|
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
401
369
|
|
@@ -403,11 +371,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
403
371
|
layer,
|
404
372
|
x,
|
405
373
|
topk_output,
|
406
|
-
|
407
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
408
|
-
inplace=inplace,
|
409
|
-
no_combine=no_combine,
|
410
|
-
routed_scaling_factor=routed_scaling_factor,
|
374
|
+
moe_runner_config,
|
411
375
|
)
|
412
376
|
|
413
377
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
@@ -146,6 +146,10 @@ def requantize_with_max_scale(
|
|
146
146
|
return max_w_scale, weight
|
147
147
|
|
148
148
|
|
149
|
+
def update_tensor_inplace(old: torch.Tensor, new: torch.Tensor) -> None:
|
150
|
+
old.copy_(new)
|
151
|
+
|
152
|
+
|
149
153
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/layer_utils.py
|
150
154
|
# Newly generated tensors need to replace existing tensors that are
|
151
155
|
# already registered as parameters by vLLM (and won't be freed)
|
@@ -172,6 +176,27 @@ def replace_parameter(
|
|
172
176
|
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
173
177
|
|
174
178
|
|
179
|
+
def assert_fp8_all_close(a: torch.Tensor, b: torch.Tensor):
|
180
|
+
assert a.shape == b.shape
|
181
|
+
assert a.dtype == b.dtype == torch.float8_e4m3fn
|
182
|
+
|
183
|
+
a_u8 = a.view(torch.uint8)
|
184
|
+
b_u8 = b.view(torch.uint8)
|
185
|
+
diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs()
|
186
|
+
|
187
|
+
numel = a.numel()
|
188
|
+
|
189
|
+
count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item()
|
190
|
+
count_tiny_diff = (diff_u8 >= 1).sum().item()
|
191
|
+
count_large_diff = (diff_u8 >= 2).sum().item()
|
192
|
+
|
193
|
+
assert (
|
194
|
+
(count_diff_sign == 0)
|
195
|
+
and (count_tiny_diff / numel < 0.005)
|
196
|
+
and (count_large_diff == 0)
|
197
|
+
), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=}"
|
198
|
+
|
199
|
+
|
175
200
|
# Match dynamic rules with module name (prefix) and override quantize
|
176
201
|
# config if module (prefix) matches a rule
|
177
202
|
def override_config(config: QuantizationConfig, prefix: str):
|
@@ -18,7 +18,9 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
|
|
18
18
|
from sglang.srt.utils import set_weight_attrs
|
19
19
|
|
20
20
|
if TYPE_CHECKING:
|
21
|
-
from sglang.srt.layers.moe
|
21
|
+
from sglang.srt.layers.moe import MoeRunnerConfig
|
22
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
23
|
+
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
22
24
|
|
23
25
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
24
26
|
|
@@ -280,11 +282,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
280
282
|
self,
|
281
283
|
layer: EPMoE,
|
282
284
|
x: torch.Tensor,
|
283
|
-
topk_output:
|
284
|
-
|
285
|
-
apply_router_weight_on_input: bool = False,
|
286
|
-
routed_scaling_factor: Optional[float] = None,
|
287
|
-
**kwargs,
|
285
|
+
topk_output: StandardTopKOutput,
|
286
|
+
moe_runner_config: MoeRunnerConfig,
|
288
287
|
) -> torch.Tensor:
|
289
288
|
|
290
289
|
# TODO(ch-wan): move it out of this class
|
@@ -324,6 +323,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
324
323
|
layer.w13_input_scale,
|
325
324
|
layer.w2_input_scale,
|
326
325
|
)
|
327
|
-
if routed_scaling_factor is not None:
|
328
|
-
output *= routed_scaling_factor
|
326
|
+
if moe_runner_config.routed_scaling_factor is not None:
|
327
|
+
output *= moe_runner_config.routed_scaling_factor
|
329
328
|
return output
|
@@ -26,7 +26,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
26
26
|
from sglang.srt.utils import set_weight_attrs
|
27
27
|
|
28
28
|
if TYPE_CHECKING:
|
29
|
-
from sglang.srt.layers.moe.
|
29
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
30
|
+
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
30
31
|
|
31
32
|
_is_fp8_fnuz = is_fp8_fnuz()
|
32
33
|
|
@@ -269,13 +270,8 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
|
269
270
|
self,
|
270
271
|
layer: torch.nn.Module,
|
271
272
|
x: torch.Tensor,
|
272
|
-
topk_output:
|
273
|
-
|
274
|
-
activation: str = "silu",
|
275
|
-
apply_router_weight_on_input: bool = False,
|
276
|
-
inplace: bool = True,
|
277
|
-
no_combine: bool = False,
|
278
|
-
routed_scaling_factor: Optional[float] = None,
|
273
|
+
topk_output: StandardTopKOutput,
|
274
|
+
moe_runner_config: MoeRunnerConfig,
|
279
275
|
) -> torch.Tensor:
|
280
276
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
281
277
|
|
@@ -284,15 +280,11 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
|
284
280
|
layer.w13_weight,
|
285
281
|
layer.w2_weight,
|
286
282
|
topk_output=topk_output,
|
287
|
-
|
288
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
289
|
-
activation=activation,
|
283
|
+
moe_runner_config=moe_runner_config,
|
290
284
|
use_fp8_w8a8=True,
|
291
285
|
per_channel_quant=True,
|
292
286
|
w1_scale=(layer.w13_weight_scale),
|
293
287
|
w2_scale=(layer.w2_weight_scale),
|
294
288
|
a1_scale=layer.w13_input_scale,
|
295
289
|
a2_scale=layer.w2_input_scale,
|
296
|
-
no_combine=no_combine,
|
297
|
-
routed_scaling_factor=routed_scaling_factor,
|
298
290
|
)
|
@@ -49,6 +49,7 @@ from sglang.srt.utils import (
|
|
49
49
|
)
|
50
50
|
|
51
51
|
if TYPE_CHECKING:
|
52
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
52
53
|
from sglang.srt.layers.moe.topk import TopKOutput
|
53
54
|
|
54
55
|
_is_cuda = is_cuda()
|
@@ -487,12 +488,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
487
488
|
layer: torch.nn.Module,
|
488
489
|
x: torch.Tensor,
|
489
490
|
topk_output: TopKOutput,
|
490
|
-
|
491
|
-
activation: str = "silu",
|
492
|
-
apply_router_weight_on_input: bool = False,
|
493
|
-
inplace: bool = True,
|
494
|
-
no_combine: bool = False,
|
495
|
-
routed_scaling_factor: Optional[float] = None,
|
491
|
+
moe_runner_config: MoeRunnerConfig,
|
496
492
|
) -> torch.Tensor:
|
497
493
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
498
494
|
|
@@ -501,7 +497,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
501
497
|
|
502
498
|
topk_weights, topk_ids, _ = topk_output
|
503
499
|
x, topk_weights = apply_topk_weights_cpu(
|
504
|
-
apply_router_weight_on_input, topk_weights, x
|
500
|
+
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
505
501
|
)
|
506
502
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
507
503
|
x,
|
@@ -525,17 +521,13 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
525
521
|
layer.w13_weight,
|
526
522
|
layer.w2_weight,
|
527
523
|
topk_output=topk_output,
|
528
|
-
|
529
|
-
activation=activation,
|
530
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
524
|
+
moe_runner_config=moe_runner_config,
|
531
525
|
use_int8_w8a8=True,
|
532
526
|
per_channel_quant=True,
|
533
527
|
w1_scale=(layer.w13_weight_scale),
|
534
528
|
w2_scale=(layer.w2_weight_scale),
|
535
529
|
a1_scale=layer.w13_input_scale,
|
536
530
|
a2_scale=layer.w2_input_scale,
|
537
|
-
no_combine=no_combine,
|
538
|
-
routed_scaling_factor=routed_scaling_factor,
|
539
531
|
)
|
540
532
|
|
541
533
|
|
@@ -982,7 +974,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
982
974
|
layer,
|
983
975
|
x,
|
984
976
|
topk_output: TopKOutput,
|
985
|
-
|
977
|
+
moe_runner_config: MoeRunnerConfig,
|
986
978
|
) -> torch.Tensor:
|
987
979
|
|
988
980
|
topk_weights, topk_ids, _ = topk_output
|
@@ -52,6 +52,8 @@ class RadixAttention(nn.Module):
|
|
52
52
|
v_head_dim: int = -1,
|
53
53
|
sliding_window_size: int = -1,
|
54
54
|
is_cross_attention: bool = False,
|
55
|
+
pos_encoding_mode: str = "NONE",
|
56
|
+
logit_capping_method: str = "tanh",
|
55
57
|
quant_config: Optional[QuantizationConfig] = None,
|
56
58
|
attn_type: AttentionType = AttentionType.DECODER,
|
57
59
|
use_irope: bool = False,
|
@@ -81,6 +83,10 @@ class RadixAttention(nn.Module):
|
|
81
83
|
self.quant_method.create_weights(self)
|
82
84
|
self.attn_type = attn_type
|
83
85
|
|
86
|
+
self.pos_encoding_mode = pos_encoding_mode
|
87
|
+
self.logit_capping_method = logit_capping_method
|
88
|
+
self.xai_temperature_len = -1
|
89
|
+
|
84
90
|
def forward(
|
85
91
|
self,
|
86
92
|
q,
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -32,8 +32,8 @@ from sglang.srt.lora.utils import (
|
|
32
32
|
LoRABatchInfo,
|
33
33
|
LoRAType,
|
34
34
|
get_layer_id,
|
35
|
-
|
36
|
-
|
35
|
+
get_normalized_target_modules,
|
36
|
+
get_target_module_name,
|
37
37
|
)
|
38
38
|
from sglang.srt.managers.io_struct import LoRAUpdateResult
|
39
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -55,7 +55,7 @@ class LoRAManager:
|
|
55
55
|
tp_rank: int = 0,
|
56
56
|
max_lora_rank: Optional[int] = None,
|
57
57
|
target_modules: Optional[Iterable[str]] = None,
|
58
|
-
lora_paths: Optional[
|
58
|
+
lora_paths: Optional[List[LoRARef]] = None,
|
59
59
|
):
|
60
60
|
self.base_model: torch.nn.Module = base_model
|
61
61
|
self.base_hf_config: AutoConfig = base_hf_config
|
@@ -350,19 +350,27 @@ class LoRAManager:
|
|
350
350
|
"""
|
351
351
|
for layer_id, layer_modules in enumerate(self.lora_modules):
|
352
352
|
for module_name, module in layer_modules.items():
|
353
|
-
|
354
|
-
module_name, self.memory_pool.
|
353
|
+
target_module = get_target_module_name(
|
354
|
+
module_name, self.memory_pool.target_modules
|
355
355
|
)
|
356
356
|
module.set_lora_info(
|
357
|
-
self.memory_pool.get_tensor(
|
358
|
-
|
357
|
+
self.memory_pool.get_tensor(
|
358
|
+
target_module=target_module,
|
359
|
+
layer_id=layer_id,
|
360
|
+
lora_type=LoRAType.LORA_A,
|
361
|
+
),
|
362
|
+
self.memory_pool.get_tensor(
|
363
|
+
target_module=target_module,
|
364
|
+
layer_id=layer_id,
|
365
|
+
lora_type=LoRAType.LORA_B,
|
366
|
+
),
|
359
367
|
)
|
360
368
|
|
361
369
|
def init_state(
|
362
370
|
self,
|
363
371
|
max_lora_rank: Optional[int] = None,
|
364
372
|
target_modules: Optional[Iterable[str]] = None,
|
365
|
-
lora_paths: Optional[
|
373
|
+
lora_paths: Optional[List[LoRARef]] = None,
|
366
374
|
):
|
367
375
|
"""
|
368
376
|
Initialize the internal (mutable) state of the LoRAManager.
|
@@ -380,12 +388,11 @@ class LoRAManager:
|
|
380
388
|
max_lora_rank=max_lora_rank,
|
381
389
|
target_modules=target_modules,
|
382
390
|
)
|
383
|
-
self.init_lora_weight_names()
|
384
391
|
self.init_lora_modules()
|
385
392
|
self.init_memory_pool()
|
386
393
|
self.update_lora_info()
|
387
394
|
|
388
|
-
def init_lora_adapters(self, lora_paths: Optional[
|
395
|
+
def init_lora_adapters(self, lora_paths: Optional[List[LoRARef]] = None):
|
389
396
|
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
390
397
|
self.configs: Dict[str, LoRAConfig] = {}
|
391
398
|
|
@@ -399,7 +406,7 @@ class LoRAManager:
|
|
399
406
|
self.num_pinned_loras: int = 0
|
400
407
|
|
401
408
|
if lora_paths:
|
402
|
-
for lora_ref in lora_paths
|
409
|
+
for lora_ref in lora_paths:
|
403
410
|
result = self.load_lora_adapter(lora_ref)
|
404
411
|
if not result.success:
|
405
412
|
raise RuntimeError(
|
@@ -426,6 +433,7 @@ class LoRAManager:
|
|
426
433
|
"enable all support modules types. "
|
427
434
|
)
|
428
435
|
self.target_modules.update(config.target_modules)
|
436
|
+
self.target_modules = get_normalized_target_modules(self.target_modules)
|
429
437
|
|
430
438
|
if max_lora_rank is not None:
|
431
439
|
self.max_lora_rank = max_lora_rank
|
@@ -435,15 +443,6 @@ class LoRAManager:
|
|
435
443
|
default=0,
|
436
444
|
)
|
437
445
|
|
438
|
-
def init_lora_weight_names(self):
|
439
|
-
"""
|
440
|
-
Add new LoRA weight names if needed based on the current `self.configs`.
|
441
|
-
"""
|
442
|
-
|
443
|
-
self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
|
444
|
-
self.target_modules
|
445
|
-
)
|
446
|
-
|
447
446
|
def load_lora_weights(self, lora_ref: LoRARef):
|
448
447
|
"""
|
449
448
|
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
|
@@ -467,7 +466,7 @@ class LoRAManager:
|
|
467
466
|
tp_size=self.tp_size,
|
468
467
|
tp_rank=self.tp_rank,
|
469
468
|
max_lora_rank=self.max_lora_rank,
|
470
|
-
|
469
|
+
target_modules=self.target_modules,
|
471
470
|
base_model=self.base_model,
|
472
471
|
)
|
473
472
|
|
@@ -494,7 +493,7 @@ class LoRAManager:
|
|
494
493
|
continue
|
495
494
|
|
496
495
|
# The module should be converted if it is included in target_names
|
497
|
-
if module_name.split(".")[-1] in self.
|
496
|
+
if module_name.split(".")[-1] in self.target_modules:
|
498
497
|
layer_id = get_layer_id(module_name)
|
499
498
|
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
500
499
|
module_name, module
|
sglang/srt/lora/lora_registry.py
CHANGED
@@ -59,9 +59,9 @@ class LoRARegistry:
|
|
59
59
|
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
|
60
60
|
"""
|
61
61
|
|
62
|
-
def __init__(self, lora_paths: Optional[
|
62
|
+
def __init__(self, lora_paths: Optional[List[LoRARef]] = None):
|
63
63
|
assert lora_paths is None or all(
|
64
|
-
isinstance(lora, LoRARef) for lora in lora_paths
|
64
|
+
isinstance(lora, LoRARef) for lora in lora_paths
|
65
65
|
), (
|
66
66
|
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
|
67
67
|
"Please file an issue if you see this error."
|
@@ -78,7 +78,7 @@ class LoRARegistry:
|
|
78
78
|
|
79
79
|
# Initialize the registry with provided LoRA paths, if present.
|
80
80
|
if lora_paths:
|
81
|
-
for lora_ref in lora_paths
|
81
|
+
for lora_ref in lora_paths:
|
82
82
|
self._register_adapter(lora_ref)
|
83
83
|
|
84
84
|
async def register(self, lora_ref: LoRARef):
|