sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -16,9 +16,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
16
16
  from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
17
17
  from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
18
18
  all_close_1d,
19
- apply_fp8_linear,
20
19
  convert_to_channelwise,
21
- cutlass_fp8_supported,
22
20
  per_tensor_dequantize,
23
21
  requantize_with_max_scale,
24
22
  )
@@ -29,14 +27,21 @@ from sglang.srt.layers.linear import (
29
27
  LinearMethodBase,
30
28
  UnquantizedLinearMethod,
31
29
  )
32
- from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
30
+ from sglang.srt.layers.parameter import (
31
+ BlockQuantScaleParameter,
32
+ ModelWeightParameter,
33
+ PerTensorScaleParameter,
34
+ )
33
35
  from sglang.srt.layers.quantization.base_config import (
34
36
  QuantizationConfig,
35
37
  QuantizeMethodBase,
36
38
  )
39
+ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
37
40
  from sglang.srt.layers.quantization.fp8_utils import (
38
- BlockQuantScaleParameter,
41
+ apply_fp8_linear,
39
42
  apply_w8a8_block_fp8_linear,
43
+ cutlass_fp8_supported,
44
+ input_to_float8,
40
45
  normalize_e4m3fn_to_e4m3fnuz,
41
46
  )
42
47
  from sglang.srt.utils import (
@@ -49,9 +54,9 @@ from sglang.srt.utils import (
49
54
 
50
55
  ACTIVATION_SCHEMES = ["static", "dynamic"]
51
56
 
52
- is_hip_ = is_hip()
57
+ _is_hip = is_hip()
53
58
 
54
- if is_hip_:
59
+ if _is_hip:
55
60
  from aiter.fused_moe_bf16_asm import asm_moe
56
61
  from aiter.ops.shuffle import shuffle_weight
57
62
 
@@ -170,7 +175,7 @@ class Fp8LinearMethod(LinearMethodBase):
170
175
  # kernel for fast weight-only FP8 quantization
171
176
  self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
172
177
  # Disable marlin for ROCm
173
- if is_hip_:
178
+ if _is_hip:
174
179
  self.use_marlin = False
175
180
 
176
181
  self.block_quant = self.quant_config.weight_block_size is not None
@@ -282,7 +287,7 @@ class Fp8LinearMethod(LinearMethodBase):
282
287
  # Block quant doesn't need to process weights after loading
283
288
  if self.block_quant:
284
289
  # If ROCm, normalize the weights and scales to e4m3fnuz
285
- if is_hip_:
290
+ if _is_hip:
286
291
  # activation_scheme: dynamic
287
292
  weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
288
293
  weight=layer.weight,
@@ -305,15 +310,15 @@ class Fp8LinearMethod(LinearMethodBase):
305
310
  layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
306
311
  # If checkpoint not serialized fp8, quantize the weights.
307
312
  if not self.quant_config.is_checkpoint_fp8_serialized:
308
- qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
309
-
310
- # If using marlin (w8a16), kernel uses channelwise weights,
311
- # so extend the weight scales to be channelwise.
312
- if self.use_marlin:
313
- assert weight_scale.numel() == 1
314
- weight_scale = convert_to_channelwise(
315
- weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
313
+ if self.cutlass_fp8_supported or self.use_marlin:
314
+ # apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
315
+ qweight, weight_scale = per_token_group_quant_fp8(
316
+ layer.weight, layer.weight.shape[-1]
316
317
  )
318
+ weight_scale = weight_scale.t().contiguous()
319
+ else:
320
+ # per-tensor quantization
321
+ qweight, weight_scale = input_to_float8(layer.weight)
317
322
 
318
323
  # Update the layer with the new values.
319
324
  layer.weight = Parameter(qweight.t(), requires_grad=False)
@@ -330,23 +335,19 @@ class Fp8LinearMethod(LinearMethodBase):
330
335
  layer.input_scale = torch.nn.Parameter(
331
336
  layer.input_scale.data, requires_grad=False
332
337
  )
333
- # If using marlin (w8a16), kernel uses channelwise weights,
334
- # so extend the weight scales to be channelwise.
335
- if self.use_marlin:
338
+
339
+ # cutlass sgl-kernel and marlin only support per-channel scale
340
+ if self.cutlass_fp8_supported or self.use_marlin:
336
341
  weight = layer.weight
337
342
  weight_scale = convert_to_channelwise(
338
343
  layer.weight_scale, layer.logical_widths
339
344
  )
340
-
341
- # If using w8a8, torch._scaled_mm needs per tensor, so
342
- # requantize the logical shards as a single weight.
343
345
  else:
344
346
  # Dequant -> Quant with max scale so we can run per tensor.
345
347
  weight = layer.weight
346
348
  weight_scale = layer.weight_scale
347
-
348
349
  # If ROCm, normalize the weights and scales to e4m3fnuz
349
- if is_hip_:
350
+ if _is_hip:
350
351
  weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
351
352
  weight=weight,
352
353
  weight_scale=weight_scale,
@@ -460,7 +461,11 @@ class Fp8MoEMethod:
460
461
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
461
462
 
462
463
  if self.quant_config.is_checkpoint_fp8_serialized:
463
- params_dtype = torch.float8_e4m3fn
464
+ params_dtype = (
465
+ torch.int32
466
+ if get_bool_env_var("USE_INT4_WEIGHT")
467
+ else torch.float8_e4m3fn
468
+ )
464
469
  tp_size = get_tensor_model_parallel_world_size()
465
470
  if self.block_quant:
466
471
  block_n, block_k = (
@@ -485,21 +490,40 @@ class Fp8MoEMethod:
485
490
  )
486
491
 
487
492
  # WEIGHTS
488
- w13_weight = torch.nn.Parameter(
489
- torch.empty(
490
- num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
491
- ),
492
- requires_grad=False,
493
- )
493
+ if get_bool_env_var("USE_INT4_WEIGHT"):
494
+ # INT4 MoE weight - INT32 packed
495
+ w13_weight = torch.nn.Parameter(
496
+ torch.empty(
497
+ num_experts,
498
+ 2 * intermediate_size,
499
+ hidden_size // 8,
500
+ dtype=params_dtype,
501
+ ),
502
+ requires_grad=False,
503
+ )
504
+ w2_weight = torch.nn.Parameter(
505
+ torch.empty(
506
+ num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype
507
+ ),
508
+ requires_grad=False,
509
+ )
510
+ else:
511
+ w13_weight = torch.nn.Parameter(
512
+ torch.empty(
513
+ num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
514
+ ),
515
+ requires_grad=False,
516
+ )
517
+ w2_weight = torch.nn.Parameter(
518
+ torch.empty(
519
+ num_experts, hidden_size, intermediate_size, dtype=params_dtype
520
+ ),
521
+ requires_grad=False,
522
+ )
523
+
494
524
  layer.register_parameter("w13_weight", w13_weight)
495
525
  set_weight_attrs(w13_weight, extra_weight_attrs)
496
526
 
497
- w2_weight = torch.nn.Parameter(
498
- torch.empty(
499
- num_experts, hidden_size, intermediate_size, dtype=params_dtype
500
- ),
501
- requires_grad=False,
502
- )
503
527
  layer.register_parameter("w2_weight", w2_weight)
504
528
  set_weight_attrs(w2_weight, extra_weight_attrs)
505
529
 
@@ -538,7 +562,9 @@ class Fp8MoEMethod:
538
562
  layer.register_parameter("w13_weight_scale", w13_weight_scale)
539
563
  layer.register_parameter("w2_weight_scale", w2_weight_scale)
540
564
 
541
- if is_hip_ and get_bool_env_var("CK_MOE"):
565
+ if (
566
+ _is_hip
567
+ ): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
542
568
  # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
543
569
  w13_weight_scale1 = torch.nn.Parameter(
544
570
  torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
@@ -565,6 +591,13 @@ class Fp8MoEMethod:
565
591
  set_weight_attrs(w13_weight_scale, extra_weight_attrs)
566
592
  set_weight_attrs(w2_weight_scale, extra_weight_attrs)
567
593
 
594
+ if get_bool_env_var("USE_INT4_WEIGHT"):
595
+ extra_weight_attrs.update(
596
+ {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
597
+ )
598
+ set_weight_attrs(w13_weight_scale1, extra_weight_attrs)
599
+ set_weight_attrs(w2_weight_scale1, extra_weight_attrs)
600
+
568
601
  # INPUT_SCALES
569
602
  if self.quant_config.activation_scheme == "static":
570
603
  if not self.quant_config.is_checkpoint_fp8_serialized:
@@ -590,14 +623,14 @@ class Fp8MoEMethod:
590
623
  layer.w2_input_scale = None
591
624
 
592
625
  def process_weights_after_loading(self, layer: Module) -> None:
593
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
594
- padding_size, # Avoid circular import
595
- )
626
+ if get_bool_env_var("USE_INT4_WEIGHT"):
627
+ self.process_weights_hip_int4(layer)
628
+ return
596
629
 
597
630
  # Block quant doesn't need to process weights after loading
598
631
  if self.block_quant:
599
632
  # If ROCm, normalize the weights and scales to e4m3fnuz
600
- if is_hip_:
633
+ if _is_hip:
601
634
  # activation_scheme: dynamic
602
635
  w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
603
636
  weight=layer.w13_weight,
@@ -630,10 +663,11 @@ class Fp8MoEMethod:
630
663
  layer.w2_weight.contiguous(), (16, 16)
631
664
  )
632
665
  return
666
+
633
667
  # If checkpoint is fp16 or bfloat16, quantize in place.
634
668
  if not self.quant_config.is_checkpoint_fp8_serialized:
635
669
  # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
636
- fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
670
+ fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
637
671
  w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
638
672
  w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
639
673
 
@@ -655,33 +689,8 @@ class Fp8MoEMethod:
655
689
  layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
656
690
  layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
657
691
 
658
- if is_hip_:
659
- if get_bool_env_var("CK_MOE"):
660
- layer.w13_weight = torch.nn.Parameter(
661
- permute_weight(layer.w13_weight.data),
662
- requires_grad=False,
663
- )
664
- torch.cuda.empty_cache()
665
- layer.w2_weight = torch.nn.Parameter(
666
- permute_weight(layer.w2_weight.data),
667
- requires_grad=False,
668
- )
669
- torch.cuda.empty_cache()
670
- # ROCm (CK_MOE): using column-wise scaling
671
- layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
672
- layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
673
- elif get_bool_env_var("MOE_PADDING"):
674
- # If ROCm, apply weight padding (min. Mem channel contention) only if set
675
- layer.w13_weight = torch.nn.Parameter(
676
- F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
677
- requires_grad=False,
678
- )
679
- torch.cuda.empty_cache()
680
- layer.w2_weight = torch.nn.Parameter(
681
- F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
682
- requires_grad=False,
683
- )
684
- torch.cuda.empty_cache()
692
+ if _is_hip:
693
+ self.process_weights_hip_scale_padding(layer)
685
694
  return
686
695
 
687
696
  # If checkpoint is fp8, we need to handle that the
@@ -712,7 +721,7 @@ class Fp8MoEMethod:
712
721
  )
713
722
 
714
723
  # If ROCm, normalize the weights and scales to e4m3fnuz
715
- if is_hip_:
724
+ if _is_hip:
716
725
  # Normalize the weights and scales
717
726
  w13_weight, w13_weight_scale, w13_input_scale = (
718
727
  normalize_e4m3fn_to_e4m3fnuz(
@@ -762,35 +771,85 @@ class Fp8MoEMethod:
762
771
  max_w13_scales, requires_grad=False
763
772
  )
764
773
 
765
- if is_hip_:
766
- if get_bool_env_var("CK_MOE"):
767
- layer.w13_weight = torch.nn.Parameter(
768
- permute_weight(layer.w13_weight.data),
769
- requires_grad=False,
770
- )
771
- torch.cuda.empty_cache()
772
- layer.w2_weight = torch.nn.Parameter(
773
- permute_weight(layer.w2_weight.data),
774
- requires_grad=False,
775
- )
776
- torch.cuda.empty_cache()
777
- # ROCm (CK_MOE): using column-wise scaling
778
- layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
779
- layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
780
- elif get_bool_env_var("MOE_PADDING"):
781
- # If ROCm, apply weight padding (min. Mem channel contention) only if set
782
- layer.w13_weight = torch.nn.Parameter(
783
- F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
784
- requires_grad=False,
785
- )
786
- torch.cuda.empty_cache()
787
- layer.w2_weight = torch.nn.Parameter(
788
- F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
789
- requires_grad=False,
790
- )
791
- torch.cuda.empty_cache()
774
+ if _is_hip:
775
+ self.process_weights_hip_scale_padding(layer)
792
776
  return
793
777
 
778
+ def process_weights_hip_int4(self, layer: Module):
779
+ # TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
780
+ # INT4-FP8 (INT4 MoE Weight, FP8 Compute)
781
+ # Weight Permutation
782
+ layer.w13_weight = torch.nn.Parameter(
783
+ permute_weight(layer.w13_weight.data),
784
+ requires_grad=False,
785
+ )
786
+ torch.cuda.empty_cache()
787
+ layer.w2_weight = torch.nn.Parameter(
788
+ permute_weight(layer.w2_weight.data),
789
+ requires_grad=False,
790
+ )
791
+ torch.cuda.empty_cache()
792
+
793
+ # INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
794
+ # Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
795
+ # We won't do requant each expert's fp8 weight (not direct available),
796
+ # instead we adjust half of INT4 w13_weight_scale1 numbers
797
+ assert layer.w13_weight_scale is not None
798
+ shard_size = layer.intermediate_size_per_partition
799
+ max_w13_scales = layer.w13_weight_scale.max(dim=1).values
800
+ for expert_id in range(layer.num_experts):
801
+ start = 0
802
+ max_w13_scale_fp8 = max_w13_scales[expert_id]
803
+ for shard_id in range(2):
804
+ if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
805
+ int4_rescale = (
806
+ layer.w13_weight_scale[expert_id][shard_id] / max_w13_scale_fp8
807
+ )
808
+ layer.w13_weight_scale1[expert_id][
809
+ start : start + shard_size
810
+ ] *= int4_rescale
811
+ start += shard_size
812
+
813
+ layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
814
+
815
+ # special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
816
+ # optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
817
+ for expert_id in range(layer.num_experts):
818
+ layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
819
+ layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
820
+
821
+ def process_weights_hip_scale_padding(self, layer: Module, padding_size: int):
822
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
823
+ padding_size, # Avoid circular import
824
+ )
825
+
826
+ if get_bool_env_var("CK_MOE"):
827
+ layer.w13_weight = torch.nn.Parameter(
828
+ permute_weight(layer.w13_weight.data),
829
+ requires_grad=False,
830
+ )
831
+ torch.cuda.empty_cache()
832
+ layer.w2_weight = torch.nn.Parameter(
833
+ permute_weight(layer.w2_weight.data),
834
+ requires_grad=False,
835
+ )
836
+ torch.cuda.empty_cache()
837
+ # ROCm (CK_MOE): using column-wise scaling
838
+ layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
839
+ layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
840
+ elif get_bool_env_var("MOE_PADDING"):
841
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
842
+ layer.w13_weight = torch.nn.Parameter(
843
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
844
+ requires_grad=False,
845
+ )
846
+ torch.cuda.empty_cache()
847
+ layer.w2_weight = torch.nn.Parameter(
848
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
849
+ requires_grad=False,
850
+ )
851
+ torch.cuda.empty_cache()
852
+
794
853
  def apply(
795
854
  self,
796
855
  layer: torch.nn.Module,
@@ -823,8 +882,24 @@ class Fp8MoEMethod:
823
882
  correction_bias=correction_bias,
824
883
  )
825
884
 
826
- if is_hip_ and get_bool_env_var("CK_MOE") and activation == "silu":
885
+ if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
886
+ # TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
887
+ assert not no_combine, f"{no_combine=} is not supported."
888
+ return asm_moe(
889
+ x,
890
+ layer.w13_weight,
891
+ layer.w2_weight,
892
+ topk_weights,
893
+ topk_ids,
894
+ layer.w13_weight_scale1,
895
+ layer.w2_weight_scale1,
896
+ activation=activation,
897
+ )
898
+ if _is_hip and get_bool_env_var("CK_MOE"):
827
899
  # TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
900
+ assert (
901
+ activation == "silu"
902
+ ), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
828
903
  assert not no_combine, f"{no_combine=} is not supported."
829
904
  if self.block_quant:
830
905
  return asm_moe(
@@ -835,10 +910,6 @@ class Fp8MoEMethod:
835
910
  topk_ids,
836
911
  layer.w13_weight_scale_inv,
837
912
  layer.w2_weight_scale_inv,
838
- None,
839
- None,
840
- False,
841
- None,
842
913
  block_shape=tuple(self.quant_config.weight_block_size),
843
914
  expert_mask=None,
844
915
  )
@@ -851,9 +922,6 @@ class Fp8MoEMethod:
851
922
  topk_ids,
852
923
  layer.w13_weight_scale1,
853
924
  layer.w2_weight_scale1,
854
- None,
855
- None,
856
- False,
857
925
  )
858
926
  else:
859
927
  # Expert fusion with FP8 quantization