sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +36 -2
  3. sglang/srt/conversation.py +56 -3
  4. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  5. sglang/srt/disaggregation/ascend/conn.py +44 -0
  6. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +50 -18
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  9. sglang/srt/disaggregation/utils.py +25 -3
  10. sglang/srt/entrypoints/engine.py +1 -1
  11. sglang/srt/entrypoints/http_server.py +1 -0
  12. sglang/srt/entrypoints/http_server_engine.py +1 -1
  13. sglang/srt/entrypoints/openai/protocol.py +11 -0
  14. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  15. sglang/srt/function_call/function_call_parser.py +2 -0
  16. sglang/srt/function_call/kimik2_detector.py +220 -0
  17. sglang/srt/hf_transformers_utils.py +18 -0
  18. sglang/srt/jinja_template_utils.py +8 -0
  19. sglang/srt/layers/communicator.py +20 -5
  20. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  21. sglang/srt/layers/layernorm.py +2 -2
  22. sglang/srt/layers/linear.py +12 -2
  23. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  24. sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
  25. sglang/srt/layers/moe/ep_moe/layer.py +141 -2
  26. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  29. sglang/srt/layers/moe/topk.py +8 -2
  30. sglang/srt/layers/parameter.py +19 -3
  31. sglang/srt/layers/quantization/__init__.py +2 -0
  32. sglang/srt/layers/quantization/fp8.py +28 -7
  33. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  35. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  36. sglang/srt/layers/quantization/w4afp8.py +264 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  38. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  39. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  40. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  41. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  42. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  43. sglang/srt/managers/cache_controller.py +41 -195
  44. sglang/srt/managers/io_struct.py +35 -3
  45. sglang/srt/managers/mm_utils.py +59 -96
  46. sglang/srt/managers/schedule_batch.py +17 -6
  47. sglang/srt/managers/scheduler.py +38 -6
  48. sglang/srt/managers/tokenizer_manager.py +16 -0
  49. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  50. sglang/srt/mem_cache/memory_pool.py +176 -101
  51. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  52. sglang/srt/mem_cache/radix_cache.py +8 -4
  53. sglang/srt/model_executor/forward_batch_info.py +13 -1
  54. sglang/srt/model_loader/loader.py +23 -12
  55. sglang/srt/models/deepseek_janus_pro.py +1 -1
  56. sglang/srt/models/deepseek_v2.py +78 -19
  57. sglang/srt/models/deepseek_vl2.py +1 -1
  58. sglang/srt/models/gemma3_mm.py +1 -1
  59. sglang/srt/models/gemma3n_mm.py +6 -3
  60. sglang/srt/models/internvl.py +8 -2
  61. sglang/srt/models/kimi_vl.py +8 -2
  62. sglang/srt/models/llama.py +2 -0
  63. sglang/srt/models/llava.py +3 -1
  64. sglang/srt/models/llavavid.py +1 -1
  65. sglang/srt/models/minicpmo.py +1 -2
  66. sglang/srt/models/minicpmv.py +1 -1
  67. sglang/srt/models/mixtral_quant.py +4 -0
  68. sglang/srt/models/mllama4.py +372 -82
  69. sglang/srt/models/phi4mm.py +8 -2
  70. sglang/srt/models/phimoe.py +553 -0
  71. sglang/srt/models/qwen2.py +2 -0
  72. sglang/srt/models/qwen2_5_vl.py +10 -7
  73. sglang/srt/models/qwen2_vl.py +12 -1
  74. sglang/srt/models/vila.py +8 -2
  75. sglang/srt/multimodal/mm_utils.py +2 -2
  76. sglang/srt/multimodal/processors/base_processor.py +197 -137
  77. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  78. sglang/srt/multimodal/processors/gemma3.py +4 -2
  79. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  80. sglang/srt/multimodal/processors/internvl.py +1 -1
  81. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  82. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  83. sglang/srt/multimodal/processors/minicpm.py +4 -3
  84. sglang/srt/multimodal/processors/mllama4.py +63 -61
  85. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  86. sglang/srt/multimodal/processors/pixtral.py +1 -1
  87. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  88. sglang/srt/multimodal/processors/vila.py +1 -1
  89. sglang/srt/server_args.py +26 -4
  90. sglang/srt/two_batch_overlap.py +3 -0
  91. sglang/srt/utils.py +191 -48
  92. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  93. sglang/utils.py +5 -5
  94. sglang/version.py +1 -1
  95. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
  96. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
  97. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
