sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,166 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
3
+
4
+ from typing import Optional
5
+
6
+ import numpy
7
+ import torch
8
+ from sgl_kernel.scalar_type import ScalarType
9
+
10
+
11
+ def get_pack_factor(num_bits):
12
+ assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
13
+ return 32 // num_bits
14
+
15
+
16
+ def pack_cols(
17
+ q_w: torch.Tensor,
18
+ num_bits: int,
19
+ size_k: int,
20
+ size_n: int,
21
+ ):
22
+ assert q_w.shape == (size_k, size_n)
23
+
24
+ pack_factor = get_pack_factor(num_bits)
25
+ assert size_n % pack_factor == 0
26
+
27
+ orig_device = q_w.device
28
+
29
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
30
+
31
+ q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
32
+
33
+ for i in range(pack_factor):
34
+ q_res |= q_w[:, i::pack_factor] << num_bits * i
35
+
36
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
37
+ q_res = q_res.contiguous()
38
+
39
+ return q_res
40
+
41
+
42
+ def unpack_cols(
43
+ packed_q_w: torch.Tensor,
44
+ num_bits: int,
45
+ size_k: int,
46
+ size_n: int,
47
+ ):
48
+ pack_factor = get_pack_factor(num_bits)
49
+ assert size_n % pack_factor == 0
50
+ assert packed_q_w.shape == (
51
+ size_k,
52
+ size_n // pack_factor,
53
+ ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
54
+ packed_q_w.shape, size_k, size_n, pack_factor
55
+ )
56
+
57
+ orig_device = packed_q_w.device
58
+
59
+ packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
60
+ q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
61
+
62
+ mask = (1 << num_bits) - 1
63
+ for i in range(pack_factor):
64
+ vals = packed_q_w_cpu & mask
65
+ packed_q_w_cpu >>= num_bits
66
+ q_res[:, i::pack_factor] = vals
67
+
68
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
69
+ q_res = q_res.contiguous()
70
+
71
+ return q_res
72
+
73
+
74
+ def quantize_weights(
75
+ w: torch.Tensor,
76
+ quant_type: ScalarType,
77
+ group_size: Optional[int],
78
+ zero_points: bool = False,
79
+ ref_zero_points_after_scales: bool = False,
80
+ ):
81
+ assert (
82
+ quant_type.is_integer()
83
+ ), "Floating point quantization may work but has not been tested"
84
+ assert not zero_points or group_size is not None, (
85
+ "to have group zero points, group_size must be provided "
86
+ "(-1 group_size is channelwise)"
87
+ )
88
+
89
+ orig_device = w.device
90
+ orig_type = w.dtype
91
+ size_k, size_n = w.shape
92
+
93
+ assert w.is_floating_point(), "w must be float"
94
+
95
+ if group_size == -1:
96
+ group_size = size_k
97
+
98
+ # Reshape to [groupsize, -1]
99
+ if group_size is not None and group_size < size_k:
100
+ w = w.reshape((-1, group_size, size_n))
101
+ w = w.permute(1, 0, 2)
102
+ w = w.reshape((group_size, -1))
103
+
104
+ # Compute scale for each group
105
+ max_val = torch.max(w, 0, keepdim=True).values
106
+ min_val = torch.min(w, 0, keepdim=True).values
107
+
108
+ max_q_val = quant_type.max()
109
+ min_q_val = quant_type.min()
110
+
111
+ w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
112
+ maybe_w_zp = None
113
+ if group_size is not None:
114
+ if zero_points:
115
+ assert not quant_type.is_signed() and quant_type.max() > 0
116
+ w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
117
+ maybe_w_zp = (
118
+ torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
119
+ )
120
+ else:
121
+ # If the bias is such that there are no possible negative/positive
122
+ # values, set the max value to inf to avoid divide by 0
123
+ w_s = torch.max(
124
+ abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
125
+ abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
126
+ )
127
+
128
+ # Quantize
129
+ w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
130
+ w_q = torch.clamp(w_q, min_q_val, max_q_val)
131
+
132
+ # Compute ref (dequantized)
133
+ # For some kernels (namely Machete) the zero-points are applied after the
134
+ # scales are applied, for this case computing the reference in similar way
135
+ # allows us to use tighter error tolerances in our unit tests.
136
+ if ref_zero_points_after_scales and maybe_w_zp is not None:
137
+ w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
138
+ else:
139
+ w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
140
+
141
+ if quant_type.has_bias():
142
+ w_q += quant_type.bias
143
+
144
+ # Restore original shapes
145
+ if group_size is not None and group_size < size_k:
146
+
147
+ def reshape_w(w):
148
+ w = w.reshape((group_size, -1, size_n))
149
+ w = w.permute(1, 0, 2)
150
+ w = w.reshape((size_k, size_n)).contiguous()
151
+ return w
152
+
153
+ w_q = reshape_w(w_q)
154
+ w_ref = reshape_w(w_ref)
155
+ w_s = w_s.reshape((-1, size_n)).contiguous()
156
+
157
+ if maybe_w_zp is not None:
158
+ maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
159
+ maybe_w_zp = maybe_w_zp.to(device=orig_device)
160
+
161
+ return (
162
+ w_ref.to(device=orig_device),
163
+ w_q.to(device=orig_device),
164
+ w_s if group_size is not None else None,
165
+ maybe_w_zp,
166
+ )
@@ -4,6 +4,7 @@ import torch
4
4
  from torch.nn.parameter import Parameter
