sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/lang/chat_template.py +24 -0
  3. sglang/srt/_custom_ops.py +59 -92
  4. sglang/srt/configs/model_config.py +5 -0
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/conversation.py +29 -4
  7. sglang/srt/custom_op.py +5 -0
  8. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  9. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/layers/attention/flashattention_backend.py +678 -83
  12. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  14. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  16. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  17. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  18. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
  30. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  31. sglang/srt/layers/moe/topk.py +49 -3
  32. sglang/srt/layers/quantization/__init__.py +5 -1
  33. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  35. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  36. sglang/srt/layers/quantization/fp8.py +3 -1
  37. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  38. sglang/srt/layers/quantization/moe_wna16.py +503 -0
  39. sglang/srt/layers/quantization/utils.py +1 -1
  40. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  41. sglang/srt/layers/radix_attention.py +2 -0
  42. sglang/srt/layers/rotary_embedding.py +63 -12
  43. sglang/srt/managers/cache_controller.py +34 -11
  44. sglang/srt/managers/mm_utils.py +202 -156
  45. sglang/srt/managers/multimodal_processor.py +0 -2
  46. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  47. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  48. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  49. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  50. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  51. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  52. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  53. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  54. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  55. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  56. sglang/srt/managers/schedule_batch.py +185 -128
  57. sglang/srt/managers/scheduler.py +4 -4
  58. sglang/srt/managers/tokenizer_manager.py +1 -1
  59. sglang/srt/managers/utils.py +1 -6
  60. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  61. sglang/srt/mem_cache/memory_pool.py +72 -6
  62. sglang/srt/mem_cache/paged_allocator.py +39 -0
  63. sglang/srt/metrics/collector.py +23 -53
  64. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  65. sglang/srt/model_executor/forward_batch_info.py +10 -10
  66. sglang/srt/model_executor/model_runner.py +60 -57
  67. sglang/srt/model_loader/loader.py +8 -0
  68. sglang/srt/models/clip.py +12 -7
  69. sglang/srt/models/deepseek_janus_pro.py +10 -15
  70. sglang/srt/models/deepseek_v2.py +212 -121
  71. sglang/srt/models/deepseek_vl2.py +105 -104
  72. sglang/srt/models/gemma3_mm.py +14 -80
  73. sglang/srt/models/llama.py +16 -5
  74. sglang/srt/models/llama4.py +420 -0
  75. sglang/srt/models/llava.py +31 -19
  76. sglang/srt/models/llavavid.py +16 -7
  77. sglang/srt/models/minicpmo.py +63 -147
  78. sglang/srt/models/minicpmv.py +17 -27
  79. sglang/srt/models/mllama.py +29 -14
  80. sglang/srt/models/mllama4.py +154 -0
  81. sglang/srt/models/qwen2.py +9 -6
  82. sglang/srt/models/qwen2_5_vl.py +21 -31
  83. sglang/srt/models/qwen2_vl.py +20 -21
  84. sglang/srt/openai_api/adapter.py +18 -6
  85. sglang/srt/platforms/interface.py +371 -0
  86. sglang/srt/server_args.py +99 -14
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  88. sglang/srt/speculative/eagle_utils.py +140 -28
  89. sglang/srt/speculative/eagle_worker.py +93 -24
  90. sglang/srt/utils.py +104 -51
  91. sglang/test/test_custom_ops.py +55 -0
  92. sglang/test/test_utils.py +13 -26
  93. sglang/utils.py +2 -2
  94. sglang/version.py +1 -1
  95. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
  96. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
  97. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,503 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/moe_wna16.py
