sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +16 -7
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +8 -1
- sglang/srt/layers/attention/flashinfer_backend.py +21 -5
- sglang/srt/layers/linear.py +89 -47
- sglang/srt/layers/logits_processor.py +6 -6
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +439 -0
- sglang/srt/layers/quantization/__init__.py +5 -2
- sglang/srt/layers/quantization/fp8.py +107 -53
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/vocab_parallel_embedding.py +16 -3
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/configure_logging.py +43 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/detokenizer_manager.py +0 -2
- sglang/srt/managers/io_struct.py +29 -13
- sglang/srt/managers/schedule_batch.py +7 -1
- sglang/srt/managers/scheduler.py +58 -15
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +109 -45
- sglang/srt/mem_cache/memory_pool.py +313 -53
- sglang/srt/metrics/collector.py +32 -35
- sglang/srt/model_executor/cuda_graph_runner.py +14 -7
- sglang/srt/model_executor/forward_batch_info.py +20 -15
- sglang/srt/model_executor/model_runner.py +53 -10
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/grok.py +25 -16
- sglang/srt/models/llama.py +46 -4
- sglang/srt/models/qwen2.py +11 -0
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +15 -5
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +125 -69
- sglang/srt/server_args.py +39 -19
- sglang/srt/speculative/eagle_utils.py +93 -85
- sglang/srt/speculative/eagle_worker.py +48 -33
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +61 -5
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -84,6 +84,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
84
84
|
self.num_wrappers = 1
|
85
85
|
self.dispatch_reason = None
|
86
86
|
|
87
|
+
# Qwen2 models require higher flashinfer workspace size
|
88
|
+
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
89
|
+
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
90
|
+
|
87
91
|
# Allocate buffers
|
88
92
|
self.workspace_buffer = torch.empty(
|
89
93
|
global_config.flashinfer_workspace_size,
|
@@ -347,11 +351,15 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
347
351
|
else forward_batch.encoder_out_cache_loc
|
348
352
|
)
|
349
353
|
|
354
|
+
logits_soft_cap = layer.logit_cap
|
355
|
+
|
350
356
|
if not self.forward_metadata.use_ragged:
|
351
357
|
if k is not None:
|
352
358
|
assert v is not None
|
353
359
|
if save_kv_cache:
|
354
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
360
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
361
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
362
|
+
)
|
355
363
|
|
356
364
|
o = prefill_wrapper_paged.forward(
|
357
365
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
@@ -359,7 +367,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
359
367
|
causal=not layer.is_cross_attention,
|
360
368
|
sm_scale=layer.scaling,
|
361
369
|
window_left=layer.sliding_window_size,
|
362
|
-
logits_soft_cap=
|
370
|
+
logits_soft_cap=logits_soft_cap,
|
371
|
+
k_scale=layer.k_scale,
|
372
|
+
v_scale=layer.v_scale,
|
363
373
|
)
|
364
374
|
else:
|
365
375
|
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
@@ -368,7 +378,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
368
378
|
v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
|
369
379
|
causal=True,
|
370
380
|
sm_scale=layer.scaling,
|
371
|
-
logits_soft_cap=
|
381
|
+
logits_soft_cap=logits_soft_cap,
|
372
382
|
)
|
373
383
|
|
374
384
|
if self.forward_metadata.extend_no_prefix:
|
@@ -385,7 +395,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
385
395
|
o, _ = merge_state(o1, s1, o2, s2)
|
386
396
|
|
387
397
|
if save_kv_cache:
|
388
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
398
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
399
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
400
|
+
)
|
389
401
|
|
390
402
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
391
403
|
|
@@ -410,13 +422,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
410
422
|
if k is not None:
|
411
423
|
assert v is not None
|
412
424
|
if save_kv_cache:
|
413
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
425
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
426
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
427
|
+
)
|
414
428
|
|
415
429
|
o = decode_wrapper.forward(
|
416
430
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
417
431
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
418
432
|
sm_scale=layer.scaling,
|
419
433
|
logits_soft_cap=layer.logit_cap,
|
434
|
+
k_scale=layer.k_scale,
|
435
|
+
v_scale=layer.v_scale,
|
420
436
|
)
|
421
437
|
|
422
438
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
sglang/srt/layers/linear.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
|
1
|
+
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from abc import abstractmethod
|
@@ -16,16 +16,16 @@ from vllm.distributed import (
|
|
16
16
|
tensor_model_parallel_all_reduce,
|
17
17
|
)
|
18
18
|
|
19
|
-
#
|
19
|
+
# Workaround: many QuantizationConfig still depends on this, so we have to use vLLM's LinearBase now.
|
20
20
|
from vllm.model_executor.layers.linear import LinearBase
|
21
|
-
|
21
|
+
|
22
|
+
from sglang.srt.layers.parameter import (
|
22
23
|
BasevLLMParameter,
|
23
24
|
PackedColumnParameter,
|
24
25
|
PackedvLLMParameter,
|
25
26
|
PerTensorScaleParameter,
|
26
27
|
RowvLLMParameter,
|
27
28
|
)
|
28
|
-
|
29
29
|
from sglang.srt.layers.quantization.base_config import (
|
30
30
|
QuantizationConfig,
|
31
31
|
QuantizeMethodBase,
|
@@ -42,8 +42,13 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|
42
42
|
"GPTQMarlinLinearMethod",
|
43
43
|
"Fp8LinearMethod",
|
44
44
|
"MarlinLinearMethod",
|
45
|
-
"GPTQLinearMethod",
|
46
45
|
"QQQLinearMethod",
|
46
|
+
"GPTQMarlin24LinearMethod",
|
47
|
+
"TPUInt8LinearMethod",
|
48
|
+
"GPTQLinearMethod",
|
49
|
+
"FBGEMMFp8LinearMethod",
|
50
|
+
"ModelOptFp8LinearMethod",
|
51
|
+
"IPEXAWQLinearMethod",
|
47
52
|
]
|
48
53
|
|
49
54
|
|
@@ -286,6 +291,8 @@ class ColumnParallelLinear(LinearBase):
|
|
286
291
|
quant_config: Optional[QuantizationConfig] = None,
|
287
292
|
output_sizes: Optional[List[int]] = None,
|
288
293
|
prefix: str = "",
|
294
|
+
tp_rank: Optional[int] = None,
|
295
|
+
tp_size: Optional[int] = None,
|
289
296
|
):
|
290
297
|
super().__init__(
|
291
298
|
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
@@ -294,7 +301,11 @@ class ColumnParallelLinear(LinearBase):
|
|
294
301
|
self.gather_output = gather_output
|
295
302
|
|
296
303
|
# Divide the weight matrix along the last dimension.
|
297
|
-
|
304
|
+
if tp_rank is None:
|
305
|
+
tp_rank = get_tensor_model_parallel_rank()
|
306
|
+
if tp_size is None:
|
307
|
+
tp_size = get_tensor_model_parallel_world_size()
|
308
|
+
self.tp_rank, self.tp_size = tp_rank, tp_size
|
298
309
|
assert self.quant_method is not None
|
299
310
|
self.output_size_per_partition = divide(self.output_size, tp_size)
|
300
311
|
self.output_partition_sizes = [self.output_size_per_partition]
|
@@ -335,7 +346,6 @@ class ColumnParallelLinear(LinearBase):
|
|
335
346
|
self.register_parameter("bias", None)
|
336
347
|
|
337
348
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
338
|
-
tp_rank = get_tensor_model_parallel_rank()
|
339
349
|
output_dim = getattr(param, "output_dim", None)
|
340
350
|
|
341
351
|
# Special case for GGUF
|
@@ -355,7 +365,7 @@ class ColumnParallelLinear(LinearBase):
|
|
355
365
|
# no need to narrow here
|
356
366
|
if output_dim is not None and not use_bitsandbytes_4bit:
|
357
367
|
shard_size = param_data.shape[output_dim]
|
358
|
-
start_idx = tp_rank * shard_size
|
368
|
+
start_idx = self.tp_rank * shard_size
|
359
369
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
360
370
|
|
361
371
|
# Special case for loading scales off disk, which often do not
|
@@ -372,7 +382,7 @@ class ColumnParallelLinear(LinearBase):
|
|
372
382
|
if len(loaded_weight.shape) == 0:
|
373
383
|
assert loaded_weight.numel() == 1
|
374
384
|
loaded_weight = loaded_weight.reshape(1)
|
375
|
-
param.load_column_parallel_weight(loaded_weight=
|
385
|
+
param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank)
|
376
386
|
|
377
387
|
def forward(self, input_):
|
378
388
|
bias = self.bias if not self.skip_bias_add else None
|
@@ -392,7 +402,7 @@ class ColumnParallelLinear(LinearBase):
|
|
392
402
|
s = f"in_features={self.input_size}"
|
393
403
|
s += f", output_features={self.output_size_per_partition}"
|
394
404
|
s += f", bias={self.bias is not None}"
|
395
|
-
s += f", tp_size={
|
405
|
+
s += f", tp_size={self.tp_size}"
|
396
406
|
s += f", gather_output={self.gather_output}"
|
397
407
|
return s
|
398
408
|
|
@@ -430,10 +440,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
430
440
|
params_dtype: Optional[torch.dtype] = None,
|
431
441
|
quant_config: Optional[QuantizationConfig] = None,
|
432
442
|
prefix: str = "",
|
443
|
+
tp_rank: Optional[int] = None,
|
444
|
+
tp_size: Optional[int] = None,
|
445
|
+
use_presharded_weights: bool = False,
|
433
446
|
):
|
434
447
|
self.output_sizes = output_sizes
|
435
|
-
|
448
|
+
if tp_rank is None:
|
449
|
+
tp_rank = get_tensor_model_parallel_rank()
|
450
|
+
if tp_size is None:
|
451
|
+
tp_size = get_tensor_model_parallel_world_size()
|
452
|
+
self.tp_rank, self.tp_size = tp_rank, tp_size
|
436
453
|
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
454
|
+
self.use_presharded_weights = use_presharded_weights
|
437
455
|
super().__init__(
|
438
456
|
input_size=input_size,
|
439
457
|
output_size=sum(output_sizes),
|
@@ -443,6 +461,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
443
461
|
params_dtype=params_dtype,
|
444
462
|
quant_config=quant_config,
|
445
463
|
prefix=prefix,
|
464
|
+
tp_rank=tp_rank,
|
465
|
+
tp_size=tp_size,
|
446
466
|
)
|
447
467
|
|
448
468
|
def weight_loader(
|
@@ -462,12 +482,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
462
482
|
return
|
463
483
|
|
464
484
|
if is_gguf_weight:
|
465
|
-
tp_size = get_tensor_model_parallel_world_size()
|
466
|
-
tp_rank = get_tensor_model_parallel_rank()
|
467
|
-
|
468
485
|
output_dim = getattr(param, "output_dim", None)
|
469
|
-
shard_size = loaded_weight.size(output_dim) // tp_size
|
470
|
-
start_idx = tp_rank * shard_size
|
486
|
+
shard_size = loaded_weight.size(output_dim) // self.tp_size
|
487
|
+
start_idx = self.tp_rank * shard_size
|
471
488
|
|
472
489
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
473
490
|
|
@@ -521,11 +538,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
521
538
|
return
|
522
539
|
|
523
540
|
assert loaded_shard_id < len(self.output_sizes)
|
524
|
-
tp_rank = get_tensor_model_parallel_rank()
|
525
|
-
tp_size = get_tensor_model_parallel_world_size()
|
526
541
|
if output_dim is not None:
|
527
|
-
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
528
|
-
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
542
|
+
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
543
|
+
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
529
544
|
# Special case for quantization.
|
530
545
|
# If quantized, we need to adjust the offset and size to account
|
531
546
|
# for the packing.
|
@@ -544,10 +559,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
544
559
|
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
|
545
560
|
|
546
561
|
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
547
|
-
start_idx = tp_rank * shard_size
|
562
|
+
start_idx = self.tp_rank * shard_size
|
548
563
|
# bitsandbytes loads the weights of the specific portion
|
549
564
|
# no need to narrow here
|
550
|
-
if not use_bitsandbytes_4bit:
|
565
|
+
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
551
566
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
552
567
|
# Special case for AQLM codebooks.
|
553
568
|
elif is_metadata:
|
@@ -623,31 +638,33 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
623
638
|
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
624
639
|
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
625
640
|
return
|
641
|
+
# TODO: @dsikka - move to parameter.py
|
626
642
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
627
643
|
return
|
628
644
|
|
629
645
|
assert loaded_shard_id < len(self.output_sizes)
|
630
646
|
|
631
|
-
tp_size = get_tensor_model_parallel_world_size()
|
632
|
-
|
633
647
|
if isinstance(param, BlockQuantScaleParameter):
|
634
648
|
weight_block_size = self.quant_method.quant_config.weight_block_size
|
635
649
|
block_n, _ = weight_block_size[0], weight_block_size[1]
|
636
650
|
shard_offset = (
|
637
651
|
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
|
638
|
-
) // tp_size
|
652
|
+
) // self.tp_size
|
639
653
|
shard_size = (
|
640
|
-
(self.output_sizes[loaded_shard_id] + block_n - 1)
|
654
|
+
(self.output_sizes[loaded_shard_id] + block_n - 1)
|
655
|
+
// block_n
|
656
|
+
// self.tp_size
|
641
657
|
)
|
642
658
|
else:
|
643
|
-
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
644
|
-
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
659
|
+
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
660
|
+
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
645
661
|
|
646
662
|
param.load_merged_column_weight(
|
647
663
|
loaded_weight=loaded_weight,
|
648
664
|
shard_id=loaded_shard_id,
|
649
665
|
shard_offset=shard_offset,
|
650
666
|
shard_size=shard_size,
|
667
|
+
use_presharded_weights=self.use_presharded_weights,
|
651
668
|
)
|
652
669
|
|
653
670
|
|
@@ -688,6 +705,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
688
705
|
params_dtype: Optional[torch.dtype] = None,
|
689
706
|
quant_config: Optional[QuantizationConfig] = None,
|
690
707
|
prefix: str = "",
|
708
|
+
tp_rank: Optional[int] = None,
|
709
|
+
tp_size: Optional[int] = None,
|
691
710
|
):
|
692
711
|
self.hidden_size = hidden_size
|
693
712
|
self.head_size = head_size
|
@@ -696,7 +715,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
696
715
|
total_num_kv_heads = total_num_heads
|
697
716
|
self.total_num_kv_heads = total_num_kv_heads
|
698
717
|
# Divide the weight matrix along the last dimension.
|
699
|
-
|
718
|
+
if tp_rank is None:
|
719
|
+
tp_rank = get_tensor_model_parallel_rank()
|
720
|
+
if tp_size is None:
|
721
|
+
tp_size = get_tensor_model_parallel_world_size()
|
722
|
+
self.tp_rank, self.tp_size = tp_rank, tp_size
|
700
723
|
self.num_heads = divide(self.total_num_heads, tp_size)
|
701
724
|
if tp_size >= self.total_num_kv_heads:
|
702
725
|
self.num_kv_heads = 1
|
@@ -723,6 +746,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
723
746
|
params_dtype=params_dtype,
|
724
747
|
quant_config=quant_config,
|
725
748
|
prefix=prefix,
|
749
|
+
tp_rank=tp_rank,
|
750
|
+
tp_size=tp_size,
|
726
751
|
)
|
727
752
|
|
728
753
|
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
@@ -799,6 +824,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
799
824
|
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
800
825
|
param.load_qkv_weight(loaded_weight=loaded_weight)
|
801
826
|
return
|
827
|
+
# TODO: @dsikka - move to parameter.py
|
802
828
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
803
829
|
return
|
804
830
|
|
@@ -819,6 +845,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
819
845
|
shard_id=loaded_shard_id,
|
820
846
|
shard_offset=shard_offset,
|
821
847
|
shard_size=shard_size,
|
848
|
+
tp_rank=self.tp_rank,
|
822
849
|
)
|
823
850
|
|
824
851
|
def weight_loader(
|
@@ -839,12 +866,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
839
866
|
return
|
840
867
|
|
841
868
|
if is_gguf_weight:
|
842
|
-
tp_size = get_tensor_model_parallel_world_size()
|
843
|
-
tp_rank = get_tensor_model_parallel_rank()
|
844
|
-
|
845
869
|
output_dim = getattr(param, "output_dim", None)
|
846
|
-
shard_size = loaded_weight.size(output_dim) // tp_size
|
847
|
-
start_idx = tp_rank * shard_size
|
870
|
+
shard_size = loaded_weight.size(output_dim) // self.tp_size
|
871
|
+
start_idx = self.tp_rank * shard_size
|
848
872
|
|
849
873
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
850
874
|
|
@@ -933,7 +957,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
933
957
|
self.weight_loader(param, loaded_weight_shard, shard_id)
|
934
958
|
return
|
935
959
|
|
936
|
-
tp_rank = get_tensor_model_parallel_rank()
|
937
960
|
assert loaded_shard_id in ["q", "k", "v"]
|
938
961
|
|
939
962
|
# If output dim is defined, use the default loading process.
|
@@ -983,9 +1006,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
983
1006
|
|
984
1007
|
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
985
1008
|
if loaded_shard_id == "q":
|
986
|
-
shard_id = tp_rank
|
1009
|
+
shard_id = self.tp_rank
|
987
1010
|
else:
|
988
|
-
shard_id = tp_rank // self.num_kv_head_replicas
|
1011
|
+
shard_id = self.tp_rank // self.num_kv_head_replicas
|
989
1012
|
start_idx = shard_id * shard_size
|
990
1013
|
|
991
1014
|
# bitsandbytes loads the weights of the specific portion
|
@@ -1054,6 +1077,9 @@ class RowParallelLinear(LinearBase):
|
|
1054
1077
|
reduce_results: bool = True,
|
1055
1078
|
quant_config: Optional[QuantizationConfig] = None,
|
1056
1079
|
prefix: str = "",
|
1080
|
+
tp_rank: Optional[int] = None,
|
1081
|
+
tp_size: Optional[int] = None,
|
1082
|
+
use_presharded_weights: bool = False,
|
1057
1083
|
):
|
1058
1084
|
super().__init__(
|
1059
1085
|
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
@@ -1063,10 +1089,14 @@ class RowParallelLinear(LinearBase):
|
|
1063
1089
|
self.reduce_results = reduce_results
|
1064
1090
|
|
1065
1091
|
# Divide the weight matrix along the last dimension.
|
1066
|
-
|
1067
|
-
|
1092
|
+
if tp_rank is None:
|
1093
|
+
tp_rank = get_tensor_model_parallel_rank()
|
1094
|
+
if tp_size is None:
|
1095
|
+
tp_size = get_tensor_model_parallel_world_size()
|
1096
|
+
self.tp_rank, self.tp_size = tp_rank, tp_size
|
1068
1097
|
self.input_size_per_partition = divide(input_size, self.tp_size)
|
1069
1098
|
assert self.quant_method is not None
|
1099
|
+
self.use_presharded_weights = use_presharded_weights
|
1070
1100
|
|
1071
1101
|
self.quant_method.create_weights(
|
1072
1102
|
layer=self,
|
@@ -1100,8 +1130,6 @@ class RowParallelLinear(LinearBase):
|
|
1100
1130
|
self.register_parameter("bias", None)
|
1101
1131
|
|
1102
1132
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
1103
|
-
tp_rank = get_tensor_model_parallel_rank()
|
1104
|
-
tp_size = get_tensor_model_parallel_world_size()
|
1105
1133
|
input_dim = getattr(param, "input_dim", None)
|
1106
1134
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
1107
1135
|
|
@@ -1115,15 +1143,19 @@ class RowParallelLinear(LinearBase):
|
|
1115
1143
|
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
1116
1144
|
weight_shape = list(loaded_weight.shape)
|
1117
1145
|
if input_dim:
|
1118
|
-
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
|
1146
|
+
weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
|
1119
1147
|
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
1120
1148
|
|
1121
1149
|
param_data = param.data
|
1122
1150
|
# bitsandbytes loads the weights of the specific portion
|
1123
1151
|
# no need to narrow here
|
1124
|
-
if
|
1152
|
+
if (
|
1153
|
+
input_dim is not None
|
1154
|
+
and not use_bitsandbytes_4bit
|
1155
|
+
and not self.use_presharded_weights
|
1156
|
+
):
|
1125
1157
|
shard_size = param_data.shape[input_dim]
|
1126
|
-
start_idx = tp_rank * shard_size
|
1158
|
+
start_idx = self.tp_rank * shard_size
|
1127
1159
|
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
1128
1160
|
|
1129
1161
|
# Special case for loading scales off disk, which often do not
|
@@ -1142,17 +1174,27 @@ class RowParallelLinear(LinearBase):
|
|
1142
1174
|
assert loaded_weight.numel() == 1
|
1143
1175
|
loaded_weight = loaded_weight.reshape(1)
|
1144
1176
|
|
1145
|
-
param
|
1177
|
+
if isinstance(param, BasevLLMParameter):
|
1178
|
+
# This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py,
|
1179
|
+
# It supports additional parameters like tp_rank and use_presharded_weights.
|
1180
|
+
param.load_row_parallel_weight(
|
1181
|
+
loaded_weight,
|
1182
|
+
tp_rank=self.tp_rank,
|
1183
|
+
use_presharded_weights=self.use_presharded_weights,
|
1184
|
+
)
|
1185
|
+
else:
|
1186
|
+
# `params` is defined in `vllm/model_executor/parameter.py`,
|
1187
|
+
# It does not support additional parameters.
|
1188
|
+
param.load_row_parallel_weight(loaded_weight)
|
1146
1189
|
|
1147
1190
|
def forward(self, input_):
|
1148
1191
|
if self.input_is_parallel:
|
1149
1192
|
input_parallel = input_
|
1150
1193
|
else:
|
1151
|
-
tp_rank = get_tensor_model_parallel_rank()
|
1152
1194
|
splitted_input = split_tensor_along_last_dim(
|
1153
1195
|
input_, num_partitions=self.tp_size
|
1154
1196
|
)
|
1155
|
-
input_parallel = splitted_input[tp_rank].contiguous()
|
1197
|
+
input_parallel = splitted_input[self.tp_rank].contiguous()
|
1156
1198
|
|
1157
1199
|
# Matrix multiply.
|
1158
1200
|
assert self.quant_method is not None
|
@@ -74,11 +74,6 @@ class LogitsMetadata:
|
|
74
74
|
|
75
75
|
@classmethod
|
76
76
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
77
|
-
if forward_batch.spec_info:
|
78
|
-
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
|
79
|
-
else:
|
80
|
-
capture_hidden_mode = CaptureHiddenMode.NULL
|
81
|
-
|
82
77
|
if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
|
83
78
|
extend_return_logprob = True
|
84
79
|
extend_return_top_logprob = any(
|
@@ -98,7 +93,7 @@ class LogitsMetadata:
|
|
98
93
|
|
99
94
|
return cls(
|
100
95
|
forward_mode=forward_batch.forward_mode,
|
101
|
-
capture_hidden_mode=capture_hidden_mode,
|
96
|
+
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
102
97
|
extend_return_logprob=extend_return_logprob,
|
103
98
|
extend_return_top_logprob=extend_return_top_logprob,
|
104
99
|
extend_seq_lens=forward_batch.extend_seq_lens,
|
@@ -122,6 +117,11 @@ class LogitsProcessor(nn.Module):
|
|
122
117
|
self.final_logit_softcapping = getattr(
|
123
118
|
self.config, "final_logit_softcapping", None
|
124
119
|
)
|
120
|
+
if (
|
121
|
+
self.final_logit_softcapping is not None
|
122
|
+
and self.final_logit_softcapping < 0
|
123
|
+
):
|
124
|
+
self.final_logit_softcapping = None
|
125
125
|
|
126
126
|
def forward(
|
127
127
|
self,
|
@@ -1011,11 +1011,22 @@ def fused_experts_impl(
|
|
1011
1011
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1012
1012
|
)
|
1013
1013
|
else:
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1014
|
+
if topk_ids.shape[1] == 1:
|
1015
|
+
out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_(
|
1016
|
+
intermediate_cache3[:, 0]
|
1017
|
+
)
|
1018
|
+
elif topk_ids.shape[1] == 2:
|
1019
|
+
torch.add(
|
1020
|
+
intermediate_cache3[:, 0],
|
1021
|
+
intermediate_cache3[:, 1],
|
1022
|
+
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1023
|
+
).squeeze(dim=1)
|
1024
|
+
elif topk_ids.shape[1] > 2:
|
1025
|
+
torch.sum(
|
1026
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
1027
|
+
dim=1,
|
1028
|
+
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1029
|
+
)
|
1019
1030
|
|
1020
1031
|
return out_hidden_states
|
1021
1032
|
|
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
18
18
|
QuantizationConfig,
|
19
19
|
QuantizeMethodBase,
|
20
20
|
)
|
21
|
-
from sglang.srt.utils import set_weight_attrs
|
21
|
+
from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs
|
22
22
|
|
23
23
|
if torch.cuda.is_available():
|
24
24
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
@@ -27,6 +27,8 @@ else:
|
|
27
27
|
|
28
28
|
import logging
|
29
29
|
|
30
|
+
is_hip_ = is_hip()
|
31
|
+
|
30
32
|
logger = logging.getLogger(__name__)
|
31
33
|
|
32
34
|
|
@@ -97,6 +99,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
97
99
|
layer.register_parameter("w2_weight", w2_weight)
|
98
100
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
99
101
|
|
102
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
103
|
+
if is_hip_ and get_bool_env_var("CK_MOE"):
|
104
|
+
layer.w13_weight = torch.nn.Parameter(
|
105
|
+
permute_weight(layer.w13_weight.data),
|
106
|
+
requires_grad=False,
|
107
|
+
)
|
108
|
+
torch.cuda.empty_cache()
|
109
|
+
layer.w2_weight = torch.nn.Parameter(
|
110
|
+
permute_weight(layer.w2_weight.data),
|
111
|
+
requires_grad=False,
|
112
|
+
)
|
113
|
+
torch.cuda.empty_cache()
|
114
|
+
return
|
115
|
+
|
100
116
|
def apply(
|
101
117
|
self,
|
102
118
|
layer: torch.nn.Module,
|
@@ -148,14 +164,26 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
148
164
|
correction_bias=correction_bias,
|
149
165
|
)
|
150
166
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
167
|
+
if is_hip_ and get_bool_env_var("CK_MOE"):
|
168
|
+
import ater
|
169
|
+
from ater.fused_moe import fused_experts_ck
|
170
|
+
|
171
|
+
return fused_experts_ck(
|
172
|
+
hidden_states=x,
|
173
|
+
w1=layer.w13_weight,
|
174
|
+
w2=layer.w2_weight,
|
175
|
+
topk_weights=topk_weights,
|
176
|
+
topk_ids=topk_ids,
|
177
|
+
)
|
178
|
+
else:
|
179
|
+
return fused_experts(
|
180
|
+
hidden_states=x,
|
181
|
+
w1=layer.w13_weight,
|
182
|
+
w2=layer.w2_weight,
|
183
|
+
topk_weights=topk_weights,
|
184
|
+
topk_ids=topk_ids,
|
185
|
+
inplace=True,
|
186
|
+
)
|
159
187
|
|
160
188
|
def forward_cpu(self, *args, **kwargs):
|
161
189
|
raise NotImplementedError("The CPU backend currently does not support MoE.")
|
@@ -204,6 +232,7 @@ class FusedMoE(torch.nn.Module):
|
|
204
232
|
prefix: str = "",
|
205
233
|
custom_routing_function: Optional[Callable] = None,
|
206
234
|
correction_bias: Optional[torch.Tensor] = None,
|
235
|
+
use_presharded_weights: bool = False,
|
207
236
|
):
|
208
237
|
super().__init__()
|
209
238
|
|
@@ -243,6 +272,7 @@ class FusedMoE(torch.nn.Module):
|
|
243
272
|
params_dtype=params_dtype,
|
244
273
|
weight_loader=self.weight_loader,
|
245
274
|
)
|
275
|
+
self.use_presharded_weights = use_presharded_weights
|
246
276
|
|
247
277
|
def _load_per_tensor_weight_scale(
|
248
278
|
self,
|
@@ -395,10 +425,7 @@ class FusedMoE(torch.nn.Module):
|
|
395
425
|
weight_name: str,
|
396
426
|
shard_id: str,
|
397
427
|
expert_id: int,
|
398
|
-
use_presharded_weights: bool = False,
|
399
428
|
) -> None:
|
400
|
-
self.use_presharded_weights = use_presharded_weights
|
401
|
-
|
402
429
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
403
430
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
404
431
|
# against known CompressionFormat enum values that have this quality
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -24,7 +24,9 @@ def fused_topk_native(
|
|
24
24
|
topk: int,
|
25
25
|
renormalize: bool,
|
26
26
|
):
|
27
|
-
assert
|
27
|
+
assert (
|
28
|
+
hidden_states.shape[0] == gating_output.shape[0]
|
29
|
+
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
|
28
30
|
M, _ = hidden_states.shape
|
29
31
|
topk_weights = torch.empty(
|
30
32
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
@@ -180,7 +182,7 @@ def select_experts(
|
|
180
182
|
num_expert_group=num_expert_group,
|
181
183
|
topk_group=topk_group,
|
182
184
|
)
|
183
|
-
elif torch_native:
|
185
|
+
elif torch_native and custom_routing_function is None:
|
184
186
|
topk_weights, topk_ids = fused_topk_native(
|
185
187
|
hidden_states=hidden_states,
|
186
188
|
gating_output=router_logits,
|