5
5
 
6
6
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
7
+ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
7
8
  from sglang.srt.layers.linear import LinearMethodBase
8
9
  from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
9
10
  from sglang.srt.layers.quantization.base_config import (
@@ -11,9 +12,17 @@ from sglang.srt.layers.quantization.base_config import (
11
12
  QuantizeMethodBase,
12
13
  )
13
14
  from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
14
- from sglang.srt.utils import is_cuda, set_weight_attrs
15
+ from sglang.srt.utils import (
16
+ cpu_has_amx_support,
17
+ is_cpu,
18
+ is_cuda,
19
+ set_weight_attrs,
20
+ use_intel_amx_backend,
21
+ )
15
22
 
16
23
  _is_cuda = is_cuda()
24
+ _is_cpu_amx_available = cpu_has_amx_support()
25
+ _is_cpu = is_cpu()
17
26
  if _is_cuda:
18
27
  from sgl_kernel import int8_scaled_mm
19
28
 
@@ -72,6 +81,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
72
81
  self.quantization_config = quantization_config
73
82
 
74
83
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
84
+ if _is_cpu:
85
+ assert (
86
+ _is_cpu_amx_available
87
+ ), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
88
+ _amx_process_weight_after_loading(layer, ["weight"])
89
+ return
90
+
75
91
  layer.weight = Parameter(layer.weight.t(), requires_grad=False)
76
92
  layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
77
93
 
@@ -112,6 +128,16 @@ class W8A8Int8LinearMethod(LinearMethodBase):
112
128
  x: torch.Tensor,
113
129
  bias: Optional[torch.Tensor] = None,
114
130
  ):
131
+ if use_intel_amx_backend(layer):
132
+ return torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
133
+ x,
134
+ layer.weight,
135
+ layer.weight_scale,
136
+ bias,
137
+ x.dtype,
138
+ True, # is_vnni
139
+ )
140
+
115
141
  x_q, x_scale = per_token_quant_int8(x)
116
142
 