26
26
  from sglang.srt.layers.quantization.utils import (
27
27
  convert_to_channelwise,
28
28
  is_layer_skipped,
29
+ per_tensor_dequantize,
29
30
  requantize_with_max_scale,
30
31
  )
31
32
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig):
110
111
  self, layer: torch.nn.Module, prefix: str
111
112
  ) -> Optional["QuantizeMethodBase"]:
112
113
  if self.exclude_modules and any(
113
- module in prefix for module in self.exclude_modules
114
+ module in prefix
115
+ or (
116
+ prefix.startswith("language_model.")
117
+ and module in prefix.removeprefix("language_model.")
118
+ )
119
+ for module in self.exclude_modules
114
120
  ):
115
121
  return None
116
122
 
@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig):
119
125
  if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
120
126
  return ModelOptFp8KVCacheMethod(self)
121
127
 
128
+ # Add MoE support
129
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
130
+
131
+ if isinstance(layer, FusedMoE):
132
+ return ModelOptFp8MoEMethod(self)
133
+
122
134
  return None
123
135
 
124
136
  def get_scaled_act_names(self) -> List[str]:
@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
234
246
  super().__init__(quant_config)
235
247
 
236
248
 
249
+ class ModelOptFp8MoEMethod:
250
+ """MoE method for ModelOpt FP8.
251
+ Supports loading FP8 checkpoints with static weight scale and activation scale.
252
+
253
+ Args:
254
+ quant_config: The ModelOpt quantization config.
255
+ """
256
+
257
+ def __new__(cls, *args, **kwargs):
258
+ """
259
+ Dynamic class composition pattern.
260
+
261
+ This allows us to effectively "inject" FusedMoEMethodBase as a parent class
262
+ at runtime while avoiding circular import issues.
263
+ """
264
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
265
+
266
+ if not hasattr(cls, "_initialized"):
267
+ original_init = cls.__init__
268
+ new_cls = type(
269
+ cls.__name__,
270
+ (FusedMoEMethodBase,),
271
+ {
272
+ "__init__": original_init,
273
+ **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
274
+ },
275
+ )
276
+ obj = super(new_cls, new_cls).__new__(new_cls)
277
+ obj.__init__(*args, **kwargs)
278
+ return obj
279
+ return super().__new__(cls)
280
+
281
+ def __init__(self, quant_config: ModelOptFp8Config):
282
+ self.quant_config = quant_config
283
+ self.cutlass_fp8_supported = cutlass_fp8_supported()
284
+
285
+ def create_weights(
286
+ self,
287
+ layer: torch.nn.Module,
288
+ num_experts: int,
289
+ hidden_size: int,
290
+ intermediate_size: int,
291
+ params_dtype: torch.dtype,
292
+ **extra_weight_attrs,
293
+ ):
294
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
295
+
296
+ # Use FP8 dtype if checkpoint is serialized, otherwise use the default dtype
297
+ weight_dtype = (
298
+ torch.float8_e4m3fn
299
+ if self.quant_config.is_checkpoint_fp8_serialized
300
+ else params_dtype
301
+ )
302
+ weight_loader = extra_weight_attrs.get("weight_loader")
303
+
304
+ w13_weight = ModelWeightParameter(
305
+ data=torch.empty(
306
+ num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
307
+ ),
308
+ input_dim=2,
309
+ output_dim=1,
310
+ weight_loader=weight_loader,
311
+ )
312
+ layer.register_parameter("w13_weight", w13_weight)
313
+
314
+ w2_weight = ModelWeightParameter(
315
+ data=torch.empty(
316
+ num_experts, hidden_size, intermediate_size, dtype=weight_dtype
317
+ ),
318
+ input_dim=2,
319
+ output_dim=1,
320
+ weight_loader=weight_loader,
321
+ )
322
+ layer.register_parameter("w2_weight", w2_weight)
323
+
324
+ if self.quant_config.is_checkpoint_fp8_serialized:
325
+ # WEIGHT SCALES - Per-tensor scaling for ModelOpts
326
+ # Allocate 2 scales for w1 and w3 respectively.
327
+ # They will be combined to a single scale after weight loading.
328
+ w13_weight_scale = PerTensorScaleParameter(
329
+ data=torch.full(
330
+ (num_experts, 2),
331
+ torch.finfo(torch.float32).min,
332
+ dtype=torch.float32,
333
+ ),
334
+ weight_loader=weight_loader,
335
+ )
336
+ w2_weight_scale = PerTensorScaleParameter(
337
+ data=torch.full(
338
+ (num_experts,), torch.finfo(torch.float32).min, dtype=torch.float32
339
+ ),
340
+ weight_loader=weight_loader,
341
+ )
342
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
343
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
344
+
345
+ # Set weight loader attributes for scales
346
+ extra_weight_attrs.update(
347
+ {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
348
+ )
349
+
350
+ # INPUT SCALES - Per-tensor scaling for ModelOpt
351
+ w13_input_scale = PerTensorScaleParameter(
352
+ data=torch.full((num_experts,), 1.0, dtype=torch.float32),
353
+ weight_loader=weight_loader,
354
+ )
355
+ w2_input_scale = PerTensorScaleParameter(
356
+ data=torch.full((num_experts,), 1.0, dtype=torch.float32),
357
+ weight_loader=weight_loader,
358
+ )
359
+ layer.register_parameter("w13_input_scale", w13_input_scale)
360
+ layer.register_parameter("w2_input_scale", w2_input_scale)
361
+
362
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
363
+ """Process FP8 MoE weights after loading from serialized checkpoint.
364
+
365
+ Only supports pre-quantized checkpoints with FP8 weights and scales.
366
+ """
367
+
368
+ layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
369
+ layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
370
+
371
+ # Handle scale parameters
372
+ if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
373
+ # Fp8 moe kernel needs single weight scale for w13 per expert.
374
+ # We take the max of the w1 and w3 scales then dequant and requant each expert.
375
+ if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2)
376
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
377
+
378
+ # Get the maximum scale across w1 and w3 for each expert
379
+ max_w13_scales = layer.w13_weight_scale.max(dim=1).values
380
+
381
+ # Requantize each expert's weights using the combined scale
382
+ # w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
383
+ # where the first intermediate_size rows are w1, the next are w3
384
+ intermediate_size = layer.w13_weight.shape[1] // 2
385
+ for expert_id in range(layer.w13_weight.shape[0]):
386
+ start = 0
387
+ for shard_id in range(2): # w1 and w3
388
+ # Dequantize using the original scale for this shard
389
+ dq_weight = per_tensor_dequantize(
390
+ layer.w13_weight[expert_id][
391
+ start : start + intermediate_size, :
392
+ ],
393
+ layer.w13_weight_scale[expert_id][shard_id],
394
+ )
395
+ # Requantize using the combined max scale
396
+ (
397
+ layer.w13_weight[expert_id][
398
+ start : start + intermediate_size, :
399
+ ],
400
+ _,
401
+ ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
402
+
403
+ start += intermediate_size
404
+
405
+ # Update the scale parameter to be per-expert instead of per-shard
406
+ layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
407
+ else:
408
+ layer.w13_weight_scale = Parameter(
409
+ layer.w13_weight_scale.data, requires_grad=False
410
+ )
411
+
412
+ if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
413
+ layer.w2_weight_scale = Parameter(
414
+ layer.w2_weight_scale.data, requires_grad=False
415
+ )
416
+ if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
417
+ layer.w13_input_scale = Parameter(
418
+ layer.w13_input_scale.max(), requires_grad=False
419
+ )
420
+ if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
421
+ layer.w2_input_scale = Parameter(
422
+ layer.w2_input_scale.max(), requires_grad=False
423
+ )
424
+
425
+ def apply(
426
+ self,
427
+ layer: torch.nn.Module,
428
+ x: torch.Tensor,
429
+ router_logits: torch.Tensor,
430
+ top_k: int,
431
+ renormalize: bool,
432
+ use_grouped_topk: bool,
433
+ topk_group: Optional[int] = None,
434
+ num_expert_group: Optional[int] = None,
435
+ num_fused_shared_experts: Optional[int] = None,
436
+ custom_routing_function: Optional[Callable] = None,
437
+ correction_bias: Optional[torch.Tensor] = None,
438
+ activation: str = "silu",
439
+ apply_router_weight_on_input: bool = False,
440
+ inplace: bool = True,
441
+ no_combine: bool = False,
442
+ routed_scaling_factor: Optional[float] = None,
443
+ ) -> torch.Tensor:
444
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
445
+ from sglang.srt.layers.moe.topk import select_experts
446
+
447
+ # Expert selection
448
+ topk_weights, topk_ids = select_experts(
449
+ hidden_states=x,
450
+ router_logits=router_logits,
451
+ use_grouped_topk=use_grouped_topk,
452
+ top_k=top_k,
453
+ renormalize=renormalize,
454
+ topk_group=topk_group,
455
+ num_expert_group=num_expert_group,
456
+ num_fused_shared_experts=num_fused_shared_experts,
457
+ custom_routing_function=custom_routing_function,
458
+ correction_bias=correction_bias,
459
+ routed_scaling_factor=routed_scaling_factor,
460
+ )
461
+
462
+ return fused_experts(
463
+ x,
464
+ layer.w13_weight,
465
+ layer.w2_weight,
466
+ topk_weights=topk_weights,
467
+ topk_ids=topk_ids,
468
+ inplace=inplace,
469
+ activation=activation,
470
+ use_fp8_w8a8=True,
471
+ per_channel_quant=False, # ModelOpt uses per-tensor quantization
472
+ w1_scale=layer.w13_weight_scale,
473
+ w2_scale=layer.w2_weight_scale,
474
+ a1_scale=layer.w13_input_scale,
475
+ a2_scale=layer.w2_input_scale,
476
+ no_combine=no_combine,
477
+ )
478
+
479
+
237
480
  class ModelOptFp4Config(QuantizationConfig):
238
481
  """Config class for FP4."""
239
482
 
@@ -116,8 +116,7 @@ class MoeWNA16Config(QuantizationConfig):
116
116
 
117
117
  @classmethod
118
118
  def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
119
- can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
120
- if can_convert and user_quant == "moe_wna16":
119
+ if user_quant == "moe_wna16" and cls.is_moe_wna16_compatible(hf_quant_cfg):
121
120
  return cls.get_name()
122
121
  return None
123
122
 
@@ -0,0 +1,264 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import torch
5
+ from torch.nn import Module
6
+ from torch.nn.parameter import Parameter
7
+
8
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
9
+ from sglang.srt.layers.quantization.base_config import (
10
+ QuantizationConfig,
11
+ QuantizeMethodBase,
12
+ )
13
+ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
14
+ from sglang.srt.layers.quantization.utils import is_layer_skipped
15
+ from sglang.srt.utils import set_weight_attrs
16
+
17
+ ACTIVATION_SCHEMES = ["static", "dynamic"]
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class W4AFp8Config(QuantizationConfig):
23
+ """Config class for MIXED_PRECISION W4AFp8."""
24
+
25
+ def __init__(
26
+ self,
27
+ is_checkpoint_fp8_serialized: bool = True,
28
+ is_checkpoint_w4afp8_serialized: bool = True,
29
+ linear_activation_scheme: str = "dynamic",
30
+ moe_activation_scheme: str = "static",
31
+ ignored_layers: Optional[List[str]] = None,
32
+ weight_block_size: Optional[List[int]] = None,
33
+ group_size: int = 128,
34
+ ) -> None:
35
+ super().__init__()
36
+ self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
37
+ self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized
38
+ if is_checkpoint_w4afp8_serialized:
39
+ logger.warning("Detected w4afp8 checkpoint. Please note that")
40
+ if moe_activation_scheme not in ACTIVATION_SCHEMES:
41
+ raise ValueError(f"Unsupported activation scheme {moe_activation_scheme}")
42
+ self.linear_activation_scheme = linear_activation_scheme
43
+ self.moe_activation_scheme = moe_activation_scheme
44
+ self.ignored_layers = ignored_layers or []
45
+ self.weight_block_size = [128, 128]
46
+ self.group_size = group_size
47
+
48
+ @classmethod
49
+ def get_name(cls) -> str:
50
+ return "w4afp8"
51
+
52
+ @classmethod
53
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
54
+ return [torch.bfloat16, torch.float8_e4m3fn]
55
+
56
+ @classmethod
57
+ def get_min_capability(cls) -> int:
58
+ return 90
59
+
60
+ @classmethod
61
+ def get_config_filenames(cls) -> List[str]:
62
+ return []
63
+
64
+ @classmethod
65
+ def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config":
66
+ quant_method = cls.get_from_keys(config, ["quant_method"])
67
+ is_checkpoint_fp8_serialized = "fp8" in quant_method
68
+ is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
69
+ linear_activation_scheme = "dynamic"
70
+ moe_activation_scheme = "static"
71
+ weight_block_size = [128, 128]
72
+ return cls(
73
+ is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
74
+ is_checkpoint_w4afp8_serialized=is_checkpoint_w4afp8_serialized,
75
+ linear_activation_scheme=linear_activation_scheme,
76
+ moe_activation_scheme=moe_activation_scheme,
77
+ weight_block_size=weight_block_size,
78
+ )
79
+
80
+ def get_quant_method(
81
+ self, layer: torch.nn.Module, prefix: str
82
+ ) -> Optional["QuantizeMethodBase"]:
83
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
84
+
85
+ if isinstance(layer, LinearBase):
86
+ if is_layer_skipped(prefix, self.ignored_layers):
87
+ return UnquantizedLinearMethod()
88
+ return Fp8LinearMethod(self)
89
+ elif isinstance(layer, FusedMoE):
90
+ return W4AFp8MoEMethod(self)
91
+ return None
92
+
93
+ def get_scaled_act_names(self) -> List[str]:
94
+ return []
95
+
96
+
97
+ class W4AFp8MoEMethod:
98
+
99
+ def __init__(self, quant_config: W4AFp8Config):
100
+ self.quant_config = quant_config
101
+
102
+ def create_weights(
103
+ self,
104
+ layer: Module,
105
+ num_experts_per_partition: int,
106
+ hidden_size: int,
107
+ intermediate_size: int,
108
+ params_dtype: torch.dtype,
109
+ **extra_weight_attrs,
110
+ ):
111
+ assert "weight_loader" in extra_weight_attrs
112
+
113
+ # Fused gate_up_proj (column parallel)
114
+ w13_weight = torch.nn.Parameter(
115
+ torch.empty(
116
+ num_experts_per_partition,
117
+ intermediate_size * 2,
118
+ hidden_size // 2,
119
+ dtype=torch.int8,
120
+ ),
121
+ requires_grad=False,
122
+ )
123
+ layer.register_parameter("w13_weight", w13_weight)
124
+ set_weight_attrs(w13_weight, extra_weight_attrs)
125
+
126
+ # down_proj (row parallel)
127
+ w2_weight = torch.nn.Parameter(
128
+ torch.empty(
129
+ num_experts_per_partition,
130
+ hidden_size,
131
+ intermediate_size // 2,
132
+ dtype=torch.int8,
133
+ ),
134
+ requires_grad=False,
135
+ )
136
+ layer.register_parameter("w2_weight", w2_weight)
137
+ set_weight_attrs(w2_weight, extra_weight_attrs)
138
+
139
+ w13_weight_scale = torch.nn.Parameter(
140
+ torch.zeros(
141
+ num_experts_per_partition,
142
+ 2 * intermediate_size,
143
+ hidden_size // self.quant_config.group_size,
144
+ dtype=torch.float32,
145
+ ),
146
+ requires_grad=False,
147
+ )
148
+ layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
149
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
150
+
151
+ w2_weight_scale = torch.nn.Parameter(
152
+ torch.zeros(
153
+ num_experts_per_partition,
154
+ hidden_size,
155
+ intermediate_size // self.quant_config.group_size,
156
+ dtype=torch.float32,
157
+ ),
158
+ requires_grad=False,
159
+ )
160
+ layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
161
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
162
+
163
+ # Input scales
164
+ w13_input_scale = torch.nn.Parameter(
165
+ torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16),
166
+ requires_grad=False,
167
+ )
168
+ layer.register_parameter("w13_input_scale", w13_input_scale)
169
+ set_weight_attrs(w13_input_scale, extra_weight_attrs)
170
+
171
+ w2_input_scale = torch.nn.Parameter(
172
+ torch.ones(num_experts_per_partition, dtype=torch.bfloat16),
173
+ requires_grad=False,
174
+ )
175
+ layer.register_parameter("w2_input_scale", w2_input_scale)
176
+ set_weight_attrs(w2_input_scale, extra_weight_attrs)
177
+
178
+ # Pre-populate the strides
179
+ device = layer.w13_weight.device
180
+
181
+ self.a_strides1 = torch.full(
182
+ (num_experts_per_partition, 3),
183
+ hidden_size,
184
+ device=device,
185
+ dtype=torch.int64,
186
+ )
187
+ self.c_strides1 = torch.full(
188
+ (num_experts_per_partition, 3),
189
+ 2 * intermediate_size,
190
+ device=device,
191
+ dtype=torch.int64,
192
+ )
193
+ self.a_strides2 = torch.full(
194
+ (num_experts_per_partition, 3),
195
+ intermediate_size,
196
+ device=device,
197
+ dtype=torch.int64,
198
+ )
199
+ self.c_strides2 = torch.full(
200
+ (num_experts_per_partition, 3),
201
+ hidden_size,
202
+ device=device,
203
+ dtype=torch.int64,
204
+ )
205
+ self.b_strides1 = self.a_strides1
206
+ self.s_strides13 = self.c_strides1
207
+ self.b_strides2 = self.a_strides2
208
+ self.s_strides2 = self.c_strides2
209
+
210
+ self.expert_offsets = torch.empty(
211
+ (num_experts_per_partition + 1), dtype=torch.int32, device=device
212
+ )
213
+ self.problem_sizes1 = torch.empty(
214
+ (num_experts_per_partition, 3), dtype=torch.int32, device=device
215
+ )
216
+ self.problem_sizes2 = torch.empty(
217
+ (num_experts_per_partition, 3), dtype=torch.int32, device=device
218
+ )
219
+
220
+ return
221
+
222
+ def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
223
+ """Interleave scales in groups of 4 similar to TRT-LLM implementation."""
224
+ s_shape = scales.shape
225
+ # Reshape to separate groups of 4
226
+ scales_interleaved = scales.reshape(
227
+ s_shape[0], s_shape[1], (s_shape[2] // 4), 4
228
+ )
229
+ # Permute dimensions to interleave
230
+ scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
231
+ # Reshape back to original dimensions but with interleaved values
232
+ scales_interleaved = scales_interleaved.reshape(
233
+ s_shape[0], s_shape[2] // 4, s_shape[1] * 4
234
+ )
235
+ return scales_interleaved.contiguous()
236
+
237
+ def process_weights_after_loading(self, layer: Module) -> None:
238
+ dtype = torch.bfloat16
239
+ device = layer.w2_weight.device
240
+
241
+ # Interleave w13_weight_scale (gate_up_proj)
242
+ w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
243
+ w13_weight_scale = self._interleave_scales(w13_weight_scale)
244
+ layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
245
+
246
+ # Interleave w2_weight_scale (down_proj)
247
+ w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
248
+ w2_weight_scale = self._interleave_scales(w2_weight_scale)
249
+ layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
250
+
251
+ # Process input scales
252
+ w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item()
253
+ new_w13_input_scale = torch.tensor(
254
+ [w13_input_scale_max],
255
+ dtype=dtype,
256
+ device=device,
257
+ )
258
+ layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False)
259
+
260
+ w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item()
261
+ new_w2_input_scale = torch.tensor(
262
+ [w2_input_scale_max], dtype=dtype, device=device
263
+ )
264
+ layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)