sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +48 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +34 -0
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/nixl/conn.py +6 -6
  10. sglang/srt/disaggregation/prefill.py +2 -2
  11. sglang/srt/disaggregation/utils.py +1 -1
  12. sglang/srt/distributed/parallel_state.py +44 -17
  13. sglang/srt/entrypoints/EngineBase.py +8 -0
  14. sglang/srt/entrypoints/engine.py +40 -6
  15. sglang/srt/entrypoints/http_server.py +111 -24
  16. sglang/srt/entrypoints/openai/protocol.py +4 -2
  17. sglang/srt/eplb/__init__.py +0 -0
  18. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  19. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  20. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  21. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  22. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  24. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  25. sglang/srt/hf_transformers_utils.py +2 -1
  26. sglang/srt/layers/activation.py +2 -2
  27. sglang/srt/layers/amx_utils.py +86 -0
  28. sglang/srt/layers/attention/ascend_backend.py +219 -0
  29. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  30. sglang/srt/layers/attention/tbo_backend.py +37 -9
  31. sglang/srt/layers/communicator.py +18 -2
  32. sglang/srt/layers/dp_attention.py +9 -3
  33. sglang/srt/layers/elementwise.py +76 -12
  34. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  35. sglang/srt/layers/layernorm.py +26 -0
  36. sglang/srt/layers/linear.py +84 -14
  37. sglang/srt/layers/logits_processor.py +4 -4
  38. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +36 -13
  40. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  41. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
  42. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
  43. sglang/srt/layers/moe/router.py +60 -22
  44. sglang/srt/layers/moe/topk.py +10 -28
  45. sglang/srt/layers/parameter.py +67 -7
  46. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  47. sglang/srt/layers/quantization/fp8.py +44 -0
  48. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  49. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  50. sglang/srt/layers/quantization/gptq.py +5 -1
  51. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  52. sglang/srt/layers/quantization/quant_utils.py +166 -0
  53. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  54. sglang/srt/layers/rotary_embedding.py +2 -2
  55. sglang/srt/layers/vocab_parallel_embedding.py +11 -7
  56. sglang/srt/lora/lora.py +4 -5
  57. sglang/srt/lora/lora_manager.py +73 -20
  58. sglang/srt/managers/configure_logging.py +1 -1
  59. sglang/srt/managers/io_struct.py +50 -13
  60. sglang/srt/managers/mm_utils.py +73 -59
  61. sglang/srt/managers/multimodal_processor.py +2 -6
  62. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  63. sglang/srt/managers/schedule_batch.py +77 -84
  64. sglang/srt/managers/scheduler.py +113 -59
  65. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  66. sglang/srt/managers/session_controller.py +12 -3
  67. sglang/srt/managers/tokenizer_manager.py +314 -103
  68. sglang/srt/managers/tp_worker.py +13 -1
  69. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  70. sglang/srt/mem_cache/allocator.py +290 -0
  71. sglang/srt/mem_cache/chunk_cache.py +34 -2
  72. sglang/srt/mem_cache/memory_pool.py +289 -3
  73. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  74. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  75. sglang/srt/model_executor/forward_batch_info.py +17 -4
  76. sglang/srt/model_executor/model_runner.py +297 -56
  77. sglang/srt/model_loader/loader.py +41 -0
  78. sglang/srt/model_loader/weight_utils.py +72 -4
  79. sglang/srt/models/deepseek_nextn.py +1 -3
  80. sglang/srt/models/deepseek_v2.py +181 -45
  81. sglang/srt/models/deepseek_vl2.py +3 -5
  82. sglang/srt/models/gemma3_causal.py +1 -2
  83. sglang/srt/models/gemma3n_causal.py +4 -3
  84. sglang/srt/models/gemma3n_mm.py +4 -20
  85. sglang/srt/models/hunyuan.py +1 -1
  86. sglang/srt/models/kimi_vl.py +1 -2
  87. sglang/srt/models/llama.py +10 -4
  88. sglang/srt/models/llama4.py +32 -45
  89. sglang/srt/models/llama_eagle3.py +61 -11
  90. sglang/srt/models/llava.py +5 -5
  91. sglang/srt/models/minicpmo.py +2 -2
  92. sglang/srt/models/mistral.py +1 -1
  93. sglang/srt/models/mllama4.py +43 -11
  94. sglang/srt/models/phi4mm.py +1 -3
  95. sglang/srt/models/pixtral.py +3 -7
  96. sglang/srt/models/qwen2.py +31 -3
  97. sglang/srt/models/qwen2_5_vl.py +1 -3
  98. sglang/srt/models/qwen2_audio.py +200 -0
  99. sglang/srt/models/qwen2_moe.py +32 -6
  100. sglang/srt/models/qwen2_vl.py +1 -4
  101. sglang/srt/models/qwen3.py +94 -25
  102. sglang/srt/models/qwen3_moe.py +68 -21
  103. sglang/srt/models/vila.py +3 -8
  104. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  105. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  106. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  107. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  108. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  109. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  110. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  111. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  112. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  117. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  120. sglang/srt/operations_strategy.py +6 -2
  121. sglang/srt/reasoning_parser.py +26 -0
  122. sglang/srt/sampling/sampling_batch_info.py +39 -1
  123. sglang/srt/server_args.py +69 -22
  124. sglang/srt/speculative/build_eagle_tree.py +57 -18
  125. sglang/srt/speculative/eagle_worker.py +6 -4
  126. sglang/srt/two_batch_overlap.py +200 -27
  127. sglang/srt/utils.py +306 -146
  128. sglang/srt/warmup.py +12 -3
  129. sglang/test/runners.py +10 -1
  130. sglang/test/test_utils.py +15 -3
  131. sglang/version.py +1 -1
  132. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  133. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
  134. sglang/math_utils.py +0 -8
  135. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  136. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  137. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  138. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  139. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  140. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  141. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,8 @@ from typing import Callable, Optional, Union
