sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -0
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +7 -7
  6. sglang/srt/disaggregation/decode.py +8 -3
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +4 -5
  14. sglang/srt/entrypoints/openai/protocol.py +0 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +59 -265
  16. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  17. sglang/srt/function_call/ebnf_composer.py +1 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  20. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  21. sglang/srt/function_call/kimik2_detector.py +3 -3
  22. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  23. sglang/srt/jinja_template_utils.py +6 -0
  24. sglang/srt/layers/attention/aiter_backend.py +370 -107
  25. sglang/srt/layers/attention/ascend_backend.py +3 -0
  26. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  27. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  28. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  29. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  30. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  31. sglang/srt/layers/attention/vision.py +9 -1
  32. sglang/srt/layers/attention/wave_backend.py +627 -0
  33. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  34. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  35. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  36. sglang/srt/layers/communicator.py +8 -10
  37. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  38. sglang/srt/layers/linear.py +1 -0
  39. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  41. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  42. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  43. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  46. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  47. sglang/srt/layers/moe/topk.py +4 -1
  48. sglang/srt/layers/quantization/__init__.py +5 -3
  49. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  50. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  51. sglang/srt/layers/quantization/modelopt_quant.py +6 -11
  52. sglang/srt/layers/quantization/mxfp4.py +4 -1
  53. sglang/srt/layers/quantization/w4afp8.py +20 -11
  54. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  55. sglang/srt/layers/rotary_embedding.py +281 -2
  56. sglang/srt/lora/backend/base_backend.py +3 -23
  57. sglang/srt/lora/layers.py +60 -114
  58. sglang/srt/lora/lora.py +17 -62
  59. sglang/srt/lora/lora_manager.py +12 -48
  60. sglang/srt/lora/lora_registry.py +20 -9
  61. sglang/srt/lora/mem_pool.py +20 -63
  62. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  63. sglang/srt/lora/utils.py +25 -58
  64. sglang/srt/managers/cache_controller.py +21 -29
  65. sglang/srt/managers/detokenizer_manager.py +1 -1
  66. sglang/srt/managers/io_struct.py +6 -6
  67. sglang/srt/managers/mm_utils.py +1 -2
  68. sglang/srt/managers/multimodal_processor.py +1 -1
  69. sglang/srt/managers/schedule_batch.py +35 -20
  70. sglang/srt/managers/schedule_policy.py +6 -6
  71. sglang/srt/managers/scheduler.py +15 -7
  72. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  73. sglang/srt/managers/tokenizer_manager.py +25 -26
  74. sglang/srt/mem_cache/allocator.py +61 -87
  75. sglang/srt/mem_cache/hicache_storage.py +1 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  77. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  78. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  79. sglang/srt/mem_cache/radix_cache.py +2 -5
  80. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  81. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  82. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  83. sglang/srt/model_executor/cuda_graph_runner.py +22 -3
  84. sglang/srt/model_executor/forward_batch_info.py +26 -5
  85. sglang/srt/model_executor/model_runner.py +129 -35
  86. sglang/srt/model_loader/loader.py +18 -6
  87. sglang/srt/models/deepseek_v2.py +74 -35
  88. sglang/srt/models/gemma2.py +0 -34
  89. sglang/srt/models/gemma3n_mm.py +8 -9
  90. sglang/srt/models/glm4.py +6 -0
  91. sglang/srt/models/glm4_moe.py +9 -9
  92. sglang/srt/models/glm4v.py +589 -0
  93. sglang/srt/models/glm4v_moe.py +400 -0
  94. sglang/srt/models/gpt_oss.py +136 -19
  95. sglang/srt/models/granite.py +0 -25
  96. sglang/srt/models/llama.py +0 -25
  97. sglang/srt/models/llama4.py +1 -1
  98. sglang/srt/models/qwen2_5_vl.py +7 -3
  99. sglang/srt/models/qwen2_audio.py +10 -9
  100. sglang/srt/models/qwen3.py +0 -24
  101. sglang/srt/models/registry.py +1 -1
  102. sglang/srt/models/torch_native_llama.py +0 -24
  103. sglang/srt/multimodal/processors/base_processor.py +23 -13
  104. sglang/srt/multimodal/processors/glm4v.py +132 -0
  105. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  106. sglang/srt/reasoning_parser.py +316 -0
  107. sglang/srt/server_args.py +115 -139
  108. sglang/srt/speculative/eagle_worker.py +16 -0
  109. sglang/srt/two_batch_overlap.py +12 -4
  110. sglang/srt/utils.py +3 -3
  111. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  112. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  113. sglang/test/doc_patch.py +59 -0
  114. sglang/test/few_shot_gsm8k.py +1 -1
  115. sglang/test/few_shot_gsm8k_engine.py +1 -1
  116. sglang/test/run_eval.py +4 -1
  117. sglang/test/simple_eval_common.py +6 -0
  118. sglang/test/simple_eval_gpqa.py +2 -0
  119. sglang/test/test_fp4_moe.py +118 -36
  120. sglang/utils.py +1 -1
  121. sglang/version.py +1 -1
  122. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
  123. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
  124. sglang/lang/backend/__init__.py +0 -0
  125. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  126. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  127. /sglang/{api.py → lang/api.py} +0 -0
  128. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  129. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1356,3 +1356,280 @@ def per_token_group_quant_fp8_hopper_moe_mn_major(
1356
1356
  expert_tokens_alignment,
1357
1357
  )
