sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,352 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import logging
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+ from sglang.srt.layers.quantization.marlin_utils import (
9
+ USE_FP32_REDUCE_DEFAULT,
10
+ marlin_make_workspace,
11
+ marlin_permute_bias,
12
+ marlin_permute_scales,
13
+ should_use_atomic_add_reduce,
14
+ )
15
+ from sglang.srt.layers.quantization.utils import get_scalar_types
16
+ from sglang.srt.utils import is_cuda
17
+
18
+ _is_cuda = is_cuda()
19
+ if _is_cuda:
20
+ from sgl_kernel import gptq_marlin_gemm, gptq_marlin_repack
21
+
22
+ ScalarType, scalar_types = get_scalar_types()
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def fp8_fused_exponent_bias_into_scales(scales):
28
+ fp8_exponent = 4
29
+ if scales.dtype == torch.half:
30
+ target_exponent = 5
31
+ elif scales.dtype == torch.bfloat16:
32
+ target_exponent = 8
33
+ # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
34
+ # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
35
+ exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp8_exponent - 1)
36
+ s = torch.ones_like(scales) * 2
37
+ s = s**exponent_bias
38
+ return scales * s
39
+
40
+
41
+ def apply_fp8_marlin_linear(
42
+ input: torch.Tensor,
43
+ weight: torch.Tensor,
44
+ weight_scale: torch.Tensor,
45
+ workspace: torch.Tensor,
46
+ size_n: int,
47
+ size_k: int,
48
+ bias: Optional[torch.Tensor],
49
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
50
+ ) -> torch.Tensor:
51
+ # For GPUs that lack FP8 hardware support, we can leverage the
52
+ # Marlin kernel for fast weight-only FP8 quantization
53
+
54
+ reshaped_x = input.reshape(-1, input.shape[-1])
55
+ out_shape = input.shape[:-1] + (size_n,)
56
+
57
+ use_atomic_add = should_use_atomic_add_reduce(
58
+ m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
59
+ )
60
+
61
+ output = gptq_marlin_gemm(
62
+ a=reshaped_x,
63
+ c=None,
64
+ b_q_weight=weight,
65
+ b_bias=bias,
66
+ b_scales=weight_scale,
67
+ global_scale=None,
68
+ b_zeros=None,
69
+ g_idx=None,
70
+ perm=None,
71
+ workspace=workspace,
72
+ b_q_type=scalar_types.float8_e4m3fn,
73
+ size_m=reshaped_x.size(0),
74
+ size_n=size_n,
75
+ size_k=size_k,
76
+ use_atomic_add=use_atomic_add,
77
+ use_fp32_reduce=use_fp32_reduce,
78
+ )
79
+
80
+ return output.reshape(out_shape)
81
+
82
+
83
+ def prepare_fp8_layer_for_marlin(
84
+ layer: torch.nn.Module, size_k_first: bool = True
85
+ ) -> None:
86
+ logger.warning_once(
87
+ "Your GPU does not have native support for FP8 computation but "
88
+ "FP8 quantization is being used. Weight-only FP8 compression will "
89
+ "be used leveraging the Marlin kernel. This may degrade "
90
+ "performance for compute-heavy workloads."
91
+ )
92
+
93
+ part_size_n = layer.output_size_per_partition
94
+ part_size_k = layer.input_size_per_partition
95
+ weight_block_size = getattr(layer, "weight_block_size", None)
96
+
97
+ if size_k_first:
98
+ assert layer.weight.shape == (part_size_k, part_size_n)
99
+ else:
100
+ assert layer.weight.shape == (part_size_n, part_size_k)
101
+
102
+ device = layer.weight.device
103
+
104
+ # WORKSPACE
105
+ layer.workspace = marlin_make_workspace(device)
106
+
107
+ # WEIGHT
108
+ # Repack weights to marlin format
109
+ perm = torch.empty(0, dtype=torch.int, device=device)
110
+ qweight = pack_fp8_to_int32(layer.weight, size_k_first)
111
+ if not size_k_first:
112
+ qweight = qweight.T.contiguous()
113
+
114
+ marlin_qweight = gptq_marlin_repack(
115
+ b_q_weight=qweight,
116
+ perm=perm,
117
+ size_k=part_size_k,
118
+ size_n=part_size_n,
119
+ num_bits=8,
120
+ )
121
+ layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
122
+
123
+ # WEIGHT SCALES
124
+ # Permute scales
125
+ if "weight_scale" in dir(layer):
126
+ scales = layer.weight_scale.to(layer.orig_dtype)
127
+ elif "weight_scale_inv" in dir(layer):
128
+ scales = layer.weight_scale_inv.to(layer.orig_dtype)
129
+ del layer.weight_scale_inv
130
+
131
+ group_size = -1 if weight_block_size is None else weight_block_size[1]
132
+
133
+ # marlin kernel only support channel-wise and group-wise quantization
134
+ # we need to convert the scales
135
+ if weight_block_size is None:
136
+ if scales.nelement() == 1:
137
+ # tensor-wise quantization -> channel-wise quantization
138
+ # (1, 1) =>(repeat)=> (1, size_n)
139
+ scales = scales.view(1, 1).repeat_interleave(part_size_n, 1)
140
+ elif scales.nelement() > 1 and scales.nelement() != part_size_n:
141
+ assert part_size_n % scales.nelement() == 0
142
+ s_size = scales.nelement()
143
+ # tensor-wise quantization (for gate-up proj)
144
+ # -> channel-wise quantization
145
+ # (1, s_size) =>(repeat)=> (1, size_n)
146
+ scales = scales.view(1, s_size)
147
+ scales = scales.repeat_interleave(part_size_n // s_size, 1)
148
+ else:
149
+ # channel-wise quantization
150
+ # (1, size_n)
151
+ scales = scales.view(1, part_size_n)
152
+ else:
153
+ # block-wise quantization -> group-wise quantization
154
+ # (size_k // block_size[1], ceil(size_n / block_size[0]))
155
+ # =>(repeat)=> (size_k // block_size[1], size_n)
156
+ if not size_k_first:
157
+ scales = scales.T.contiguous()
158
+ block_n = weight_block_size[0]
159
+ scales = scales.repeat_interleave(block_n, 1)
160
+ # size_n may not divisible by block_size[0]
161
+ scales = scales[:, :part_size_n]
162
+
163
+ marlin_scales = marlin_permute_scales(
164
+ s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size
165
+ )
166
+ marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
167
+ layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
168
+
169
+ if hasattr(layer, "bias") and layer.bias is not None:
170
+ assert layer.bias.shape == (part_size_n,)
171
+ bias = marlin_permute_bias(layer.bias)
172
+ layer.bias = torch.nn.Parameter(bias, requires_grad=False)
173
+
174
+
175
+ def prepare_moe_fp8_layer_for_marlin(
176
+ layer: torch.nn.Module, size_k_first: bool = True
177
+ ) -> None:
178
+ logger.warning_once(
179
+ "Your GPU does not have native support for FP8 computation but "
180
+ "FP8 quantization is being used. Weight-only FP8 compression will "
181
+ "be used leveraging the Marlin kernel. This may degrade "
182
+ "performance for compute-heavy workloads."
183
+ )
184
+
185
+ e = layer.num_experts
186
+ k = layer.hidden_size
187
+ n = layer.intermediate_size_per_partition
188
+ weight_block_size = getattr(layer, "weight_block_size", None)
189
+
190
+ # WORKSPACE
191
+ device = layer.w13_weight.device
192
+ layer.workspace = marlin_make_workspace(device, 4)
193
+ perm = torch.empty(0, dtype=torch.int, device=device)
194
+
195
+ # WEIGHT
196
+ # Repack weights to marlin format
197
+ for name in ["w13_weight", "w2_weight"]:
198
+ weight = getattr(layer, name)
199
+ tensor_list = []
200
+ if "w13" in name:
201
+ size_n, size_k = n * 2, k
202
+ else:
203
+ size_n, size_k = k, n
204
+
205
+ if size_k_first:
206
+ assert weight.shape == (e, size_k, size_n)
207
+ else:
208
+ assert weight.shape == (e, size_n, size_k)
209
+
210
+ for i in range(e):
211
+ qweight = pack_fp8_to_int32(weight[i], size_k_first)
212
+ if not size_k_first:
213
+ qweight = qweight.T.contiguous()
214
+
215
+ marlin_qweight = gptq_marlin_repack(
216
+ b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8
217
+ )
218
+ tensor_list.append(marlin_qweight)
219
+
220
+ weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
221
+ weight = torch.nn.Parameter(weight, requires_grad=False)
222
+
223
+ setattr(layer, name, weight)
224
+
225
+ # WEIGHT SCALES
226
+ # Permute scales
227
+ group_size = -1 if weight_block_size is None else weight_block_size[1]
228
+
229
+ for name in ["w13", "w2"]:
230
+ if name + "_weight_scale" in dir(layer):
231
+ new_name = name + "_weight_scale"
232
+ scales = getattr(layer, new_name).to(layer.orig_dtype)
233
+ delattr(layer, new_name)
234
+ elif name + "_weight_scale_inv" in dir(layer):
235
+ new_name = name + "_weight_scale_inv"
236
+ scales = getattr(layer, new_name).to(layer.orig_dtype)
237
+ delattr(layer, new_name)
238
+
239
+ tensor_list = []
240
+ if "w13" in name:
241
+ size_n, size_k = n * 2, k
242
+ else:
243
+ size_n, size_k = k, n
244
+
245
+ # marlin kernel only support channel-wise and group-wise quantization
246
+ # we need to convert the scales
247
+ if weight_block_size is None:
248
+ if scales.nelement() == e:
249
+ # tensor-wise quantization -> channel-wise quantization
250
+ # (e, 1, 1) =>(repeat)=> (e, 1, size_n)
251
+ scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2)
252
+ elif scales.nelement() > e and scales.nelement() != e * size_n:
253
+ assert (e * size_n) % scales.nelement() == 0
254
+ s_size = scales.nelement() // e
255
+ # tensor-wise quantization (for gate-up proj)
256
+ # -> channel-wise quantization
257
+ # (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
258
+ scales = scales.view(e, 1, s_size)
259
+ scales = scales.repeat_interleave(size_n // s_size, 2)
260
+ else:
261
+ # channel-wise quantization
262
+ # (e, 1, size_n)
263
+ scales = scales.view(e, 1, size_n)
264
+ else:
265
+ # block-wise quantization -> group-wise quantization
266
+ # (e, size_k // block_size[1], ceil(size_n / block_size[0]))
267
+ # =>(repeat)=> (e, size_k // block_size[1], size_n)
268
+ if not size_k_first:
269
+ scales = scales.permute(0, 2, 1)
270
+ block_n = weight_block_size[0]
271
+ scales = scales.repeat_interleave(block_n, 2)
272
+ # size_n may not divisible by block_size[0]
273
+ scales = scales[..., :size_n].contiguous()
274
+
275
+ for i in range(e):
276
+ marlin_scales = marlin_permute_scales(
277
+ s=scales[i], size_k=size_k, size_n=size_n, group_size=group_size
278
+ )
279
+ tensor_list.append(marlin_scales)
280
+
281
+ scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
282
+ scales = fp8_fused_exponent_bias_into_scales(scales)
283
+ scales = torch.nn.Parameter(scales, requires_grad=False)
284
+
285
+ setattr(layer, name + "_weight_scale", scales)
286
+
287
+ # BIAS
288
+ # Permute bias
289
+ for name in ["w13_bias", "w2_bias"]:
290
+ if not hasattr(layer, name):
291
+ continue
292
+ bias = getattr(layer, name).to(layer.orig_dtype)
293
+
294
+ tensor_list = []
295
+ for i in range(e):
296
+ expert_bias = bias[i]
297
+
298
+ tensor_list.append(marlin_permute_bias(expert_bias))
299
+
300
+ bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
301
+ bias = torch.nn.Parameter(bias, requires_grad=False)
302
+ setattr(layer, name, bias)
303
+
304
+
305
+ def pack_fp8_to_int32(
306
+ fp8_tensor: torch.Tensor, size_k_first: bool = True
307
+ ) -> torch.Tensor:
308
+ """
309
+ Repack FP8 weights to gptq format (packed int32 elements)
310
+ """
311
+ assert fp8_tensor.dtype == torch.float8_e4m3fn
312
+ assert fp8_tensor.ndim == 2
313
+
314
+ fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
315
+ fp8_tensor = fp8_tensor.contiguous()
316
+ # fp8_tensor is contiguous and have shape (N, K) now
317
+ # with `.view(torch.int32)`, it become (N, K // 4)
318
+ int32_tensor = fp8_tensor.view(torch.int32)
319
+ return int32_tensor.T.contiguous() if size_k_first else int32_tensor
320
+
321
+
322
+ def marlin_quant_fp8_torch(weight, group_size):
323
+ size_n, size_k = weight.shape
324
+ device = weight.device
325
+
326
+ if group_size != -1:
327
+ scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
328
+ repeated_scales = scales.repeat_interleave(group_size, 1)
329
+ fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
330
+ weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
331
+ else:
332
+ scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
333
+ repeated_scales = scales.repeat_interleave(size_k, 1)
334
+ fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
335
+ weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
336
+
337
+ packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
338
+ marlin_qweight = gptq_marlin_repack(
339
+ b_q_weight=packed_weight,
340
+ perm=torch.empty(0, dtype=torch.int, device=device),
341
+ size_k=size_k,
342
+ size_n=size_n,
343
+ num_bits=8,
344
+ )
345
+
346
+ marlin_scales = marlin_permute_scales(
347
+ s=scales.T, size_k=size_k, size_n=size_n, group_size=group_size
348
+ )
349
+
350
+ marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
351
+
352
+ return weight_ref.T, marlin_qweight, marlin_scales