sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. sglang/bench_one_batch.py +1 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/lang/chat_template.py +44 -0
  4. sglang/srt/configs/deepseekvl2.py +3 -0
  5. sglang/srt/configs/device_config.py +1 -1
  6. sglang/srt/configs/internvl.py +696 -0
  7. sglang/srt/configs/janus_pro.py +3 -0
  8. sglang/srt/configs/model_config.py +17 -0
  9. sglang/srt/constrained/xgrammar_backend.py +11 -19
  10. sglang/srt/conversation.py +30 -3
  11. sglang/srt/disaggregation/decode.py +4 -1
  12. sglang/srt/disaggregation/mini_lb.py +74 -23
  13. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  14. sglang/srt/disaggregation/nixl/conn.py +241 -71
  15. sglang/srt/disaggregation/utils.py +44 -1
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  17. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  19. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  20. sglang/srt/distributed/parallel_state.py +22 -1
  21. sglang/srt/entrypoints/engine.py +14 -2
  22. sglang/srt/entrypoints/http_server.py +28 -1
  23. sglang/srt/entrypoints/verl_engine.py +3 -2
  24. sglang/srt/hf_transformers_utils.py +20 -1
  25. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  26. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  27. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  28. sglang/srt/layers/attention/merge_state.py +46 -0
  29. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  30. sglang/srt/layers/attention/vision.py +290 -163
  31. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  32. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  33. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
  37. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  38. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  39. sglang/srt/layers/quantization/deep_gemm.py +5 -0
  40. sglang/srt/layers/quantization/fp8.py +108 -95
  41. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  42. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  43. sglang/srt/layers/quantization/kv_cache.py +3 -10
  44. sglang/srt/layers/quantization/utils.py +0 -5
  45. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  46. sglang/srt/lora/lora_manager.py +10 -13
  47. sglang/srt/managers/cache_controller.py +115 -119
  48. sglang/srt/managers/io_struct.py +10 -0
  49. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  50. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  51. sglang/srt/managers/schedule_batch.py +19 -1
  52. sglang/srt/managers/schedule_policy.py +11 -5
  53. sglang/srt/managers/scheduler.py +28 -13
  54. sglang/srt/managers/tokenizer_manager.py +24 -13
  55. sglang/srt/managers/tp_worker.py +9 -12
  56. sglang/srt/mem_cache/chunk_cache.py +2 -0
  57. sglang/srt/mem_cache/memory_pool.py +2 -2
  58. sglang/srt/model_executor/model_runner.py +44 -33
  59. sglang/srt/model_loader/loader.py +18 -11
  60. sglang/srt/models/clip.py +4 -4
  61. sglang/srt/models/deepseek_janus_pro.py +1 -1
  62. sglang/srt/models/deepseek_nextn.py +1 -20
  63. sglang/srt/models/deepseek_v2.py +55 -20
  64. sglang/srt/models/gemma3_mm.py +1 -1
  65. sglang/srt/models/internlm2.py +3 -0
  66. sglang/srt/models/internvl.py +670 -0
  67. sglang/srt/models/llama.py +1 -1
  68. sglang/srt/models/llama4.py +53 -7
  69. sglang/srt/models/minicpmv.py +1 -1
  70. sglang/srt/models/mllama.py +1 -1
  71. sglang/srt/models/phi3_small.py +16 -2
  72. sglang/srt/models/qwen2_5_vl.py +8 -4
  73. sglang/srt/models/qwen2_vl.py +4 -4
  74. sglang/srt/models/xiaomi_mimo.py +171 -0
  75. sglang/srt/openai_api/adapter.py +24 -40
  76. sglang/srt/openai_api/protocol.py +28 -16
  77. sglang/srt/reasoning_parser.py +2 -2
  78. sglang/srt/sampling/sampling_batch_info.py +54 -2
  79. sglang/srt/sampling/sampling_params.py +2 -0
  80. sglang/srt/server_args.py +30 -6
  81. sglang/srt/utils.py +35 -1
  82. sglang/test/test_block_fp8.py +2 -2
  83. sglang/test/test_deepep_utils.py +219 -0
  84. sglang/test/test_utils.py +3 -1
  85. sglang/version.py +1 -1
  86. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
  87. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
  88. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  89. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  90. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,8 @@ from sglang.srt.layers.quantization.base_config import (
42
42
  QuantizeMethodBase,
43
43
  )