1358
1358
  return a_q, sfa
1359
+
1360
+
1361
+ @triton.jit
1362
+ def _per_group_transpose(
1363
+ data_ptr: torch.Tensor,
1364
+ trans_data_ptr: torch.Tensor,
1365
+ expert_offsets: torch.Tensor,
1366
+ k: int,
1367
+ M_ALIGNMENT: tl.constexpr,
1368
+ BLOCK_SIZE_M: tl.constexpr,
1369
+ BLOCK_SIZE_K: tl.constexpr,
1370
+ ):
1371
+ expert_id = tl.program_id(0)
1372
+ m_id = tl.program_id(1)
1373
+ k_id = tl.program_id(2)
1374
+
1375
+ curr_expert_offset = tl.load(expert_offsets + expert_id)
1376
+ next_expert_offset = tl.load(expert_offsets + expert_id + 1)
1377
+ num_tokens_of_expert = next_expert_offset - curr_expert_offset
1378
+ tl.multiple_of(curr_expert_offset, M_ALIGNMENT)
1379
+ tl.multiple_of(next_expert_offset, M_ALIGNMENT)
1380
+
1381
+ data_start_ptr = data_ptr + curr_expert_offset * k
1382
+ trans_data_start_ptr = trans_data_ptr + curr_expert_offset * k
1383
+
1384
+ k_coord = k_id * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
1385
+ k_mask = k_coord < k
1386
+ for start_m in tl.range(0, num_tokens_of_expert, BLOCK_SIZE_M * tl.num_programs(1)):
1387
+ m_coord = start_m + m_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1388
+ m_mask = m_coord < num_tokens_of_expert
1389
+ off = m_coord[:, None] * k + k_coord[None, :]
1390
+ trans_off = m_coord[:, None] + k_coord[None, :] * num_tokens_of_expert
1391
+ mask = m_mask[:, None] & k_mask[None, :]
1392
+
1393
+ data = tl.load(data_start_ptr + off, mask=mask)
1394
+ tl.store(trans_data_start_ptr + trans_off, data, mask=mask)
1395
+
1396
+
1397
+ def per_group_transpose(
1398
+ a: torch.Tensor,
1399
+ expert_offsets: torch.Tensor,
1400
+ M_ALIGNMENT: int = 1,
1401
+ ) -> torch.Tensor:
1402
+ assert a.dim() == 2
1403
+ assert a.is_contiguous(), "`a` is not contiguous"
1404
+
1405
+ m, k = a.size()
1406
+ trans_a = torch.empty_like(a)
1407
+ num_experts = expert_offsets.size(0) - 1
1408
+
1409
+ grid = lambda META: (
1410
+ num_experts,
1411
+ triton.cdiv((m + num_experts - 1) // num_experts, META["BLOCK_SIZE_M"]),
1412
+ triton.cdiv(k, META["BLOCK_SIZE_K"]),
1413
+ )
1414
+ _per_group_transpose[grid](
1415
+ a, trans_a, expert_offsets, k, M_ALIGNMENT, BLOCK_SIZE_M=16, BLOCK_SIZE_K=8
1416
+ )
1417
+ return trans_a
1418
+
1419
+
1420
+ def is_weak_contiguous(x: torch.Tensor):
1421
+ strides = x.stride()
1422
+ sizes = x.shape
1423
+ is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
1424
+ is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
1425
+ return is_transpose or is_not_transpose
1426
+
1427
+
1428
+ @triton.jit
1429
+ def scaled_mm_kernel(
1430
+ a_ptr,
1431
+ b_ptr,
1432
+ scale_a_ptr,
1433
+ scale_b_ptr,
1434
+ c_ptr,
1435
+ bias_ptr,
1436
+ M,
1437
+ N,
1438
+ K,
1439
+ stride_am,
1440
+ stride_ak,
1441
+ stride_bk,
1442
+ stride_bn,
1443
+ stride_cm,
1444
+ stride_cn,
1445
+ ACCUMULATOR_DTYPE: tl.constexpr,
1446
+ BLOCK_SIZE_M: tl.constexpr,
1447
+ BLOCK_SIZE_N: tl.constexpr,
1448
+ BLOCK_SIZE_K: tl.constexpr,
1449
+ BLOCK_SIZE_SCALE_A: tl.constexpr,
1450
+ BLOCK_SIZE_SCALE_B: tl.constexpr,
1451
+ ):
1452
+ pid = tl.program_id(axis=0)
1453
+
1454
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
1455
+
1456
+ pid_m = pid // num_pid_n
1457
+ pid_n = pid % num_pid_n
1458
+
1459
+ accumulator_dtype = ACCUMULATOR_DTYPE
1460
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)
1461
+
1462
+ # NOTE: Some tensor inputs are so large, they will cause int32 overflow
1463
+ # so it is necessary to use tl.int64 for all the offsets, else SEGV will
1464
+ # eventually occur.
1465
+
1466
+ # Offsets and masks.
1467
+ offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
1468
+ masks_am = offsets_am < M
1469
+
1470
+ offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
1471
+ masks_bn = offsets_bn < N
1472
+
1473
+ offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
1474
+ offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :]
1475
+ offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :]
1476
+
1477
+ # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
1478
+ # appropriate offsets and masks for each case. Same goes for
1479
+ # BLOCK_SIZE_SCALE_B.
1480
+ offsets_scale_am = (
1481
+ tl.arange(0, BLOCK_SIZE_SCALE_A)
1482
+ + (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M
1483
+ )
1484
+ masks_scale_am = offsets_scale_am < M
1485
+
1486
+ offsets_scale_bn = (
1487
+ tl.arange(0, BLOCK_SIZE_SCALE_B)
1488
+ + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N
1489
+ )
1490
+ masks_scale_bn = offsets_scale_bn < N
1491
+
1492
+ a_ptrs = a_ptr + offsets_a
1493
+ b_ptrs = b_ptr + offsets_b
1494
+
1495
+ scale_a_ptrs = scale_a_ptr + offsets_scale_am
1496
+ scale_b_ptrs = scale_b_ptr + offsets_scale_bn
1497
+
1498
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
1499
+ masks_k = offsets_k < K
1500
+ masks_a = masks_am[:, None] & masks_k[None, :]
1501
+ a = tl.load(a_ptrs, mask=masks_a)
1502
+
1503
+ masks_b = masks_k[:, None] & masks_bn[None, :]
1504
+ b = tl.load(b_ptrs, mask=masks_b)
1505
+
1506
+ # Accumulate results.
1507
+ accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
1508
+
1509
+ offsets_k += BLOCK_SIZE_K
1510
+ a_ptrs += BLOCK_SIZE_K * stride_ak
1511
+ b_ptrs += BLOCK_SIZE_K * stride_bk
1512
+
1513
+ # Apply scale at end.
1514
+ masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
1515
+ scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
1516
+ # Need to broadcast to the appropriate size, if scale_a is already
1517
+ # (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
1518
+ # for scale_b below.
1519
+ scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
1520
+ accumulator = scale_a * accumulator.to(tl.float32)
1521
+
1522
+ masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
1523
+ scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
1524
+ scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
1525
+ accumulator = scale_b.T * accumulator.to(tl.float32)
1526
+
1527
+ # Convert to output format.
1528
+ c = accumulator.to(c_ptr.type.element_ty)
1529
+
1530
+ # Add bias, it's already in output format, so add it after conversion.
1531
+ if bias_ptr:
1532
+ offsets_bias = offsets_bn
1533
+ bias_ptrs = bias_ptr + offsets_bias
1534
+ bias_mask = offsets_bias < N
1535
+ bias = tl.load(bias_ptrs, bias_mask)
1536
+ c += bias
1537
+
1538
+ # Save output
1539
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
1540
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
1541
+ offs_cm = offs_cm.to(tl.int64)
1542
+ offs_cn = offs_cn.to(tl.int64)
1543
+ c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
1544
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
1545
+
1546
+ tl.store(c_ptrs, c, mask=c_mask)
1547
+
1548
+
1549
+ # input - [M, K]
1550
+ # weight - [K, N]
1551
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py
1552
+ def triton_scaled_mm(
1553
+ input: torch.Tensor,
1554
+ weight: torch.Tensor,
1555
+ scale_a: torch.Tensor,
1556
+ scale_b: torch.Tensor,
1557
+ out_dtype: type[torch.dtype],
1558
+ bias: Optional[torch.Tensor] = None,
1559
+ block_size_m: int = 32,
1560
+ block_size_n: int = 32,
1561
+ block_size_k: int = 32,
1562
+ use_heuristic=True,
1563
+ ) -> torch.Tensor:
1564
+ M, K = input.shape
1565
+ N = weight.shape[1]
1566
+
1567
+ assert N > 0 and K > 0 and M > 0
1568
+ assert weight.shape[0] == K
1569
+ assert input.dtype == weight.dtype
1570
+
1571
+ scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
1572
+ scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
1573
+
1574
+ assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
1575
+ assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 or scale_a.shape[0] == M)
1576
+ assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 or scale_b.shape[0] == N)
1577
+ assert out_dtype.is_floating_point
1578
+ assert bias is None or bias.is_floating_point()
1579
+ assert is_weak_contiguous(input)
1580
+ assert is_weak_contiguous(weight)
1581
+
1582
+ grid = lambda META: (
1583
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
1584
+ )
1585
+
1586
+ result = torch.empty((M, N), dtype=out_dtype, device=input.device)
1587
+
1588
+ has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
1589
+
1590
+ if use_heuristic:
1591
+ is_small_N = N < 8192
1592
+ next_power_of_2_M = max(32, triton.next_power_of_2(M))
1593
+ if next_power_of_2_M <= 32:
1594
+ tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
1595
+ elif next_power_of_2_M <= 64:
1596
+ tile_shape = (64, 64, 256)
1597
+ elif next_power_of_2_M <= 128:
1598
+ tile_shape = (64, 128, 128)
1599
+ else:
1600
+ tile_shape = (128, 128, 128)
1601
+
1602
+ block_size_m, block_size_n, block_size_k = tile_shape
1603
+
1604
+ block_size_sa = 1 if has_scalar(scale_a) else block_size_m
1605
+ block_size_sb = 1 if has_scalar(scale_b) else block_size_n
1606
+
1607
+ accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32
1608
+
1609
+ # A = input, B = weight, C = result
1610
+ # A = M x K, B = K x N, C = M x N
1611
+ scaled_mm_kernel[grid](
1612
+ input,
1613
+ weight,
1614
+ scale_a,
1615
+ scale_b,
1616
+ result,
1617
+ bias,
1618
+ M,
1619
+ N,
1620
+ K,
1621
+ input.stride(0),
1622
+ input.stride(1),
1623
+ weight.stride(0),
1624
+ weight.stride(1),
1625
+ result.stride(0),
1626
+ result.stride(1),
1627
+ accumulator_dtype,
1628
+ BLOCK_SIZE_M=block_size_m,
1629
+ BLOCK_SIZE_N=block_size_n,
1630
+ BLOCK_SIZE_K=block_size_k,
1631
+ BLOCK_SIZE_SCALE_A=block_size_sa,
1632
+ BLOCK_SIZE_SCALE_B=block_size_sb,
1633
+ )
1634
+
1635
+ return result.to(out_dtype)
@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
22
22
  scaled_fp8_quant,