2
+
3
+ import logging
4
+ from typing import Any, Callable, Dict, List, Optional
5
+
6
+ import torch
7
+
8
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
9
+ from sglang.srt.distributed.parallel_state import get_tp_group
10
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
11
+ from sglang.srt.layers.quantization.awq import AWQConfig
12
+ from sglang.srt.layers.quantization.base_config import (
13
+ QuantizationConfig,
14
+ QuantizeMethodBase,
15
+ )
16
+ from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
17
+ from sglang.srt.utils import get_device_capability, set_weight_attrs
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class MoeWNA16Config(QuantizationConfig):
23
+ """Config class for MOE WNA16 (W8A16/W4A16) quantization."""
24
+
25
+ def __init__(
26
+ self,
27
+ linear_quant_method: str,
28
+ weight_bits: int,
29
+ group_size: int,
30
+ has_zp: bool,
31
+ lm_head_quantized: bool,
32
+ modules_to_not_convert: Optional[List[str]],
33
+ full_config: Dict[str, Any],
34
+ ) -> None:
35
+ super().__init__()
36
+ self.weight_bits = weight_bits
37
+ self.group_size = group_size
38
+ self.has_zp = has_zp
39
+ self.bit8_pack_factor = 8 // self.weight_bits
40
+ self.lm_head_quantized = lm_head_quantized
41
+ self.linear_quant_method = linear_quant_method
42
+ self.full_config = full_config
43
+ self.use_marlin = False
44
+ # Avoid circular import
45
+
46
+ if self.linear_quant_method == "gptq":
47
+ self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config)
48
+ elif self.linear_quant_method == "awq":
49
+ capability_tuple = get_device_capability()
50
+ device_capability = (
51
+ -1
52
+ if capability_tuple is None
53
+ else capability_tuple[0] * 10 + capability_tuple[1]
54
+ )
55
+ awq_min_capability = AWQConfig.get_min_capability()
56
+ if device_capability < awq_min_capability:
57
+ raise ValueError(
58
+ "The quantization method moe_wna16 + awq is not supported "
59
+ "for the current GPU. "
60
+ f"Minimum capability: {awq_min_capability}. "
61
+ f"Current capability: {device_capability}."
62
+ )
63
+ else:
64
+ raise ValueError("moe_wna16 only support gptq and awq.")
65
+
66
+ if modules_to_not_convert is None:
67
+ self.modules_to_not_convert = []
68
+ else:
69
+ self.modules_to_not_convert = modules_to_not_convert
70
+
71
+ @classmethod
72
+ def get_name(cls) -> str:
73
+ return "moe_wna16"
74
+
75
+ @classmethod
76
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
77
+ return [torch.bfloat16, torch.half]
78
+
79
+ @classmethod
80
+ def get_min_capability(cls) -> int:
81
+ return 70
82
+
83
+ @classmethod
84
+ def get_config_filenames(cls) -> List[str]:
85
+ return ["quantize_config.json"]
86
+
87
+ def get_scaled_act_names(self) -> List[str]:
88
+ raise NotImplementedError
89
+
90
+ @classmethod
91
+ def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
92
+ quant_method = cls.get_from_keys(config, ["quant_method"])
93
+ weight_bits = cls.get_from_keys(config, ["bits"])
94
+ group_size = cls.get_from_keys(config, ["group_size"])
95
+ lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
96
+ if quant_method == "gptq":
97
+ has_zp = not cls.get_from_keys(config, ["sym"])
98
+ modules_to_not_convert = []
99
+ elif quant_method == "awq":
100
+ has_zp = cls.get_from_keys(config, ["zero_point"])
101
+ modules_to_not_convert = cls.get_from_keys_or(
102
+ config, ["modules_to_not_convert"], None
103
+ )
104
+ else:
105
+ raise ValueError("moe_wna16 only support gptq and awq.")
106
+
107
+ return cls(
108
+ quant_method,
109
+ weight_bits,
110
+ group_size,
111
+ has_zp,
112
+ lm_head_quantized,
113
+ modules_to_not_convert,
114
+ config,
115
+ )
116
+
117
+ @classmethod
118
+ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
119
+ can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
120
+ if can_convert and user_quant == "moe_wna16":
121
+ return cls.get_name()
122
+ return None
123
+
124
+ @classmethod
125
+ def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]):
126
+ # Extract data from quant config.
127
+ quant_method = quant_config.get("quant_method", "").lower()
128
+ num_bits = quant_config.get("bits")
129
+ desc_act = quant_config.get("desc_act")
130
+
131
+ capability_tuple = get_device_capability()
132
+ device_capability = (
133
+ -1
134
+ if capability_tuple is None
135
+ else capability_tuple[0] * 10 + capability_tuple[1]
136
+ )
137
+ # Avoid circular import
138
+ awq_min_capability = AWQConfig.get_min_capability()
139
+
140
+ gptq_compatible = quant_method == "gptq" and not desc_act and num_bits in [4, 8]
141
+ awq_compatible = (
142
+ quant_method == "awq"
143
+ and num_bits == 4
144
+ and device_capability >= awq_min_capability
145
+ )
146
+
147
+ return gptq_compatible or awq_compatible
148
+
149
+ def get_quant_method(
150
+ self, layer: torch.nn.Module, prefix: str
151
+ ) -> Optional["QuantizeMethodBase"]:
152
+ # avoid circular import
153
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
154
+
155
+ if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
156
+ return UnquantizedLinearMethod()
157
+ elif isinstance(layer, LinearBase):
158
+
159
+ if self.linear_quant_method == "gptq":
160
+ if self.use_marlin:
161
+ return GPTQMarlinConfig.from_config(
162
+ self.full_config
163
+ ).get_quant_method(layer, prefix)
164
+ else:
165
+ return GPTQConfig.from_config(self.full_config).get_quant_method(
166
+ layer, prefix
167
+ )
168
+ elif self.linear_quant_method == "awq":
169
+ return AWQConfig.from_config(self.full_config).get_quant_method(
170
+ layer, prefix
171
+ )
172
+ else:
173
+ raise ValueError("moe_wna16 only support gptq and awq.")
174
+ elif isinstance(layer, FusedMoE):
175
+ return MoeWNA16Method(self)
176
+ return None
177
+
178
+
179
+ def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
180
+ return any(module_name in prefix for module_name in modules_to_not_convert)
181
+
182
+
183
+ class MoeWNA16Method:
184
+ """Linear method for MOE WNA16 (W8A16/W4A16) quantization.
185
+
186
+ Args:
187
+ quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
188
+ """
189
+
190
+ def __new__(cls, *args, **kwargs):
191
+ # avoid circular import
192
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
193
+
194
+ if not hasattr(cls, "_initialized"):
195
+ original_init = cls.__init__
196
+ new_cls = type(
197
+ cls.__name__,
198
+ (FusedMoEMethodBase,),
199
+ {
200
+ "__init__": original_init,
201
+ **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
202
+ },
203
+ )
204
+ obj = super(new_cls, new_cls).__new__(new_cls)
205
+ obj.__init__(*args, **kwargs)
206
+ return obj
207
+ return super().__new__(cls)
208
+
209
+ def __init__(self, quant_config: MoeWNA16Config):
210
+ self.quant_config = quant_config
211
+
212
+ def create_weights(
213
+ self,
214
+ layer: torch.nn.Module,
215
+ num_experts: int,
216
+ hidden_size: int,
217
+ intermediate_size_per_partition: int,
218
+ params_dtype: torch.dtype,
219
+ **extra_weight_attrs,
220
+ ):
221
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
222
+
223
+ layer.quant_config = self.quant_config
224
+ bit8_pack_factor = self.quant_config.bit8_pack_factor
225
+ group_size = self.quant_config.group_size
226
+ group_size_div_factor = 1
227
+
228
+ # make intermediate_size and hidden_size diviable by group_size
229
+ # we reduce the group size to ensure that
230
+ # and we would repeat the loaded_weight later
231
+ while intermediate_size_per_partition % group_size or hidden_size % group_size:
232
+ group_size = group_size // 2
233
+ group_size_div_factor *= 2
234
+ assert group_size >= 32
235
+ layer.group_size = group_size
236
+ layer.group_size_div_factor = group_size_div_factor
237
+
238
+ strategy = FusedMoeWeightScaleSupported.GROUP.value
239
+ extra_weight_attrs.update({"quant_method": strategy, "is_transposed": False})
240
+
241
+ assert "weight_loader" in extra_weight_attrs
242
+ weight_loader = extra_weight_attrs["weight_loader"]
243
+ wrapped_weight_loader = MoeWNA16Method.get_weight_loader(layer, weight_loader)
244
+ extra_weight_attrs["weight_loader"] = wrapped_weight_loader
245
+
246
+ # Fused gate_up_proj (column parallel)
247
+ w13_qweight = torch.nn.Parameter(
248
+ torch.empty(
249
+ num_experts,
250
+ 2 * intermediate_size_per_partition,
251
+ hidden_size // bit8_pack_factor,
252
+ dtype=torch.uint8,
253
+ ),
254
+ requires_grad=False,
255
+ )
256
+ layer.register_parameter("w13_qweight", w13_qweight)
257
+ set_weight_attrs(w13_qweight, extra_weight_attrs)
258
+
259
+ # down_proj (row parallel)
260
+ w2_qweight = torch.nn.Parameter(
261
+ torch.empty(
262
+ num_experts,
263
+ hidden_size,
264
+ intermediate_size_per_partition // bit8_pack_factor,
265
+ dtype=torch.uint8,
266
+ ),
267
+ requires_grad=False,
268
+ )
269
+ layer.register_parameter("w2_qweight", w2_qweight)
270
+ set_weight_attrs(w2_qweight, extra_weight_attrs)
271
+
272
+ w13_scales = torch.nn.Parameter(
273
+ torch.zeros(
274
+ num_experts,
275
+ 2 * intermediate_size_per_partition,
276
+ hidden_size // group_size,
277
+ dtype=params_dtype,
278
+ ),
279
+ requires_grad=False,
280
+ )
281
+ layer.register_parameter("w13_scales", w13_scales)
282
+ set_weight_attrs(w13_scales, extra_weight_attrs)
283
+
284
+ w2_scales = torch.nn.Parameter(
285
+ torch.zeros(
286
+ num_experts,
287
+ hidden_size,
288
+ intermediate_size_per_partition // group_size,
289
+ dtype=params_dtype,
290
+ ),
291
+ requires_grad=False,
292
+ )
293
+ layer.register_parameter("w2_scales", w2_scales)
294
+ set_weight_attrs(w2_scales, extra_weight_attrs)
295
+
296
+ if self.quant_config.has_zp:
297
+ w13_qzeros = torch.nn.Parameter(
298
+ torch.zeros(
299
+ num_experts,
300
+ 2 * intermediate_size_per_partition // bit8_pack_factor,
301
+ hidden_size // group_size,
302
+ dtype=torch.uint8,
303
+ ),
304
+ requires_grad=False,
305
+ )
306
+ layer.register_parameter("w13_qzeros", w13_qzeros)
307
+ set_weight_attrs(w13_qzeros, extra_weight_attrs)
308
+
309
+ w2_qzeros = torch.nn.Parameter(
310
+ torch.zeros(
311
+ num_experts,
312
+ hidden_size // bit8_pack_factor,
313
+ intermediate_size_per_partition // group_size,
314
+ dtype=torch.uint8,
315
+ ),
316
+ requires_grad=False,
317
+ )
318
+ layer.register_parameter("w2_qzeros", w2_qzeros)
319
+ set_weight_attrs(w2_qzeros, extra_weight_attrs)
320
+
321
+ if self.quant_config.linear_quant_method == "gptq":
322
+ # some param are unused, but we need to init them in order to
323
+ # load weights
324
+ invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
325
+ if not self.quant_config.has_zp:
326
+ invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
327
+ for key in invalid_param_keys:
328
+ param = torch.nn.Parameter(
329
+ torch.empty((0,), dtype=torch.int32), requires_grad=False
330
+ )
331
+ layer.register_parameter(key, param)
332
+ set_weight_attrs(param, extra_weight_attrs)
333
+
334
+ def apply(
335
+ self,
336
+ layer: torch.nn.Module,
337
+ x: torch.Tensor,
338
+ router_logits: torch.Tensor,
339
+ top_k: int,
340
+ renormalize: bool,
341
+ use_grouped_topk: bool = False,
342
+ topk_group: Optional[int] = None,
343
+ num_expert_group: Optional[int] = None,
344
+ custom_routing_function: Optional[Callable] = None,
345
+ correction_bias: Optional[torch.Tensor] = None,
346
+ activation: str = "silu",
347
+ apply_router_weight_on_input: bool = False,
348
+ inplace: bool = True,
349
+ no_combine: bool = False,
350
+ ) -> torch.Tensor:
351
+ # avoid circular import
352
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
353
+ from sglang.srt.layers.moe.topk import select_experts
354
+
355
+ assert activation == "silu", "Only SiLU activation is supported."
356
+ topk_weights, topk_ids = select_experts(
357
+ hidden_states=x,
358
+ router_logits=router_logits,
359
+ top_k=top_k,
360
+ use_grouped_topk=use_grouped_topk,
361
+ renormalize=renormalize,
362
+ topk_group=topk_group,
363
+ num_expert_group=num_expert_group,
364
+ custom_routing_function=custom_routing_function,
365
+ correction_bias=correction_bias,
366
+ )
367
+
368
+ weight_bits = self.quant_config.weight_bits
369
+ has_zp = self.quant_config.has_zp
370
+
371
+ return fused_experts(
372
+ x,
373
+ layer.w13_qweight,
374
+ layer.w2_qweight,
375
+ topk_weights=topk_weights,
376
+ topk_ids=topk_ids,
377
+ inplace=inplace,
378
+ apply_router_weight_on_input=apply_router_weight_on_input,
379
+ use_int4_w4a16=weight_bits == 4,
380
+ use_int8_w8a16=weight_bits == 8,
381
+ w1_scale=layer.w13_scales,
382
+ w2_scale=layer.w2_scales,
383
+ w1_zp=layer.w13_qzeros if has_zp else None,
384
+ w2_zp=layer.w2_qzeros if has_zp else None,
385
+ block_shape=[0, layer.group_size],
386
+ no_combine=no_combine,
387
+ )
388
+
389
+ @staticmethod
390
+ def get_weight_loader(layer, weight_loader):
391
+
392
+ def convert_awq_tensor(tensor, tensor_type):
393
+ # convert awq qweight/qzeros to a standard format (assume int4)
394
+ # qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
395
+ # qzeros: (k // group_size, n // pack_factor_bit32) ->
396
+ # (n // pack_factor_bit8, k // group_size)
397
+ # pack_factor_bit32 = 32 // weight_bits
398
+ # pack_factor_bit8 = 8 // weight_bits
399
+
400
+ # 0. suppose origin shape (a, b), dtype int32
401
+ # 1. convert to uint8, shape (a, b) -> (a, 4 * b)
402
+ size0 = tensor.size(0)
403
+ tensor = tensor.view(torch.uint8)
404
+
405
+ # 2. unpack to uint4 (only when weight_bits == 4)
406
+ # shape (a, 4 * b) -> (a, 4 * b, 2)
407
+ shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device)
408
+ tensor = (tensor[:, :, None] >> shifter) & 0xF
409
+
410
+ # 3. change order, see
411
+ # https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
412
+ # shape -> (a, 4 * b * pack_factor_bit8)
413
+ reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
414
+ tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
415
+ tensor = tensor.view(size0, -1)
416
+
417
+ # 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
418
+ tensor = tensor.T.contiguous()
419
+
420
+ # 5. repack (only when weight_bits == 4)
421
+ # qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
422
+ # qzeros shape -> (4 * b, a)
423
+
424
+ if tensor_type == "qweight":
425
+ tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
426
+ elif tensor_type == "qzeros":
427
+ tensor = tensor[1::2, :] * 16 + tensor[::2, :]
428
+ return tensor
429
+
430
+ def convert_gptq_int4_qzeros(tensor):
431
+ tensor = tensor.view(torch.uint8)
432
+ shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device)
433
+ tensor = (tensor[:, :, None] >> shifter) & 0xF
434
+ tensor = tensor + 1
435
+ tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
436
+ return tensor
437
+
438
+ def moe_wna16_weight_loader(
439
+ param: torch.nn.Parameter,
440
+ loaded_weight: torch.Tensor,
441
+ weight_name: str,
442
+ shard_id: str,
443
+ expert_id: int,
444
+ ):
445
+ if "g_idx" in weight_name:
446
+ return
447
+ if not layer.quant_config.has_zp and "qzeros" in weight_name:
448
+ return
449
+
450
+ device = get_tp_group().device
451
+ tp_rank = get_tensor_model_parallel_rank()
452
+ loaded_weight = loaded_weight.to(device)
453
+ shard_size = layer.intermediate_size_per_partition
454
+
455
+ # convert gptq and awq weight to a standard format
456
+ if layer.quant_config.linear_quant_method == "awq":
457
+ assert layer.quant_config.weight_bits == 4
458
+ if "weight" in weight_name:
459
+ loaded_weight = convert_awq_tensor(loaded_weight, "qweight")
460
+ elif "zeros" in weight_name:
461
+ loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
462
+ else:
463
+ loaded_weight = loaded_weight.T
464
+ elif layer.quant_config.linear_quant_method == "gptq":
465
+ assert layer.quant_config.weight_bits in [4, 8]
466
+ if "weight" in weight_name:
467
+ loaded_weight = loaded_weight.T.contiguous().view(torch.uint8)
468
+ elif "zeros" in weight_name:
469
+ # add 1 to gptq qzeros to align with awq
470
+ loaded_weight = loaded_weight.view(torch.uint8)
471
+ if layer.quant_config.weight_bits == 4:
472
+ loaded_weight = convert_gptq_int4_qzeros(loaded_weight).T
473
+ else:
474
+ loaded_weight = loaded_weight.T + 1
475
+ else:
476
+ loaded_weight = loaded_weight.T
477
+
478
+ # repeat the qzeros/scales to fit new group size
479
+ if (
480
+ layer.group_size_div_factor > 1
481
+ and "qzeros" in weight_name
482
+ or "scales" in weight_name
483
+ ):
484
+ loaded_weight = loaded_weight.repeat_interleave(
485
+ layer.group_size_div_factor, 1
486
+ )
487
+
488
+ if "w13_qzeros" in weight_name:
489
+ tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[
490
+ tp_rank
491
+ ]
492
+ if shard_id == "w1":
493
+ param.data[expert_id, : shard_size // 2] = tensor
494
+ else:
495
+ param.data[expert_id, shard_size // 2 :] = tensor
496
+ elif "w2_qzeros" in weight_name:
497
+ param.data[expert_id] = loaded_weight.view(
498
+ loaded_weight.size(0), layer.tp_size, -1
499
+ )[:, tp_rank]
500
+ else:
501
+ weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)
502
+
503
+ return moe_wna16_weight_loader
@@ -1,7 +1,7 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
2
2
 
3
3
  from types import MappingProxyType
4
- from typing import List, Mapping, Tuple, Union
4
+ from typing import List, Mapping, Optional, Tuple, Union
5
5
 
6
6
  import torch
7
7
 
@@ -230,6 +230,7 @@ class W8A8Int8MoEMethod:
230
230
  custom_routing_function: Optional[Callable] = None,
231
231
  correction_bias: Optional[torch.Tensor] = None,
232
232
  activation: str = "silu",
233
+ apply_router_weight_on_input: bool = False,
233
234
  inplace: bool = True,
234
235
  no_combine: bool = False,
235
236
  ) -> torch.Tensor:
