sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/bench_serving.py +18 -1
  2. sglang/lang/interpreter.py +71 -1
  3. sglang/lang/ir.py +2 -0
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/chatglm.py +78 -0
  6. sglang/srt/configs/dbrx.py +279 -0
  7. sglang/srt/configs/model_config.py +16 -7
  8. sglang/srt/hf_transformers_utils.py +9 -14
  9. sglang/srt/layers/attention/__init__.py +8 -1
  10. sglang/srt/layers/attention/flashinfer_backend.py +21 -5
  11. sglang/srt/layers/linear.py +89 -47
  12. sglang/srt/layers/logits_processor.py +6 -6
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
  15. sglang/srt/layers/moe/topk.py +4 -2
  16. sglang/srt/layers/parameter.py +439 -0
  17. sglang/srt/layers/quantization/__init__.py +5 -2
  18. sglang/srt/layers/quantization/fp8.py +107 -53
  19. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  20. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  21. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  22. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  23. sglang/srt/layers/radix_attention.py +2 -0
  24. sglang/srt/layers/vocab_parallel_embedding.py +16 -3
  25. sglang/srt/managers/cache_controller.py +307 -0
  26. sglang/srt/managers/configure_logging.py +43 -0
  27. sglang/srt/managers/data_parallel_controller.py +2 -0
  28. sglang/srt/managers/detokenizer_manager.py +0 -2
  29. sglang/srt/managers/io_struct.py +29 -13
  30. sglang/srt/managers/schedule_batch.py +7 -1
  31. sglang/srt/managers/scheduler.py +58 -15
  32. sglang/srt/managers/session_controller.py +1 -1
  33. sglang/srt/managers/tokenizer_manager.py +109 -45
  34. sglang/srt/mem_cache/memory_pool.py +313 -53
  35. sglang/srt/metrics/collector.py +32 -35
  36. sglang/srt/model_executor/cuda_graph_runner.py +14 -7
  37. sglang/srt/model_executor/forward_batch_info.py +20 -15
  38. sglang/srt/model_executor/model_runner.py +53 -10
  39. sglang/srt/models/chatglm.py +1 -1
  40. sglang/srt/models/dbrx.py +1 -1
  41. sglang/srt/models/grok.py +25 -16
  42. sglang/srt/models/llama.py +46 -4
  43. sglang/srt/models/qwen2.py +11 -0
  44. sglang/srt/models/qwen2_eagle.py +131 -0
  45. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  46. sglang/srt/sampling/sampling_batch_info.py +15 -5
  47. sglang/srt/sampling/sampling_params.py +1 -1
  48. sglang/srt/server.py +125 -69
  49. sglang/srt/server_args.py +39 -19
  50. sglang/srt/speculative/eagle_utils.py +93 -85
  51. sglang/srt/speculative/eagle_worker.py +48 -33
  52. sglang/srt/torch_memory_saver_adapter.py +59 -0
  53. sglang/srt/utils.py +61 -5
  54. sglang/test/test_programs.py +23 -1
  55. sglang/test/test_utils.py +36 -7
  56. sglang/version.py +1 -1
  57. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
  58. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
  59. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
  60. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
  61. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -84,6 +84,10 @@ class FlashInferAttnBackend(AttentionBackend):
84
84
  self.num_wrappers = 1
85
85
  self.dispatch_reason = None
86
86
 
87
+ # Qwen2 models require higher flashinfer workspace size
88
+ if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
89
+ global_config.flashinfer_workspace_size = 512 * 1024 * 1024
90
+
87
91
  # Allocate buffers
88
92
  self.workspace_buffer = torch.empty(
89
93
  global_config.flashinfer_workspace_size,
@@ -347,11 +351,15 @@ class FlashInferAttnBackend(AttentionBackend):
347
351
  else forward_batch.encoder_out_cache_loc
348
352
  )
349
353
 
354
+ logits_soft_cap = layer.logit_cap
355
+
350
356
  if not self.forward_metadata.use_ragged:
351
357
  if k is not None:
352
358
  assert v is not None
353
359
  if save_kv_cache:
354
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
360
+ forward_batch.token_to_kv_pool.set_kv_buffer(
361
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
362
+ )
355
363
 
356
364
  o = prefill_wrapper_paged.forward(
357
365
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -359,7 +367,9 @@ class FlashInferAttnBackend(AttentionBackend):
359
367
  causal=not layer.is_cross_attention,
360
368
  sm_scale=layer.scaling,
361
369
  window_left=layer.sliding_window_size,
362
- logits_soft_cap=layer.logit_cap,
370
+ logits_soft_cap=logits_soft_cap,
371
+ k_scale=layer.k_scale,
372
+ v_scale=layer.v_scale,
363
373
  )
364
374
  else:
365
375
  o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
@@ -368,7 +378,7 @@ class FlashInferAttnBackend(AttentionBackend):
368
378
  v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
369
379
  causal=True,
370
380
  sm_scale=layer.scaling,
371
- logits_soft_cap=layer.logit_cap,
381
+ logits_soft_cap=logits_soft_cap,
372
382
  )
373
383
 
374
384
  if self.forward_metadata.extend_no_prefix:
@@ -385,7 +395,9 @@ class FlashInferAttnBackend(AttentionBackend):
385
395
  o, _ = merge_state(o1, s1, o2, s2)
386
396
 
387
397
  if save_kv_cache:
388
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
398
+ forward_batch.token_to_kv_pool.set_kv_buffer(
399
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
400
+ )
389
401
 
390
402
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
391
403
 
@@ -410,13 +422,17 @@ class FlashInferAttnBackend(AttentionBackend):
410
422
  if k is not None:
411
423
  assert v is not None
412
424
  if save_kv_cache:
413
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
425
+ forward_batch.token_to_kv_pool.set_kv_buffer(
426
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
427
+ )
414
428
 
415
429
  o = decode_wrapper.forward(
416
430
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
417
431
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
418
432
  sm_scale=layer.scaling,
419
433
  logits_soft_cap=layer.logit_cap,
434
+ k_scale=layer.k_scale,
435
+ v_scale=layer.v_scale,
420
436
  )
421
437
 
