sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. sglang/srt/configs/model_config.py +15 -6
  2. sglang/srt/layers/attention/flashinfer_backend.py +17 -3
  3. sglang/srt/layers/linear.py +36 -98
  4. sglang/srt/layers/moe/fused_moe_triton/layer.py +37 -9
  5. sglang/srt/layers/moe/topk.py +4 -2
  6. sglang/srt/layers/parameter.py +24 -16
  7. sglang/srt/layers/quantization/__init__.py +2 -0
  8. sglang/srt/layers/quantization/fp8.py +106 -52
  9. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  10. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  11. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  12. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  13. sglang/srt/layers/radix_attention.py +2 -0
  14. sglang/srt/layers/vocab_parallel_embedding.py +15 -2
  15. sglang/srt/managers/configure_logging.py +43 -0
  16. sglang/srt/managers/detokenizer_manager.py +0 -2
  17. sglang/srt/managers/io_struct.py +29 -13
  18. sglang/srt/managers/scheduler.py +48 -9
  19. sglang/srt/managers/tokenizer_manager.py +109 -49
  20. sglang/srt/mem_cache/memory_pool.py +107 -52
  21. sglang/srt/metrics/collector.py +10 -5
  22. sglang/srt/model_executor/model_runner.py +43 -6
  23. sglang/srt/models/llama.py +37 -2
  24. sglang/srt/models/qwen2.py +11 -0
  25. sglang/srt/models/qwen2_eagle.py +131 -0
  26. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  27. sglang/srt/sampling/sampling_batch_info.py +14 -5
  28. sglang/srt/sampling/sampling_params.py +1 -1
  29. sglang/srt/server.py +114 -61
  30. sglang/srt/server_args.py +27 -18
  31. sglang/srt/speculative/eagle_worker.py +1 -0
  32. sglang/srt/torch_memory_saver_adapter.py +59 -0
  33. sglang/srt/utils.py +29 -0
  34. sglang/version.py +1 -1
  35. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +12 -10
  36. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +39 -34
  37. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
  38. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +0 -0
  39. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -223,7 +223,11 @@ class ModelConfig:
223
223
  "compressed_tensors",
224
224
  "compressed-tensors",
225
225
  "experts_int8",
226
+ "w8a8_int8",
226
227
  ]
228
+ compatible_quantization_methods = {
229
+ "w8a8_int8": ["compressed-tensors", "compressed_tensors"]
230
+ }
227
231
  if self.quantization is not None:
228
232
  self.quantization = self.quantization.lower()
229
233
 
@@ -247,12 +251,17 @@ class ModelConfig:
247
251
  if self.quantization is None:
248
252
  self.quantization = quant_method
249
253
  elif self.quantization != quant_method:
250
- raise ValueError(
251
- "Quantization method specified in the model config "
252
- f"({quant_method}) does not match the quantization "
253
- f"method specified in the `quantization` argument "
254
- f"({self.quantization})."
255
- )
254
+ if (
255
+ self.quantization not in compatible_quantization_methods
256
+ or quant_method
257
+ not in compatible_quantization_methods[self.quantization]
258
+ ):
259
+ raise ValueError(
260
+ "Quantization method specified in the model config "
261
+ f"({quant_method}) does not match the quantization "
262
+ f"method specified in the `quantization` argument "
263
+ f"({self.quantization})."
264
+ )
256
265
 
257
266
  if self.quantization is not None:
258
267
  if self.quantization not in supported_quantization:
@@ -84,6 +84,10 @@ class FlashInferAttnBackend(AttentionBackend):
84
84
  self.num_wrappers = 1
85
85
  self.dispatch_reason = None
86
86
 
87
+ # Qwen2 models require higher flashinfer workspace size
88
+ if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
89
+ global_config.flashinfer_workspace_size = 512 * 1024 * 1024
90
+
87
91
  # Allocate buffers