117
143
  return int8_scaled_mm(
@@ -206,6 +232,13 @@ class W8A8Int8MoEMethod:
206
232
  layer.register_parameter("w2_input_scale", w2_input_scale)
207
233
 
208
234
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
235
+ if _is_cpu:
236
+ assert (
237
+ _is_cpu_amx_available
238
+ ), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
239
+ _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
240
+ return
241
+
209
242
  layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
210
243
  layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
211
244
  layer.w13_weight_scale = Parameter(
@@ -252,6 +285,24 @@ class W8A8Int8MoEMethod:
252
285
  routed_scaling_factor=routed_scaling_factor,
253
286
  )
254
287
 
288
+ if use_intel_amx_backend(layer):
289
+ return torch.ops.sgl_kernel.fused_experts_cpu(
290
+ x,
291
+ layer.w13_weight,
292
+ layer.w2_weight,
293
+ topk_weights,
294
+ topk_ids,
295
+ False, # inplace See [Note] inplace should be False in fused_experts.
296
+ True, # use_int8_w8a8
297
+ False, # use_fp8_w8a16
298
+ layer.w13_weight_scale, # w1_scale
299
+ layer.w2_weight_scale, # w2_scale
300
+ None, # block_size
301
+ layer.w13_input_scale, # a1_scale
302
+ layer.w2_input_scale, # a2_scale
303
+ True, # is_vnni
304
+ )
305
+
255
306
  return fused_experts(
256
307
  x,
257
308
  layer.w13_weight,
@@ -8,16 +8,29 @@ import torch
8
8
  import torch.nn as nn
9
9
 
10
10
  from sglang.srt.custom_op import CustomOp
11
- from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
11
+ from sglang.srt.utils import (
12
+ cpu_has_amx_support,
13
+ get_bool_env_var,
14
+ is_cpu,
15
+ is_cuda,
16
+ is_hip,
17
+ is_npu,
18
+ )
12
19
 
13
20
  _is_cuda = is_cuda()
14
21
  _is_hip = is_hip()
22
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
15
23
  _is_npu = is_npu()
16
24
  _is_cpu_amx_available = cpu_has_amx_support()
17
25
  _is_cpu = is_cpu()
18
26
 
19
27
  if _is_cuda:
20
28
  from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
29
+ if _use_aiter:
30
+ from aiter.rotary_embedding import get_rope as aiter_get_rope
31
+
32
+ if is_npu():
33
+ import torch_npu
21
34
 
22
35
 
23
36
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -152,6 +165,36 @@ class RotaryEmbedding(CustomOp):
152
165
  key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
153
166
  return query, key
154
167
 
168
+ def forward_npu(
169
+ self,
170
+ positions: torch.Tensor,
171
+ query: torch.Tensor,
172
+ key: torch.Tensor,
173
+ offsets: Optional[torch.Tensor] = None,
174
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
175
+ """A PyTorch-npu implementation of forward()."""
176
+ import os
177
+
178
+ if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
179
+ return self.forward_native(positions, query, key, offsets)
180
+ else:
181
+ rotary_mode = "half"
182
+ if self.is_neox_style:
183
+ rotary_mode = "half"
184
+ else:
185
+ rotary_mode = "interleave"
186
+ mrope_section = [0, 0, 0]
187
+ query_out, key_out = torch_npu.npu_mrope(
188
+ positions,
189
+ query,
190
+ key,
191
+ self.cos_sin_cache,
192
+ self.head_size,
193
+ mrope_section=mrope_section,
194
+ rotary_mode=rotary_mode,
195
+ )
196
+ return query_out, key_out
197
+
155
198
  def forward_cpu(
156
199
  self,
157
200
  positions: torch.Tensor,
@@ -617,7 +660,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
617
660
  beta_slow: int = 1,
618
661
  mscale: float = 1,
619
662
  mscale_all_dim: float = 0,
620
- device: Optional[str] = "cuda",
663
+ device: Optional[str] = "cuda" if not _is_npu else "npu",
621
664
  ) -> None:
622
665
  self.scaling_factor = scaling_factor
623
666
  self.extrapolation_factor = extrapolation_factor
@@ -636,7 +679,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
636
679
  )
637
680
 
638
681
  # Re-dispatch
639
- if _is_hip:
682
+ if _is_hip or _is_npu:
640
683
  self._forward_method = self.forward_native
641
684
 
642
685
  def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
@@ -847,6 +890,43 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
847
890
  return query_out.type_as(query), key_out.type_as(key)
848
891
 
849
892
 
893
+ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
894
+ """RotaryEmbedding extended with Dynamic NTK scaling.
895
+
896
+ Credits to the Reddit users /u/bloc97 and /u/emozilla
897
+ """
898
+
899
+ def __init__(
900
+ self,
901
+ head_size: int,
902
+ rotary_dim: int,
903
+ max_position_embeddings: int,
904
+ base: int,
905
+ is_neox_style: bool,
906
+ scaling_alpha: float,
907
+ dtype: torch.dtype,
908
+ ) -> None:
909
+ self.scaling_alpha = scaling_alpha
910
+ super().__init__(
911
+ head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
912
+ )
913
+
914
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
915
+ max_len = self.max_position_embeddings
916
+ base = self.base * self.scaling_alpha ** (
917
+ self.rotary_dim / (self.rotary_dim - 2)
918
+ )
919
+
920
+ inv_freq = self._compute_inv_freq(base)
921
+ t = torch.arange(max_len, dtype=torch.float)
922
+
923
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
924
+ cos = freqs.cos()
925
+ sin = freqs.sin()
926
+ cache = torch.cat((cos, sin), dim=-1)
927
+ return cache
928
+
929
+
850
930
  class MRotaryEmbedding(RotaryEmbedding):
851
931
  """Rotary Embedding with Multimodal Sections."""
852
932
 
@@ -1191,15 +1271,26 @@ def get_rope(
1191
1271
  )
1192
1272
  elif scaling_type == "dynamic":
1193
1273
  scaling_factor = rope_scaling["factor"]
1194
- rotary_emb = DynamicNTKScalingRotaryEmbedding(
1195
- head_size,
1196
- rotary_dim,
1197
- max_position,
1198
- base,
1199
- is_neox_style,
1200
- scaling_factor,
1201
- dtype,
1202
- )
1274
+ if "alpha" in rope_scaling:
1275
+ rotary_emb = DynamicNTKAlphaRotaryEmbedding(
1276
+ head_size,
1277
+ rotary_dim,
1278
+ max_position,
1279
+ base,
1280
+ is_neox_style,
1281
+ rope_scaling["alpha"],
1282
+ dtype,
1283
+ )
1284
+ else:
1285
+ rotary_emb = DynamicNTKScalingRotaryEmbedding(
1286
+ head_size,
1287
+ rotary_dim,
1288
+ max_position,
1289
+ base,
1290
+ is_neox_style,
1291
+ scaling_factor,
1292
+ dtype,
1293
+ )
1203
1294
  elif scaling_type == "yarn":
1204
1295
  scaling_factor = rope_scaling["factor"]
1205
1296
  original_max_position = rope_scaling["original_max_position_embeddings"]
@@ -1388,7 +1479,8 @@ def get_rope_wrapper(
1388
1479
  device: Optional[str] = None,
1389
1480
  ):
1390
1481
  if device != "cpu":
1391
- return get_rope(
1482
+ wrapper = aiter_get_rope if _use_aiter else get_rope
1483
+ return wrapper(
1392
1484
  head_size,
1393
1485
  rotary_dim,
1394
1486
  max_position,
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
13
13
  get_tensor_model_parallel_world_size,
14
14
  tensor_model_parallel_all_reduce,
15
15
  )
16
+ from sglang.srt.layers.amx_utils import PackWeightMethod
16
17
  from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
17
18
  from sglang.srt.layers.parameter import BasevLLMParameter
18
19
  from sglang.srt.layers.quantization.base_config import (
@@ -20,10 +21,13 @@ from sglang.srt.layers.quantization.base_config import (
20
21
  QuantizeMethodBase,
21
22
  method_has_implemented_embedding,
22
23
  )
23
- from sglang.srt.utils import set_weight_attrs
24
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
24
25
 
25
26
  DEFAULT_VOCAB_PADDING_SIZE = 64
26
27
 
28
+ _is_cpu_amx_available = cpu_has_amx_support()
29
+ _is_cpu = is_cpu()
30
+
27
31
 
28
32
  class UnquantizedEmbeddingMethod(QuantizeMethodBase):
29
33
  """Unquantized method for embeddings."""
@@ -242,8 +246,16 @@ class VocabParallelEmbedding(torch.nn.Module):
242
246
  self.tp_size = 1
243
247
 
244
248
  self.num_embeddings = num_embeddings
245
- self.padding_size = padding_size
246
249
  self.org_vocab_size = org_num_embeddings or num_embeddings
250
+
251
+ # Support the case where the vocab size is not divisible by the TP size.
252
+ if (
253
+ _is_cpu
254
+ and pad_vocab_size(self.org_vocab_size, padding_size) % self.tp_size != 0
255
+ ):
256
+ padding_size *= self.tp_size
257
+ self.padding_size = padding_size
258
+
247
259
  num_added_embeddings = num_embeddings - self.org_vocab_size
248
260
  self.use_presharded_weights = use_presharded_weights
249
261
  if use_presharded_weights:
@@ -549,6 +561,11 @@ class ParallelLMHead(VocabParallelEmbedding):
549
561
  use_presharded_weights=use_presharded_weights,
550
562
  )
551
563
  self.quant_config = quant_config
564
+
565
+ # We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight"
566
+ if self.quant_config is None and _is_cpu and _is_cpu_amx_available:
567
+ self.quant_method = PackWeightMethod(weight_names=["weight"])
568
+
552
569
  if bias:
553
570
  self.bias = Parameter(
554
571
  torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
sglang/srt/lora/lora.py CHANGED
@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module):
65
65
  self.layers: List[LoRALayer] = nn.ModuleList(
66
66
  [
67
67
  LoRALayer(config, base_hf_config)
68
- for i in range(base_hf_config.num_hidden_layers)
68
+ for _ in range(base_hf_config.num_hidden_layers)
69
69
  ]
70
70
  )
71
71
 
@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module):
88
88
  else:
89
89
  self.weights[name] = loaded_weight.cpu()
90
90
 
91
- # stack kv_proj and gate_up_proj
92
- for i in range(self.base_hf_config.num_hidden_layers):
93
- layer = self.layers[i]
94
- weight_names = [name for name, _ in layer.weights.items()]
91
+ # normalize kv_proj and gate_up_proj
92
+ for layer in self.layers:
93
+ weight_names = list(layer.weights.keys())
95
94
  self.normalize_qkv_proj(weight_names, layer.weights)
96
95
  self.normalize_gate_up_proj(weight_names, layer.weights)
97
96
 
@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import (
35
35
  get_normalized_lora_weight_names,
36
36
  get_weight_name,
37
37
  )
38
+ from sglang.srt.managers.io_struct import LoRAUpdateResult
38
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
40
  from sglang.srt.utils import replace_submodule
40
41
 
@@ -98,44 +99,96 @@ class LoRAManager:
98
99
  ],
99
100
  )
