sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -0
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +7 -7
- sglang/srt/disaggregation/decode.py +8 -3
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +4 -5
- sglang/srt/entrypoints/openai/protocol.py +0 -9
- sglang/srt/entrypoints/openai/serving_chat.py +59 -265
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +8 -10
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/quantization/__init__.py +5 -3
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/modelopt_quant.py +6 -11
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +21 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +6 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +35 -20
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +15 -7
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +25 -26
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +22 -3
- sglang/srt/model_executor/forward_batch_info.py +26 -5
- sglang/srt/model_executor/model_runner.py +129 -35
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_v2.py +74 -35
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +9 -9
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +136 -19
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/server_args.py +115 -139
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +12 -4
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1356,3 +1356,280 @@ def per_token_group_quant_fp8_hopper_moe_mn_major(
|
|
1356
1356
|
expert_tokens_alignment,
|
1357
1357
|
)
|
1358
1358
|
return a_q, sfa
|
1359
|
+
|
1360
|
+
|
1361
|
+
@triton.jit
|
1362
|
+
def _per_group_transpose(
|
1363
|
+
data_ptr: torch.Tensor,
|
1364
|
+
trans_data_ptr: torch.Tensor,
|
1365
|
+
expert_offsets: torch.Tensor,
|
1366
|
+
k: int,
|
1367
|
+
M_ALIGNMENT: tl.constexpr,
|
1368
|
+
BLOCK_SIZE_M: tl.constexpr,
|
1369
|
+
BLOCK_SIZE_K: tl.constexpr,
|
1370
|
+
):
|
1371
|
+
expert_id = tl.program_id(0)
|
1372
|
+
m_id = tl.program_id(1)
|
1373
|
+
k_id = tl.program_id(2)
|
1374
|
+
|
1375
|
+
curr_expert_offset = tl.load(expert_offsets + expert_id)
|
1376
|
+
next_expert_offset = tl.load(expert_offsets + expert_id + 1)
|
1377
|
+
num_tokens_of_expert = next_expert_offset - curr_expert_offset
|
1378
|
+
tl.multiple_of(curr_expert_offset, M_ALIGNMENT)
|
1379
|
+
tl.multiple_of(next_expert_offset, M_ALIGNMENT)
|
1380
|
+
|
1381
|
+
data_start_ptr = data_ptr + curr_expert_offset * k
|
1382
|
+
trans_data_start_ptr = trans_data_ptr + curr_expert_offset * k
|
1383
|
+
|
1384
|
+
k_coord = k_id * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
1385
|
+
k_mask = k_coord < k
|
1386
|
+
for start_m in tl.range(0, num_tokens_of_expert, BLOCK_SIZE_M * tl.num_programs(1)):
|
1387
|
+
m_coord = start_m + m_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
1388
|
+
m_mask = m_coord < num_tokens_of_expert
|
1389
|
+
off = m_coord[:, None] * k + k_coord[None, :]
|
1390
|
+
trans_off = m_coord[:, None] + k_coord[None, :] * num_tokens_of_expert
|
1391
|
+
mask = m_mask[:, None] & k_mask[None, :]
|
1392
|
+
|
1393
|
+
data = tl.load(data_start_ptr + off, mask=mask)
|
1394
|
+
tl.store(trans_data_start_ptr + trans_off, data, mask=mask)
|
1395
|
+
|
1396
|
+
|
1397
|
+
def per_group_transpose(
|
1398
|
+
a: torch.Tensor,
|
1399
|
+
expert_offsets: torch.Tensor,
|
1400
|
+
M_ALIGNMENT: int = 1,
|
1401
|
+
) -> torch.Tensor:
|
1402
|
+
assert a.dim() == 2
|
1403
|
+
assert a.is_contiguous(), "`a` is not contiguous"
|
1404
|
+
|
1405
|
+
m, k = a.size()
|
1406
|
+
trans_a = torch.empty_like(a)
|
1407
|
+
num_experts = expert_offsets.size(0) - 1
|
1408
|
+
|
1409
|
+
grid = lambda META: (
|
1410
|
+
num_experts,
|
1411
|
+
triton.cdiv((m + num_experts - 1) // num_experts, META["BLOCK_SIZE_M"]),
|
1412
|
+
triton.cdiv(k, META["BLOCK_SIZE_K"]),
|
1413
|
+
)
|
1414
|
+
_per_group_transpose[grid](
|
1415
|
+
a, trans_a, expert_offsets, k, M_ALIGNMENT, BLOCK_SIZE_M=16, BLOCK_SIZE_K=8
|
1416
|
+
)
|
1417
|
+
return trans_a
|
1418
|
+
|
1419
|
+
|
1420
|
+
def is_weak_contiguous(x: torch.Tensor):
|
1421
|
+
strides = x.stride()
|
1422
|
+
sizes = x.shape
|
1423
|
+
is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
|
1424
|
+
is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
|
1425
|
+
return is_transpose or is_not_transpose
|
1426
|
+
|
1427
|
+
|
1428
|
+
@triton.jit
|
1429
|
+
def scaled_mm_kernel(
|
1430
|
+
a_ptr,
|
1431
|
+
b_ptr,
|
1432
|
+
scale_a_ptr,
|
1433
|
+
scale_b_ptr,
|
1434
|
+
c_ptr,
|
1435
|
+
bias_ptr,
|
1436
|
+
M,
|
1437
|
+
N,
|
1438
|
+
K,
|
1439
|
+
stride_am,
|
1440
|
+
stride_ak,
|
1441
|
+
stride_bk,
|
1442
|
+
stride_bn,
|
1443
|
+
stride_cm,
|
1444
|
+
stride_cn,
|
1445
|
+
ACCUMULATOR_DTYPE: tl.constexpr,
|
1446
|
+
BLOCK_SIZE_M: tl.constexpr,
|
1447
|
+
BLOCK_SIZE_N: tl.constexpr,
|
1448
|
+
BLOCK_SIZE_K: tl.constexpr,
|
1449
|
+
BLOCK_SIZE_SCALE_A: tl.constexpr,
|
1450
|
+
BLOCK_SIZE_SCALE_B: tl.constexpr,
|
1451
|
+
):
|
1452
|
+
pid = tl.program_id(axis=0)
|
1453
|
+
|
1454
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
1455
|
+
|
1456
|
+
pid_m = pid // num_pid_n
|
1457
|
+
pid_n = pid % num_pid_n
|
1458
|
+
|
1459
|
+
accumulator_dtype = ACCUMULATOR_DTYPE
|
1460
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)
|
1461
|
+
|
1462
|
+
# NOTE: Some tensor inputs are so large, they will cause int32 overflow
|
1463
|
+
# so it is necessary to use tl.int64 for all the offsets, else SEGV will
|
1464
|
+
# eventually occur.
|
1465
|
+
|
1466
|
+
# Offsets and masks.
|
1467
|
+
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
1468
|
+
masks_am = offsets_am < M
|
1469
|
+
|
1470
|
+
offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
|
1471
|
+
masks_bn = offsets_bn < N
|
1472
|
+
|
1473
|
+
offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
1474
|
+
offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :]
|
1475
|
+
offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :]
|
1476
|
+
|
1477
|
+
# NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
|
1478
|
+
# appropriate offsets and masks for each case. Same goes for
|
1479
|
+
# BLOCK_SIZE_SCALE_B.
|
1480
|
+
offsets_scale_am = (
|
1481
|
+
tl.arange(0, BLOCK_SIZE_SCALE_A)
|
1482
|
+
+ (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M
|
1483
|
+
)
|
1484
|
+
masks_scale_am = offsets_scale_am < M
|
1485
|
+
|
1486
|
+
offsets_scale_bn = (
|
1487
|
+
tl.arange(0, BLOCK_SIZE_SCALE_B)
|
1488
|
+
+ (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N
|
1489
|
+
)
|
1490
|
+
masks_scale_bn = offsets_scale_bn < N
|
1491
|
+
|
1492
|
+
a_ptrs = a_ptr + offsets_a
|
1493
|
+
b_ptrs = b_ptr + offsets_b
|
1494
|
+
|
1495
|
+
scale_a_ptrs = scale_a_ptr + offsets_scale_am
|
1496
|
+
scale_b_ptrs = scale_b_ptr + offsets_scale_bn
|
1497
|
+
|
1498
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
1499
|
+
masks_k = offsets_k < K
|
1500
|
+
masks_a = masks_am[:, None] & masks_k[None, :]
|
1501
|
+
a = tl.load(a_ptrs, mask=masks_a)
|
1502
|
+
|
1503
|
+
masks_b = masks_k[:, None] & masks_bn[None, :]
|
1504
|
+
b = tl.load(b_ptrs, mask=masks_b)
|
1505
|
+
|
1506
|
+
# Accumulate results.
|
1507
|
+
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
|
1508
|
+
|
1509
|
+
offsets_k += BLOCK_SIZE_K
|
1510
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
1511
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
1512
|
+
|
1513
|
+
# Apply scale at end.
|
1514
|
+
masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
|
1515
|
+
scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
|
1516
|
+
# Need to broadcast to the appropriate size, if scale_a is already
|
1517
|
+
# (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
|
1518
|
+
# for scale_b below.
|
1519
|
+
scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
|
1520
|
+
accumulator = scale_a * accumulator.to(tl.float32)
|
1521
|
+
|
1522
|
+
masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
|
1523
|
+
scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
|
1524
|
+
scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
|
1525
|
+
accumulator = scale_b.T * accumulator.to(tl.float32)
|
1526
|
+
|
1527
|
+
# Convert to output format.
|
1528
|
+
c = accumulator.to(c_ptr.type.element_ty)
|
1529
|
+
|
1530
|
+
# Add bias, it's already in output format, so add it after conversion.
|
1531
|
+
if bias_ptr:
|
1532
|
+
offsets_bias = offsets_bn
|
1533
|
+
bias_ptrs = bias_ptr + offsets_bias
|
1534
|
+
bias_mask = offsets_bias < N
|
1535
|
+
bias = tl.load(bias_ptrs, bias_mask)
|
1536
|
+
c += bias
|
1537
|
+
|
1538
|
+
# Save output
|
1539
|
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
1540
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
|
1541
|
+
offs_cm = offs_cm.to(tl.int64)
|
1542
|
+
offs_cn = offs_cn.to(tl.int64)
|
1543
|
+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
1544
|
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
1545
|
+
|
1546
|
+
tl.store(c_ptrs, c, mask=c_mask)
|
1547
|
+
|
1548
|
+
|
1549
|
+
# input - [M, K]
|
1550
|
+
# weight - [K, N]
|
1551
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py
|
1552
|
+
def triton_scaled_mm(
|
1553
|
+
input: torch.Tensor,
|
1554
|
+
weight: torch.Tensor,
|
1555
|
+
scale_a: torch.Tensor,
|
1556
|
+
scale_b: torch.Tensor,
|
1557
|
+
out_dtype: type[torch.dtype],
|
1558
|
+
bias: Optional[torch.Tensor] = None,
|
1559
|
+
block_size_m: int = 32,
|
1560
|
+
block_size_n: int = 32,
|
1561
|
+
block_size_k: int = 32,
|
1562
|
+
use_heuristic=True,
|
1563
|
+
) -> torch.Tensor:
|
1564
|
+
M, K = input.shape
|
1565
|
+
N = weight.shape[1]
|
1566
|
+
|
1567
|
+
assert N > 0 and K > 0 and M > 0
|
1568
|
+
assert weight.shape[0] == K
|
1569
|
+
assert input.dtype == weight.dtype
|
1570
|
+
|
1571
|
+
scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
|
1572
|
+
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
|
1573
|
+
|
1574
|
+
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
|
1575
|
+
assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 or scale_a.shape[0] == M)
|
1576
|
+
assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 or scale_b.shape[0] == N)
|
1577
|
+
assert out_dtype.is_floating_point
|
1578
|
+
assert bias is None or bias.is_floating_point()
|
1579
|
+
assert is_weak_contiguous(input)
|
1580
|
+
assert is_weak_contiguous(weight)
|
1581
|
+
|
1582
|
+
grid = lambda META: (
|
1583
|
+
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
1584
|
+
)
|
1585
|
+
|
1586
|
+
result = torch.empty((M, N), dtype=out_dtype, device=input.device)
|
1587
|
+
|
1588
|
+
has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
|
1589
|
+
|
1590
|
+
if use_heuristic:
|
1591
|
+
is_small_N = N < 8192
|
1592
|
+
next_power_of_2_M = max(32, triton.next_power_of_2(M))
|
1593
|
+
if next_power_of_2_M <= 32:
|
1594
|
+
tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
|
1595
|
+
elif next_power_of_2_M <= 64:
|
1596
|
+
tile_shape = (64, 64, 256)
|
1597
|
+
elif next_power_of_2_M <= 128:
|
1598
|
+
tile_shape = (64, 128, 128)
|
1599
|
+
else:
|
1600
|
+
tile_shape = (128, 128, 128)
|
1601
|
+
|
1602
|
+
block_size_m, block_size_n, block_size_k = tile_shape
|
1603
|
+
|
1604
|
+
block_size_sa = 1 if has_scalar(scale_a) else block_size_m
|
1605
|
+
block_size_sb = 1 if has_scalar(scale_b) else block_size_n
|
1606
|
+
|
1607
|
+
accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32
|
1608
|
+
|
1609
|
+
# A = input, B = weight, C = result
|
1610
|
+
# A = M x K, B = K x N, C = M x N
|
1611
|
+
scaled_mm_kernel[grid](
|
1612
|
+
input,
|
1613
|
+
weight,
|
1614
|
+
scale_a,
|
1615
|
+
scale_b,
|
1616
|
+
result,
|
1617
|
+
bias,
|
1618
|
+
M,
|
1619
|
+
N,
|
1620
|
+
K,
|
1621
|
+
input.stride(0),
|
1622
|
+
input.stride(1),
|
1623
|
+
weight.stride(0),
|
1624
|
+
weight.stride(1),
|
1625
|
+
result.stride(0),
|
1626
|
+
result.stride(1),
|
1627
|
+
accumulator_dtype,
|
1628
|
+
BLOCK_SIZE_M=block_size_m,
|
1629
|
+
BLOCK_SIZE_N=block_size_n,
|
1630
|
+
BLOCK_SIZE_K=block_size_k,
|
1631
|
+
BLOCK_SIZE_SCALE_A=block_size_sa,
|
1632
|
+
BLOCK_SIZE_SCALE_B=block_size_sb,
|
1633
|
+
)
|
1634
|
+
|
1635
|
+
return result.to(out_dtype)
|
@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
22
22
|
scaled_fp8_quant,
|
23
23
|
sglang_per_token_quant_fp8,
|
24
24
|
static_quant_fp8,
|
25
|
+
triton_scaled_mm,
|
25
26
|
w8a8_block_fp8_matmul_deepgemm,
|
26
27
|
w8a8_block_fp8_matmul_triton,
|
27
28
|
)
|
@@ -161,16 +162,16 @@ def flashinfer_gemm_w8a8_block_fp8_linear(
|
|
161
162
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
162
163
|
|
163
164
|
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
164
|
-
input_2d, block_size[1], column_major_scales=
|
165
|
+
input_2d, block_size[1], column_major_scales=True
|
165
166
|
)
|
166
|
-
|
167
|
+
# TRTLLM requires column-major scaling factors
|
167
168
|
output = gemm_fp8_nt_groupwise(
|
168
169
|
q_input,
|
169
170
|
weight,
|
170
171
|
x_scale,
|
171
172
|
weight_scale,
|
172
|
-
scale_major_mode="K",
|
173
173
|
out_dtype=input_2d.dtype,
|
174
|
+
backend="trtllm",
|
174
175
|
)
|
175
176
|
|
176
177
|
if bias is not None:
|
@@ -586,14 +587,25 @@ def apply_fp8_linear(
|
|
586
587
|
assert (
|
587
588
|
weight_scale.numel() == weight.shape[1]
|
588
589
|
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
589
|
-
|
590
|
-
|
591
|
-
weight
|
592
|
-
x_scale,
|
593
|
-
weight_scale,
|
594
|
-
out_dtype=input.dtype,
|
595
|
-
bias=bias,
|
590
|
+
|
591
|
+
cutlass_compatible_b = (
|
592
|
+
weight.shape[0] % 16 == 0 and weight.shape[1] % 16 == 0
|
596
593
|
)
|
594
|
+
if not cutlass_compatible_b:
|
595
|
+
# Massage the input to be 2D
|
596
|
+
qinput = qinput.view(-1, qinput.shape[-1])
|
597
|
+
output = triton_scaled_mm(
|
598
|
+
qinput, weight, x_scale, weight_scale, input.dtype, bias
|
599
|
+
)
|
600
|
+
else:
|
601
|
+
output = fp8_scaled_mm(
|
602
|
+
qinput,
|
603
|
+
weight,
|
604
|
+
x_scale,
|
605
|
+
weight_scale,
|
606
|
+
out_dtype=input.dtype,
|
607
|
+
bias=bias,
|
608
|
+
)
|
597
609
|
return output.view(*output_shape)
|
598
610
|
|
599
611
|
# torch.scaled_mm supports per tensor weights + activations only
|
@@ -1,9 +1,8 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
2
2
|
from __future__ import annotations
|
3
3
|
|
4
|
-
import importlib.util
|
5
4
|
import logging
|
6
|
-
from typing import TYPE_CHECKING, Any,
|
5
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
7
6
|
|
8
7
|
import torch
|
9
8
|
from torch.nn.parameter import Parameter
|
@@ -42,11 +41,7 @@ if is_cuda():
|
|
42
41
|
|
43
42
|
try:
|
44
43
|
from flashinfer import mm_fp4 as fp4_gemm
|
45
|
-
from flashinfer import
|
46
|
-
reorder_rows_for_gated_act_gemm,
|
47
|
-
shuffle_matrix_a,
|
48
|
-
shuffle_matrix_sf_a,
|
49
|
-
)
|
44
|
+
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_sf_a
|
50
45
|
|
51
46
|
enable_flashinfer_fp4_gemm = True
|
52
47
|
except ImportError:
|
@@ -682,9 +677,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
682
677
|
padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
|
683
678
|
padded_scales = padded_scales.contiguous().cuda()
|
684
679
|
padded_scales = (
|
685
|
-
padded_scales.reshape(
|
680
|
+
padded_scales.reshape(M_padded, K_padded)
|
686
681
|
if scale_ndim == 2
|
687
|
-
else padded_scales.reshape(B,
|
682
|
+
else padded_scales.reshape(B, M_padded, K_padded)
|
688
683
|
)
|
689
684
|
layer.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False)
|
690
685
|
|
@@ -883,9 +878,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
883
878
|
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
884
879
|
swizzled_scale = swizzled_scale.contiguous().cuda()
|
885
880
|
return (
|
886
|
-
swizzled_scale.reshape(
|
881
|
+
swizzled_scale.reshape(M_padded, K_padded)
|
887
882
|
if scale_ndim == 2
|
888
|
-
else swizzled_scale.reshape(B,
|
883
|
+
else swizzled_scale.reshape(B, M_padded, K_padded)
|
889
884
|
)
|
890
885
|
|
891
886
|
def prepare_static_weights_for_kernel(
|
@@ -570,8 +570,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
570
570
|
) -> torch.Tensor:
|
571
571
|
if self.use_flashinfer:
|
572
572
|
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
573
|
-
x_quant, x_scale = mxfp8_quantize(
|
573
|
+
x_quant, x_scale = mxfp8_quantize(
|
574
|
+
x, False, alignment=self.hidden_size
|
575
|
+
) # to mxfp8
|
574
576
|
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
577
|
+
assert x_quant.shape[-1] == self.hidden_size
|
575
578
|
|
576
579
|
top_k, router_logits = topk_output
|
577
580
|
|
@@ -116,6 +116,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
116
116
|
params_dtype: torch.dtype,
|
117
117
|
**extra_weight_attrs,
|
118
118
|
):
|
119
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
120
|
+
|
119
121
|
assert "weight_loader" in extra_weight_attrs
|
120
122
|
|
121
123
|
# Fused gate_up_proj (column parallel)
|
@@ -144,6 +146,9 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
144
146
|
layer.register_parameter("w2_weight", w2_weight)
|
145
147
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
146
148
|
|
149
|
+
extra_weight_attrs.update(
|
150
|
+
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
|
151
|
+
)
|
147
152
|
w13_weight_scale = torch.nn.Parameter(
|
148
153
|
torch.zeros(
|
149
154
|
num_experts,
|
@@ -274,8 +279,11 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
274
279
|
def apply(
|
275
280
|
self,
|
276
281
|
layer: EPMoE,
|
277
|
-
|
282
|
+
x: torch.Tensor,
|
278
283
|
topk_output: TopKOutput,
|
284
|
+
activation: str = "silu",
|
285
|
+
apply_router_weight_on_input: bool = False,
|
286
|
+
routed_scaling_factor: Optional[float] = None,
|
279
287
|
**kwargs,
|
280
288
|
) -> torch.Tensor:
|
281
289
|
|
@@ -284,19 +292,17 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
284
292
|
|
285
293
|
topk_weights, topk_ids, _ = topk_output
|
286
294
|
local_topk_ids = topk_ids
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
return cutlass_w4a8_moe(
|
295
|
+
local_topk_ids = torch.where(
|
296
|
+
topk_ids == -1,
|
297
|
+
layer.num_experts,
|
298
|
+
topk_ids,
|
299
|
+
)
|
300
|
+
|
301
|
+
output = cutlass_w4a8_moe(
|
296
302
|
layer.start_expert_id,
|
297
303
|
layer.end_expert_id,
|
298
304
|
layer.num_experts,
|
299
|
-
|
305
|
+
x,
|
300
306
|
layer.w13_weight,
|
301
307
|
layer.w2_weight,
|
302
308
|
layer.w13_weight_scale_inv,
|
@@ -318,3 +324,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
318
324
|
layer.w13_input_scale,
|
319
325
|
layer.w2_input_scale,
|
320
326
|
)
|
327
|
+
if routed_scaling_factor is not None:
|
328
|
+
output *= routed_scaling_factor
|
329
|
+
return output
|
@@ -3,7 +3,18 @@ from __future__ import annotations
|
|
3
3
|
import importlib
|
4
4
|
import sys
|
5
5
|
from types import MappingProxyType
|
6
|
-
from typing import
|
6
|
+
from typing import (
|
7
|
+
TYPE_CHECKING,
|
8
|
+
Any,
|
9
|
+
Callable,
|
10
|
+
Dict,
|
11
|
+
List,
|
12
|
+
Mapping,
|
13
|
+
Optional,
|
14
|
+
Tuple,
|
15
|
+
Union,
|
16
|
+
cast,
|
17
|
+
)
|
7
18
|
|
8
19
|
import torch
|
9
20
|
from torch.nn.parameter import Parameter
|
@@ -79,22 +90,16 @@ def npu_wrapper_rmsnorm_forward(func):
|
|
79
90
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
80
91
|
if not x.is_contiguous():
|
81
92
|
x = x.contiguous()
|
82
|
-
original_dtype = x.dtype
|
83
|
-
x = x.to(torch.float32)
|
84
93
|
if residual is not None:
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
x, self.weight.to(torch.float32), self.variance_epsilon
|
91
|
-
)[0]
|
92
|
-
+ self.bias
|
93
|
-
)
|
94
|
+
out, _, residual_out = torch_npu.npu_add_rms_norm(
|
95
|
+
residual, x, self.weight.data, self.variance_epsilon
|
96
|
+
)
|
97
|
+
out = out + self.bias
|
98
|
+
return out.to(x.dtype), residual_out
|
94
99
|
|
95
|
-
|
96
|
-
|
97
|
-
return
|
100
|
+
out = torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
|
101
|
+
out = out + self.bias
|
102
|
+
return out.to(x.dtype)
|
98
103
|
|
99
104
|
return _rmsnorm_forward_oot
|
100
105
|
|
@@ -250,17 +255,23 @@ class W8A8Int8Config(QuantizationConfig):
|
|
250
255
|
|
251
256
|
if _is_npu:
|
252
257
|
if isinstance(layer, LinearBase):
|
258
|
+
key = "model"
|
259
|
+
if "vision_model" in prefix:
|
260
|
+
key = "vision_model"
|
261
|
+
elif "visual" in prefix:
|
262
|
+
key = "visual"
|
263
|
+
packed_modules_mapping_subset = self.packed_modules_mapping.get(key, {})
|
253
264
|
prefix_in_quant_config = prefix
|
254
265
|
proj_name = prefix.split(".")[-1]
|
255
|
-
if proj_name in
|
266
|
+
if proj_name in packed_modules_mapping_subset:
|
256
267
|
prefix_in_quant_config = prefix.replace(
|
257
|
-
proj_name,
|
268
|
+
proj_name, packed_modules_mapping_subset[proj_name][0]
|
258
269
|
)
|
259
270
|
self.is_dynamic = (
|
260
271
|
self.quant_description[prefix_in_quant_config + ".weight"]
|
261
272
|
== "W8A8_DYNAMIC"
|
262
273
|
)
|
263
|
-
if self.is_layer_skipped(prefix,
|
274
|
+
if self.is_layer_skipped(prefix, packed_modules_mapping_subset):
|
264
275
|
return UnquantizedLinearMethod()
|
265
276
|
return (
|
266
277
|
NPU_W8A8DynamicLinearMethod(self)
|
@@ -571,8 +582,10 @@ class NPU_W8A8LinearMethodImpl:
|
|
571
582
|
layer: torch.nn.Module,
|
572
583
|
x: torch.Tensor,
|
573
584
|
bias: Optional[torch.Tensor] = None,
|
574
|
-
tp_rank: Optional[int] = 0,
|
575
585
|
) -> torch.Tensor:
|
586
|
+
# To prevent import loops
|
587
|
+
from sglang.srt.layers.linear import RowParallelLinear
|
588
|
+
|
576
589
|
original_dtype = x.dtype
|
577
590
|
if original_dtype != torch.int8:
|
578
591
|
x = torch_npu.npu_quantize(
|
@@ -583,8 +596,12 @@ class NPU_W8A8LinearMethodImpl:
|
|
583
596
|
-1,
|
584
597
|
True,
|
585
598
|
)
|
586
|
-
|
587
|
-
|
599
|
+
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
600
|
+
# bias will not get added more than once in Attention TP>1 case)
|
601
|
+
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
|
602
|
+
quant_bias = None
|
603
|
+
else:
|
604
|
+
quant_bias = layer.quant_bias
|
588
605
|
return torch_npu.npu_quant_matmul(
|
589
606
|
x,
|
590
607
|
layer.weight,
|
@@ -651,13 +668,21 @@ class NPU_W8A8LinearMethodMTImpl:
|
|
651
668
|
layer: torch.nn.Module,
|
652
669
|
x: torch.Tensor,
|
653
670
|
bias: Optional[torch.Tensor] = None,
|
654
|
-
tp_rank: Optional[int] = 0,
|
655
671
|
) -> torch.Tensor:
|
672
|
+
# To prevent import loops
|
673
|
+
from sglang.srt.layers.linear import RowParallelLinear
|
674
|
+
|
656
675
|
original_dtype = x.dtype
|
657
676
|
if original_dtype != torch.int8:
|
658
677
|
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
|
659
678
|
|
660
|
-
|
679
|
+
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
680
|
+
# bias will not get added more than once in Attention TP>1 case)
|
681
|
+
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
|
682
|
+
quant_bias = None
|
683
|
+
else:
|
684
|
+
quant_bias = layer.quant_bias
|
685
|
+
|
661
686
|
return ops.quant_matmul(
|
662
687
|
x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
|
663
688
|
)
|
@@ -737,11 +762,6 @@ class NPU_W8A8LinearMethod(LinearMethodBase):
|
|
737
762
|
x: torch.Tensor,
|
738
763
|
bias: Optional[torch.Tensor] = None,
|
739
764
|
) -> torch.Tensor:
|
740
|
-
from sglang.srt.layers.linear import RowParallelLinear
|
741
|
-
|
742
|
-
if isinstance(layer, RowParallelLinear):
|
743
|
-
tp_rank = get_tensor_model_parallel_rank()
|
744
|
-
return self.quant_method.apply(layer, x, bias, tp_rank)
|
745
765
|
return self.quant_method.apply(layer, x, bias)
|
746
766
|
|
747
767
|
|
@@ -780,7 +800,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
|
|
780
800
|
tp_rank: Optional[int] = 0,
|
781
801
|
) -> torch.Tensor:
|
782
802
|
original_dtype = x.dtype
|
783
|
-
# use ATB quantize
|
784
803
|
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
785
804
|
return torch_npu.npu_quant_matmul(
|
786
805
|
quant_out,
|
@@ -863,11 +882,6 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
|
863
882
|
x: torch.Tensor,
|
864
883
|
bias: Optional[torch.Tensor] = None,
|
865
884
|
) -> torch.Tensor:
|
866
|
-
from sglang.srt.layers.linear import RowParallelLinear
|
867
|
-
|
868
|
-
if isinstance(layer, RowParallelLinear):
|
869
|
-
tp_rank = get_tensor_model_parallel_rank()
|
870
|
-
return self.quant_method.apply(layer, x, bias, tp_rank)
|
871
885
|
return self.quant_method.apply(layer, x, bias)
|
872
886
|
|
873
887
|
|