sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
26
26
  from sglang.srt.layers.quantization.utils import (
27
27
  convert_to_channelwise,
28
28
  is_layer_skipped,
29
+ per_tensor_dequantize,
29
30
  requantize_with_max_scale,
30
31
  )
31
32
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig):
110
111
  self, layer: torch.nn.Module, prefix: str
111
112
  ) -> Optional["QuantizeMethodBase"]:
112
113
  if self.exclude_modules and any(
113
- module in prefix for module in self.exclude_modules
114
+ module in prefix
115
+ or (
116
+ prefix.startswith("language_model.")
117
+ and module in prefix.removeprefix("language_model.")
118
+ )
119
+ for module in self.exclude_modules
114
120
  ):
115
121
  return None
116
122
 
@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig):
119
125
  if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
120
126
  return ModelOptFp8KVCacheMethod(self)
121
127
 
128
+ # Add MoE support
129
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
130
+
131
+ if isinstance(layer, FusedMoE):
132
+ return ModelOptFp8MoEMethod(self)
133
+
122
134
  return None
123
135
 
124
136
  def get_scaled_act_names(self) -> List[str]:
@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
234
246
  super().__init__(quant_config)
235
247
 
236
248
 
249
+ class ModelOptFp8MoEMethod:
250
+ """MoE method for ModelOpt FP8.
251
+ Supports loading FP8 checkpoints with static weight scale and activation scale.
252
+
253
+ Args:
254
+ quant_config: The ModelOpt quantization config.
255
+ """
256
+
257
+ def __new__(cls, *args, **kwargs):
258
+ """
259
+ Dynamic class composition pattern.
260
+
261
+ This allows us to effectively "inject" FusedMoEMethodBase as a parent class
262
+ at runtime while avoiding circular import issues.
263
+ """
264
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
265
+
266
+ if not hasattr(cls, "_initialized"):
267
+ original_init = cls.__init__
268
+ new_cls = type(
269
+ cls.__name__,
270
+ (FusedMoEMethodBase,),
271
+ {
272
+ "__init__": original_init,
273
+ **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
274
+ },
275
+ )
276
+ obj = super(new_cls, new_cls).__new__(new_cls)
277
+ obj.__init__(*args, **kwargs)
278
+ return obj
279
+ return super().__new__(cls)
280
+
281
+ def __init__(self, quant_config: ModelOptFp8Config):
282
+ self.quant_config = quant_config
283
+ self.cutlass_fp8_supported = cutlass_fp8_supported()
284
+
285
+ def create_weights(
286
+ self,
287
+ layer: torch.nn.Module,
288
+ num_experts: int,
289
+ hidden_size: int,
290
+ intermediate_size: int,
291
+ params_dtype: torch.dtype,
292
+ **extra_weight_attrs,
293
+ ):
294
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
295
+
296
+ # Use FP8 dtype if checkpoint is serialized, otherwise use the default dtype
297
+ weight_dtype = (
298
+ torch.float8_e4m3fn
299
+ if self.quant_config.is_checkpoint_fp8_serialized
300
+ else params_dtype
301
+ )
302
+ weight_loader = extra_weight_attrs.get("weight_loader")
303
+
304
+ w13_weight = ModelWeightParameter(
305
+ data=torch.empty(
306
+ num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
307
+ ),
308
+ input_dim=2,
309
+ output_dim=1,
310
+ weight_loader=weight_loader,
311
+ )
312
+ layer.register_parameter("w13_weight", w13_weight)
313
+
314
+ w2_weight = ModelWeightParameter(
315
+ data=torch.empty(
316
+ num_experts, hidden_size, intermediate_size, dtype=weight_dtype
317
+ ),
318
+ input_dim=2,
319
+ output_dim=1,
320
+ weight_loader=weight_loader,
321
+ )
322
+ layer.register_parameter("w2_weight", w2_weight)
323
+
324
+ if self.quant_config.is_checkpoint_fp8_serialized:
325
+ # WEIGHT SCALES - Per-tensor scaling for ModelOpts
326
+ # Allocate 2 scales for w1 and w3 respectively.
327
+ # They will be combined to a single scale after weight loading.
328
+ w13_weight_scale = PerTensorScaleParameter(
329
+ data=torch.full(
330
+ (num_experts, 2),
331
+ torch.finfo(torch.float32).min,
332
+ dtype=torch.float32,
333
+ ),
334
+ weight_loader=weight_loader,
335
+ )
336
+ w2_weight_scale = PerTensorScaleParameter(
337
+ data=torch.full(
338
+ (num_experts,), torch.finfo(torch.float32).min, dtype=torch.float32
339
+ ),
340
+ weight_loader=weight_loader,
341
+ )
342
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
343
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
344
+
345
+ # Set weight loader attributes for scales
346
+ extra_weight_attrs.update(
347
+ {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
348
+ )
349
+
350
+ # INPUT SCALES - Per-tensor scaling for ModelOpt
351
+ w13_input_scale = PerTensorScaleParameter(
352
+ data=torch.full((num_experts,), 1.0, dtype=torch.float32),
353
+ weight_loader=weight_loader,
354
+ )
355
+ w2_input_scale = PerTensorScaleParameter(
356
+ data=torch.full((num_experts,), 1.0, dtype=torch.float32),
357
+ weight_loader=weight_loader,
358
+ )
359
+ layer.register_parameter("w13_input_scale", w13_input_scale)
360
+ layer.register_parameter("w2_input_scale", w2_input_scale)
361
+
362
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
363
+ """Process FP8 MoE weights after loading from serialized checkpoint.
364
+
365
+ Only supports pre-quantized checkpoints with FP8 weights and scales.
366
+ """
367
+
368
+ layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
369
+ layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
370
+
371
+ # Handle scale parameters
372
+ if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
373
+ # Fp8 moe kernel needs single weight scale for w13 per expert.
374
+ # We take the max of the w1 and w3 scales then dequant and requant each expert.
375
+ if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2)
376
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
377
+
378
+ # Get the maximum scale across w1 and w3 for each expert
379
+ max_w13_scales = layer.w13_weight_scale.max(dim=1).values
380
+
381
+ # Requantize each expert's weights using the combined scale
382
+ # w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
383
+ # where the first intermediate_size rows are w1, the next are w3
384
+ intermediate_size = layer.w13_weight.shape[1] // 2
385
+ for expert_id in range(layer.w13_weight.shape[0]):
386
+ start = 0
387
+ for shard_id in range(2): # w1 and w3
388
+ # Dequantize using the original scale for this shard
389
+ dq_weight = per_tensor_dequantize(
390
+ layer.w13_weight[expert_id][
391
+ start : start + intermediate_size, :
392
+ ],
393
+ layer.w13_weight_scale[expert_id][shard_id],
394
+ )
395
+ # Requantize using the combined max scale
396
+ (
397
+ layer.w13_weight[expert_id][
398
+ start : start + intermediate_size, :
399
+ ],
400
+ _,
401
+ ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
402
+
403
+ start += intermediate_size
404
+
405
+ # Update the scale parameter to be per-expert instead of per-shard
406
+ layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
407
+ else:
408
+ layer.w13_weight_scale = Parameter(
409
+ layer.w13_weight_scale.data, requires_grad=False
410
+ )
411
+
412
+ if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
413
+ layer.w2_weight_scale = Parameter(
414
+ layer.w2_weight_scale.data, requires_grad=False
415
+ )
416
+ if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
417
+ layer.w13_input_scale = Parameter(
418
+ layer.w13_input_scale.max(), requires_grad=False
419
+ )
420
+ if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
421
+ layer.w2_input_scale = Parameter(
422
+ layer.w2_input_scale.max(), requires_grad=False
423
+ )
424
+
425
+ def apply(
426
+ self,
427
+ layer: torch.nn.Module,
428
+ x: torch.Tensor,
429
+ router_logits: torch.Tensor,
430
+ top_k: int,
431
+ renormalize: bool,
432
+ use_grouped_topk: bool,
433
+ topk_group: Optional[int] = None,
434
+ num_expert_group: Optional[int] = None,
435
+ num_fused_shared_experts: Optional[int] = None,
436
+ custom_routing_function: Optional[Callable] = None,
437
+ correction_bias: Optional[torch.Tensor] = None,
438
+ activation: str = "silu",
439
+ apply_router_weight_on_input: bool = False,
440
+ inplace: bool = True,
441
+ no_combine: bool = False,
442
+ routed_scaling_factor: Optional[float] = None,
443
+ ) -> torch.Tensor:
444
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
445
+ from sglang.srt.layers.moe.topk import select_experts
446
+
447
+ # Expert selection
448
+ topk_weights, topk_ids = select_experts(
449
+ hidden_states=x,
450
+ router_logits=router_logits,
451
+ use_grouped_topk=use_grouped_topk,
452
+ top_k=top_k,
453
+ renormalize=renormalize,
454
+ topk_group=topk_group,
455
+ num_expert_group=num_expert_group,
456
+ num_fused_shared_experts=num_fused_shared_experts,
457
+ custom_routing_function=custom_routing_function,
458
+ correction_bias=correction_bias,
459
+ routed_scaling_factor=routed_scaling_factor,
460
+ )
461
+
462
+ return fused_experts(
463
+ x,
464
+ layer.w13_weight,
465
+ layer.w2_weight,
466
+ topk_weights=topk_weights,
467
+ topk_ids=topk_ids,
468
+ inplace=inplace,
469
+ activation=activation,
470
+ use_fp8_w8a8=True,
471
+ per_channel_quant=False, # ModelOpt uses per-tensor quantization
472
+ w1_scale=layer.w13_weight_scale,
473
+ w2_scale=layer.w2_weight_scale,
474
+ a1_scale=layer.w13_input_scale,
475
+ a2_scale=layer.w2_input_scale,
476
+ no_combine=no_combine,
477
+ )
478
+
479
+
237
480
  class ModelOptFp4Config(QuantizationConfig):
238
481
  """Config class for FP4."""
239
482
 
@@ -131,7 +131,7 @@ class MoeWNA16Config(QuantizationConfig):
131
131
  capability_tuple = get_device_capability()
132
132
  device_capability = (
133
133
  -1
134
- if capability_tuple is None
134
+ if all(capability is None for capability in capability_tuple)
135
135
  else capability_tuple[0] * 10 + capability_tuple[1]
136
136
  )
137
137
  # Avoid circular import
@@ -0,0 +1,166 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
3
+
4
+ from typing import Optional
5
+
6
+ import numpy
7
+ import torch
8
+ from sgl_kernel.scalar_type import ScalarType
9
+
10
+
11
+ def get_pack_factor(num_bits):
12
+ assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
13
+ return 32 // num_bits
14
+
15
+
16
+ def pack_cols(
17
+ q_w: torch.Tensor,
18
+ num_bits: int,
19
+ size_k: int,
20
+ size_n: int,
21
+ ):
22
+ assert q_w.shape == (size_k, size_n)
23
+
24
+ pack_factor = get_pack_factor(num_bits)
25
+ assert size_n % pack_factor == 0
26
+
27
+ orig_device = q_w.device
28
+
29
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
30
+
31
+ q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
32
+
33
+ for i in range(pack_factor):
34
+ q_res |= q_w[:, i::pack_factor] << num_bits * i
35
+
36
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
37
+ q_res = q_res.contiguous()
38
+
39
+ return q_res
40
+
41
+
42
+ def unpack_cols(
43
+ packed_q_w: torch.Tensor,
44
+ num_bits: int,
45
+ size_k: int,
46
+ size_n: int,
47
+ ):
48
+ pack_factor = get_pack_factor(num_bits)
49
+ assert size_n % pack_factor == 0
50
+ assert packed_q_w.shape == (
51
+ size_k,
52
+ size_n // pack_factor,
53
+ ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
54
+ packed_q_w.shape, size_k, size_n, pack_factor
55
+ )
56
+
57
+ orig_device = packed_q_w.device
58
+
59
+ packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
60
+ q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
61
+
62
+ mask = (1 << num_bits) - 1
63
+ for i in range(pack_factor):
64
+ vals = packed_q_w_cpu & mask
65
+ packed_q_w_cpu >>= num_bits
66
+ q_res[:, i::pack_factor] = vals
67
+
68
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
69
+ q_res = q_res.contiguous()
70
+
71
+ return q_res
72
+
73
+
74
+ def quantize_weights(
75
+ w: torch.Tensor,
76
+ quant_type: ScalarType,
77
+ group_size: Optional[int],
78
+ zero_points: bool = False,
79
+ ref_zero_points_after_scales: bool = False,
80
+ ):
81
+ assert (
82
+ quant_type.is_integer()
83
+ ), "Floating point quantization may work but has not been tested"
84
+ assert not zero_points or group_size is not None, (
85
+ "to have group zero points, group_size must be provided "
86
+ "(-1 group_size is channelwise)"
87
+ )
88
+
89
+ orig_device = w.device
90
+ orig_type = w.dtype
91
+ size_k, size_n = w.shape
92
+
93
+ assert w.is_floating_point(), "w must be float"
94
+
95
+ if group_size == -1:
96
+ group_size = size_k
97
+
98
+ # Reshape to [groupsize, -1]
99
+ if group_size is not None and group_size < size_k:
100
+ w = w.reshape((-1, group_size, size_n))
101
+ w = w.permute(1, 0, 2)
102
+ w = w.reshape((group_size, -1))
103
+
104
+ # Compute scale for each group
105
+ max_val = torch.max(w, 0, keepdim=True).values
106
+ min_val = torch.min(w, 0, keepdim=True).values
107
+
108
+ max_q_val = quant_type.max()
109
+ min_q_val = quant_type.min()
110
+
111
+ w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
112
+ maybe_w_zp = None
113
+ if group_size is not None:
114
+ if zero_points:
115
+ assert not quant_type.is_signed() and quant_type.max() > 0
116
+ w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
117
+ maybe_w_zp = (
118
+ torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
119
+ )
120
+ else:
121
+ # If the bias is such that there are no possible negative/positive
122
+ # values, set the max value to inf to avoid divide by 0
123
+ w_s = torch.max(
124
+ abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
125
+ abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
126
+ )
127
+
128
+ # Quantize
129
+ w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
130
+ w_q = torch.clamp(w_q, min_q_val, max_q_val)
131
+
132
+ # Compute ref (dequantized)
133
+ # For some kernels (namely Machete) the zero-points are applied after the
134
+ # scales are applied, for this case computing the reference in similar way
135
+ # allows us to use tighter error tolerances in our unit tests.
136
+ if ref_zero_points_after_scales and maybe_w_zp is not None:
137
+ w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
138
+ else:
139
+ w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
140
+
141
+ if quant_type.has_bias():
142
+ w_q += quant_type.bias
143
+
144
+ # Restore original shapes
145
+ if group_size is not None and group_size < size_k:
146
+
147
+ def reshape_w(w):
148
+ w = w.reshape((group_size, -1, size_n))
149
+ w = w.permute(1, 0, 2)
150
+ w = w.reshape((size_k, size_n)).contiguous()
151
+ return w
152
+
153
+ w_q = reshape_w(w_q)
154
+ w_ref = reshape_w(w_ref)
155
+ w_s = w_s.reshape((-1, size_n)).contiguous()
156
+
157
+ if maybe_w_zp is not None:
158
+ maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
159
+ maybe_w_zp = maybe_w_zp.to(device=orig_device)
160
+
161
+ return (
162
+ w_ref.to(device=orig_device),
163
+ w_q.to(device=orig_device),
164
+ w_s if group_size is not None else None,
165
+ maybe_w_zp,
166
+ )