100
101
 
101
- def load_lora_adapters(self, lora_paths: Dict[str, str]):
102
+ def create_lora_update_result(
103
+ self, success: bool, error_message: str = ""
104
+ ) -> LoRAUpdateResult:
105
+ return LoRAUpdateResult(
106
+ success=success,
107
+ error_message=error_message,
108
+ loaded_adapters={
109
+ name: config.path for name, config in self.configs.items()
110
+ },
111
+ )
112
+
113
+ def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
102
114
  """
103
115
  Load LoRA adapters from the specified paths.
104
- TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
105
116
 
106
117
  Args:
107
118
  lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
108
119
  If a LoRA adapter is already loaded, it will be skipped with a warning.
109
120
  """
110
121
 
122
+ results = []
111
123
  for lora_name, lora_path in lora_paths.items():
112
- if lora_name in self.loras:
113
- logger.warning(
114
- f"LoRA adapter {lora_name} is already loaded."
115
- "If you want to reload it, please unload it first."
116
- )
117
- continue
124
+ result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
125
+ results.append(result)
126
+
127
+ self.update_state_from_configs()
128
+
129
+ return self.create_lora_update_result(
130
+ success=all(result.success for result in results),
131
+ error_message="\n".join(
132
+ result.error_message for result in results if not result.success
133
+ ),
134
+ )
135
+
136
+ def load_lora_adapter(
137
+ self, lora_name: str, lora_path: str, update_state: bool = True
138
+ ) -> LoRAUpdateResult:
139
+ """
140
+ Load a single LoRA adapter from the specified path.
141
+
142
+ Args:
143
+ lora_name (str): The name of the LoRA adapter.
144
+ lora_path (str): The file path to the LoRA adapter.
145
+ update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
146
+ """
118
147
 
