sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
sglang/srt/layers/elementwise.py
CHANGED
@@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip
|
|
8
8
|
|
9
9
|
_is_hip = is_hip()
|
10
10
|
|
11
|
+
|
11
12
|
fused_softcap_autotune = triton.autotune(
|
12
13
|
configs=[
|
13
14
|
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
|
@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
|
|
189
190
|
assert x.shape == residual.shape and x.dtype == residual.dtype
|
190
191
|
output, mid = torch.empty_like(x), torch.empty_like(x)
|
191
192
|
bs, hidden_dim = x.shape
|
192
|
-
|
193
|
-
min_num_warps = 16 if _is_hip else 32
|
194
|
-
|
195
193
|
if autotune:
|
196
194
|
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
|
197
195
|
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
|
198
196
|
)
|
199
197
|
else:
|
198
|
+
max_warps = 16 if _is_hip else 32
|
200
199
|
config = {
|
201
200
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
202
201
|
"num_warps": max(
|
203
|
-
min(
|
204
|
-
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
|
205
|
-
),
|
206
|
-
4,
|
202
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
|
207
203
|
),
|
208
204
|
}
|
209
205
|
|
@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
|
|
260
256
|
else:
|
261
257
|
output = torch.empty_like(x)
|
262
258
|
bs, hidden_dim = x.shape
|
263
|
-
|
264
|
-
min_num_warps = 16 if _is_hip else 32
|
265
|
-
|
259
|
+
max_warps = 16 if _is_hip else 32
|
266
260
|
config = {
|
267
261
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
268
262
|
"num_warps": max(
|
269
|
-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)),
|
263
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
|
270
264
|
),
|
271
265
|
}
|
272
266
|
|
@@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm:
|
|
331
325
|
return self.rmsnorm2.forward_native(residual), residual
|
332
326
|
|
333
327
|
|
328
|
+
@triton.jit
|
329
|
+
def experts_combine_kernel(
|
330
|
+
out_hidden_states,
|
331
|
+
moe_hidden_states,
|
332
|
+
mlp_hidden_states,
|
333
|
+
combine_k: tl.constexpr,
|
334
|
+
hidden_dim: tl.constexpr,
|
335
|
+
BLOCK_SIZE: tl.constexpr,
|
336
|
+
):
|
337
|
+
pid = tl.program_id(0)
|
338
|
+
start_index_mlp = pid * hidden_dim
|
339
|
+
start_index_rmoe = pid * hidden_dim * combine_k
|
340
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
341
|
+
mask = offsets < hidden_dim
|
342
|
+
combine_k_offsets = tl.arange(0, combine_k)
|
343
|
+
|
344
|
+
moe_x = tl.load(
|
345
|
+
moe_hidden_states
|
346
|
+
+ start_index_rmoe
|
347
|
+
+ combine_k_offsets[:, None] * hidden_dim
|
348
|
+
+ offsets[None, :],
|
349
|
+
mask=mask[None, :],
|
350
|
+
other=0.0,
|
351
|
+
)
|
352
|
+
moe_x = tl.sum(moe_x, axis=0)
|
353
|
+
mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0)
|
354
|
+
combined_x = (moe_x + mlp_x) / 1.4142135623730951
|
355
|
+
|
356
|
+
tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask)
|
357
|
+
|
358
|
+
|
359
|
+
def experts_combine_triton(moe_hidden_states, mlp_hidden_states, output_buffer=None):
|
360
|
+
assert moe_hidden_states.is_contiguous()
|
361
|
+
assert mlp_hidden_states.is_contiguous()
|
362
|
+
|
363
|
+
if len(moe_hidden_states.shape) == 2:
|
364
|
+
combine_k = 1 # pre-combined
|
365
|
+
else:
|
366
|
+
combine_k = moe_hidden_states.shape[1]
|
367
|
+
|
368
|
+
if output_buffer is None:
|
369
|
+
out_hidden_states = torch.empty_like(mlp_hidden_states)
|
370
|
+
else:
|
371
|
+
flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1)
|
372
|
+
assert flat_output_buffer.numel() >= mlp_hidden_states.numel()
|
373
|
+
out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape(
|
374
|
+
mlp_hidden_states.shape
|
375
|
+
)
|
376
|
+
|
377
|
+
bs, hidden_dim = mlp_hidden_states.shape
|
378
|
+
|
379
|
+
config = {
|
380
|
+
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
381
|
+
"num_warps": max(
|
382
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4
|
383
|
+
),
|
384
|
+
}
|
385
|
+
|
386
|
+
experts_combine_kernel[(bs,)](
|
387
|
+
out_hidden_states,
|
388
|
+
moe_hidden_states,
|
389
|
+
mlp_hidden_states,
|
390
|
+
combine_k,
|
391
|
+
hidden_dim,
|
392
|
+
**config,
|
393
|
+
)
|
394
|
+
return out_hidden_states
|
395
|
+
|
396
|
+
|
334
397
|
# gelu on first half of vector
|
335
398
|
@triton.jit
|
336
399
|
def gelu_and_mul_kernel(
|
@@ -400,10 +463,11 @@ def gelu_and_mul_triton(
|
|
400
463
|
out_scales = scales
|
401
464
|
static_scale = True
|
402
465
|
|
466
|
+
max_warps = 16 if _is_hip else 32
|
403
467
|
config = {
|
404
468
|
# 8 ele per thread (not tuned)
|
405
469
|
"num_warps": max(
|
406
|
-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)),
|
470
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
|
407
471
|
),
|
408
472
|
}
|
409
473
|
|
@@ -0,0 +1,202 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.distributed as dist
|
6
|
+
|
7
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
8
|
+
from sglang.srt.utils import is_flashinfer_available
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
_flashinfer_comm = None
|
13
|
+
_workspace_manager = None
|
14
|
+
|
15
|
+
if is_flashinfer_available():
|
16
|
+
try:
|
17
|
+
import flashinfer.comm as comm
|
18
|
+
|
19
|
+
_flashinfer_comm = comm
|
20
|
+
except ImportError:
|
21
|
+
logger.warning(
|
22
|
+
"flashinfer.comm is not available, falling back to standard "
|
23
|
+
"implementation"
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
class FlashInferWorkspaceManager:
|
28
|
+
def __init__(self):
|
29
|
+
self.workspace_tensor = None
|
30
|
+
self.ipc_handles = None
|
31
|
+
self.world_size = None
|
32
|
+
self.rank = None
|
33
|
+
self.initialized = False
|
34
|
+
|
35
|
+
def initialize(
|
36
|
+
self,
|
37
|
+
world_size: int,
|
38
|
+
rank: int,
|
39
|
+
max_token_num: int,
|
40
|
+
hidden_dim: int,
|
41
|
+
group=None,
|
42
|
+
use_fp32_lamport: bool = False,
|
43
|
+
):
|
44
|
+
"""Initialize workspace"""
|
45
|
+
if self.initialized and self.world_size == world_size:
|
46
|
+
return
|
47
|
+
|
48
|
+
if _flashinfer_comm is None:
|
49
|
+
logger.warning(
|
50
|
+
"FlashInfer comm not available, skipping workspace " "initialization"
|
51
|
+
)
|
52
|
+
return
|
53
|
+
|
54
|
+
self.cleanup()
|
55
|
+
|
56
|
+
self.ipc_handles, self.workspace_tensor = (
|
57
|
+
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
58
|
+
rank,
|
59
|
+
world_size,
|
60
|
+
max_token_num,
|
61
|
+
hidden_dim,
|
62
|
+
group=group,
|
63
|
+
use_fp32_lamport=use_fp32_lamport,
|
64
|
+
)
|
65
|
+
)
|
66
|
+
|
67
|
+
self.world_size = world_size
|
68
|
+
self.rank = rank
|
69
|
+
self.initialized = True
|
70
|
+
|
71
|
+
logger.info(
|
72
|
+
f"FlashInfer workspace initialized for rank {rank}, "
|
73
|
+
f"world_size {world_size}"
|
74
|
+
)
|
75
|
+
|
76
|
+
def cleanup(self):
|
77
|
+
"""Clean up workspace"""
|
78
|
+
if self.initialized and self.ipc_handles is not None:
|
79
|
+
try:
|
80
|
+
_flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
|
81
|
+
self.ipc_handles, group=dist.group.WORLD
|
82
|
+
)
|
83
|
+
except Exception as e:
|
84
|
+
logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
|
85
|
+
finally:
|
86
|
+
self.workspace_tensor = None
|
87
|
+
self.ipc_handles = None
|
88
|
+
self.initialized = False
|
89
|
+
|
90
|
+
|
91
|
+
_workspace_manager = FlashInferWorkspaceManager()
|
92
|
+
|
93
|
+
|
94
|
+
def ensure_workspace_initialized(
|
95
|
+
max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
96
|
+
):
|
97
|
+
"""Ensure workspace is initialized"""
|
98
|
+
if not is_flashinfer_available() or _flashinfer_comm is None:
|
99
|
+
return False
|
100
|
+
|
101
|
+
world_size = get_tensor_model_parallel_world_size()
|
102
|
+
if world_size <= 1:
|
103
|
+
return False
|
104
|
+
|
105
|
+
rank = dist.get_rank()
|
106
|
+
|
107
|
+
if (
|
108
|
+
not _workspace_manager.initialized
|
109
|
+
or _workspace_manager.world_size != world_size
|
110
|
+
):
|
111
|
+
_workspace_manager.initialize(
|
112
|
+
world_size=world_size,
|
113
|
+
rank=rank,
|
114
|
+
max_token_num=max_token_num,
|
115
|
+
hidden_dim=hidden_dim,
|
116
|
+
use_fp32_lamport=use_fp32_lamport,
|
117
|
+
)
|
118
|
+
|
119
|
+
return _workspace_manager.initialized
|
120
|
+
|
121
|
+
|
122
|
+
def flashinfer_allreduce_add_rmsnorm(
|
123
|
+
input_tensor: torch.Tensor,
|
124
|
+
residual: torch.Tensor,
|
125
|
+
weight: torch.Tensor,
|
126
|
+
eps: float = 1e-6,
|
127
|
+
max_token_num: int = 1024,
|
128
|
+
use_oneshot: bool = True,
|
129
|
+
trigger_completion_at_end: bool = False,
|
130
|
+
fp32_acc: bool = False,
|
131
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
132
|
+
"""
|
133
|
+
Use FlashInfer's fused allreduce + residual + RMS norm operation
|
134
|
+
|
135
|
+
Args:
|
136
|
+
input_tensor: Input tensor that needs allreduce
|
137
|
+
residual: Residual tensor
|
138
|
+
weight: RMS norm weight
|
139
|
+
eps: RMS norm epsilon
|
140
|
+
max_token_num: Maximum token number
|
141
|
+
use_oneshot: Whether to use oneshot mode
|
142
|
+
trigger_completion_at_end: Whether to trigger completion at end
|
143
|
+
fp32_acc: Whether to use fp32 precision
|
144
|
+
|
145
|
+
Returns:
|
146
|
+
Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output)
|
147
|
+
"""
|
148
|
+
if not is_flashinfer_available() or _flashinfer_comm is None:
|
149
|
+
logger.debug(
|
150
|
+
"FlashInfer not available, falling back to standard " "implementation"
|
151
|
+
)
|
152
|
+
return None, None
|
153
|
+
|
154
|
+
world_size = get_tensor_model_parallel_world_size()
|
155
|
+
if world_size <= 1:
|
156
|
+
logger.debug("Single GPU, no need for allreduce fusion")
|
157
|
+
return None, None
|
158
|
+
|
159
|
+
if not ensure_workspace_initialized(
|
160
|
+
max_token_num=max_token_num,
|
161
|
+
hidden_dim=input_tensor.shape[-1],
|
162
|
+
use_fp32_lamport=(input_tensor.dtype == torch.float32),
|
163
|
+
):
|
164
|
+
logger.debug("FlashInfer workspace not available")
|
165
|
+
return None, None
|
166
|
+
|
167
|
+
token_num, hidden_dim = input_tensor.shape
|
168
|
+
|
169
|
+
residual_out = torch.empty_like(residual)
|
170
|
+
norm_out = torch.empty_like(input_tensor)
|
171
|
+
|
172
|
+
_flashinfer_comm.trtllm_allreduce_fusion(
|
173
|
+
allreduce_in=input_tensor,
|
174
|
+
world_size=world_size,
|
175
|
+
world_rank=dist.get_rank(),
|
176
|
+
token_num=token_num,
|
177
|
+
hidden_dim=hidden_dim,
|
178
|
+
workspace_ptrs=_workspace_manager.workspace_tensor,
|
179
|
+
launch_with_pdl=True,
|
180
|
+
use_oneshot=use_oneshot,
|
181
|
+
trigger_completion_at_end=trigger_completion_at_end,
|
182
|
+
fp32_acc=fp32_acc,
|
183
|
+
pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm),
|
184
|
+
allreduce_out=None,
|
185
|
+
residual_in=residual,
|
186
|
+
residual_out=residual_out,
|
187
|
+
norm_out=norm_out,
|
188
|
+
quant_out=None,
|
189
|
+
scale_out=None,
|
190
|
+
rms_gamma=weight,
|
191
|
+
rms_eps=eps,
|
192
|
+
scale_factor=None,
|
193
|
+
layout_code=None,
|
194
|
+
)
|
195
|
+
|
196
|
+
return norm_out, residual_out
|
197
|
+
|
198
|
+
|
199
|
+
def cleanup_flashinfer_workspace():
|
200
|
+
global _workspace_manager
|
201
|
+
if _workspace_manager is not None:
|
202
|
+
_workspace_manager.cleanup()
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -52,6 +52,9 @@ elif _is_hip:
|
|
52
52
|
|
53
53
|
logger = logging.getLogger(__name__)
|
54
54
|
|
55
|
+
if is_npu():
|
56
|
+
import torch_npu
|
57
|
+
|
55
58
|
|
56
59
|
class RMSNorm(CustomOp):
|
57
60
|
def __init__(
|
@@ -76,6 +79,18 @@ class RMSNorm(CustomOp):
|
|
76
79
|
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
77
80
|
return out
|
78
81
|
|
82
|
+
def forward_npu(
|
83
|
+
self,
|
84
|
+
x: torch.Tensor,
|
85
|
+
residual: Optional[torch.Tensor] = None,
|
86
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
87
|
+
if residual is not None:
|
88
|
+
out, _, residual_out = torch_npu.npu_add_rms_norm(
|
89
|
+
residual, x, self.weight.data, self.variance_epsilon
|
90
|
+
)
|
91
|
+
return out, residual_out
|
92
|
+
return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
|
93
|
+
|
79
94
|
def forward_aiter(
|
80
95
|
self,
|
81
96
|
x: torch.Tensor,
|
@@ -148,6 +163,32 @@ class RMSNorm(CustomOp):
|
|
148
163
|
else:
|
149
164
|
return self.forward_native(x, residual)
|
150
165
|
|
166
|
+
def forward_with_allreduce_fusion(
|
167
|
+
self,
|
168
|
+
x: torch.Tensor,
|
169
|
+
residual: Optional[torch.Tensor] = None,
|
170
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
171
|
+
"""
|
172
|
+
Forward method with allreduce fusion, prioritizing flashinfer fused operations
|
173
|
+
"""
|
174
|
+
if residual is not None:
|
175
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
176
|
+
from sglang.srt.layers.flashinfer_comm_fusion import (
|
177
|
+
flashinfer_allreduce_add_rmsnorm,
|
178
|
+
)
|
179
|
+
|
180
|
+
if get_tensor_model_parallel_world_size() > 1:
|
181
|
+
fused_result = flashinfer_allreduce_add_rmsnorm(
|
182
|
+
input_tensor=x,
|
183
|
+
residual=residual,
|
184
|
+
weight=self.weight,
|
185
|
+
eps=self.variance_epsilon,
|
186
|
+
)
|
187
|
+
if fused_result[0] is not None:
|
188
|
+
return fused_result
|
189
|
+
|
190
|
+
return self.forward(x, residual)
|
191
|
+
|
151
192
|
|
152
193
|
class GemmaRMSNorm(CustomOp):
|
153
194
|
def __init__(
|
sglang/srt/layers/linear.py
CHANGED
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
|
|
17
17
|
tensor_model_parallel_all_gather,
|
18
18
|
tensor_model_parallel_all_reduce,
|
19
19
|
)
|
20
|
+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
20
21
|
from sglang.srt.layers.parameter import (
|
21
22
|
BasevLLMParameter,
|
22
23
|
BlockQuantScaleParameter,
|
@@ -30,7 +31,12 @@ from sglang.srt.layers.quantization.base_config import (
|
|
30
31
|
QuantizationConfig,
|
31
32
|
QuantizeMethodBase,
|
32
33
|
)
|
33
|
-
from sglang.srt.utils import
|
34
|
+
from sglang.srt.utils import (
|
35
|
+
cpu_has_amx_support,
|
36
|
+
is_cpu,
|
37
|
+
set_weight_attrs,
|
38
|
+
use_intel_amx_backend,
|
39
|
+
)
|
34
40
|
|
35
41
|
logger = logging.getLogger(__name__)
|
36
42
|
|
@@ -52,6 +58,9 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|
52
58
|
"IPEXAWQLinearMethod",
|
53
59
|
]
|
54
60
|
|
61
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
62
|
+
_is_cpu = is_cpu()
|
63
|
+
|
55
64
|
|
56
65
|
def adjust_marlin_shard(param, shard_size, shard_offset):
|
57
66
|
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
@@ -165,6 +174,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|
165
174
|
layer.register_parameter("weight", weight)
|
166
175
|
set_weight_attrs(weight, extra_weight_attrs)
|
167
176
|
|
177
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
178
|
+
if _is_cpu and _is_cpu_amx_available:
|
179
|
+
_amx_process_weight_after_loading(layer, ["weight"])
|
180
|
+
|
168
181
|
def apply(
|
169
182
|
self,
|
170
183
|
layer: torch.nn.Module,
|
@@ -172,6 +185,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|
172
185
|
bias: Optional[torch.Tensor] = None,
|
173
186
|
) -> torch.Tensor:
|
174
187
|
|
188
|
+
if use_intel_amx_backend(layer):
|
189
|
+
return torch.ops.sgl_kernel.weight_packed_linear(
|
190
|
+
x, layer.weight, bias, True # is_vnni
|
191
|
+
)
|
192
|
+
|
175
193
|
return F.linear(x, layer.weight, bias)
|
176
194
|
|
177
195
|
|
@@ -408,8 +426,26 @@ class ColumnParallelLinear(LinearBase):
|
|
408
426
|
if output_dim is not None and not use_bitsandbytes_4bit:
|
409
427
|
shard_size = param_data.shape[output_dim]
|
410
428
|
start_idx = self.tp_rank * shard_size
|
411
|
-
|
412
|
-
|
429
|
+
|
430
|
+
if _is_cpu:
|
431
|
+
from sglang.srt.model_loader.weight_utils import (
|
432
|
+
narrow_padded_param_and_loaded_weight,
|
433
|
+
)
|
434
|
+
|
435
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
436
|
+
param_data,
|
437
|
+
loaded_weight,
|
438
|
+
0, # param_data_start
|
439
|
+
start_idx,
|
440
|
+
output_dim,
|
441
|
+
shard_size,
|
442
|
+
not self.use_presharded_weights,
|
443
|
+
)
|
444
|
+
else:
|
445
|
+
if not self.use_presharded_weights:
|
446
|
+
loaded_weight = loaded_weight.narrow(
|
447
|
+
output_dim, start_idx, shard_size
|
448
|
+
)
|
413
449
|
|
414
450
|
# Special case for loading scales off disk, which often do not
|
415
451
|
# have a shape (such as in the case of AutoFP8).
|
@@ -626,10 +662,29 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
626
662
|
|
627
663
|
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
628
664
|
start_idx = self.tp_rank * shard_size
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
665
|
+
|
666
|
+
if _is_cpu:
|
667
|
+
from sglang.srt.model_loader.weight_utils import (
|
668
|
+
narrow_padded_param_and_loaded_weight,
|
669
|
+
)
|
670
|
+
|
671
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
672
|
+
param_data,
|
673
|
+
loaded_weight,
|
674
|
+
0, # param_data_start
|
675
|
+
start_idx,
|
676
|
+
output_dim,
|
677
|
+
shard_size,
|
678
|
+
not use_bitsandbytes_4bit and not self.use_presharded_weights,
|
679
|
+
)
|
680
|
+
else:
|
681
|
+
# bitsandbytes loads the weights of the specific portion
|
682
|
+
# no need to narrow here
|
683
|
+
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
684
|
+
loaded_weight = loaded_weight.narrow(
|
685
|
+
output_dim, start_idx, shard_size
|
686
|
+
)
|
687
|
+
|
633
688
|
# Special case for AQLM codebooks.
|
634
689
|
elif is_metadata:
|
635
690
|
# metadata indicates fixed size concatenated along dim 0
|
@@ -1094,10 +1149,27 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
1094
1149
|
shard_id = self.tp_rank // self.num_kv_head_replicas
|
1095
1150
|
start_idx = shard_id * shard_size
|
1096
1151
|
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1152
|
+
if _is_cpu:
|
1153
|
+
from sglang.srt.model_loader.weight_utils import (
|
1154
|
+
narrow_padded_param_and_loaded_weight,
|
1155
|
+
)
|
1156
|
+
|
1157
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
1158
|
+
param_data,
|
1159
|
+
loaded_weight,
|
1160
|
+
0, # param_data_start
|
1161
|
+
start_idx,
|
1162
|
+
output_dim,
|
1163
|
+
shard_size,
|
1164
|
+
not use_bitsandbytes_4bit and not self.use_presharded_weights,
|
1165
|
+
)
|
1166
|
+
else:
|
1167
|
+
# bitsandbytes loads the weights of the specific portion
|
1168
|
+
# no need to narrow here
|
1169
|
+
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
1170
|
+
loaded_weight = loaded_weight.narrow(
|
1171
|
+
output_dim, start_idx, shard_size
|
1172
|
+
)
|
1101
1173
|
|
1102
1174
|
# Special case for for AQLM codebooks.
|
1103
1175
|
elif is_metadata:
|
@@ -1239,7 +1311,22 @@ class RowParallelLinear(LinearBase):
|
|
1239
1311
|
):
|
1240
1312
|
shard_size = param_data.shape[input_dim]
|
1241
1313
|
start_idx = self.tp_rank * shard_size
|
1242
|
-
|
1314
|
+
|
1315
|
+
if _is_cpu:
|
1316
|
+
from sglang.srt.model_loader.weight_utils import (
|
1317
|
+
narrow_padded_param_and_loaded_weight,
|
1318
|
+
)
|
1319
|
+
|
1320
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
1321
|
+
param_data,
|
1322
|
+
loaded_weight,
|
1323
|
+
0, # param_data_start
|
1324
|
+
start_idx,
|
1325
|
+
input_dim,
|
1326
|
+
shard_size,
|
1327
|
+
)
|
1328
|
+
else:
|
1329
|
+
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
1243
1330
|
|
1244
1331
|
# Special case for loading scales off disk, which often do not
|
1245
1332
|
# have a shape (such as in the case of AutoFP8).
|
@@ -42,7 +42,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
42
42
|
ForwardBatch,
|
43
43
|
ForwardMode,
|
44
44
|
)
|
45
|
-
from sglang.srt.utils import dump_to_file
|
45
|
+
from sglang.srt.utils import dump_to_file, use_intel_amx_backend
|
46
46
|
|
47
47
|
logger = logging.getLogger(__name__)
|
48
48
|
|
@@ -436,17 +436,26 @@ class LogitsProcessor(nn.Module):
|
|
436
436
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
437
437
|
logits_metadata.compute_dp_attention_metadata(hidden_states)
|
438
438
|
hidden_states, local_hidden_states = (
|
439
|
-
logits_metadata.gathered_buffer,
|
440
|
-
hidden_states
|
439
|
+
torch.empty_like(logits_metadata.gathered_buffer),
|
440
|
+
hidden_states,
|
441
441
|
)
|
442
442
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
443
443
|
|
444
444
|
if hasattr(lm_head, "weight"):
|
445
|
-
|
446
|
-
|
447
|
-
|
445
|
+
if use_intel_amx_backend(lm_head):
|
446
|
+
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
447
|
+
hidden_states.to(lm_head.weight.dtype),
|
448
|
+
lm_head.weight,
|
449
|
+
None, # bias
|
450
|
+
True, # is_vnni
|
451
|
+
)
|
452
|
+
else:
|
453
|
+
logits = torch.matmul(
|
454
|
+
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
|
455
|
+
)
|
448
456
|
else:
|
449
457
|
# GGUF models
|
458
|
+
# TODO: use weight_packed_linear for GGUF models
|
450
459
|
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
|
451
460
|
|
452
461
|
if self.logit_scale is not None:
|
@@ -4,9 +4,8 @@ from typing import List, Optional
|
|
4
4
|
import torch
|
5
5
|
import triton
|
6
6
|
|
7
|
-
from sglang.math_utils import ceil_div
|
8
7
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
9
|
-
from sglang.srt.utils import dispose_tensor, is_cuda
|
8
|
+
from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
|
10
9
|
|
11
10
|
logger = logging.getLogger(__name__)
|
12
11
|
|
@@ -814,14 +813,17 @@ def _fwd_kernel_ep_scatter_2(
|
|
814
813
|
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
|
815
814
|
mask = offset_in < HIDDEN_SIZE
|
816
815
|
|
817
|
-
|
818
|
-
mask_s =
|
816
|
+
index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
|
817
|
+
mask_s = index_in_s < SCALE_HIDDEN_SIZE
|
819
818
|
|
820
819
|
for token_id_int32 in range(start_token_id, total_token_num, grid_num):
|
821
820
|
token_id = token_id_int32.to(tl.int64)
|
822
821
|
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
|
823
822
|
to_copy_s = tl.load(
|
824
|
-
recv_x_scale
|
823
|
+
recv_x_scale
|
824
|
+
+ token_id * recv_x_scale_stride0
|
825
|
+
+ index_in_s * recv_x_scale_stride1,
|
826
|
+
mask=mask_s,
|
825
827
|
)
|
826
828
|
|
827
829
|
for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
|
@@ -842,7 +844,11 @@ def _fwd_kernel_ep_scatter_2(
|
|
842
844
|
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
|
843
845
|
)
|
844
846
|
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
|
845
|
-
tl.store(
|
847
|
+
tl.store(
|
848
|
+
output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1,
|
849
|
+
to_copy_s,
|
850
|
+
mask=mask_s,
|
851
|
+
)
|
846
852
|
|
847
853
|
|
848
854
|
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
|
@@ -857,6 +863,7 @@ def ep_scatter(
|
|
857
863
|
output_tensor_scale: torch.Tensor,
|
858
864
|
m_indices: torch.Tensor,
|
859
865
|
output_index: torch.Tensor,
|
866
|
+
scale_ue8m0: bool = False,
|
860
867
|
):
|
861
868
|
BLOCK_E = 128 # token num of per expert is aligned to 128
|
862
869
|
BLOCK_D = 128 # block size of quantization
|
@@ -866,7 +873,15 @@ def ep_scatter(
|
|
866
873
|
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
|
867
874
|
grid = num_experts
|
868
875
|
|
876
|
+
scale_hidden_size = hidden_size // BLOCK_D
|
877
|
+
if scale_ue8m0:
|
878
|
+
# ue8m0 scales are packed here (4 scales per int32),
|
879
|
+
# hence the effective size of this dimension is divided by 4.
|
880
|
+
scale_hidden_size = ceil_div(scale_hidden_size, 4)
|
881
|
+
|
869
882
|
assert m_indices.shape[0] % BLOCK_E == 0
|
883
|
+
assert recv_x_scale.dtype == output_tensor_scale.dtype
|
884
|
+
assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size
|
870
885
|
|
871
886
|
_fwd_kernel_ep_scatter_1[(grid,)](
|
872
887
|
num_recv_tokens_per_expert,
|
@@ -905,8 +920,8 @@ def ep_scatter(
|
|
905
920
|
num_warps=num_warps,
|
906
921
|
HIDDEN_SIZE=hidden_size,
|
907
922
|
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
|
908
|
-
SCALE_HIDDEN_SIZE=
|
909
|
-
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(
|
923
|
+
SCALE_HIDDEN_SIZE=scale_hidden_size,
|
924
|
+
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
|
910
925
|
)
|
911
926
|
return
|
912
927
|
|