422
438
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -1,4 +1,4 @@
1
- # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/linear.py
1
+ """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
2
2
 
3
3
  import logging
4
4
  from abc import abstractmethod
@@ -16,16 +16,16 @@ from vllm.distributed import (
16
16
  tensor_model_parallel_all_reduce,
17
17
  )
18
18
 
19
- # workaround
19
+ # Workaround: many QuantizationConfig still depends on this, so we have to use vLLM's LinearBase now.
20
20
  from vllm.model_executor.layers.linear import LinearBase
21
- from vllm.model_executor.parameter import (
21
+
22
+ from sglang.srt.layers.parameter import (
22
23
  BasevLLMParameter,
23
24
  PackedColumnParameter,
24
25
  PackedvLLMParameter,
25
26
  PerTensorScaleParameter,
26
27
  RowvLLMParameter,
27
28
  )
28
-
29
29
  from sglang.srt.layers.quantization.base_config import (
30
30
  QuantizationConfig,
31
31
  QuantizeMethodBase,
@@ -42,8 +42,13 @@ WEIGHT_LOADER_V2_SUPPORTED = [
42
42
  "GPTQMarlinLinearMethod",
43
43
  "Fp8LinearMethod",
44
44
  "MarlinLinearMethod",
45
- "GPTQLinearMethod",
46
45
  "QQQLinearMethod",
46
+ "GPTQMarlin24LinearMethod",
47
+ "TPUInt8LinearMethod",
48
+ "GPTQLinearMethod",
49
+ "FBGEMMFp8LinearMethod",
50
+ "ModelOptFp8LinearMethod",
51
+ "IPEXAWQLinearMethod",
47
52
  ]
48
53
 
49
54
 
@@ -286,6 +291,8 @@ class ColumnParallelLinear(LinearBase):
286
291
  quant_config: Optional[QuantizationConfig] = None,
287
292
  output_sizes: Optional[List[int]] = None,
288
293
  prefix: str = "",
294
+ tp_rank: Optional[int] = None,
295
+ tp_size: Optional[int] = None,
289
296
  ):
290
297
  super().__init__(
291
298
  input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
@@ -294,7 +301,11 @@ class ColumnParallelLinear(LinearBase):
294
301
  self.gather_output = gather_output
295
302
 
296
303
  # Divide the weight matrix along the last dimension.
297
- tp_size = get_tensor_model_parallel_world_size()
304
+ if tp_rank is None:
305
+ tp_rank = get_tensor_model_parallel_rank()
306
+ if tp_size is None:
307
+ tp_size = get_tensor_model_parallel_world_size()
308
+ self.tp_rank, self.tp_size = tp_rank, tp_size
298
309
  assert self.quant_method is not None
299
310
  self.output_size_per_partition = divide(self.output_size, tp_size)
300
311
  self.output_partition_sizes = [self.output_size_per_partition]
@@ -335,7 +346,6 @@ class ColumnParallelLinear(LinearBase):
335
346
  self.register_parameter("bias", None)
336
347
 
337
348
  def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
338
- tp_rank = get_tensor_model_parallel_rank()
339
349
  output_dim = getattr(param, "output_dim", None)
340
350
 
341
351
  # Special case for GGUF
@@ -355,7 +365,7 @@ class ColumnParallelLinear(LinearBase):
355
365
  # no need to narrow here
356
366
  if output_dim is not None and not use_bitsandbytes_4bit:
357
367
  shard_size = param_data.shape[output_dim]
358
- start_idx = tp_rank * shard_size
368
+ start_idx = self.tp_rank * shard_size
359
369
  loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
360
370
 
361
371
  # Special case for loading scales off disk, which often do not
@@ -372,7 +382,7 @@ class ColumnParallelLinear(LinearBase):
372
382
  if len(loaded_weight.shape) == 0:
373
383
  assert loaded_weight.numel() == 1
374
384
  loaded_weight = loaded_weight.reshape(1)
375
- param.load_column_parallel_weight(loaded_weight=loaded_weight)
385
+ param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank)
376
386
 
377
387
  def forward(self, input_):
378
388
  bias = self.bias if not self.skip_bias_add else None
@@ -392,7 +402,7 @@ class ColumnParallelLinear(LinearBase):
392
402
  s = f"in_features={self.input_size}"
393
403
  s += f", output_features={self.output_size_per_partition}"
394
404
  s += f", bias={self.bias is not None}"
395
- s += f", tp_size={get_tensor_model_parallel_world_size()}"
405
+ s += f", tp_size={self.tp_size}"
396
406
  s += f", gather_output={self.gather_output}"
397
407
  return s
398
408
 
@@ -430,10 +440,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
430
440
  params_dtype: Optional[torch.dtype] = None,
431
441
  quant_config: Optional[QuantizationConfig] = None,
432
442
  prefix: str = "",
443
+ tp_rank: Optional[int] = None,
444
+ tp_size: Optional[int] = None,
445
+ use_presharded_weights: bool = False,
433
446
  ):
434
447
  self.output_sizes = output_sizes
435
- tp_size = get_tensor_model_parallel_world_size()
448
+ if tp_rank is None:
449
+ tp_rank = get_tensor_model_parallel_rank()
450
+ if tp_size is None:
451
+ tp_size = get_tensor_model_parallel_world_size()
452
+ self.tp_rank, self.tp_size = tp_rank, tp_size
436
453
  assert all(output_size % tp_size == 0 for output_size in output_sizes)
454
+ self.use_presharded_weights = use_presharded_weights
437
455
  super().__init__(
438
456
  input_size=input_size,
439
457
  output_size=sum(output_sizes),
@@ -443,6 +461,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
443
461
  params_dtype=params_dtype,
444
462
  quant_config=quant_config,
445
463
  prefix=prefix,
464
+ tp_rank=tp_rank,
465
+ tp_size=tp_size,
446
466
  )
447
467
 
448
468
  def weight_loader(
@@ -462,12 +482,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
462
482
  return
463
483
 
464
484
  if is_gguf_weight:
465
- tp_size = get_tensor_model_parallel_world_size()
466
- tp_rank = get_tensor_model_parallel_rank()
467
-
468
485
  output_dim = getattr(param, "output_dim", None)
469
- shard_size = loaded_weight.size(output_dim) // tp_size
470
- start_idx = tp_rank * shard_size
486
+ shard_size = loaded_weight.size(output_dim) // self.tp_size
487
+ start_idx = self.tp_rank * shard_size
471
488
 
472
489
  loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
473
490
 
@@ -521,11 +538,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
521
538
  return
522
539
 
523
540
  assert loaded_shard_id < len(self.output_sizes)
524
- tp_rank = get_tensor_model_parallel_rank()
525
- tp_size = get_tensor_model_parallel_world_size()
526
541
  if output_dim is not None:
527
- shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
528
- shard_size = self.output_sizes[loaded_shard_id] // tp_size
542
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
543
+ shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
529
544
  # Special case for quantization.
530
545
  # If quantized, we need to adjust the offset and size to account
531
546
  # for the packing.
@@ -544,10 +559,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
544
559
  shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
545
560
 
546
561
  param_data = param_data.narrow(output_dim, shard_offset, shard_size)
547
- start_idx = tp_rank * shard_size
562
+ start_idx = self.tp_rank * shard_size
548
563
  # bitsandbytes loads the weights of the specific portion
549
564
  # no need to narrow here
550
- if not use_bitsandbytes_4bit:
565
+ if not use_bitsandbytes_4bit and not self.use_presharded_weights:
551
566
  loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
552
567
  # Special case for AQLM codebooks.
553
568
  elif is_metadata:
@@ -623,31 +638,33 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
623
638
  elif type(param) in (RowvLLMParameter, BasevLLMParameter):
624
639
  param.load_merged_column_weight(loaded_weight=loaded_weight)
625
640
  return
641
+ # TODO: @dsikka - move to parameter.py
626
642
  self._load_fused_module_from_checkpoint(param, loaded_weight)
627
643
  return
628
644
 
629
645
  assert loaded_shard_id < len(self.output_sizes)
630
646
 
631
- tp_size = get_tensor_model_parallel_world_size()
632
-
633
647
  if isinstance(param, BlockQuantScaleParameter):
634
648
  weight_block_size = self.quant_method.quant_config.weight_block_size
635
649
  block_n, _ = weight_block_size[0], weight_block_size[1]
636
650
  shard_offset = (
637
651
  (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
638
- ) // tp_size
652
+ ) // self.tp_size
639
653
  shard_size = (
640
- (self.output_sizes[loaded_shard_id] + block_n - 1) // block_n // tp_size
654
+ (self.output_sizes[loaded_shard_id] + block_n - 1)
655
+ // block_n
656
+ // self.tp_size
641
657
  )
642
658
  else:
643
- shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
644
- shard_size = self.output_sizes[loaded_shard_id] // tp_size
659
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
660
+ shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
645
661
 
646
662
  param.load_merged_column_weight(
647
663
  loaded_weight=loaded_weight,
648
664
  shard_id=loaded_shard_id,
649
665
  shard_offset=shard_offset,
650
666
  shard_size=shard_size,
667
+ use_presharded_weights=self.use_presharded_weights,
651
668
  )
652
669
 
653
670
 
@@ -688,6 +705,8 @@ class QKVParallelLinear(ColumnParallelLinear):
688
705
  params_dtype: Optional[torch.dtype] = None,
689
706
  quant_config: Optional[QuantizationConfig] = None,
690
707
  prefix: str = "",
708
+ tp_rank: Optional[int] = None,
709
+ tp_size: Optional[int] = None,
691
710
  ):
692
711
  self.hidden_size = hidden_size
693
712
  self.head_size = head_size
@@ -696,7 +715,11 @@ class QKVParallelLinear(ColumnParallelLinear):
696
715
  total_num_kv_heads = total_num_heads
697
716
  self.total_num_kv_heads = total_num_kv_heads
698
717
  # Divide the weight matrix along the last dimension.
699
- tp_size = get_tensor_model_parallel_world_size()
718
+ if tp_rank is None:
719
+ tp_rank = get_tensor_model_parallel_rank()
720
+ if tp_size is None:
721
+ tp_size = get_tensor_model_parallel_world_size()
722
+ self.tp_rank, self.tp_size = tp_rank, tp_size
700
723
  self.num_heads = divide(self.total_num_heads, tp_size)
701
724
  if tp_size >= self.total_num_kv_heads:
702
725
  self.num_kv_heads = 1
@@ -723,6 +746,8 @@ class QKVParallelLinear(ColumnParallelLinear):
723
746
  params_dtype=params_dtype,
724
747
  quant_config=quant_config,
725
748
  prefix=prefix,
749
+ tp_rank=tp_rank,
750
+ tp_size=tp_size,
726
751
  )
727
752
 
728
753
  def _get_shard_offset_mapping(self, loaded_shard_id: str):
@@ -799,6 +824,7 @@ class QKVParallelLinear(ColumnParallelLinear):
799
824
  elif type(param) in (RowvLLMParameter, BasevLLMParameter):
800
825
  param.load_qkv_weight(loaded_weight=loaded_weight)
801
826
  return
827
+ # TODO: @dsikka - move to parameter.py
802
828
  self._load_fused_module_from_checkpoint(param, loaded_weight)
803
829
  return
804
830
 
@@ -819,6 +845,7 @@ class QKVParallelLinear(ColumnParallelLinear):
819
845
  shard_id=loaded_shard_id,
820
846
  shard_offset=shard_offset,
821
847
  shard_size=shard_size,
848
+ tp_rank=self.tp_rank,
822
849
  )
823
850
 
824
851
  def weight_loader(
@@ -839,12 +866,9 @@ class QKVParallelLinear(ColumnParallelLinear):
839
866
  return
840
867
 
841
868
  if is_gguf_weight:
842
- tp_size = get_tensor_model_parallel_world_size()
843
- tp_rank = get_tensor_model_parallel_rank()
844
-
845
869
  output_dim = getattr(param, "output_dim", None)
846
- shard_size = loaded_weight.size(output_dim) // tp_size
847
- start_idx = tp_rank * shard_size
870
+ shard_size = loaded_weight.size(output_dim) // self.tp_size
871
+ start_idx = self.tp_rank * shard_size
848
872
 
849
873
  loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
850
874
 
@@ -933,7 +957,6 @@ class QKVParallelLinear(ColumnParallelLinear):
933
957
  self.weight_loader(param, loaded_weight_shard, shard_id)
934
958
  return
935
959
 
936
- tp_rank = get_tensor_model_parallel_rank()
937
960
  assert loaded_shard_id in ["q", "k", "v"]
938
961
 
939
962
  # If output dim is defined, use the default loading process.
@@ -983,9 +1006,9 @@ class QKVParallelLinear(ColumnParallelLinear):
983
1006
 
984
1007
  param_data = param_data.narrow(output_dim, shard_offset, shard_size)
985
1008
  if loaded_shard_id == "q":
986
- shard_id = tp_rank
1009
+ shard_id = self.tp_rank
987
1010
  else:
988
- shard_id = tp_rank // self.num_kv_head_replicas
1011
+ shard_id = self.tp_rank // self.num_kv_head_replicas
989
1012
  start_idx = shard_id * shard_size
990
1013
 
991
1014
  # bitsandbytes loads the weights of the specific portion
@@ -1054,6 +1077,9 @@ class RowParallelLinear(LinearBase):
1054
1077
  reduce_results: bool = True,
1055
1078
  quant_config: Optional[QuantizationConfig] = None,
1056
1079
  prefix: str = "",
1080
+ tp_rank: Optional[int] = None,
1081
+ tp_size: Optional[int] = None,
1082
+ use_presharded_weights: bool = False,
1057
1083
  ):
1058
1084
  super().__init__(
1059
1085
  input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
@@ -1063,10 +1089,14 @@ class RowParallelLinear(LinearBase):
1063
1089
  self.reduce_results = reduce_results
1064
1090
 
1065
1091
  # Divide the weight matrix along the last dimension.
1066
- self.tp_rank = get_tensor_model_parallel_rank()
1067
- self.tp_size = get_tensor_model_parallel_world_size()
1092
+ if tp_rank is None:
1093
+ tp_rank = get_tensor_model_parallel_rank()
1094
+ if tp_size is None:
1095
+ tp_size = get_tensor_model_parallel_world_size()
1096
+ self.tp_rank, self.tp_size = tp_rank, tp_size
1068
1097
  self.input_size_per_partition = divide(input_size, self.tp_size)
1069
1098
  assert self.quant_method is not None
1099
+ self.use_presharded_weights = use_presharded_weights
1070
1100
 
1071
1101
  self.quant_method.create_weights(
1072
1102
  layer=self,
@@ -1100,8 +1130,6 @@ class RowParallelLinear(LinearBase):
1100
1130
  self.register_parameter("bias", None)
1101
1131
 
1102
1132
  def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
1103
- tp_rank = get_tensor_model_parallel_rank()
1104
- tp_size = get_tensor_model_parallel_world_size()
1105
1133
  input_dim = getattr(param, "input_dim", None)
1106
1134
  use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1107
1135
 
@@ -1115,15 +1143,19 @@ class RowParallelLinear(LinearBase):
1115
1143
  if is_gguf_weight and isinstance(param, UninitializedParameter):
1116
1144
  weight_shape = list(loaded_weight.shape)
1117
1145
  if input_dim:
1118
- weight_shape[input_dim] = weight_shape[input_dim] // tp_size
1146
+ weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1119
1147
  param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1120
1148
 
1121
1149
  param_data = param.data
1122
1150
  # bitsandbytes loads the weights of the specific portion
1123
1151
  # no need to narrow here
1124
- if input_dim is not None and not use_bitsandbytes_4bit:
1152
+ if (
1153
+ input_dim is not None
1154
+ and not use_bitsandbytes_4bit
1155
+ and not self.use_presharded_weights
1156
+ ):
1125
1157
  shard_size = param_data.shape[input_dim]
1126
- start_idx = tp_rank * shard_size
1158
+ start_idx = self.tp_rank * shard_size
1127
1159
  loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1128
1160
 
1129
1161
  # Special case for loading scales off disk, which often do not
@@ -1142,17 +1174,27 @@ class RowParallelLinear(LinearBase):
1142
1174
  assert loaded_weight.numel() == 1
1143
1175
  loaded_weight = loaded_weight.reshape(1)
1144
1176
 
1145
- param.load_row_parallel_weight(loaded_weight=loaded_weight)
1177
+ if isinstance(param, BasevLLMParameter):
1178
+ # This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py,
1179
+ # It supports additional parameters like tp_rank and use_presharded_weights.
1180
+ param.load_row_parallel_weight(
1181
+ loaded_weight,
1182
+ tp_rank=self.tp_rank,
1183
+ use_presharded_weights=self.use_presharded_weights,
1184
+ )
1185
+ else:
1186
+ # `params` is defined in `vllm/model_executor/parameter.py`,
1187
+ # It does not support additional parameters.
1188
+ param.load_row_parallel_weight(loaded_weight)
1146
1189
 
1147
1190
  def forward(self, input_):
1148
1191
  if self.input_is_parallel:
1149
1192
  input_parallel = input_
1150
1193
  else:
1151
- tp_rank = get_tensor_model_parallel_rank()
1152
1194
  splitted_input = split_tensor_along_last_dim(
1153
1195
  input_, num_partitions=self.tp_size
1154
1196
  )
1155
- input_parallel = splitted_input[tp_rank].contiguous()
1197
+ input_parallel = splitted_input[self.tp_rank].contiguous()
1156
1198
 
1157
1199
  # Matrix multiply.
1158
1200
  assert self.quant_method is not None
@@ -74,11 +74,6 @@ class LogitsMetadata:
74
74
 
75
75
  @classmethod
76
76
  def from_forward_batch(cls, forward_batch: ForwardBatch):
77
- if forward_batch.spec_info:
78
- capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
79
- else:
80
- capture_hidden_mode = CaptureHiddenMode.NULL
81
-
82
77
  if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
83
78
  extend_return_logprob = True
84
79
  extend_return_top_logprob = any(
@@ -98,7 +93,7 @@ class LogitsMetadata:
98
93
 
99
94
  return cls(
100
95
  forward_mode=forward_batch.forward_mode,
101
- capture_hidden_mode=capture_hidden_mode,
96
+ capture_hidden_mode=forward_batch.capture_hidden_mode,
102
97
  extend_return_logprob=extend_return_logprob,
103
98
  extend_return_top_logprob=extend_return_top_logprob,
104
99
  extend_seq_lens=forward_batch.extend_seq_lens,
@@ -122,6 +117,11 @@ class LogitsProcessor(nn.Module):
122
117
  self.final_logit_softcapping = getattr(
123
118
  self.config, "final_logit_softcapping", None
124
119
  )
120
+ if (
121
+ self.final_logit_softcapping is not None
122
+ and self.final_logit_softcapping < 0
123
+ ):
124
+ self.final_logit_softcapping = None
125
125
 
126
126
  def forward(
127
127
  self,
@@ -1011,11 +1011,22 @@ def fused_experts_impl(
1011
1011
  out_hidden_states[begin_chunk_idx:end_chunk_idx],
1012
1012
  )
1013
1013
  else:
1014
- torch.sum(
1015
- intermediate_cache3.view(*intermediate_cache3.shape),
1016
- dim=1,
1017
- out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
1018
- )
1014
+ if topk_ids.shape[1] == 1:
1015
+ out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_(
1016
+ intermediate_cache3[:, 0]
1017
+ )
1018
+ elif topk_ids.shape[1] == 2:
1019
+ torch.add(
1020
+ intermediate_cache3[:, 0],
1021
+ intermediate_cache3[:, 1],
1022
+ out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
1023
+ ).squeeze(dim=1)
1024
+ elif topk_ids.shape[1] > 2:
1025
+ torch.sum(
1026
+ intermediate_cache3.view(*intermediate_cache3.shape),
1027
+ dim=1,
1028
+ out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
1029
+ )
1019
1030
 
1020
1031
  return out_hidden_states
1021
1032
 
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
18
18
  QuantizationConfig,
19
19
  QuantizeMethodBase,
20
20
  )
21
- from sglang.srt.utils import set_weight_attrs
21
+ from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs
22
22
 
23
23
  if torch.cuda.is_available():
24
24
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -27,6 +27,8 @@ else:
27
27
 
28
28
  import logging
29
29
 
30
+ is_hip_ = is_hip()
31
+
30
32
  logger = logging.getLogger(__name__)
31
33
 
32
34
 
@@ -97,6 +99,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
97
99
  layer.register_parameter("w2_weight", w2_weight)
98
100
  set_weight_attrs(w2_weight, extra_weight_attrs)
99
101
 
102
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
103
+ if is_hip_ and get_bool_env_var("CK_MOE"):
104
+ layer.w13_weight = torch.nn.Parameter(
105
+ permute_weight(layer.w13_weight.data),
106
+ requires_grad=False,
107
+ )
108
+ torch.cuda.empty_cache()
109
+ layer.w2_weight = torch.nn.Parameter(
110
+ permute_weight(layer.w2_weight.data),
111
+ requires_grad=False,
112
+ )
113
+ torch.cuda.empty_cache()
114
+ return
115
+
100
116
  def apply(
101
117
  self,
102
118
  layer: torch.nn.Module,
@@ -148,14 +164,26 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
148
164
  correction_bias=correction_bias,
149
165
  )
150
166
 
151
- return fused_experts(
152
- hidden_states=x,
153
- w1=layer.w13_weight,
154
- w2=layer.w2_weight,
155
- topk_weights=topk_weights,
156
- topk_ids=topk_ids,
157
- inplace=True,
158
- )
167
+ if is_hip_ and get_bool_env_var("CK_MOE"):
168
+ import ater
169
+ from ater.fused_moe import fused_experts_ck
170
+
171
+ return fused_experts_ck(
172
+ hidden_states=x,
173
+ w1=layer.w13_weight,
174
+ w2=layer.w2_weight,
175
+ topk_weights=topk_weights,
176
+ topk_ids=topk_ids,
177
+ )
178
+ else:
179
+ return fused_experts(
180
+ hidden_states=x,
181
+ w1=layer.w13_weight,
182
+ w2=layer.w2_weight,
183
+ topk_weights=topk_weights,
184
+ topk_ids=topk_ids,
185
+ inplace=True,
186
+ )
159
187
 
160
188
  def forward_cpu(self, *args, **kwargs):
161
189
  raise NotImplementedError("The CPU backend currently does not support MoE.")
@@ -204,6 +232,7 @@ class FusedMoE(torch.nn.Module):
204
232
  prefix: str = "",
205
233
  custom_routing_function: Optional[Callable] = None,
206
234
  correction_bias: Optional[torch.Tensor] = None,
235
+ use_presharded_weights: bool = False,
207
236
  ):
208
237
  super().__init__()
209
238
 
@@ -243,6 +272,7 @@ class FusedMoE(torch.nn.Module):
243
272
  params_dtype=params_dtype,
244
273
  weight_loader=self.weight_loader,
245
274
  )
275
+ self.use_presharded_weights = use_presharded_weights
246
276
 
247
277
  def _load_per_tensor_weight_scale(
248
278
  self,
@@ -395,10 +425,7 @@ class FusedMoE(torch.nn.Module):
395
425
  weight_name: str,
396
426
  shard_id: str,
397
427
  expert_id: int,
398
- use_presharded_weights: bool = False,
399
428
  ) -> None:
400
- self.use_presharded_weights = use_presharded_weights
401
-
402
429
  # compressed-tensors checkpoints with packed weights are stored flipped
403
430
  # TODO (mgoin): check self.quant_method.quant_config.quant_format
404
431
  # against known CompressionFormat enum values that have this quality
@@ -24,7 +24,9 @@ def fused_topk_native(
24
24
  topk: int,
25
25
  renormalize: bool,
26
26
  ):
27
- assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
27
+ assert (
28
+ hidden_states.shape[0] == gating_output.shape[0]
29
+ ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
28
30
  M, _ = hidden_states.shape
29
31
  topk_weights = torch.empty(
30
32
  M, topk, dtype=torch.float32, device=hidden_states.device
@@ -180,7 +182,7 @@ def select_experts(
180
182
  num_expert_group=num_expert_group,
181
183
  topk_group=topk_group,
182
184
  )
183
- elif torch_native:
185
+ elif torch_native and custom_routing_function is None:
184
186
  topk_weights, topk_ids = fused_topk_native(
185
187
  hidden_states=hidden_states,
186
188
  gating_output=router_logits,