44
44
  from sglang.srt.layers.quantization.fp8_kernel import (
45
+ fp8_dtype,
46
+ is_fp8_fnuz,
45
47
  per_token_group_quant_fp8,
46
48
  scaled_fp8_quant,
47
49
  )
@@ -64,6 +66,7 @@ from sglang.srt.utils import (
64
66
  get_bool_env_var,
65
67
  is_cuda,
66
68
  is_hip,
69
+ log_info_on_rank0,
67
70
  print_warning_once,
68
71
  set_weight_attrs,
69
72
  )
@@ -71,6 +74,11 @@ from sglang.srt.utils import (
71
74
  _is_hip = is_hip()
72
75
  _is_cuda = is_cuda()
73
76
 
77
+ _is_fp8_fnuz = is_fp8_fnuz()
78
+
79
+ use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
80
+ use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
81
+
74
82
  if _is_hip:
75
83
  from aiter import ActivationType, QuantType
76
84
  from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
@@ -97,10 +105,7 @@ class Fp8Config(QuantizationConfig):
97
105
  ) -> None:
98
106
  self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
99
107
  if is_checkpoint_fp8_serialized:
100
- logger.warning(
101
- "Detected fp8 checkpoint. Please note that the "
102
- "format is experimental and subject to change."
103
- )
108
+ log_info_on_rank0(logger, "Detected fp8 checkpoint.")
104
109
  if activation_scheme not in ACTIVATION_SCHEMES:
105
110
  raise ValueError(f"Unsupported activation scheme {activation_scheme}")
106
111
  self.activation_scheme = activation_scheme
@@ -306,25 +311,21 @@ class Fp8LinearMethod(LinearMethodBase):
306
311
  # Block quant doesn't need to process weights after loading
307
312
  if self.block_quant:
308
313
  # If ROCm, normalize the weights and scales to e4m3fnuz
309
- if _is_hip:
314
+ if _is_fp8_fnuz:
310
315
  # activation_scheme: dynamic
311
316
  weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
312
317
  weight=layer.weight,
313
318
  weight_scale=layer.weight_scale_inv,
314
319
  input_scale=None,
315
320
  )
316
- layer.weight = torch.nn.Parameter(weight, requires_grad=False)
317
- layer.weight_scale_inv = torch.nn.Parameter(
318
- weight_scale, requires_grad=False
319
- )
321
+
320
322
  layer.input_scale = None
321
323
  else:
322
- layer.weight = torch.nn.Parameter(
323
- layer.weight.data, requires_grad=False
324
- )
325
- layer.weight_scale_inv = torch.nn.Parameter(
326
- layer.weight_scale_inv.data, requires_grad=False
327
- )
324
+ weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
325
+ layer.weight = torch.nn.Parameter(weight, requires_grad=False)
326
+ layer.weight_scale_inv = torch.nn.Parameter(
327
+ weight_scale, requires_grad=False
328
+ )
328
329
  return
329
330
 
330
331
  layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
@@ -368,7 +369,7 @@ class Fp8LinearMethod(LinearMethodBase):
368
369
  weight = layer.weight
369
370
  weight_scale = layer.weight_scale
370
371
  # If ROCm, normalize the weights and scales to e4m3fnuz
371
- if _is_hip:
372
+ if _is_fp8_fnuz:
372
373
  weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
373
374
  weight=weight,
374
375
  weight_scale=weight_scale,
@@ -482,11 +483,7 @@ class Fp8MoEMethod:
482
483
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
483
484
 
484
485
  if self.quant_config.is_checkpoint_fp8_serialized:
485
- params_dtype = (
486
- torch.uint32
487
- if get_bool_env_var("SGLANG_INT4_WEIGHT")
488
- else torch.float8_e4m3fn
489
- )
486
+ params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn
490
487
  tp_size = get_tensor_model_parallel_world_size()
491
488
  if self.block_quant:
492
489
  block_n, block_k = (
@@ -511,7 +508,7 @@ class Fp8MoEMethod:
511
508
  )
512
509
 
513
510
  # WEIGHTS
514
- if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
511
+ if _is_hip and use_hip_int4:
515
512
  # INT4 MoE weight - INT32 packed
516
513
  w13_weight = torch.nn.Parameter(
517
514
  torch.empty(
@@ -583,9 +580,7 @@ class Fp8MoEMethod:
583
580
  layer.register_parameter("w13_weight_scale", w13_weight_scale)
584
581
  layer.register_parameter("w2_weight_scale", w2_weight_scale)
585
582
 
586
- if (
587
- _is_hip
588
- ): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
583
+ if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel
589
584
  # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
590
585
  w13_weight_scale1 = torch.nn.Parameter(
591
586
  torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
@@ -612,7 +607,7 @@ class Fp8MoEMethod:
612
607
  set_weight_attrs(w13_weight_scale, extra_weight_attrs)
613
608
  set_weight_attrs(w2_weight_scale, extra_weight_attrs)
614
609
 
615
- if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
610
+ if _is_hip and use_hip_int4:
616
611
  extra_weight_attrs.update(
617
612
  {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
618
613
  )
@@ -644,14 +639,14 @@ class Fp8MoEMethod:
644
639
  layer.w2_input_scale = None
645
640
 
646
641
  def process_weights_after_loading(self, layer: Module) -> None:
647
- if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
642
+ if _is_hip and use_hip_int4:
648
643
  self.process_weights_hip_int4(layer)
649
644
  return
650
645
 
651
646
  # Block quant doesn't need to process weights after loading
652
647
  if self.block_quant:
653
648
  # If ROCm, normalize the weights and scales to e4m3fnuz
654
- if _is_hip:
649
+ if _is_fp8_fnuz:
655
650
  # activation_scheme: dynamic
656
651
  w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
657
652
  weight=layer.w13_weight,
@@ -675,20 +670,19 @@ class Fp8MoEMethod:
675
670
  )
676
671
  layer.w2_input_scale = None
677
672
 
678
- if get_bool_env_var("SGLANG_AITER_MOE"):
679
- # Pre-shuffle weights
680
- layer.w13_weight.data = shuffle_weight(
681
- layer.w13_weight.contiguous(), (16, 16)
682
- )
683
- layer.w2_weight.data = shuffle_weight(
684
- layer.w2_weight.contiguous(), (16, 16)
685
- )
673
+ if _is_hip and use_aiter_moe:
674
+ # Pre-shuffle weights
675
+ layer.w13_weight.data = shuffle_weight(
676
+ layer.w13_weight.contiguous(), (16, 16)
677
+ )
678
+ layer.w2_weight.data = shuffle_weight(
679
+ layer.w2_weight.contiguous(), (16, 16)
680
+ )
686
681
  return
687
682
 
688
683
  # If checkpoint is fp16 or bfloat16, quantize in place.
689
684
  if not self.quant_config.is_checkpoint_fp8_serialized:
690
- # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
691
- fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
685
+ # If ROCm, fp8_dtype will be float8_e4m3fnuz (MI300x HW)
692
686
  w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
693
687
  w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
694
688
 
@@ -742,7 +736,7 @@ class Fp8MoEMethod:
742
736
  )
743
737
 
744
738
  # If ROCm, normalize the weights and scales to e4m3fnuz
745
- if _is_hip:
739
+ if _is_fp8_fnuz:
746
740
  # Normalize the weights and scales
747
741
  w13_weight, w13_weight_scale, w13_input_scale = (
748
742
  normalize_e4m3fn_to_e4m3fnuz(
@@ -798,7 +792,7 @@ class Fp8MoEMethod:
798
792
  return
799
793
 
800
794
  def process_weights_hip_int4(self, layer: Module):
801
- # TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
795
+ # TODO: and use_aiter_moe: add after triton kernel added
802
796
  # INT4-FP8 (INT4 MoE Weight, FP8 Compute)
803
797
  # Weight Permutation
804
798
  layer.w13_weight = torch.nn.Parameter(
@@ -845,7 +839,7 @@ class Fp8MoEMethod:
845
839
  padding_size, # Avoid circular import
846
840
  )
847
841
 
848
- if get_bool_env_var("SGLANG_AITER_MOE"):
842
+ if use_aiter_moe:
849
843
  layer.w13_weight = torch.nn.Parameter(
850
844
  shuffle_weight(layer.w13_weight.data, (16, 16)),
851
845
  requires_grad=False,
@@ -856,7 +850,7 @@ class Fp8MoEMethod:
856
850
  requires_grad=False,
857
851
  )
858
852
  torch.cuda.empty_cache()
859
- # ROCm (SGLANG_AITER_MOE): using column-wise scaling
853
+ # ROCm (use_aiter_moe): using column-wise scaling
860
854
  layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
861
855
  layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
862
856
  elif get_bool_env_var("SGLANG_MOE_PADDING"):
@@ -908,59 +902,16 @@ class Fp8MoEMethod:
908
902
  )
909
903
 
910
904
  if _is_hip:
911
- if get_bool_env_var("SGLANG_INT4_WEIGHT"):
912
- # TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
913
- assert not no_combine, f"{no_combine=} is not supported."
914
- return ck_moe_2stages(
915
- x,
916
- layer.w13_weight,
917
- layer.w2_weight,
918
- topk_weights,
919
- topk_ids,
920
- QuantType.per_Token,
921
- layer.w13_weight_scale1,
922
- layer.w2_weight_scale1,
923
- activation=(
924
- ActivationType.Silu
925
- if activation == "silu"
926
- else ActivationType.Gelu
927
- ),
928
- )
929
-
930
- if get_bool_env_var("SGLANG_AITER_MOE"):
931
- assert not no_combine, f"{no_combine=} is not supported."
932
- if self.block_quant:
933
- # TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
934
- assert (
935
- activation == "silu"
936
- ), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
937
- return asm_moe(
938
- x,
939
- layer.w13_weight,
940
- layer.w2_weight,
941
- topk_weights,
942
- topk_ids,
943
- layer.w13_weight_scale_inv,
944
- layer.w2_weight_scale_inv,
945
- block_shape=tuple(self.quant_config.weight_block_size),
946
- expert_mask=None,
947
- )
948
- else:
949
- return ck_moe_2stages(
950
- x,
951
- layer.w13_weight,
952
- layer.w2_weight,
953
- topk_weights,
954
- topk_ids,
955
- QuantType.per_Token,
956
- layer.w13_weight_scale1,
957
- layer.w2_weight_scale1,
958
- activation=(
959
- ActivationType.Silu
960
- if activation == "silu"
961
- else ActivationType.Gelu
962
- ),
963
- )
905
+ ret = self.maybe_apply_hip_fused_experts(
906
+ layer,
907
+ x,
908
+ topk_weights,
909
+ topk_ids,
910
+ activation,
911
+ no_combine,
912
+ )
913
+ if ret is not None:
914
+ return ret
964
915
 
965
916
  # Expert fusion with FP8 quantization
966
917
  return fused_experts(
@@ -987,6 +938,68 @@ class Fp8MoEMethod:
987
938
  no_combine=no_combine,
988
939
  )
989
940
 
941
+ def maybe_apply_hip_fused_experts(
942
+ self,
943
+ layer: torch.nn.Module,
944
+ x: torch.Tensor,
945
+ topk_weights: torch.Tensor,
946
+ topk_ids: torch.Tensor,
947
+ activation: str = "silu",
948
+ no_combine: bool = False,
949
+ ) -> Optional[torch.Tensor]:
950
+ if use_hip_int4:
951
+ # TODO: add triton kernel and add check use_aiter_moe
952
+ assert not no_combine, f"{no_combine=} is not supported."
953
+ return ck_moe_2stages(
954
+ x,
955
+ layer.w13_weight,
956
+ layer.w2_weight,
957
+ topk_weights,
958
+ topk_ids,
959
+ QuantType.per_Token,
960
+ layer.w13_weight_scale1,
961
+ layer.w2_weight_scale1,
962
+ activation=(
963
+ ActivationType.Silu if activation == "silu" else ActivationType.Gelu
964
+ ),
965
+ )
966
+
967
+ if use_aiter_moe:
968
+ assert not no_combine, f"{no_combine=} is not supported."
969
+ if self.block_quant:
970
+ # TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being.
971
+ assert (
972
+ activation == "silu"
973
+ ), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe"
974
+ return asm_moe(
975
+ x,
976
+ layer.w13_weight,
977
+ layer.w2_weight,
978
+ topk_weights,
979
+ topk_ids,
980
+ layer.w13_weight_scale_inv,
981
+ layer.w2_weight_scale_inv,
982
+ block_shape=tuple(self.quant_config.weight_block_size),
983
+ expert_mask=None,
984
+ )
985
+ else:
986
+ return ck_moe_2stages(
987
+ x,
988
+ layer.w13_weight,
989
+ layer.w2_weight,
990
+ topk_weights,
991
+ topk_ids,
992
+ QuantType.per_Token,
993
+ layer.w13_weight_scale1,
994
+ layer.w2_weight_scale1,
995
+ activation=(
996
+ ActivationType.Silu
997
+ if activation == "silu"
998
+ else ActivationType.Gelu
999
+ ),
1000
+ )
1001
+ return None
1002
+
990
1003
 
991
1004
  class Fp8KVCacheMethod(BaseKVCacheMethod):
992
1005
  """
@@ -16,6 +16,7 @@ import functools
16
16
  import json
17
17
  import logging
18
18
  import os
19
+ from functools import lru_cache
19
20
  from typing import Any, Dict, List, Optional, Tuple
20
21
 
21
22
  import torch
@@ -29,17 +30,12 @@ from sglang.srt.utils import (
29
30
  get_device_name,
30
31
  is_cuda,
31
32
  is_hip,
33
+ log_info_on_rank0,
32
34
  supports_custom_op,
33
35
  )
34
36
 
35
37
  _is_hip = is_hip()
36
38
  _is_cuda = is_cuda()
37
- _fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
38
- if _is_hip:
39
- fp8_max = 224.0
40
- else:
41
- fp8_max = torch.finfo(_fp8_type).max
42
- fp8_min = -fp8_max
43
39
 
44
40
  if _is_cuda:
45
41
  from sgl_kernel import (
@@ -54,6 +50,24 @@ if _is_cuda:
54
50
 
55
51
  logger = logging.getLogger(__name__)
56
52
 
53
+
54
+ @lru_cache()
55
+ def is_fp8_fnuz() -> bool:
56
+ if _is_hip:
57
+ # only device 0 is checked, this assumes MI300 platforms are homogeneous
58
+ return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
59
+ return False
60
+
61
+
62
+ if is_fp8_fnuz():
63
+ fp8_dtype = torch.float8_e4m3fnuz
64
+ fp8_max = 224.0
65
+ else:
66
+ fp8_dtype = torch.float8_e4m3fn
67
+ fp8_max = torch.finfo(fp8_dtype).max
68
+ fp8_min = -fp8_max
69
+
70
+
57
71
  if supports_custom_op():
58
72
 
59
73
  def deep_gemm_fp8_fp8_bf16_nt(
@@ -198,7 +212,7 @@ def per_token_group_quant_fp8(
198
212
  ), "the last dimension of `x` cannot be divisible by `group_size`"
199
213
  assert x.is_contiguous(), "`x` is not contiguous"
200
214
 
201
- x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
215
+ x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
202
216
  M = x.numel() // group_size
203
217
  N = group_size
204
218
  if column_major_scales:
@@ -272,7 +286,7 @@ def sglang_per_token_group_quant_fp8(
272
286
  ), "the last dimension of `x` cannot be divisible by `group_size`"
273
287
  assert x.is_contiguous(), "`x` is not contiguous"
274
288
 
275
- x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
289
+ x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
276
290
  if column_major_scales:
277
291
  if scale_tma_aligned:
278
292
  # aligned to 4 * sizeof(float)
@@ -294,15 +308,15 @@ def sglang_per_token_group_quant_fp8(
294
308
  device=x.device,
295
309
  dtype=torch.float32,
296
310
  )
297
-
298
- sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
311
+ if x.shape[0] > 0:
312
+ sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
299
313
 
300
314
  return x_q, x_s
301
315
 
302
316
 
303
317
  def sglang_per_token_quant_fp8(
304
318
  x: torch.Tensor,
305
- dtype: torch.dtype = _fp8_type,
319
+ dtype: torch.dtype = fp8_dtype,
306
320
  ):
307
321
  assert x.is_contiguous(), "`x` is not contiguous"
308
322
 
@@ -384,7 +398,7 @@ def static_quant_fp8(
384
398
  assert x.is_contiguous(), "`x` is not contiguous"
385
399
  assert x_s.numel() == 1, "only supports per-tensor scale"
386
400
 
387
- x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
401
+ x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
388
402
  M = x.numel() // x.shape[-1]
389
403
  N = x.shape[-1]
390
404
  if repeat_scale:
@@ -685,9 +699,9 @@ def get_w8a8_block_fp8_configs(
685
699
  )
686
700
  if os.path.exists(config_file_path):
687
701
  with open(config_file_path) as f:
688
- logger.info(
689
- "Using configuration from %s for W8A8 Block FP8 kernel.",
690
- config_file_path,
702
+ log_info_on_rank0(
703
+ logger,
704
+ f"Using configuration from {config_file_path} for W8A8 Block FP8 kernel.",
691
705
  )
692
706
  # If a configuration has been found, return it
693
707
  return {int(key): val for key, val in json.load(f).items()}
@@ -704,6 +718,28 @@ def get_w8a8_block_fp8_configs(
704
718
  return None
705
719
 
706
720
 
721
+ def select_w8a8_block_fp8_matmul_kernel(M, N, META):
722
+ return _w8a8_block_fp8_matmul
723
+
724
+
725
+ if _is_hip:
726
+
727
+ def use_w8a8_block_fp8_matmul_unrolledx4(M, N, META):
728
+ # Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
729
+ # Empirical testing shows the sweet spot lies when it's less than the # of
730
+ # compute units available on the device.
731
+ num_workgroups = triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(
732
+ N, META["BLOCK_SIZE_N"]
733
+ )
734
+ num_workgroups <= get_device_core_count()
735
+
736
+ def select_w8a8_block_fp8_matmul_kernel(M, N, META):
737
+ if use_w8a8_block_fp8_matmul_unrolledx4(M, N, META):
738
+ return _w8a8_block_fp8_matmul_unrolledx4
739
+ else:
740
+ return _w8a8_block_fp8_matmul
741
+
742
+
707
743
  def w8a8_block_fp8_matmul(
708
744
  A: torch.Tensor,
709
745
  B: torch.Tensor,
@@ -744,35 +780,6 @@ def w8a8_block_fp8_matmul(
744
780
  C_shape = A.shape[:-1] + (N,)
745
781
  C = A.new_empty(C_shape, dtype=output_dtype)
746
782
 
747
- configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
748
- if configs:
749
- # If an optimal configuration map has been found, look up the
750
- # optimal config
751
- config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
752
- else:
753
- # Default config
754
- # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
755
- config = {
756
- "BLOCK_SIZE_M": 64,
757
- "BLOCK_SIZE_N": block_size[0],
758
- "BLOCK_SIZE_K": block_size[1],
759
- "GROUP_SIZE_M": 32,
760
- "num_warps": 4,
761
- "num_stages": 3,
762
- }
763
-
764
- def grid(META):
765
- return (
766
- triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
767
- )
768
-
769
- # Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
770
- # Empirical testing shows the sweet spot lies when it's less than the # of
771
- # compute units available on the device.
772
- num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
773
- N, config["BLOCK_SIZE_N"]
774
- )
775
-
776
783
  # deepgemm only support bf16
777
784
  if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
778
785
  if supports_custom_op():
@@ -780,11 +787,30 @@ def w8a8_block_fp8_matmul(
780
787
  else:
781
788
  deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
782
789
  else:
783
- kernel = (
784
- _w8a8_block_fp8_matmul_unrolledx4
785
- if (_is_hip == True and num_workgroups <= get_device_core_count())
786
- else _w8a8_block_fp8_matmul
787
- )
790
+ configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
791
+ if configs:
792
+ # If an optimal configuration map has been found, look up the
793
+ # optimal config
794
+ config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
795
+ else:
796
+ # Default config
797
+ # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
798
+ config = {
799
+ "BLOCK_SIZE_M": 64,
800
+ "BLOCK_SIZE_N": block_size[0],
801
+ "BLOCK_SIZE_K": block_size[1],
802
+ "GROUP_SIZE_M": 32,
803
+ "num_warps": 4,
804
+ "num_stages": 3,
805
+ }
806
+
807
+ def grid(META):
808
+ return (
809
+ triton.cdiv(M, META["BLOCK_SIZE_M"])
810
+ * triton.cdiv(N, META["BLOCK_SIZE_N"]),
811
+ )
812
+
813
+ kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config)
788
814
 
789
815
  kernel[grid](
790
816
  A,
@@ -879,7 +905,7 @@ def per_tensor_quant_mla_fp8(
879
905
  and x_s_out.device == x.device
880
906
  )
881
907
 
882
- x_q = x.new_empty(x.size(), dtype=_fp8_type)
908
+ x_q = x.new_empty(x.size(), dtype=fp8_dtype)
883
909
 
884
910
  num_head, num_seq, head_size = x.shape
885
911
  BLOCK_SIZE = triton.next_power_of_2(head_size)
@@ -961,11 +987,11 @@ def _per_token_group_quant_mla_deep_gemm_masked_fp8(
961
987
  tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
962
988
 
963
989
 
964
- def per_tensor_quant_mla_deep_gemm_masked_fp8(
990
+ def per_token_group_quant_mla_deep_gemm_masked_fp8(
965
991
  x: torch.Tensor,
966
992
  group_size: int = 128,
967
993
  eps: float = 1e-12,
968
- dtype: torch.dtype = torch.float8_e4m3fn,
994
+ dtype: torch.dtype = fp8_dtype,
969
995
  ) -> Tuple[torch.Tensor, torch.Tensor]:
970
996
  """
971
997
  This function quantizes input values to float8 values with per-token-group-quantization
@@ -973,12 +999,6 @@ def per_tensor_quant_mla_deep_gemm_masked_fp8(
973
999
  """
974
1000
  assert x.dim() == 3, "`x` is not a 3d-tensor"
975
1001
 
976
- finfo = torch.finfo(dtype)
977
- fp8_max = finfo.max
978
- if _is_hip:
979
- dtype = torch.float8_e4m3fnuz
980
- fp8_max = 224.0
981
-
982
1002
  b, m, k = x.shape
983
1003
  aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
984
1004
  num_tiles_k = k // group_size
@@ -1043,10 +1063,9 @@ def scaled_fp8_quant(
1043
1063
  """
1044
1064
  assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
1045
1065
  shape = input.shape
1046
- out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
1047
1066
  if num_token_padding:
1048
1067
  shape = (max(num_token_padding, input.shape[0]), shape[1])
1049
- output = torch.empty(shape, device=input.device, dtype=out_dtype)
1068
+ output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
1050
1069
 
1051
1070
  if scale is None:
1052
1071
  # Dynamic scaling