sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +8 -1
- sglang/srt/layers/attention/flashinfer_backend.py +4 -2
- sglang/srt/layers/linear.py +159 -55
- sglang/srt/layers/logits_processor.py +6 -6
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -3
- sglang/srt/layers/parameter.py +431 -0
- sglang/srt/layers/quantization/__init__.py +3 -2
- sglang/srt/layers/quantization/fp8.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -1
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/schedule_batch.py +7 -1
- sglang/srt/managers/scheduler.py +10 -6
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +6 -2
- sglang/srt/mem_cache/memory_pool.py +206 -1
- sglang/srt/metrics/collector.py +22 -30
- sglang/srt/model_executor/cuda_graph_runner.py +14 -7
- sglang/srt/model_executor/forward_batch_info.py +20 -15
- sglang/srt/model_executor/model_runner.py +10 -4
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/grok.py +25 -16
- sglang/srt/models/llama.py +9 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -0
- sglang/srt/server.py +11 -8
- sglang/srt/server_args.py +12 -1
- sglang/srt/speculative/eagle_utils.py +93 -85
- sglang/srt/speculative/eagle_worker.py +47 -33
- sglang/srt/utils.py +32 -5
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +6 -7
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +48 -43
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
sglang/srt/layers/linear.py
CHANGED
@@ -18,14 +18,15 @@ from vllm.distributed import (
|
|
18
18
|
|
19
19
|
# workaround
|
20
20
|
from vllm.model_executor.layers.linear import LinearBase
|
21
|
-
|
21
|
+
|
22
|
+
from sglang.srt.layers.parameter import (
|
22
23
|
BasevLLMParameter,
|
23
24
|
PackedColumnParameter,
|
24
25
|
PackedvLLMParameter,
|
25
26
|
PerTensorScaleParameter,
|
26
27
|
RowvLLMParameter,
|
28
|
+
_ColumnvLLMParameter,
|
27
29
|
)
|
28
|
-
|
29
30
|
from sglang.srt.layers.quantization.base_config import (
|
30
31
|
QuantizationConfig,
|
31
32
|
QuantizeMethodBase,
|
@@ -44,6 +45,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|
44
45
|
"MarlinLinearMethod",
|
45
46
|
"GPTQLinearMethod",
|
46
47
|
"QQQLinearMethod",
|
48
|
+
"ModelOptFp8LinearMethod",
|
47
49
|
]
|
48
50
|
|
49
51
|
|
@@ -93,6 +95,62 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
|
93
95
|
return param[shard_id], loaded_weight
|
94
96
|
|
95
97
|
|
98
|
+
def load_column_qkv_weight(
|
99
|
+
self, loaded_weight, num_heads, shard_id, shard_offset, shard_size, tp_rank
|
100
|
+
):
|
101
|
+
if (
|
102
|
+
isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
|
103
|
+
and self.output_dim == self.packed_dim
|
104
|
+
):
|
105
|
+
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
106
|
+
shard_offset=shard_offset, shard_size=shard_size
|
107
|
+
)
|
108
|
+
|
109
|
+
param_data = self.data
|
110
|
+
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
111
|
+
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
112
|
+
loaded_weight = loaded_weight.narrow(
|
113
|
+
self.output_dim, shard_id * shard_size, shard_size
|
114
|
+
)
|
115
|
+
|
116
|
+
assert param_data.shape == loaded_weight.shape
|
117
|
+
param_data.copy_(loaded_weight)
|
118
|
+
|
119
|
+
|
120
|
+
def load_column_parallel_weight(
|
121
|
+
self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
|
122
|
+
):
|
123
|
+
if isinstance(self, _ColumnvLLMParameter):
|
124
|
+
if not use_presharded_weights:
|
125
|
+
shard_size = self.data.shape[self.output_dim]
|
126
|
+
loaded_weight = loaded_weight.narrow(
|
127
|
+
self.output_dim, tp_rank * shard_size, shard_size
|
128
|
+
)
|
129
|
+
assert self.data.shape == loaded_weight.shape
|
130
|
+
self.data.copy_(loaded_weight)
|
131
|
+
else:
|
132
|
+
self.data.copy_(loaded_weight)
|
133
|
+
|
134
|
+
|
135
|
+
def load_row_parallel_weight(
|
136
|
+
self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
|
137
|
+
):
|
138
|
+
if isinstance(self, RowvLLMParameter):
|
139
|
+
if not use_presharded_weights:
|
140
|
+
shard_size = self.data.shape[self.input_dim]
|
141
|
+
loaded_weight = loaded_weight.narrow(
|
142
|
+
self.input_dim, tp_rank * shard_size, shard_size
|
143
|
+
)
|
144
|
+
|
145
|
+
if len(loaded_weight.shape) == 0:
|
146
|
+
loaded_weight = loaded_weight.reshape(1)
|
147
|
+
|
148
|
+
assert self.data.shape == loaded_weight.shape
|
149
|
+
self.data.copy_(loaded_weight)
|
150
|
+
else:
|
151
|
+
self.data.copy_(loaded_weight)
|
152
|
+
|
153
|
+
|
96
154
|
class LinearMethodBase(QuantizeMethodBase):
|
97
155
|
"""Base class for different (maybe quantized) linear methods."""
|
98
156
|
|
@@ -286,6 +344,8 @@ class ColumnParallelLinear(LinearBase):
|
|
286
344
|
quant_config: Optional[QuantizationConfig] = None,
|
287
345
|
output_sizes: Optional[List[int]] = None,
|
288
346
|
prefix: str = "",
|
347
|
+
tp_rank: Optional[int] = None,
|
348
|
+
tp_size: Optional[int] = None,
|
289
349
|
):
|
290
350
|
super().__init__(
|
291
351
|
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
@@ -294,7 +354,11 @@ class ColumnParallelLinear(LinearBase):
|
|
294
354
|
self.gather_output = gather_output
|
295
355
|
|
296
356
|
# Divide the weight matrix along the last dimension.
|
297
|
-
|
357
|
+
if tp_rank is None:
|
358
|
+
tp_rank = get_tensor_model_parallel_rank()
|
359
|
+
if tp_size is None:
|
360
|
+
tp_size = get_tensor_model_parallel_world_size()
|
361
|
+
self.tp_rank, self.tp_size = tp_rank, tp_size
|
298
362
|
assert self.quant_method is not None
|
299
363
|
self.output_size_per_partition = divide(self.output_size, tp_size)
|
300
364
|
self.output_partition_sizes = [self.output_size_per_partition]
|
@@ -335,7 +399,6 @@ class ColumnParallelLinear(LinearBase):
|
|
335
399
|
self.register_parameter("bias", None)
|
336
400
|
|
337
401
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
338
|
-
tp_rank = get_tensor_model_parallel_rank()
|
339
402
|
output_dim = getattr(param, "output_dim", None)
|
340
403
|
|
341
404
|
# Special case for GGUF
|
@@ -355,7 +418,7 @@ class ColumnParallelLinear(LinearBase):
|
|
355
418
|
# no need to narrow here
|
356
419
|
if output_dim is not None and not use_bitsandbytes_4bit:
|
357
420
|
shard_size = param_data.shape[output_dim]
|
358
|
-
start_idx = tp_rank * shard_size
|
421
|
+
start_idx = self.tp_rank * shard_size
|
359
422
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
360
423
|
|
361
424
|
# Special case for loading scales off disk, which often do not
|
@@ -363,7 +426,9 @@ class ColumnParallelLinear(LinearBase):
|
|
363
426
|
if len(loaded_weight.shape) == 0:
|
364
427
|
loaded_weight = loaded_weight.reshape(1)
|
365
428
|
|
366
|
-
assert
|
429
|
+
assert (
|
430
|
+
param_data.shape == loaded_weight.shape
|
431
|
+
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
367
432
|
param_data.copy_(loaded_weight)
|
368
433
|
|
369
434
|
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
@@ -392,7 +457,7 @@ class ColumnParallelLinear(LinearBase):
|
|
392
457
|
s = f"in_features={self.input_size}"
|
393
458
|
s += f", output_features={self.output_size_per_partition}"
|
394
459
|
s += f", bias={self.bias is not None}"
|
395
|
-
s += f", tp_size={
|
460
|
+
s += f", tp_size={self.tp_size}"
|
396
461
|
s += f", gather_output={self.gather_output}"
|
397
462
|
return s
|
398
463
|
|
@@ -430,10 +495,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
430
495
|
params_dtype: Optional[torch.dtype] = None,
|
431
496
|
quant_config: Optional[QuantizationConfig] = None,
|
432
497
|
prefix: str = "",
|
498
|
+
tp_rank: Optional[int] = None,
|
499
|
+
tp_size: Optional[int] = None,
|
500
|
+
use_presharded_weights: bool = False,
|
433
501
|
):
|
434
502
|
self.output_sizes = output_sizes
|
435
|
-
|
503
|
+
if tp_rank is None:
|
504
|
+
tp_rank = get_tensor_model_parallel_rank()
|
505
|
+
if tp_size is None:
|
506
|
+
tp_size = get_tensor_model_parallel_world_size()
|
507
|
+
self.tp_rank, self.tp_size = tp_rank, tp_size
|
436
508
|
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
509
|
+
self.use_presharded_weights = use_presharded_weights
|
437
510
|
super().__init__(
|
438
511
|
input_size=input_size,
|
439
512
|
output_size=sum(output_sizes),
|
@@ -443,6 +516,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
443
516
|
params_dtype=params_dtype,
|
444
517
|
quant_config=quant_config,
|
445
518
|
prefix=prefix,
|
519
|
+
tp_rank=tp_rank,
|
520
|
+
tp_size=tp_size,
|
446
521
|
)
|
447
522
|
|
448
523
|
def weight_loader(
|
@@ -462,12 +537,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
462
537
|
return
|
463
538
|
|
464
539
|
if is_gguf_weight:
|
465
|
-
tp_size = get_tensor_model_parallel_world_size()
|
466
|
-
tp_rank = get_tensor_model_parallel_rank()
|
467
|
-
|
468
540
|
output_dim = getattr(param, "output_dim", None)
|
469
|
-
shard_size = loaded_weight.size(output_dim) // tp_size
|
470
|
-
start_idx = tp_rank * shard_size
|
541
|
+
shard_size = loaded_weight.size(output_dim) // self.tp_size
|
542
|
+
start_idx = self.tp_rank * shard_size
|
471
543
|
|
472
544
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
473
545
|
|
@@ -493,7 +565,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
493
565
|
param_data, loaded_weight, 0
|
494
566
|
)
|
495
567
|
|
496
|
-
assert
|
568
|
+
assert (
|
569
|
+
param_data.shape == loaded_weight.shape
|
570
|
+
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
497
571
|
param_data.copy_(loaded_weight)
|
498
572
|
return
|
499
573
|
current_shard_offset = 0
|
@@ -521,11 +595,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
521
595
|
return
|
522
596
|
|
523
597
|
assert loaded_shard_id < len(self.output_sizes)
|
524
|
-
tp_rank = get_tensor_model_parallel_rank()
|
525
|
-
tp_size = get_tensor_model_parallel_world_size()
|
526
598
|
if output_dim is not None:
|
527
|
-
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
528
|
-
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
599
|
+
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
600
|
+
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
529
601
|
# Special case for quantization.
|
530
602
|
# If quantized, we need to adjust the offset and size to account
|
531
603
|
# for the packing.
|
@@ -544,10 +616,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
544
616
|
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
|
545
617
|
|
546
618
|
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
547
|
-
start_idx = tp_rank * shard_size
|
619
|
+
start_idx = self.tp_rank * shard_size
|
548
620
|
# bitsandbytes loads the weights of the specific portion
|
549
621
|
# no need to narrow here
|
550
|
-
if not use_bitsandbytes_4bit:
|
622
|
+
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
551
623
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
552
624
|
# Special case for AQLM codebooks.
|
553
625
|
elif is_metadata:
|
@@ -571,7 +643,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
571
643
|
"the same for all partitions."
|
572
644
|
)
|
573
645
|
|
574
|
-
assert
|
646
|
+
assert (
|
647
|
+
param_data.shape == loaded_weight.shape
|
648
|
+
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
575
649
|
param_data.copy_(loaded_weight)
|
576
650
|
|
577
651
|
def _load_fused_module_from_checkpoint(
|
@@ -628,26 +702,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
628
702
|
|
629
703
|
assert loaded_shard_id < len(self.output_sizes)
|
630
704
|
|
631
|
-
tp_size = get_tensor_model_parallel_world_size()
|
632
|
-
|
633
705
|
if isinstance(param, BlockQuantScaleParameter):
|
634
706
|
weight_block_size = self.quant_method.quant_config.weight_block_size
|
635
707
|
block_n, _ = weight_block_size[0], weight_block_size[1]
|
636
708
|
shard_offset = (
|
637
709
|
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
|
638
|
-
) // tp_size
|
710
|
+
) // self.tp_size
|
639
711
|
shard_size = (
|
640
|
-
(self.output_sizes[loaded_shard_id] + block_n - 1)
|
712
|
+
(self.output_sizes[loaded_shard_id] + block_n - 1)
|
713
|
+
// block_n
|
714
|
+
// self.tp_size
|
641
715
|
)
|
642
716
|
else:
|
643
|
-
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
644
|
-
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
717
|
+
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
718
|
+
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
645
719
|
|
646
720
|
param.load_merged_column_weight(
|
647
721
|
loaded_weight=loaded_weight,
|
648
722
|
shard_id=loaded_shard_id,
|
649
723
|
shard_offset=shard_offset,
|
650
724
|
shard_size=shard_size,
|
725
|
+
use_presharded_weights=self.use_presharded_weights,
|
651
726
|
)
|
652
727
|
|
653
728
|
|
@@ -688,6 +763,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
688
763
|
params_dtype: Optional[torch.dtype] = None,
|
689
764
|
quant_config: Optional[QuantizationConfig] = None,
|
690
765
|
prefix: str = "",
|
766
|
+
tp_rank: Optional[int] = None,
|
767
|
+
tp_size: Optional[int] = None,
|
691
768
|
):
|
692
769
|
self.hidden_size = hidden_size
|
693
770
|
self.head_size = head_size
|
@@ -696,7 +773,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
696
773
|
total_num_kv_heads = total_num_heads
|
697
774
|
self.total_num_kv_heads = total_num_kv_heads
|
698
775
|
# Divide the weight matrix along the last dimension.
|
699
|
-
|
776
|
+
if tp_rank is None:
|
777
|
+
tp_rank = get_tensor_model_parallel_rank()
|
778
|
+
if tp_size is None:
|
779
|
+
tp_size = get_tensor_model_parallel_world_size()
|
780
|
+
self.tp_rank, self.tp_size = tp_rank, tp_size
|
700
781
|
self.num_heads = divide(self.total_num_heads, tp_size)
|
701
782
|
if tp_size >= self.total_num_kv_heads:
|
702
783
|
self.num_kv_heads = 1
|
@@ -723,6 +804,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
723
804
|
params_dtype=params_dtype,
|
724
805
|
quant_config=quant_config,
|
725
806
|
prefix=prefix,
|
807
|
+
tp_rank=tp_rank,
|
808
|
+
tp_size=tp_size,
|
726
809
|
)
|
727
810
|
|
728
811
|
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
@@ -813,13 +896,24 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
813
896
|
shard_offset = (shard_offset + block_n - 1) // block_n
|
814
897
|
shard_size = (shard_size + block_n - 1) // block_n
|
815
898
|
|
816
|
-
param
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
899
|
+
if isinstance(param, _ColumnvLLMParameter):
|
900
|
+
load_column_qkv_weight(
|
901
|
+
param,
|
902
|
+
loaded_weight,
|
903
|
+
num_heads=self.num_kv_head_replicas,
|
904
|
+
shard_id=loaded_shard_id,
|
905
|
+
shard_offset=shard_offset,
|
906
|
+
shard_size=shard_size,
|
907
|
+
tp_rank=self.tp_rank,
|
908
|
+
)
|
909
|
+
else:
|
910
|
+
param.load_qkv_weight(
|
911
|
+
loaded_weight=loaded_weight,
|
912
|
+
num_heads=self.num_kv_head_replicas,
|
913
|
+
shard_id=loaded_shard_id,
|
914
|
+
shard_offset=shard_offset,
|
915
|
+
shard_size=shard_size,
|
916
|
+
)
|
823
917
|
|
824
918
|
def weight_loader(
|
825
919
|
self,
|
@@ -839,12 +933,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
839
933
|
return
|
840
934
|
|
841
935
|
if is_gguf_weight:
|
842
|
-
tp_size = get_tensor_model_parallel_world_size()
|
843
|
-
tp_rank = get_tensor_model_parallel_rank()
|
844
|
-
|
845
936
|
output_dim = getattr(param, "output_dim", None)
|
846
|
-
shard_size = loaded_weight.size(output_dim) // tp_size
|
847
|
-
start_idx = tp_rank * shard_size
|
937
|
+
shard_size = loaded_weight.size(output_dim) // self.tp_size
|
938
|
+
start_idx = self.tp_rank * shard_size
|
848
939
|
|
849
940
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
850
941
|
|
@@ -871,7 +962,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
871
962
|
param_data, loaded_weight, 0
|
872
963
|
)
|
873
964
|
|
874
|
-
assert
|
965
|
+
assert (
|
966
|
+
param_data.shape == loaded_weight.shape
|
967
|
+
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
875
968
|
param_data.copy_(loaded_weight)
|
876
969
|
return
|
877
970
|
shard_offsets = [
|
@@ -933,7 +1026,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
933
1026
|
self.weight_loader(param, loaded_weight_shard, shard_id)
|
934
1027
|
return
|
935
1028
|
|
936
|
-
tp_rank = get_tensor_model_parallel_rank()
|
937
1029
|
assert loaded_shard_id in ["q", "k", "v"]
|
938
1030
|
|
939
1031
|
# If output dim is defined, use the default loading process.
|
@@ -983,9 +1075,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
983
1075
|
|
984
1076
|
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
985
1077
|
if loaded_shard_id == "q":
|
986
|
-
shard_id = tp_rank
|
1078
|
+
shard_id = self.tp_rank
|
987
1079
|
else:
|
988
|
-
shard_id = tp_rank // self.num_kv_head_replicas
|
1080
|
+
shard_id = self.tp_rank // self.num_kv_head_replicas
|
989
1081
|
start_idx = shard_id * shard_size
|
990
1082
|
|
991
1083
|
# bitsandbytes loads the weights of the specific portion
|
@@ -1013,7 +1105,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
1013
1105
|
"for all partitions."
|
1014
1106
|
)
|
1015
1107
|
|
1016
|
-
assert
|
1108
|
+
assert (
|
1109
|
+
param_data.shape == loaded_weight.shape
|
1110
|
+
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
1017
1111
|
param_data.copy_(loaded_weight)
|
1018
1112
|
|
1019
1113
|
|
@@ -1054,6 +1148,9 @@ class RowParallelLinear(LinearBase):
|
|
1054
1148
|
reduce_results: bool = True,
|
1055
1149
|
quant_config: Optional[QuantizationConfig] = None,
|
1056
1150
|
prefix: str = "",
|
1151
|
+
tp_rank: Optional[int] = None,
|
1152
|
+
tp_size: Optional[int] = None,
|
1153
|
+
use_presharded_weights: bool = False,
|
1057
1154
|
):
|
1058
1155
|
super().__init__(
|
1059
1156
|
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
@@ -1063,10 +1160,14 @@ class RowParallelLinear(LinearBase):
|
|
1063
1160
|
self.reduce_results = reduce_results
|
1064
1161
|
|
1065
1162
|
# Divide the weight matrix along the last dimension.
|
1066
|
-
|
1067
|
-
|
1163
|
+
if tp_rank is None:
|
1164
|
+
tp_rank = get_tensor_model_parallel_rank()
|
1165
|
+
if tp_size is None:
|
1166
|
+
tp_size = get_tensor_model_parallel_world_size()
|
1167
|
+
self.tp_rank, self.tp_size = tp_rank, tp_size
|
1068
1168
|
self.input_size_per_partition = divide(input_size, self.tp_size)
|
1069
1169
|
assert self.quant_method is not None
|
1170
|
+
self.use_presharded_weights = use_presharded_weights
|
1070
1171
|
|
1071
1172
|
self.quant_method.create_weights(
|
1072
1173
|
layer=self,
|
@@ -1100,8 +1201,6 @@ class RowParallelLinear(LinearBase):
|
|
1100
1201
|
self.register_parameter("bias", None)
|
1101
1202
|
|
1102
1203
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
1103
|
-
tp_rank = get_tensor_model_parallel_rank()
|
1104
|
-
tp_size = get_tensor_model_parallel_world_size()
|
1105
1204
|
input_dim = getattr(param, "input_dim", None)
|
1106
1205
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
1107
1206
|
|
@@ -1115,15 +1214,19 @@ class RowParallelLinear(LinearBase):
|
|
1115
1214
|
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
1116
1215
|
weight_shape = list(loaded_weight.shape)
|
1117
1216
|
if input_dim:
|
1118
|
-
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
|
1217
|
+
weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
|
1119
1218
|
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
1120
1219
|
|
1121
1220
|
param_data = param.data
|
1122
1221
|
# bitsandbytes loads the weights of the specific portion
|
1123
1222
|
# no need to narrow here
|
1124
|
-
if
|
1223
|
+
if (
|
1224
|
+
input_dim is not None
|
1225
|
+
and not use_bitsandbytes_4bit
|
1226
|
+
and not self.use_presharded_weights
|
1227
|
+
):
|
1125
1228
|
shard_size = param_data.shape[input_dim]
|
1126
|
-
start_idx = tp_rank * shard_size
|
1229
|
+
start_idx = self.tp_rank * shard_size
|
1127
1230
|
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
1128
1231
|
|
1129
1232
|
# Special case for loading scales off disk, which often do not
|
@@ -1131,7 +1234,9 @@ class RowParallelLinear(LinearBase):
|
|
1131
1234
|
if len(loaded_weight.shape) == 0:
|
1132
1235
|
loaded_weight = loaded_weight.reshape(1)
|
1133
1236
|
|
1134
|
-
assert
|
1237
|
+
assert (
|
1238
|
+
param_data.shape == loaded_weight.shape
|
1239
|
+
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
1135
1240
|
param_data.copy_(loaded_weight)
|
1136
1241
|
|
1137
1242
|
def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
|
@@ -1148,11 +1253,10 @@ class RowParallelLinear(LinearBase):
|
|
1148
1253
|
if self.input_is_parallel:
|
1149
1254
|
input_parallel = input_
|
1150
1255
|
else:
|
1151
|
-
tp_rank = get_tensor_model_parallel_rank()
|
1152
1256
|
splitted_input = split_tensor_along_last_dim(
|
1153
1257
|
input_, num_partitions=self.tp_size
|
1154
1258
|
)
|
1155
|
-
input_parallel = splitted_input[tp_rank].contiguous()
|
1259
|
+
input_parallel = splitted_input[self.tp_rank].contiguous()
|
1156
1260
|
|
1157
1261
|
# Matrix multiply.
|
1158
1262
|
assert self.quant_method is not None
|
@@ -74,11 +74,6 @@ class LogitsMetadata:
|
|
74
74
|
|
75
75
|
@classmethod
|
76
76
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
77
|
-
if forward_batch.spec_info:
|
78
|
-
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
|
79
|
-
else:
|
80
|
-
capture_hidden_mode = CaptureHiddenMode.NULL
|
81
|
-
|
82
77
|
if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
|
83
78
|
extend_return_logprob = True
|
84
79
|
extend_return_top_logprob = any(
|
@@ -98,7 +93,7 @@ class LogitsMetadata:
|
|
98
93
|
|
99
94
|
return cls(
|
100
95
|
forward_mode=forward_batch.forward_mode,
|
101
|
-
capture_hidden_mode=capture_hidden_mode,
|
96
|
+
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
102
97
|
extend_return_logprob=extend_return_logprob,
|
103
98
|
extend_return_top_logprob=extend_return_top_logprob,
|
104
99
|
extend_seq_lens=forward_batch.extend_seq_lens,
|
@@ -122,6 +117,11 @@ class LogitsProcessor(nn.Module):
|
|
122
117
|
self.final_logit_softcapping = getattr(
|
123
118
|
self.config, "final_logit_softcapping", None
|
124
119
|
)
|
120
|
+
if (
|
121
|
+
self.final_logit_softcapping is not None
|
122
|
+
and self.final_logit_softcapping < 0
|
123
|
+
):
|
124
|
+
self.final_logit_softcapping = None
|
125
125
|
|
126
126
|
def forward(
|
127
127
|
self,
|
@@ -1011,11 +1011,22 @@ def fused_experts_impl(
|
|
1011
1011
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1012
1012
|
)
|
1013
1013
|
else:
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1014
|
+
if topk_ids.shape[1] == 1:
|
1015
|
+
out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_(
|
1016
|
+
intermediate_cache3[:, 0]
|
1017
|
+
)
|
1018
|
+
elif topk_ids.shape[1] == 2:
|
1019
|
+
torch.add(
|
1020
|
+
intermediate_cache3[:, 0],
|
1021
|
+
intermediate_cache3[:, 1],
|
1022
|
+
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1023
|
+
).squeeze(dim=1)
|
1024
|
+
elif topk_ids.shape[1] > 2:
|
1025
|
+
torch.sum(
|
1026
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
1027
|
+
dim=1,
|
1028
|
+
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1029
|
+
)
|
1019
1030
|
|
1020
1031
|
return out_hidden_states
|
1021
1032
|
|
@@ -204,6 +204,7 @@ class FusedMoE(torch.nn.Module):
|
|
204
204
|
prefix: str = "",
|
205
205
|
custom_routing_function: Optional[Callable] = None,
|
206
206
|
correction_bias: Optional[torch.Tensor] = None,
|
207
|
+
use_presharded_weights: bool = False,
|
207
208
|
):
|
208
209
|
super().__init__()
|
209
210
|
|
@@ -243,6 +244,7 @@ class FusedMoE(torch.nn.Module):
|
|
243
244
|
params_dtype=params_dtype,
|
244
245
|
weight_loader=self.weight_loader,
|
245
246
|
)
|
247
|
+
self.use_presharded_weights = use_presharded_weights
|
246
248
|
|
247
249
|
def _load_per_tensor_weight_scale(
|
248
250
|
self,
|
@@ -395,10 +397,7 @@ class FusedMoE(torch.nn.Module):
|
|
395
397
|
weight_name: str,
|
396
398
|
shard_id: str,
|
397
399
|
expert_id: int,
|
398
|
-
use_presharded_weights: bool = False,
|
399
400
|
) -> None:
|
400
|
-
self.use_presharded_weights = use_presharded_weights
|
401
|
-
|
402
401
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
403
402
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
404
403
|
# against known CompressionFormat enum values that have this quality
|