23
23
  sglang_per_token_quant_fp8,
24
24
  static_quant_fp8,
25
+ triton_scaled_mm,
25
26
  w8a8_block_fp8_matmul_deepgemm,
26
27
  w8a8_block_fp8_matmul_triton,
27
28
  )
@@ -161,16 +162,16 @@ def flashinfer_gemm_w8a8_block_fp8_linear(
161
162
  output_shape = [*input.shape[:-1], weight.shape[0]]
162
163
 
163
164
  q_input, x_scale = sglang_per_token_group_quant_fp8(
164
- input_2d, block_size[1], column_major_scales=False
165
+ input_2d, block_size[1], column_major_scales=True
165
166
  )
166
-
167
+ # TRTLLM requires column-major scaling factors
167
168
  output = gemm_fp8_nt_groupwise(
168
169
  q_input,
169
170
  weight,
170
171
  x_scale,
171
172
  weight_scale,
172
- scale_major_mode="K",
173
173
  out_dtype=input_2d.dtype,
174
+ backend="trtllm",
174
175
  )
175
176
 
176
177
  if bias is not None:
@@ -586,14 +587,25 @@ def apply_fp8_linear(
586
587
  assert (
587
588
  weight_scale.numel() == weight.shape[1]
588
589
  ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
589
- output = fp8_scaled_mm(
590
- qinput,
591
- weight,
592
- x_scale,
593
- weight_scale,
594
- out_dtype=input.dtype,
595
- bias=bias,
590
+
591
+ cutlass_compatible_b = (
592
+ weight.shape[0] % 16 == 0 and weight.shape[1] % 16 == 0
596
593
  )
594
+ if not cutlass_compatible_b:
595
+ # Massage the input to be 2D
596
+ qinput = qinput.view(-1, qinput.shape[-1])
597
+ output = triton_scaled_mm(
598
+ qinput, weight, x_scale, weight_scale, input.dtype, bias
599
+ )
600
+ else:
601
+ output = fp8_scaled_mm(
602
+ qinput,
603
+ weight,
604
+ x_scale,
605
+ weight_scale,
606
+ out_dtype=input.dtype,
607
+ bias=bias,
608
+ )
597
609
  return output.view(*output_shape)
598
610
 
599
611
  # torch.scaled_mm supports per tensor weights + activations only
@@ -1,9 +1,8 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
2
2
  from __future__ import annotations
3
3
 
4
- import importlib.util
5
4
  import logging
6
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
5
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
7
6
 
8
7
  import torch
9
8
  from torch.nn.parameter import Parameter
@@ -42,11 +41,7 @@ if is_cuda():
42
41
 
43
42
  try:
44
43
  from flashinfer import mm_fp4 as fp4_gemm
45
- from flashinfer import (
46
- reorder_rows_for_gated_act_gemm,
47
- shuffle_matrix_a,
48
- shuffle_matrix_sf_a,
49
- )
44
+ from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_sf_a
50
45
 
51
46
  enable_flashinfer_fp4_gemm = True
52
47
  except ImportError:
@@ -682,9 +677,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
682
677
  padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
683
678
  padded_scales = padded_scales.contiguous().cuda()
684
679
  padded_scales = (
685
- padded_scales.reshape(M, K)
680
+ padded_scales.reshape(M_padded, K_padded)
686
681
  if scale_ndim == 2
687
- else padded_scales.reshape(B, M, K)
682
+ else padded_scales.reshape(B, M_padded, K_padded)
688
683
  )
689
684
  layer.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False)
690
685
 
@@ -883,9 +878,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
883
878
  swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
884
879
  swizzled_scale = swizzled_scale.contiguous().cuda()
885
880
  return (
886
- swizzled_scale.reshape(M, K)
881
+ swizzled_scale.reshape(M_padded, K_padded)
887
882
  if scale_ndim == 2
888
- else swizzled_scale.reshape(B, M, K)
883
+ else swizzled_scale.reshape(B, M_padded, K_padded)
889
884
  )
890
885
 
891
886
  def prepare_static_weights_for_kernel(
@@ -570,8 +570,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
570
570
  ) -> torch.Tensor:
571
571
  if self.use_flashinfer:
572
572
  # Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
573
- x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
573
+ x_quant, x_scale = mxfp8_quantize(
574
+ x, False, alignment=self.hidden_size
575
+ ) # to mxfp8
574
576
  x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
577
+ assert x_quant.shape[-1] == self.hidden_size
575
578
 
576
579
  top_k, router_logits = topk_output
577
580
 
@@ -116,6 +116,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
116
116
  params_dtype: torch.dtype,
117
117
  **extra_weight_attrs,
118
118
  ):
119
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
120
+
119
121
  assert "weight_loader" in extra_weight_attrs
120
122
 
121
123
  # Fused gate_up_proj (column parallel)
@@ -144,6 +146,9 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
144
146
  layer.register_parameter("w2_weight", w2_weight)
145
147
  set_weight_attrs(w2_weight, extra_weight_attrs)
146
148
 
149
+ extra_weight_attrs.update(
150
+ {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
151
+ )
147
152
  w13_weight_scale = torch.nn.Parameter(
148
153
  torch.zeros(
149
154
  num_experts,
@@ -274,8 +279,11 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
274
279
  def apply(
275
280
  self,
276
281
  layer: EPMoE,
277
- hidden_states: torch.Tensor,
282
+ x: torch.Tensor,
278
283
  topk_output: TopKOutput,
284
+ activation: str = "silu",
285
+ apply_router_weight_on_input: bool = False,
286
+ routed_scaling_factor: Optional[float] = None,
279
287
  **kwargs,
280
288
  ) -> torch.Tensor:
281
289
 
@@ -284,19 +292,17 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
284
292
 
285
293
  topk_weights, topk_ids, _ = topk_output
286
294
  local_topk_ids = topk_ids
287
- if layer.expert_map is not None:
288
- "Translate info from expert_map to topk_ids"
289
- local_topk_ids = torch.where(
290
- layer.expert_map[topk_ids] != layer.num_experts,
291
- layer.expert_map[topk_ids],
292
- layer.num_experts,
293
- )
294
-
295
- return cutlass_w4a8_moe(
295
+ local_topk_ids = torch.where(
296
+ topk_ids == -1,
297
+ layer.num_experts,
298
+ topk_ids,
299
+ )
300
+
301
+ output = cutlass_w4a8_moe(
296
302
  layer.start_expert_id,
297
303
  layer.end_expert_id,
298
304
  layer.num_experts,
299
- hidden_states,
305
+ x,
300
306
  layer.w13_weight,
301
307
  layer.w2_weight,
302
308
  layer.w13_weight_scale_inv,
@@ -318,3 +324,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
318
324
  layer.w13_input_scale,
319
325
  layer.w2_input_scale,
320
326
  )
327
+ if routed_scaling_factor is not None:
328
+ output *= routed_scaling_factor
329
+ return output
@@ -3,7 +3,18 @@ from __future__ import annotations
3
3
  import importlib
4
4
  import sys
5
5
  from types import MappingProxyType
6
- from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
6
+ from typing import (
7
+ TYPE_CHECKING,
8
+ Any,
9
+ Callable,
10
+ Dict,
11
+ List,
12
+ Mapping,
13
+ Optional,
14
+ Tuple,
15
+ Union,
16
+ cast,
17
+ )
7
18
 
8
19
  import torch
9
20
  from torch.nn.parameter import Parameter
@@ -79,22 +90,16 @@ def npu_wrapper_rmsnorm_forward(func):
79
90
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
80
91
  if not x.is_contiguous():
81
92
  x = x.contiguous()
82
- original_dtype = x.dtype
83
- x = x.to(torch.float32)
84
93
  if residual is not None:
85
- x = x + residual.to(torch.float32)
86
- residual = x.to(original_dtype)
87
-
88
- x = (
89
- torch_npu.npu_rms_norm(
90
- x, self.weight.to(torch.float32), self.variance_epsilon
91
- )[0]
92
- + self.bias
93
- )
94
+ out, _, residual_out = torch_npu.npu_add_rms_norm(
95
+ residual, x, self.weight.data, self.variance_epsilon
96
+ )
97
+ out = out + self.bias
98
+ return out.to(x.dtype), residual_out
94
99
 
95
- if residual is None:
96
- return x.to(original_dtype)
97
- return x.to(original_dtype), residual
100
+ out = torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
101
+ out = out + self.bias
102
+ return out.to(x.dtype)
98
103
 
99
104
  return _rmsnorm_forward_oot
100
105
 
@@ -250,17 +255,23 @@ class W8A8Int8Config(QuantizationConfig):
250
255
 
251
256
  if _is_npu:
252
257
  if isinstance(layer, LinearBase):
258
+ key = "model"
259
+ if "vision_model" in prefix:
260
+ key = "vision_model"
261
+ elif "visual" in prefix:
262
+ key = "visual"
263
+ packed_modules_mapping_subset = self.packed_modules_mapping.get(key, {})
253
264
  prefix_in_quant_config = prefix
254
265
  proj_name = prefix.split(".")[-1]
255
- if proj_name in self.packed_modules_mapping:
266
+ if proj_name in packed_modules_mapping_subset:
256
267
  prefix_in_quant_config = prefix.replace(
257
- proj_name, self.packed_modules_mapping[proj_name][0]
268
+ proj_name, packed_modules_mapping_subset[proj_name][0]
258
269
  )
259
270
  self.is_dynamic = (
260
271
  self.quant_description[prefix_in_quant_config + ".weight"]
261
272
  == "W8A8_DYNAMIC"
262
273
  )
263
- if self.is_layer_skipped(prefix, self.packed_modules_mapping):
274
+ if self.is_layer_skipped(prefix, packed_modules_mapping_subset):
264
275
  return UnquantizedLinearMethod()
265
276
  return (
266
277
  NPU_W8A8DynamicLinearMethod(self)
@@ -571,8 +582,10 @@ class NPU_W8A8LinearMethodImpl:
571
582
  layer: torch.nn.Module,
572
583
  x: torch.Tensor,
573
584
  bias: Optional[torch.Tensor] = None,
574
- tp_rank: Optional[int] = 0,
575
585
  ) -> torch.Tensor:
586
+ # To prevent import loops
587
+ from sglang.srt.layers.linear import RowParallelLinear
588
+
576
589
  original_dtype = x.dtype
577
590
  if original_dtype != torch.int8:
578
591
  x = torch_npu.npu_quantize(
@@ -583,8 +596,12 @@ class NPU_W8A8LinearMethodImpl:
583
596
  -1,
584
597
  True,
585
598
  )
586
-
587
- quant_bias = layer.quant_bias if tp_rank == 0 else None
599
+ # Only fuse bias add into GEMM for rank 0 (this ensures that
600
+ # bias will not get added more than once in Attention TP>1 case)
601
+ if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
602
+ quant_bias = None
603
+ else:
604
+ quant_bias = layer.quant_bias
588
605
  return torch_npu.npu_quant_matmul(
589
606
  x,
590
607
  layer.weight,
@@ -651,13 +668,21 @@ class NPU_W8A8LinearMethodMTImpl:
651
668
  layer: torch.nn.Module,
652
669
  x: torch.Tensor,
653
670
  bias: Optional[torch.Tensor] = None,
654
- tp_rank: Optional[int] = 0,
655
671
  ) -> torch.Tensor:
672
+ # To prevent import loops
673
+ from sglang.srt.layers.linear import RowParallelLinear
674
+
656
675
  original_dtype = x.dtype
657
676
  if original_dtype != torch.int8:
658
677
  x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
659
678
 
660
- quant_bias = layer.quant_bias if tp_rank == 0 else None
679
+ # Only fuse bias add into GEMM for rank 0 (this ensures that
680
+ # bias will not get added more than once in Attention TP>1 case)
681
+ if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
682
+ quant_bias = None
683
+ else:
684
+ quant_bias = layer.quant_bias
685
+
661
686
  return ops.quant_matmul(
662
687
  x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
663
688
  )
@@ -737,11 +762,6 @@ class NPU_W8A8LinearMethod(LinearMethodBase):
737
762
  x: torch.Tensor,
738
763
  bias: Optional[torch.Tensor] = None,
739
764
  ) -> torch.Tensor:
740
- from sglang.srt.layers.linear import RowParallelLinear
741
-
742
- if isinstance(layer, RowParallelLinear):
743
- tp_rank = get_tensor_model_parallel_rank()
744
- return self.quant_method.apply(layer, x, bias, tp_rank)
745
765
  return self.quant_method.apply(layer, x, bias)
746
766
 
747
767
 
@@ -780,7 +800,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
780
800
  tp_rank: Optional[int] = 0,
781
801
  ) -> torch.Tensor:
782
802
  original_dtype = x.dtype
783
- # use ATB quantize
784
803
  quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
785
804
  return torch_npu.npu_quant_matmul(
786
805
  quant_out,
@@ -863,11 +882,6 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
863
882
  x: torch.Tensor,
864
883
  bias: Optional[torch.Tensor] = None,
865
884
  ) -> torch.Tensor:
866
- from sglang.srt.layers.linear import RowParallelLinear
867
-
868
- if isinstance(layer, RowParallelLinear):
869
- tp_rank = get_tensor_model_parallel_rank()
870
- return self.quant_method.apply(layer, x, bias, tp_rank)
871
885
  return self.quant_method.apply(layer, x, bias)
872
886
 
873
887