88
92
  self.workspace_buffer = torch.empty(
89
93
  global_config.flashinfer_workspace_size,
@@ -353,7 +357,9 @@ class FlashInferAttnBackend(AttentionBackend):
353
357
  if k is not None:
354
358
  assert v is not None
355
359
  if save_kv_cache:
356
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
360
+ forward_batch.token_to_kv_pool.set_kv_buffer(
361
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
362
+ )
357
363
 
358
364
  o = prefill_wrapper_paged.forward(
359
365
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -362,6 +368,8 @@ class FlashInferAttnBackend(AttentionBackend):
362
368
  sm_scale=layer.scaling,
363
369
  window_left=layer.sliding_window_size,
364
370
  logits_soft_cap=logits_soft_cap,
371
+ k_scale=layer.k_scale,
372
+ v_scale=layer.v_scale,
365
373
  )
366
374
  else:
367
375
  o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
@@ -387,7 +395,9 @@ class FlashInferAttnBackend(AttentionBackend):
387
395
  o, _ = merge_state(o1, s1, o2, s2)
388
396
 
389
397
  if save_kv_cache:
390
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
398
+ forward_batch.token_to_kv_pool.set_kv_buffer(
399
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
400
+ )
391
401
 
392
402
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
393
403
 
@@ -412,13 +422,17 @@ class FlashInferAttnBackend(AttentionBackend):
412
422
  if k is not None:
413
423
  assert v is not None
414
424
  if save_kv_cache:
415
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
425
+ forward_batch.token_to_kv_pool.set_kv_buffer(
426
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
427
+ )
416
428
 
417
429
  o = decode_wrapper.forward(
418
430
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
419
431
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
420
432
  sm_scale=layer.scaling,
421
433
  logits_soft_cap=layer.logit_cap,
434
+ k_scale=layer.k_scale,
435
+ v_scale=layer.v_scale,
422
436
  )
423
437
 
424
438
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -1,4 +1,4 @@
1
- # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/linear.py
1
+ """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
2
2
 
3
3
  import logging
4
4
  from abc import abstractmethod
@@ -16,7 +16,7 @@ from vllm.distributed import (
16
16
  tensor_model_parallel_all_reduce,
17
17
  )
18
18
 
19
- # workaround
19
+ # Workaround: many QuantizationConfig still depends on this, so we have to use vLLM's LinearBase now.
20
20
  from vllm.model_executor.layers.linear import LinearBase
21
21
 
22
22
  from sglang.srt.layers.parameter import (
@@ -25,7 +25,6 @@ from sglang.srt.layers.parameter import (
25
25
  PackedvLLMParameter,
26
26
  PerTensorScaleParameter,
27
27
  RowvLLMParameter,
28
- _ColumnvLLMParameter,
29
28
  )
30
29
  from sglang.srt.layers.quantization.base_config import (
31
30
  QuantizationConfig,
@@ -43,9 +42,13 @@ WEIGHT_LOADER_V2_SUPPORTED = [
43
42
  "GPTQMarlinLinearMethod",
44
43
  "Fp8LinearMethod",
45
44
  "MarlinLinearMethod",
46
- "GPTQLinearMethod",
47
45
  "QQQLinearMethod",
46
+ "GPTQMarlin24LinearMethod",
47
+ "TPUInt8LinearMethod",
48
+ "GPTQLinearMethod",
49
+ "FBGEMMFp8LinearMethod",
48
50
  "ModelOptFp8LinearMethod",
51
+ "IPEXAWQLinearMethod",
49
52
  ]
50
53
 
51
54
 
@@ -95,62 +98,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
95
98
  return param[shard_id], loaded_weight
96
99
 
97
100
 
98
- def load_column_qkv_weight(
99
- self, loaded_weight, num_heads, shard_id, shard_offset, shard_size, tp_rank
100
- ):
101
- if (
102
- isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
103
- and self.output_dim == self.packed_dim
104
- ):
105
- shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
106
- shard_offset=shard_offset, shard_size=shard_size
107
- )
108
-
109
- param_data = self.data
110
- shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
111
- param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
112
- loaded_weight = loaded_weight.narrow(
113
- self.output_dim, shard_id * shard_size, shard_size
114
- )
115
-
116
- assert param_data.shape == loaded_weight.shape
117
- param_data.copy_(loaded_weight)
118
-
119
-
120
- def load_column_parallel_weight(
121
- self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
122
- ):
123
- if isinstance(self, _ColumnvLLMParameter):
124
- if not use_presharded_weights:
125
- shard_size = self.data.shape[self.output_dim]
126
- loaded_weight = loaded_weight.narrow(
127
- self.output_dim, tp_rank * shard_size, shard_size
128
- )
129
- assert self.data.shape == loaded_weight.shape
130
- self.data.copy_(loaded_weight)
131
- else:
132
- self.data.copy_(loaded_weight)
133
-
134
-
135
- def load_row_parallel_weight(
136
- self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
137
- ):
138
- if isinstance(self, RowvLLMParameter):
139
- if not use_presharded_weights:
140
- shard_size = self.data.shape[self.input_dim]
141
- loaded_weight = loaded_weight.narrow(
142
- self.input_dim, tp_rank * shard_size, shard_size
143
- )
144
-
145
- if len(loaded_weight.shape) == 0:
146
- loaded_weight = loaded_weight.reshape(1)
147
-
148
- assert self.data.shape == loaded_weight.shape
149
- self.data.copy_(loaded_weight)
150
- else:
151
- self.data.copy_(loaded_weight)
152
-
153
-
154
101
  class LinearMethodBase(QuantizeMethodBase):
155
102
  """Base class for different (maybe quantized) linear methods."""
156
103
 
@@ -426,9 +373,7 @@ class ColumnParallelLinear(LinearBase):
426
373
  if len(loaded_weight.shape) == 0:
427
374
  loaded_weight = loaded_weight.reshape(1)
428
375
 
429
- assert (
430
- param_data.shape == loaded_weight.shape
431
- ), f"{param_data.shape=}, {loaded_weight.shape=}"
376
+ assert param_data.shape == loaded_weight.shape
432
377
  param_data.copy_(loaded_weight)
433
378
 
434
379
  def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
@@ -437,7 +382,7 @@ class ColumnParallelLinear(LinearBase):
437
382
  if len(loaded_weight.shape) == 0:
438
383
  assert loaded_weight.numel() == 1
439
384
  loaded_weight = loaded_weight.reshape(1)
440
- param.load_column_parallel_weight(loaded_weight=loaded_weight)
385
+ param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank)
441
386
 
442
387
  def forward(self, input_):
443
388
  bias = self.bias if not self.skip_bias_add else None
@@ -565,9 +510,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
565
510
  param_data, loaded_weight, 0
566
511
  )
567
512
 
568
- assert (
569
- param_data.shape == loaded_weight.shape
570
- ), f"{param_data.shape=}, {loaded_weight.shape=}"
513
+ assert param_data.shape == loaded_weight.shape
571
514
  param_data.copy_(loaded_weight)
572
515
  return
573
516
  current_shard_offset = 0
@@ -643,9 +586,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
643
586
  "the same for all partitions."
644
587
  )
645
588
 
646
- assert (
647
- param_data.shape == loaded_weight.shape
648
- ), f"{param_data.shape=}, {loaded_weight.shape=}"
589
+ assert param_data.shape == loaded_weight.shape
649
590
  param_data.copy_(loaded_weight)
650
591
 
651
592
  def _load_fused_module_from_checkpoint(
@@ -697,6 +638,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
697
638
  elif type(param) in (RowvLLMParameter, BasevLLMParameter):
698
639
  param.load_merged_column_weight(loaded_weight=loaded_weight)
699
640
  return
641
+ # TODO: @dsikka - move to parameter.py
700
642
  self._load_fused_module_from_checkpoint(param, loaded_weight)
701
643
  return
702
644
 
@@ -882,6 +824,7 @@ class QKVParallelLinear(ColumnParallelLinear):
882
824
  elif type(param) in (RowvLLMParameter, BasevLLMParameter):
883
825
  param.load_qkv_weight(loaded_weight=loaded_weight)
884
826
  return
827
+ # TODO: @dsikka - move to parameter.py
885
828
  self._load_fused_module_from_checkpoint(param, loaded_weight)
886
829
  return
887
830
 
@@ -896,24 +839,14 @@ class QKVParallelLinear(ColumnParallelLinear):
896
839
  shard_offset = (shard_offset + block_n - 1) // block_n
897
840
  shard_size = (shard_size + block_n - 1) // block_n
898
841
 
899
- if isinstance(param, _ColumnvLLMParameter):
900
- load_column_qkv_weight(
901
- param,
902
- loaded_weight,
903
- num_heads=self.num_kv_head_replicas,
904
- shard_id=loaded_shard_id,
905
- shard_offset=shard_offset,
906
- shard_size=shard_size,
907
- tp_rank=self.tp_rank,
908
- )
909
- else:
910
- param.load_qkv_weight(
911
- loaded_weight=loaded_weight,
912
- num_heads=self.num_kv_head_replicas,
913
- shard_id=loaded_shard_id,
914
- shard_offset=shard_offset,
915
- shard_size=shard_size,
916
- )
842
+ param.load_qkv_weight(
843
+ loaded_weight=loaded_weight,
844
+ num_heads=self.num_kv_head_replicas,
845
+ shard_id=loaded_shard_id,
846
+ shard_offset=shard_offset,
847
+ shard_size=shard_size,
848
+ tp_rank=self.tp_rank,
849
+ )
917
850
 
918
851
  def weight_loader(
919
852
  self,
@@ -962,9 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear):
962
895
  param_data, loaded_weight, 0
963
896
  )
964
897
 
965
- assert (
966
- param_data.shape == loaded_weight.shape
967
- ), f"{param_data.shape=}, {loaded_weight.shape=}"
898
+ assert param_data.shape == loaded_weight.shape
968
899
  param_data.copy_(loaded_weight)
969
900
  return
970
901
  shard_offsets = [
@@ -1105,9 +1036,7 @@ class QKVParallelLinear(ColumnParallelLinear):
1105
1036
  "for all partitions."
1106
1037
  )
1107
1038
 
1108
- assert (
1109
- param_data.shape == loaded_weight.shape
1110
- ), f"{param_data.shape=}, {loaded_weight.shape=}"
1039
+ assert param_data.shape == loaded_weight.shape
1111
1040
  param_data.copy_(loaded_weight)
1112
1041
 
1113
1042
 
@@ -1234,9 +1163,7 @@ class RowParallelLinear(LinearBase):
1234
1163
  if len(loaded_weight.shape) == 0:
1235
1164
  loaded_weight = loaded_weight.reshape(1)
1236
1165
 
1237
- assert (
1238
- param_data.shape == loaded_weight.shape
1239
- ), f"{param_data.shape=}, {loaded_weight.shape=}"
1166
+ assert param_data.shape == loaded_weight.shape
1240
1167
  param_data.copy_(loaded_weight)
1241
1168
 
1242
1169
  def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
@@ -1247,7 +1174,18 @@ class RowParallelLinear(LinearBase):
1247
1174
  assert loaded_weight.numel() == 1
1248
1175
  loaded_weight = loaded_weight.reshape(1)
1249
1176
 
1250
- param.load_row_parallel_weight(loaded_weight=loaded_weight)
1177
+ if isinstance(param, BasevLLMParameter):
1178
+ # This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py,
1179
+ # It supports additional parameters like tp_rank and use_presharded_weights.
1180
+ param.load_row_parallel_weight(
1181
+ loaded_weight,
1182
+ tp_rank=self.tp_rank,
1183
+ use_presharded_weights=self.use_presharded_weights,
1184
+ )
1185
+ else:
1186
+ # `params` is defined in `vllm/model_executor/parameter.py`,
1187
+ # It does not support additional parameters.
1188
+ param.load_row_parallel_weight(loaded_weight)
1251
1189
 
1252
1190
  def forward(self, input_):
1253
1191
  if self.input_is_parallel:
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
18
18
  QuantizationConfig,
19
19
  QuantizeMethodBase,
20
20
  )
21
- from sglang.srt.utils import set_weight_attrs
21
+ from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs
22
22
 
23
23
  if torch.cuda.is_available():
24
24
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -27,6 +27,8 @@ else:
27
27
 
28
28
  import logging
29
29
 
30
+ is_hip_ = is_hip()
31
+
30
32
  logger = logging.getLogger(__name__)
31
33
 
32
34
 
@@ -97,6 +99,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
97
99
  layer.register_parameter("w2_weight", w2_weight)
98
100
  set_weight_attrs(w2_weight, extra_weight_attrs)
99
101
 
102
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
103
+ if is_hip_ and get_bool_env_var("CK_MOE"):
104
+ layer.w13_weight = torch.nn.Parameter(
105
+ permute_weight(layer.w13_weight.data),
106
+ requires_grad=False,
107
+ )
108
+ torch.cuda.empty_cache()
109
+ layer.w2_weight = torch.nn.Parameter(
110
+ permute_weight(layer.w2_weight.data),
111
+ requires_grad=False,
112
+ )
113
+ torch.cuda.empty_cache()
114
+ return
115
+
100
116
  def apply(
101
117
  self,
102
118
  layer: torch.nn.Module,
@@ -148,14 +164,26 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
148
164
  correction_bias=correction_bias,
149
165
  )
150
166
 
151
- return fused_experts(
152
- hidden_states=x,
153
- w1=layer.w13_weight,
154
- w2=layer.w2_weight,
155
- topk_weights=topk_weights,
156
- topk_ids=topk_ids,
157
- inplace=True,
158
- )
167
+ if is_hip_ and get_bool_env_var("CK_MOE"):
168
+ import ater
169
+ from ater.fused_moe import fused_experts_ck
170
+
171
+ return fused_experts_ck(
172
+ hidden_states=x,
173
+ w1=layer.w13_weight,
174
+ w2=layer.w2_weight,
175
+ topk_weights=topk_weights,
176
+ topk_ids=topk_ids,
177
+ )
178
+ else:
179
+ return fused_experts(
180
+ hidden_states=x,
181
+ w1=layer.w13_weight,
182
+ w2=layer.w2_weight,
183
+ topk_weights=topk_weights,
184
+ topk_ids=topk_ids,
185
+ inplace=True,
186
+ )
159
187
 
160
188
  def forward_cpu(self, *args, **kwargs):
161
189
  raise NotImplementedError("The CPU backend currently does not support MoE.")
@@ -24,7 +24,9 @@ def fused_topk_native(
24
24
  topk: int,
25
25
  renormalize: bool,
26
26
  ):
27
- assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
27
+ assert (
28
+ hidden_states.shape[0] == gating_output.shape[0]
29
+ ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
28
30
  M, _ = hidden_states.shape
29
31
  topk_weights = torch.empty(
30
32
  M, topk, dtype=torch.float32, device=hidden_states.device
@@ -180,7 +182,7 @@ def select_experts(
180
182
  num_expert_group=num_expert_group,
181
183
  topk_group=topk_group,
182
184
  )
183
- elif torch_native:
185
+ elif torch_native and custom_routing_function is None:
184
186
  topk_weights, topk_ids = fused_topk_native(
185
187
  hidden_states=hidden_states,
186
188
  gating_output=router_logits,
@@ -1,7 +1,4 @@
1
- """
2
- Adapted from vLLM (0.6.4.post1).
3
- https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/parameter.py
4
- """
1
+ """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/parameter.py"""
5
2
 
6
3
  import logging
7
4
  from fractions import Fraction
@@ -88,12 +85,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
88
85
  def output_dim(self):
89
86
  return self._output_dim
90
87
 
91
- def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
92
- tp_rank = get_tensor_model_parallel_rank()
93
- shard_size = self.data.shape[self.output_dim]
94
- loaded_weight = loaded_weight.narrow(
95
- self.output_dim, tp_rank * shard_size, shard_size
96
- )
88
+ def load_column_parallel_weight(
89
+ self,
90
+ loaded_weight: torch.Tensor,
91
+ tp_rank: int,
92
+ use_presharded_weights: bool = False,
93
+ ):
94
+ if not use_presharded_weights:
95
+ shard_size = self.data.shape[self.output_dim]
96
+ loaded_weight = loaded_weight.narrow(
97
+ self.output_dim, tp_rank * shard_size, shard_size
98
+ )
97
99
  assert self.data.shape == loaded_weight.shape
98
100
  self.data.copy_(loaded_weight)
99
101
 
@@ -121,7 +123,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
121
123
  assert param_data.shape == loaded_weight.shape
122
124
  param_data.copy_(loaded_weight)
123
125
 
124
- def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
126
+ def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
125
127
 
126
128
  shard_offset = kwargs.get("shard_offset")
127
129
  shard_size = kwargs.get("shard_size")
@@ -137,7 +139,6 @@ class _ColumnvLLMParameter(BasevLLMParameter):
137
139
  )
138
140
 
139
141
  param_data = self.data
140
- tp_rank = get_tensor_model_parallel_rank()
141
142
  shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
142
143
  param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
143
144
  loaded_weight = loaded_weight.narrow(
@@ -164,11 +165,14 @@ class RowvLLMParameter(BasevLLMParameter):
164
165
  def input_dim(self):
165
166
  return self._input_dim
166
167
 
167
- def load_row_parallel_weight(self, loaded_weight: torch.Tensor, **kwargs):
168
- use_presharded_weights = kwargs.get("use_presharded_weights")
169
- tp_rank = get_tensor_model_parallel_rank()
170
- shard_size = self.data.shape[self.input_dim]
168
+ def load_row_parallel_weight(
169
+ self,
170
+ loaded_weight: torch.Tensor,
171
+ tp_rank: int,
172
+ use_presharded_weights: bool = False,
173
+ ):
171
174
  if not use_presharded_weights:
175
+ shard_size = self.data.shape[self.input_dim]
172
176
  loaded_weight = loaded_weight.narrow(
173
177
  self.input_dim, tp_rank * shard_size, shard_size
174
178
  )
@@ -238,6 +242,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
238
242
  # For row parallel layers, no sharding needed
239
243
  # load weight into parameter as is
240
244
  def load_row_parallel_weight(self, *args, **kwargs):
245
+ kwargs.pop("tp_rank", None)
246
+ kwargs.pop("use_presharded_weights", None)
241
247
  super().load_row_parallel_weight(*args, **kwargs)
242
248
 
243
249
  def load_merged_column_weight(self, *args, **kwargs):
@@ -247,6 +253,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
247
253
  self._load_into_shard_id(*args, **kwargs)
248
254
 
249
255
  def load_column_parallel_weight(self, *args, **kwargs):
256
+ kwargs.pop("tp_rank", None)
257
+ kwargs.pop("use_presharded_weights", None)
250
258
  super().load_row_parallel_weight(*args, **kwargs)
251
259
 
252
260
  def _load_into_shard_id(
@@ -23,6 +23,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
23
23
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
24
24
  from sglang.srt.layers.quantization.fp8 import Fp8Config
25
25
  from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
26
+ from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
26
27
 
27
28
  QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
28
29
  "aqlm": AQLMConfig,
@@ -42,6 +43,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
42
43
  "bitsandbytes": BitsAndBytesConfig,
43
44
  "qqq": QQQConfig,
44
45
  "experts_int8": ExpertsInt8Config,
46
+ "w8a8_int8": W8A8Int8Config,
45
47
  }
46
48
 
47
49