148
+ success = True
149
+ error_message = ""
150
+
151
+ if lora_name in self.loras:
152
+ success = False
153
+ error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
154
+
155
+ try:
119
156
  self.configs[lora_name] = LoRAConfig(lora_path)
157
+ except Exception as e:
158
+ success = False
159
+ error_message = (
160
+ f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
161
+ )
120
162
 
121
- self.update_state_from_configs()
163
+ if update_state:
164
+ self.update_state_from_configs()
165
+
166
+ return self.create_lora_update_result(
167
+ success=success,
168
+ error_message=error_message,
169
+ )
122
170
 
123
- def unload_lora_adapters(self, lora_names: Set[str]):
171
+ def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
124
172
  """
125
173
  Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
126
174
  delete the corresponding LoRA modules.
127
-
128
- Args:
129
- lora_names (Set[str]): A set of LoRA adapter names to unload.
130
175
  """
131
- for lora_name in lora_names:
132
- if lora_name in self.loras:
133
- del self.configs[lora_name]
134
- else:
135
- logger.warning(f"LoRA adapter {lora_name} is not loaded.")
176
+
177
+ success = True
178
+ error_message = ""
179
+ if lora_name in self.loras:
180
+ del self.configs[lora_name]
181
+ else:
182
+ error_message = f"LoRA adapter {lora_name} is not loaded."
183
+ success = False
136
184
 