@@ -257,6 +258,7 @@ class W8A8Int8MoEMethod:
257
258
  topk_ids=topk_ids,
258
259
  inplace=inplace,
259
260
  activation=activation,
261
+ apply_router_weight_on_input=apply_router_weight_on_input,
260
262
  use_int8_w8a8=True,
261
263
  w1_scale=(layer.w13_weight_scale),
262
264
  w2_scale=(layer.w2_weight_scale),
@@ -35,6 +35,7 @@ class RadixAttention(nn.Module):
35
35
  sliding_window_size: int = -1,
36
36
  is_cross_attention: bool = False,
37
37
  prefix: str = "",
38
+ use_irope: bool = False,
38
39
  ):
39
40
  super().__init__()
40
41
  self.tp_q_head_num = num_heads
@@ -50,6 +51,7 @@ class RadixAttention(nn.Module):
50
51
  self.is_cross_attention = is_cross_attention
51
52
  self.k_scale = None
52
53
  self.v_scale = None
54
+ self.use_irope = use_irope
53
55
 
54
56
  def forward(
55
57
  self,
@@ -651,18 +651,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
651
651
  query: torch.Tensor,
652
652
  key: torch.Tensor,
653
653
  offsets: Optional[torch.Tensor] = None,
654
- ) -> Tuple[torch.Tensor, torch.Tensor]:
655
- if _is_cuda_available:
656
- return self.forward_cuda(positions, query, key, offsets)
657
- else:
658
- return self.forward_native(positions, query, key, offsets)
659
-
660
- def forward_native(
661
- self,
662
- positions: torch.Tensor,
663
- query: torch.Tensor,
664
- key: torch.Tensor,
665
- offsets: Optional[torch.Tensor] = None,
666
654
  ) -> Tuple[torch.Tensor, torch.Tensor]:
667
655
  """PyTorch-native implementation equivalent to forward()."""
668
656
  query_rot = query[..., : self.rotary_dim]
@@ -745,6 +733,69 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
745
733
  return new_freqs
746
734
 
747
735
 
736
+ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
737
+
738
+ def __init__(
739
+ self,
740
+ head_size: int,
741
+ rotary_dim: int,
742
+ max_position_embeddings: int,
743
+ base: int,
744
+ is_neox_style: bool,
745
+ dtype: torch.dtype,
746
+ ):
747
+ super().__init__(
748
+ head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
749
+ )
750
+
751
+ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
752
+ inv_freqs = super()._compute_inv_freq(base)
753
+ inv_freqs = inv_freqs[: (self.rotary_dim // 2)]
754
+ return inv_freqs
755
+
756
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
757
+ inv_freq = self._compute_inv_freq(self.base)
758
+
759
+ # self.max_position_embeddings here is number of image patches
760
+ # i.e. (image_size // patch_size) ** 2
761
+ num_patches = self.max_position_embeddings
762
+ img_idx = torch.arange(num_patches, dtype=torch.int32).reshape(num_patches, 1)
763
+ img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
764
+ img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
765
+ num_patches_single_dim = int(math.sqrt(num_patches))
766
+ frequencies_x = img_idx % num_patches_single_dim
767
+ frequencies_y = img_idx // num_patches_single_dim
768
+ freqs_x = (
769
+ (frequencies_x + 1)[..., None] * inv_freq[None, None, :]
770
+ ).repeat_interleave(2, dim=-1)
771
+ freqs_y = (
772
+ (frequencies_y + 1)[..., None] * inv_freq[None, None, :]
773
+ ).repeat_interleave(2, dim=-1)
774
+ freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
775
+ freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
776
+ cache = torch.view_as_complex(
777
+ torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
778
+ )
779
+ return cache
780
+
781
+ def forward(
782
+ self,
783
+ query: torch.Tensor,
784
+ key: torch.Tensor,
785
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
786
+ self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
787
+ query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
788
+ key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
789
+ broadcast_shape = [
790
+ d if i == 1 or i == (query_.ndim - 1) else 1
791
+ for i, d in enumerate(query_.shape)
792
+ ]
793
+ freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
794
+ query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
795
+ key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
796
+ return query_out.type_as(query), key_out.type_as(key)
797
+
798
+
748
799
  class MRotaryEmbedding(RotaryEmbedding):
749
800
  """Rotary Embedding with Multimodal Sections."""
750
801