7
7
  import torch
8
8
  from torch.nn import Parameter
9
9
 
10
+ from sglang.srt.utils import is_cpu
11
+
10
12
  __all__ = [
11
13
  "BasevLLMParameter",
12
14
  "PackedvLLMParameter",
@@ -21,6 +23,8 @@ __all__ = [
21
23
 
22
24
  logger = logging.getLogger(__name__)
23
25
 
26
+ _is_cpu = is_cpu()
27
+
24
28
 
25
29
  class BasevLLMParameter(Parameter):
26
30
  """
@@ -93,9 +97,28 @@ class _ColumnvLLMParameter(BasevLLMParameter):
93
97
  ):
94
98
  if not use_presharded_weights:
95
99
  shard_size = self.data.shape[self.output_dim]
96
- loaded_weight = loaded_weight.narrow(
97
- self.output_dim, tp_rank * shard_size, shard_size
100
+
101
+ from sglang.srt.model_loader.weight_utils import (
102
+ narrow_padded_param_and_loaded_weight,
98
103
  )
104
+
105
+ if _is_cpu:
106
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
107
+ self.data,
108
+ loaded_weight,
109
+ 0, # param_data_start
110
+ tp_rank * shard_size,
111
+ self.output_dim,
112
+ shard_size,
113
+ )
114
+ assert param_data.shape == loaded_weight.shape
115
+ param_data.copy_(loaded_weight)
116
+ return
117
+ else:
118
+ loaded_weight = loaded_weight.narrow(
119
+ self.output_dim, tp_rank * shard_size, shard_size
120
+ )
121
+
99
122
  assert self.data.shape == loaded_weight.shape
100
123
  self.data.copy_(loaded_weight)
101
124
 
@@ -116,10 +139,27 @@ class _ColumnvLLMParameter(BasevLLMParameter):
116
139
  param_data = self.data
117
140
 
118
141
  param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
119
- if not use_presharded_weights:
120
- loaded_weight = loaded_weight.narrow(
121
- self.output_dim, tp_rank * shard_size, shard_size
142
+
143
+ from sglang.srt.model_loader.weight_utils import (
144
+ narrow_padded_param_and_loaded_weight,
145
+ )
146
+
147
+ if _is_cpu:
148
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
149
+ param_data,
150
+ loaded_weight,
151
+ 0, # param_data_start
152
+ tp_rank * shard_size,
153
+ self.output_dim,
154
+ shard_size,
155
+ not use_presharded_weights,
122
156
  )
157
+ else:
158
+ if not use_presharded_weights:
159
+ loaded_weight = loaded_weight.narrow(
160
+ self.output_dim, tp_rank * shard_size, shard_size
161
+ )
162
+
123
163
  assert param_data.shape == loaded_weight.shape
124
164
  param_data.copy_(loaded_weight)
125
165
 
@@ -182,10 +222,30 @@ class RowvLLMParameter(BasevLLMParameter):
182
222
  ):
183
223
  if not use_presharded_weights:
184
224
  shard_size = self.data.shape[self.input_dim]
185
- loaded_weight = loaded_weight.narrow(
186
- self.input_dim, tp_rank * shard_size, shard_size
225
+
226
+ from sglang.srt.model_loader.weight_utils import (
227
+ narrow_padded_param_and_loaded_weight,
187
228
  )
188
229
 
230
+ if _is_cpu:
231
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
232
+ self.data,
233
+ loaded_weight,
234
+ 0, # param_data_start
235
+ tp_rank * shard_size,
236
+ self.input_dim,
237
+ shard_size,
238
+ )
239
+
240
+ assert param_data.shape == loaded_weight.shape
241
+ param_data.copy_(loaded_weight)
242
+
243
+ return
244
+ else:
245
+ loaded_weight = loaded_weight.narrow(
246
+ self.input_dim, tp_rank * shard_size, shard_size
247
+ )
248
+
189
249
  if len(loaded_weight.shape) == 0:
190
250
  loaded_weight = loaded_weight.reshape(1)
191
251
 
@@ -76,7 +76,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
76
76
  layer.input_scale = torch.nn.Parameter(
77
77
  layer.input_scale.data, requires_grad=False
78
78
  )
79
- prepare_fp8_layer_for_marlin(layer, strategy="channel")
79
+ prepare_fp8_layer_for_marlin(layer, size_k_first=True)
80
80
 
81
81
  def create_weights(
82
82
  self,
@@ -27,6 +27,7 @@ except ImportError:
27
27
 
28
28
 
29
29
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
30
+ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
30
31
  from sglang.srt.layers.linear import (
31
32
  LinearBase,
32
33
  LinearMethodBase,
@@ -73,6 +74,7 @@ from sglang.srt.utils import (
73
74
  log_info_on_rank0,
74
75
  print_warning_once,
75
76
  set_weight_attrs,
77
+ use_intel_amx_backend,
76
78
  )
77
79
 
78
80
  _is_hip = is_hip()
@@ -330,6 +332,12 @@ class Fp8LinearMethod(LinearMethodBase):
330
332
  )
331
333
 
332
334
  layer.input_scale = None
335
+ elif _is_cpu:
336
+ assert (
337
+ _is_cpu_amx_available
338
+ ), "Fp8LinearMethod on CPU requires that CPU has AMX support"
339
+ _amx_process_weight_after_loading(layer, ["weight"])
340
+ return
333
341
  else:
334
342
  weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
335
343
  layer.weight = torch.nn.Parameter(weight, requires_grad=False)
@@ -426,6 +434,17 @@ class Fp8LinearMethod(LinearMethodBase):
426
434
  )
427
435
 
428
436
  if self.block_quant:
437
+ if use_intel_amx_backend(layer):
438
+ return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
439
+ x,
440
+ layer.weight,
441
+ layer.weight_scale_inv,
442
+ self.quant_config.weight_block_size,
443
+ bias,
444
+ x.dtype,
445
+ True, # is_vnni
446
+ )
447
+
429
448
  return self.w8a8_block_fp8_linear(
430
449
  input=x,
431
450
  weight=layer.weight,
@@ -746,6 +765,13 @@ class Fp8MoEMethod:
746
765
  layer.w2_weight.data = shuffle_weight(
747
766
  layer.w2_weight.contiguous(), (16, 16)
748
767
  )
768
+
769
+ if _is_cpu:
770
+ assert (
771
+ _is_cpu_amx_available
772
+ ), "Fp8MoEMethod on CPU requires that CPU has AMX support"
773
+ _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
774
+
749
775
  return
750
776
 
751
777
  # If checkpoint is fp16 or bfloat16, quantize in place.
@@ -971,6 +997,24 @@ class Fp8MoEMethod:
971
997
  routed_scaling_factor=routed_scaling_factor,
972
998
  )
973
999
 
1000
+ if use_intel_amx_backend(layer):
1001
+ return torch.ops.sgl_kernel.fused_experts_cpu(
1002
+ x,
1003
+ layer.w13_weight,
1004
+ layer.w2_weight,
1005
+ topk_weights,
1006
+ topk_ids,
1007
+ False, # inplace See [Note] inplace should be False in fused_experts.
1008
+ False, # use_int8_w8a8
1009
+ True, # use_fp8_w8a16
1010
+ layer.w13_weight_scale_inv, # w1_scale
1011
+ layer.w2_weight_scale_inv, # w2_scale
1012
+ self.quant_config.weight_block_size, # block_size
1013
+ None, # a1_scale
1014
+ None, # a2_scale
1015
+ True, # is_vnni
1016
+ )
1017
+
974
1018
  if _is_hip:
975
1019
  ret = self.maybe_apply_hip_fused_experts(
976
1020
  layer,
@@ -23,9 +23,9 @@ import torch
23
23
  import triton
24
24
  import triton.language as tl
25
25
 
26
- from sglang.math_utils import align
27
26
  from sglang.srt.layers.quantization import deep_gemm_wrapper
28
27
  from sglang.srt.utils import (
28
+ align,
29
29
  direct_register_custom_op,
30
30
  get_device_core_count,
31
31
  get_device_name,
@@ -1,9 +1,7 @@
1
1
  from typing import Callable, List, Optional, Tuple
2
2
 
3
- import einops
4
3
  import torch
5
4
 
6
- from sglang.math_utils import align
7
5
  from sglang.srt.layers.quantization import deep_gemm_wrapper
8
6
  from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
9
7
  from sglang.srt.layers.utils import is_sm100_supported
@@ -27,6 +25,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
27
25
  w8a8_block_fp8_matmul_triton,
28
26
  )
29
27
  from sglang.srt.utils import (
28
+ align,
30
29
  get_bool_env_var,
31
30
  get_cuda_version,
32
31
  get_device_capability,
@@ -344,6 +344,10 @@ class GPTQMarlinConfig(QuantizationConfig):
344
344
  if (num_bits, sym) not in cls.TYPE_MAP:
345
345
  return False
346
346
 
347
+ assert (
348
+ VLLM_AVAILABLE
349
+ ), "vllm is not installed, to use gptq_marlin, please install vllm"
350
+
347
351
  return check_marlin_supported(
348
352
  quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
349
353
  )
@@ -726,6 +730,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
726
730
  g_idx2=layer.w2_g_idx,
727
731
  sort_indices1=layer.w13_g_idx_sort_indices,
728
732
  sort_indices2=layer.w2_g_idx_sort_indices,
729
- num_bits=self.quant_config.quant_type.size_bits,
733
+ quant_type_id=self.quant_config.quant_type.id,
730
734
  is_k_full=self.is_k_full,
731
735
  ).to(orig_dtype)
@@ -131,7 +131,7 @@ class MoeWNA16Config(QuantizationConfig):
131
131
  capability_tuple = get_device_capability()
132
132
  device_capability = (
133
133
  -1
134
- if capability_tuple is None
134
+ if all(capability is None for capability in capability_tuple)
135
135
  else capability_tuple[0] * 10 + capability_tuple[1]
136
136
  )
137
137
  # Avoid circular import
@@ -0,0 +1,166 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
3
+
4
+ from typing import Optional
5
+
6
+ import numpy
7
+ import torch
8
+ from sgl_kernel.scalar_type import ScalarType
9
+
10
+
11
+ def get_pack_factor(num_bits):
12
+ assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
13
+ return 32 // num_bits
14
+
15
+
16
+ def pack_cols(
17
+ q_w: torch.Tensor,
18
+ num_bits: int,
19
+ size_k: int,
20
+ size_n: int,
21
+ ):
22
+ assert q_w.shape == (size_k, size_n)
23
+
24
+ pack_factor = get_pack_factor(num_bits)
25
+ assert size_n % pack_factor == 0
26
+
27
+ orig_device = q_w.device
28
+
29
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
30
+
31
+ q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
32
+
33
+ for i in range(pack_factor):
34
+ q_res |= q_w[:, i::pack_factor] << num_bits * i
35
+
36
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
37
+ q_res = q_res.contiguous()
38
+
39
+ return q_res
40
+
41
+
42
+ def unpack_cols(
43
+ packed_q_w: torch.Tensor,
44
+ num_bits: int,
45
+ size_k: int,
46
+ size_n: int,
47
+ ):
48
+ pack_factor = get_pack_factor(num_bits)
49
+ assert size_n % pack_factor == 0
50
+ assert packed_q_w.shape == (
51
+ size_k,
52
+ size_n // pack_factor,
53
+ ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
54
+ packed_q_w.shape, size_k, size_n, pack_factor
55
+ )
56
+
57
+ orig_device = packed_q_w.device
58
+
59
+ packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
60
+ q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
61
+
62
+ mask = (1 << num_bits) - 1
63
+ for i in range(pack_factor):
64
+ vals = packed_q_w_cpu & mask
65
+ packed_q_w_cpu >>= num_bits
66
+ q_res[:, i::pack_factor] = vals
67
+
68
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
69
+ q_res = q_res.contiguous()
70
+
71
+ return q_res
72
+
73
+
74
+ def quantize_weights(
75
+ w: torch.Tensor,
76
+ quant_type: ScalarType,
77
+ group_size: Optional[int],
78
+ zero_points: bool = False,
79
+ ref_zero_points_after_scales: bool = False,
80
+ ):
81
+ assert (
82
+ quant_type.is_integer()
83
+ ), "Floating point quantization may work but has not been tested"
84
+ assert not zero_points or group_size is not None, (
85
+ "to have group zero points, group_size must be provided "
86
+ "(-1 group_size is channelwise)"
87
+ )
88
+
89
+ orig_device = w.device
90
+ orig_type = w.dtype
91
+ size_k, size_n = w.shape
92
+
93
+ assert w.is_floating_point(), "w must be float"
94
+
95
+ if group_size == -1:
96
+ group_size = size_k
97
+
98
+ # Reshape to [groupsize, -1]
99
+ if group_size is not None and group_size < size_k:
100
+ w = w.reshape((-1, group_size, size_n))
101
+ w = w.permute(1, 0, 2)
102
+ w = w.reshape((group_size, -1))
103
+
104
+ # Compute scale for each group
105
+ max_val = torch.max(w, 0, keepdim=True).values
106
+ min_val = torch.min(w, 0, keepdim=True).values
107
+
108
+ max_q_val = quant_type.max()
109
+ min_q_val = quant_type.min()
110
+
111
+ w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
112
+ maybe_w_zp = None
113
+ if group_size is not None:
114
+ if zero_points:
115
+ assert not quant_type.is_signed() and quant_type.max() > 0
116
+ w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
117
+ maybe_w_zp = (
118
+ torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
119
+ )
120
+ else:
121
+ # If the bias is such that there are no possible negative/positive
122
+ # values, set the max value to inf to avoid divide by 0
123
+ w_s = torch.max(
124
+ abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
125
+ abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
126
+ )
127
+
128
+ # Quantize
129
+ w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
130
+ w_q = torch.clamp(w_q, min_q_val, max_q_val)
131
+
132
+ # Compute ref (dequantized)
133
+ # For some kernels (namely Machete) the zero-points are applied after the
134
+ # scales are applied, for this case computing the reference in similar way
135
+ # allows us to use tighter error tolerances in our unit tests.
136
+ if ref_zero_points_after_scales and maybe_w_zp is not None:
137
+ w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
138
+ else:
139
+ w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
140
+
141
+ if quant_type.has_bias():
142
+ w_q += quant_type.bias
143
+
144
+ # Restore original shapes
145
+ if group_size is not None and group_size < size_k:
146
+
147
+ def reshape_w(w):
148
+ w = w.reshape((group_size, -1, size_n))
149
+ w = w.permute(1, 0, 2)
150
+ w = w.reshape((size_k, size_n)).contiguous()
151
+ return w
152
+
153
+ w_q = reshape_w(w_q)
154
+ w_ref = reshape_w(w_ref)
155
+ w_s = w_s.reshape((-1, size_n)).contiguous()
156
+
157
+ if maybe_w_zp is not None:
158
+ maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
159
+ maybe_w_zp = maybe_w_zp.to(device=orig_device)
160
+
161
+ return (
162
+ w_ref.to(device=orig_device),
163
+ w_q.to(device=orig_device),
164
+ w_s if group_size is not None else None,
165
+ maybe_w_zp,
166
+ )
@@ -4,6 +4,7 @@ import torch
4
4
  from torch.nn.parameter import Parameter
5
5
 
6
6
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
7
+ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
7
8
  from sglang.srt.layers.linear import LinearMethodBase
8
9
  from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
9
10
  from sglang.srt.layers.quantization.base_config import (
@@ -11,9 +12,17 @@ from sglang.srt.layers.quantization.base_config import (
11
12
  QuantizeMethodBase,
12
13
  )
13
14
  from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
14
- from sglang.srt.utils import is_cuda, set_weight_attrs
15
+ from sglang.srt.utils import (
16
+ cpu_has_amx_support,
17
+ is_cpu,
18
+ is_cuda,
19
+ set_weight_attrs,
20
+ use_intel_amx_backend,
21
+ )
15
22
 
16
23
  _is_cuda = is_cuda()
24
+ _is_cpu_amx_available = cpu_has_amx_support()
25
+ _is_cpu = is_cpu()
17
26
  if _is_cuda:
18
27
  from sgl_kernel import int8_scaled_mm
19
28
 
@@ -72,6 +81,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
72
81
  self.quantization_config = quantization_config
73
82
 
74
83
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
84
+ if _is_cpu:
85
+ assert (
86
+ _is_cpu_amx_available
87
+ ), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
88
+ _amx_process_weight_after_loading(layer, ["weight"])
89
+ return
90
+
75
91
  layer.weight = Parameter(layer.weight.t(), requires_grad=False)
76
92
  layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
77
93
 
@@ -112,6 +128,16 @@ class W8A8Int8LinearMethod(LinearMethodBase):
112
128
  x: torch.Tensor,
113
129
  bias: Optional[torch.Tensor] = None,
114
130
  ):
131
+ if use_intel_amx_backend(layer):
132
+ return torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
133
+ x,
134
+ layer.weight,
135
+ layer.weight_scale,
136
+ bias,
137
+ x.dtype,
138
+ True, # is_vnni
139
+ )
140
+
115
141
  x_q, x_scale = per_token_quant_int8(x)
116
142
 
117
143
  return int8_scaled_mm(
@@ -206,6 +232,13 @@ class W8A8Int8MoEMethod:
206
232
  layer.register_parameter("w2_input_scale", w2_input_scale)
207
233
 
208
234
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
235
+ if _is_cpu:
236
+ assert (
237
+ _is_cpu_amx_available
238
+ ), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
239
+ _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
240
+ return
241
+
209
242
  layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
210
243
  layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
211
244
  layer.w13_weight_scale = Parameter(
@@ -252,6 +285,24 @@ class W8A8Int8MoEMethod:
252
285
  routed_scaling_factor=routed_scaling_factor,
253
286
  )
254
287
 
288
+ if use_intel_amx_backend(layer):
289
+ return torch.ops.sgl_kernel.fused_experts_cpu(
290
+ x,
291
+ layer.w13_weight,
292
+ layer.w2_weight,
293
+ topk_weights,
294
+ topk_ids,
295
+ False, # inplace See [Note] inplace should be False in fused_experts.
296
+ True, # use_int8_w8a8
297
+ False, # use_fp8_w8a16
298
+ layer.w13_weight_scale, # w1_scale
299
+ layer.w2_weight_scale, # w2_scale
300
+ None, # block_size
301
+ layer.w13_input_scale, # a1_scale
302
+ layer.w2_input_scale, # a2_scale
303
+ True, # is_vnni
304
+ )
305
+
255
306
  return fused_experts(
256
307
  x,
257
308
  layer.w13_weight,
@@ -660,7 +660,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
660
660
  beta_slow: int = 1,
661
661
  mscale: float = 1,
662
662
  mscale_all_dim: float = 0,
663
- device: Optional[str] = "cuda",
663
+ device: Optional[str] = "cuda" if not _is_npu else "npu",
664
664
  ) -> None:
665
665
  self.scaling_factor = scaling_factor
666
666
  self.extrapolation_factor = extrapolation_factor
@@ -679,7 +679,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
679
679
  )
680
680
 
681
681
  # Re-dispatch
682
- if _is_hip:
682
+ if _is_hip or _is_npu:
683
683
  self._forward_method = self.forward_native
684
684
 
685
685
  def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
13
13
  get_tensor_model_parallel_world_size,
14
14
  tensor_model_parallel_all_reduce,
15
15
  )
16
+ from sglang.srt.layers.amx_utils import PackWeightMethod
16
17
  from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
17
18
  from sglang.srt.layers.parameter import BasevLLMParameter
18
19
  from sglang.srt.layers.quantization.base_config import (
@@ -20,12 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
20
21
  QuantizeMethodBase,
21
22
  method_has_implemented_embedding,
22
23
  )
23
- from sglang.srt.utils import (
24
- PackWeightMethod,
25
- cpu_has_amx_support,
26
- is_cpu,
27
- set_weight_attrs,
28
- )
24
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
29
25
 
30
26
  DEFAULT_VOCAB_PADDING_SIZE = 64
31
27
 
@@ -250,8 +246,16 @@ class VocabParallelEmbedding(torch.nn.Module):
250
246
  self.tp_size = 1
251
247
 
252
248
  self.num_embeddings = num_embeddings
253
- self.padding_size = padding_size
254
249
  self.org_vocab_size = org_num_embeddings or num_embeddings
250
+
251
+ # Support the case where the vocab size is not divisible by the TP size.
252
+ if (
253
+ _is_cpu
254
+ and pad_vocab_size(self.org_vocab_size, padding_size) % self.tp_size != 0
255
+ ):
256
+ padding_size *= self.tp_size
257
+ self.padding_size = padding_size
258
+
255
259
  num_added_embeddings = num_embeddings - self.org_vocab_size
256
260
  self.use_presharded_weights = use_presharded_weights
257
261
  if use_presharded_weights:
sglang/srt/lora/lora.py CHANGED
@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module):
65
65
  self.layers: List[LoRALayer] = nn.ModuleList(
66
66
  [
67
67
  LoRALayer(config, base_hf_config)
68
- for i in range(base_hf_config.num_hidden_layers)
68
+ for _ in range(base_hf_config.num_hidden_layers)
69
69
  ]
70
70
  )
71
71
 
@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module):
88
88
  else:
89
89
  self.weights[name] = loaded_weight.cpu()
90
90
 
91
- # stack kv_proj and gate_up_proj
92
- for i in range(self.base_hf_config.num_hidden_layers):
93
- layer = self.layers[i]
94
- weight_names = [name for name, _ in layer.weights.items()]
91
+ # normalize kv_proj and gate_up_proj
92
+ for layer in self.layers:
93
+ weight_names = list(layer.weights.keys())
95
94
  self.normalize_qkv_proj(weight_names, layer.weights)
96
95
  self.normalize_gate_up_proj(weight_names, layer.weights)
97
96