137
185
  self.update_state_from_configs()
138
186
 
187
+ return self.create_lora_update_result(
188
+ success=success,
189
+ error_message=error_message,
190
+ )
191
+
139
192
  def prepare_lora_batch(self, forward_batch: ForwardBatch):
140
193
  # load active loras into lora memory pool
141
194
  cur_uids = set(forward_batch.lora_paths)
@@ -372,8 +425,8 @@ class LoRAManager:
372
425
  lora_adapter.initialize_weights()
373
426
  self.loras[name] = lora_adapter
374
427
 
375
- # Clean up unused LoRA adapters
376
- for name in self.loras:
428
+ # Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
429
+ for name in list(self.loras):
377
430
  if name not in self.configs:
378
431
  logger.info(f"Unloading LoRA adapter {name}")
379
432
  del self.loras[name]
@@ -28,7 +28,7 @@ if __name__ == "__main__":
28
28
  parser = argparse.ArgumentParser()
29
29
  parser.add_argument("--url", type=str, default="http://localhost:30000")
30
30
  parser.add_argument("--log-requests", action="store_true")
31
- parser.add_argument("--log-requests-level", type=int, default=2)
31
+ parser.add_argument("--log-requests-level", type=int, default=3)
32
32
  parser.add_argument(
33
33
  "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
34
34
  )