sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/bench_serving.py +18 -1
  3. sglang/lang/interpreter.py +71 -1
  4. sglang/lang/ir.py +2 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/chatglm.py +78 -0
  7. sglang/srt/configs/dbrx.py +279 -0
  8. sglang/srt/configs/model_config.py +1 -1
  9. sglang/srt/hf_transformers_utils.py +9 -14
  10. sglang/srt/layers/attention/__init__.py +22 -6
  11. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  12. sglang/srt/layers/attention/flashinfer_backend.py +215 -83
  13. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  14. sglang/srt/layers/attention/triton_backend.py +20 -11
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  16. sglang/srt/layers/linear.py +159 -55
  17. sglang/srt/layers/logits_processor.py +170 -215
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
  41. sglang/srt/layers/parameter.py +431 -0
  42. sglang/srt/layers/quantization/__init__.py +3 -2
  43. sglang/srt/layers/quantization/fp8.py +3 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  45. sglang/srt/layers/sampler.py +57 -21
  46. sglang/srt/layers/torchao_utils.py +17 -3
  47. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  48. sglang/srt/managers/cache_controller.py +307 -0
  49. sglang/srt/managers/data_parallel_controller.py +2 -0
  50. sglang/srt/managers/io_struct.py +1 -2
  51. sglang/srt/managers/schedule_batch.py +33 -3
  52. sglang/srt/managers/schedule_policy.py +159 -90
  53. sglang/srt/managers/scheduler.py +68 -28
  54. sglang/srt/managers/session_controller.py +1 -1
  55. sglang/srt/managers/tokenizer_manager.py +27 -21
  56. sglang/srt/managers/tp_worker.py +16 -4
  57. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  58. sglang/srt/mem_cache/memory_pool.py +206 -1
  59. sglang/srt/metrics/collector.py +22 -30
  60. sglang/srt/model_executor/cuda_graph_runner.py +129 -77
  61. sglang/srt/model_executor/forward_batch_info.py +51 -21
  62. sglang/srt/model_executor/model_runner.py +72 -64
  63. sglang/srt/models/chatglm.py +1 -1
  64. sglang/srt/models/dbrx.py +1 -1
  65. sglang/srt/models/deepseek_v2.py +34 -7
  66. sglang/srt/models/grok.py +109 -29
  67. sglang/srt/models/llama.py +9 -2
  68. sglang/srt/openai_api/adapter.py +0 -17
  69. sglang/srt/openai_api/protocol.py +3 -3
  70. sglang/srt/sampling/sampling_batch_info.py +22 -0
  71. sglang/srt/sampling/sampling_params.py +9 -1
  72. sglang/srt/server.py +20 -13
  73. sglang/srt/server_args.py +120 -58
  74. sglang/srt/speculative/build_eagle_tree.py +347 -0
  75. sglang/srt/speculative/eagle_utils.py +626 -0
  76. sglang/srt/speculative/eagle_worker.py +184 -0
  77. sglang/srt/speculative/spec_info.py +5 -0
  78. sglang/srt/utils.py +47 -7
  79. sglang/test/test_programs.py +23 -1
  80. sglang/test/test_utils.py +36 -7
  81. sglang/version.py +1 -1
  82. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
  83. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
  84. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  85. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  86. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING
3
+ from typing import TYPE_CHECKING, Optional
4
4
 
5
5
  import torch
6
6
 
7
7
  from sglang.srt.layers.attention import AttentionBackend
8
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
8
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  from sglang.srt.layers.radix_attention import RadixAttention
12
12
  from sglang.srt.model_executor.model_runner import ModelRunner
13
+ from sglang.srt.speculative.spec_info import SpecInfo
13
14
 
14
15
 
15
16
  class TritonAttnBackend(AttentionBackend):
@@ -80,11 +81,17 @@ class TritonAttnBackend(AttentionBackend):
80
81
  def init_forward_metadata_capture_cuda_graph(
81
82
  self,
82
83
  bs: int,
84
+ num_tokens: int,
83
85
  req_pool_indices: torch.Tensor,
84
86
  seq_lens: torch.Tensor,
85
- encoder_lens=None,
87
+ encoder_lens: Optional[torch.Tensor],
88
+ forward_mode: ForwardMode,
89
+ spec_info: Optional[SpecInfo],
86
90
  ):
87
- # NOTE: encoder_lens expected to be zeros or None
91
+ assert encoder_lens is None, "Not supported"
92
+ assert forward_mode.is_decode(), "Not supported"
93
+ assert spec_info is None, "Not supported"
94
+
88
95
  self.forward_metadata = (
89
96
  self.cuda_graph_attn_logits,
90
97
  None,
@@ -96,7 +103,9 @@ class TritonAttnBackend(AttentionBackend):
96
103
  req_pool_indices: torch.Tensor,
97
104
  seq_lens: torch.Tensor,
98
105
  seq_lens_sum: int,
99
- encoder_lens=None,
106
+ encoder_lens: Optional[torch.Tensor],
107
+ forward_mode: ForwardMode,
108
+ spec_info: Optional[SpecInfo],
100
109
  ):
101
110
  # NOTE: encoder_lens expected to be zeros or None
102
111
  self.cuda_graph_start_loc.zero_()
@@ -107,9 +116,9 @@ class TritonAttnBackend(AttentionBackend):
107
116
 
108
117
  def forward_extend(
109
118
  self,
110
- q,
111
- k,
112
- v,
119
+ q: torch.Tensor,
120
+ k: torch.Tensor,
121
+ v: torch.Tensor,
113
122
  layer: RadixAttention,
114
123
  forward_batch: ForwardBatch,
115
124
  save_kv_cache=True,
@@ -146,9 +155,9 @@ class TritonAttnBackend(AttentionBackend):
146
155
 
147
156
  def forward_decode(
148
157
  self,
149
- q,
150
- k,
151
- v,
158
+ q: torch.Tensor,
159
+ k: torch.Tensor,
160
+ v: torch.Tensor,
152
161
  layer: RadixAttention,
153
162
  forward_batch: ForwardBatch,
154
163
  save_kv_cache=True,
@@ -406,6 +406,10 @@ def _decode_grouped_att_m_fwd(
406
406
  Lk = k_buffer.shape[-1]
407
407
  Lv = v_buffer.shape[-1]
408
408
 
409
+ # [TODO] work around shmem limit on MI3xx
410
+ if is_hip_ and Lk >= 576:
411
+ BLOCK = 16
412
+
409
413
  if Lk == 576:
410
414
  BLOCK_DMODEL = 512
411
415
  BLOCK_DPE = 64
@@ -18,14 +18,15 @@ from vllm.distributed import (
18
18
 
19
19
  # workaround
20
20
  from vllm.model_executor.layers.linear import LinearBase
21
- from vllm.model_executor.parameter import (
21
+
22
+ from sglang.srt.layers.parameter import (
22
23
  BasevLLMParameter,
23
24
  PackedColumnParameter,
24
25
  PackedvLLMParameter,
25
26
  PerTensorScaleParameter,
26
27
  RowvLLMParameter,
28
+ _ColumnvLLMParameter,
27
29
  )
28
-
29
30
  from sglang.srt.layers.quantization.base_config import (
30
31
  QuantizationConfig,
31
32
  QuantizeMethodBase,
@@ -44,6 +45,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
44
45
  "MarlinLinearMethod",
45
46
  "GPTQLinearMethod",
46
47
  "QQQLinearMethod",
48
+ "ModelOptFp8LinearMethod",
47
49
  ]
48
50
 
49
51
 
@@ -93,6 +95,62 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
93
95
  return param[shard_id], loaded_weight
94
96
 
95
97
 
98
+ def load_column_qkv_weight(
99
+ self, loaded_weight, num_heads, shard_id, shard_offset, shard_size, tp_rank
100
+ ):
101
+ if (
102
+ isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
103
+ and self.output_dim == self.packed_dim
104
+ ):
105
+ shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
106
+ shard_offset=shard_offset, shard_size=shard_size
107
+ )
108
+
109
+ param_data = self.data
110
+ shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
111
+ param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
112
+ loaded_weight = loaded_weight.narrow(
113
+ self.output_dim, shard_id * shard_size, shard_size
114
+ )
115
+
116
+ assert param_data.shape == loaded_weight.shape
117
+ param_data.copy_(loaded_weight)
118
+
119
+
120
+ def load_column_parallel_weight(
121
+ self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
122
+ ):
123
+ if isinstance(self, _ColumnvLLMParameter):
124
+ if not use_presharded_weights:
125
+ shard_size = self.data.shape[self.output_dim]
126
+ loaded_weight = loaded_weight.narrow(
127
+ self.output_dim, tp_rank * shard_size, shard_size
128
+ )
129
+ assert self.data.shape == loaded_weight.shape
130
+ self.data.copy_(loaded_weight)
131
+ else:
132
+ self.data.copy_(loaded_weight)
133
+
134
+
135
+ def load_row_parallel_weight(
136
+ self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
137
+ ):
138
+ if isinstance(self, RowvLLMParameter):
139
+ if not use_presharded_weights:
140
+ shard_size = self.data.shape[self.input_dim]
141
+ loaded_weight = loaded_weight.narrow(
142
+ self.input_dim, tp_rank * shard_size, shard_size
143
+ )
144
+
145
+ if len(loaded_weight.shape) == 0:
146
+ loaded_weight = loaded_weight.reshape(1)
147
+
148
+ assert self.data.shape == loaded_weight.shape
149
+ self.data.copy_(loaded_weight)
150
+ else:
151
+ self.data.copy_(loaded_weight)
152
+
153
+
96
154
  class LinearMethodBase(QuantizeMethodBase):
97
155
  """Base class for different (maybe quantized) linear methods."""
98
156
 
@@ -286,6 +344,8 @@ class ColumnParallelLinear(LinearBase):
286
344
  quant_config: Optional[QuantizationConfig] = None,
287
345
  output_sizes: Optional[List[int]] = None,
288
346
  prefix: str = "",
347
+ tp_rank: Optional[int] = None,
348
+ tp_size: Optional[int] = None,
289
349
  ):
290
350
  super().__init__(
291
351
  input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
@@ -294,7 +354,11 @@ class ColumnParallelLinear(LinearBase):
294
354
  self.gather_output = gather_output
295
355
 
296
356
  # Divide the weight matrix along the last dimension.
297
- tp_size = get_tensor_model_parallel_world_size()
357
+ if tp_rank is None:
358
+ tp_rank = get_tensor_model_parallel_rank()
359
+ if tp_size is None:
360
+ tp_size = get_tensor_model_parallel_world_size()
361
+ self.tp_rank, self.tp_size = tp_rank, tp_size
298
362
  assert self.quant_method is not None
299
363
  self.output_size_per_partition = divide(self.output_size, tp_size)
300
364
  self.output_partition_sizes = [self.output_size_per_partition]
@@ -335,7 +399,6 @@ class ColumnParallelLinear(LinearBase):
335
399
  self.register_parameter("bias", None)
336
400
 
337
401
  def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
338
- tp_rank = get_tensor_model_parallel_rank()
339
402
  output_dim = getattr(param, "output_dim", None)
340
403
 
341
404
  # Special case for GGUF
@@ -355,7 +418,7 @@ class ColumnParallelLinear(LinearBase):
355
418
  # no need to narrow here
356
419
  if output_dim is not None and not use_bitsandbytes_4bit:
357
420
  shard_size = param_data.shape[output_dim]
358
- start_idx = tp_rank * shard_size
421
+ start_idx = self.tp_rank * shard_size
359
422
  loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
360
423
 
361
424
  # Special case for loading scales off disk, which often do not
@@ -363,7 +426,9 @@ class ColumnParallelLinear(LinearBase):
363
426
  if len(loaded_weight.shape) == 0:
364
427
  loaded_weight = loaded_weight.reshape(1)
365
428
 
366
- assert param_data.shape == loaded_weight.shape
429
+ assert (
430
+ param_data.shape == loaded_weight.shape
431
+ ), f"{param_data.shape=}, {loaded_weight.shape=}"
367
432
  param_data.copy_(loaded_weight)
368
433
 
369
434
  def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
@@ -392,7 +457,7 @@ class ColumnParallelLinear(LinearBase):
392
457
  s = f"in_features={self.input_size}"
393
458
  s += f", output_features={self.output_size_per_partition}"
394
459
  s += f", bias={self.bias is not None}"
395
- s += f", tp_size={get_tensor_model_parallel_world_size()}"
460
+ s += f", tp_size={self.tp_size}"
396
461
  s += f", gather_output={self.gather_output}"
397
462
  return s
398
463
 
@@ -430,10 +495,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
430
495
  params_dtype: Optional[torch.dtype] = None,
431
496
  quant_config: Optional[QuantizationConfig] = None,
432
497
  prefix: str = "",
498
+ tp_rank: Optional[int] = None,
499
+ tp_size: Optional[int] = None,
500
+ use_presharded_weights: bool = False,
433
501
  ):
434
502
  self.output_sizes = output_sizes
435
- tp_size = get_tensor_model_parallel_world_size()
503
+ if tp_rank is None:
504
+ tp_rank = get_tensor_model_parallel_rank()
505
+ if tp_size is None:
506
+ tp_size = get_tensor_model_parallel_world_size()
507
+ self.tp_rank, self.tp_size = tp_rank, tp_size
436
508
  assert all(output_size % tp_size == 0 for output_size in output_sizes)
509
+ self.use_presharded_weights = use_presharded_weights
437
510
  super().__init__(
438
511
  input_size=input_size,
439
512
  output_size=sum(output_sizes),
@@ -443,6 +516,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
443
516
  params_dtype=params_dtype,
444
517
  quant_config=quant_config,
445
518
  prefix=prefix,
519
+ tp_rank=tp_rank,
520
+ tp_size=tp_size,
446
521
  )
447
522
 
448
523
  def weight_loader(
@@ -462,12 +537,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
462
537
  return
463
538
 
464
539
  if is_gguf_weight:
465
- tp_size = get_tensor_model_parallel_world_size()
466
- tp_rank = get_tensor_model_parallel_rank()
467
-
468
540
  output_dim = getattr(param, "output_dim", None)
469
- shard_size = loaded_weight.size(output_dim) // tp_size
470
- start_idx = tp_rank * shard_size
541
+ shard_size = loaded_weight.size(output_dim) // self.tp_size
542
+ start_idx = self.tp_rank * shard_size
471
543
 
472
544
  loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
473
545
 
@@ -493,7 +565,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
493
565
  param_data, loaded_weight, 0
494
566
  )
495
567
 
496
- assert param_data.shape == loaded_weight.shape
568
+ assert (
569
+ param_data.shape == loaded_weight.shape
570
+ ), f"{param_data.shape=}, {loaded_weight.shape=}"
497
571
  param_data.copy_(loaded_weight)
498
572
  return
499
573
  current_shard_offset = 0
@@ -521,11 +595,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
521
595
  return
522
596
 
523
597
  assert loaded_shard_id < len(self.output_sizes)
524
- tp_rank = get_tensor_model_parallel_rank()
525
- tp_size = get_tensor_model_parallel_world_size()
526
598
  if output_dim is not None:
527
- shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
528
- shard_size = self.output_sizes[loaded_shard_id] // tp_size
599
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
600
+ shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
529
601
  # Special case for quantization.
530
602
  # If quantized, we need to adjust the offset and size to account
531
603
  # for the packing.
@@ -544,10 +616,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
544
616
  shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
545
617
 
546
618
  param_data = param_data.narrow(output_dim, shard_offset, shard_size)
547
- start_idx = tp_rank * shard_size
619
+ start_idx = self.tp_rank * shard_size
548
620
  # bitsandbytes loads the weights of the specific portion
549
621
  # no need to narrow here
550
- if not use_bitsandbytes_4bit:
622
+ if not use_bitsandbytes_4bit and not self.use_presharded_weights:
551
623
  loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
552
624
  # Special case for AQLM codebooks.
553
625
  elif is_metadata:
@@ -571,7 +643,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
571
643
  "the same for all partitions."
572
644
  )
573
645
 
574
- assert param_data.shape == loaded_weight.shape
646
+ assert (
647
+ param_data.shape == loaded_weight.shape
648
+ ), f"{param_data.shape=}, {loaded_weight.shape=}"
575
649
  param_data.copy_(loaded_weight)
576
650
 
577
651
  def _load_fused_module_from_checkpoint(
@@ -628,26 +702,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
628
702
 
629
703
  assert loaded_shard_id < len(self.output_sizes)
630
704
 
631
- tp_size = get_tensor_model_parallel_world_size()
632
-
633
705
  if isinstance(param, BlockQuantScaleParameter):
634
706
  weight_block_size = self.quant_method.quant_config.weight_block_size
635
707
  block_n, _ = weight_block_size[0], weight_block_size[1]
636
708
  shard_offset = (
637
709
  (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
638
- ) // tp_size
710
+ ) // self.tp_size
639
711
  shard_size = (
640
- (self.output_sizes[loaded_shard_id] + block_n - 1) // block_n // tp_size
712
+ (self.output_sizes[loaded_shard_id] + block_n - 1)
713
+ // block_n
714
+ // self.tp_size
641
715
  )
642
716
  else:
643
- shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
644
- shard_size = self.output_sizes[loaded_shard_id] // tp_size
717
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
718
+ shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
645
719
 
646
720
  param.load_merged_column_weight(
647
721
  loaded_weight=loaded_weight,
648
722
  shard_id=loaded_shard_id,
649
723
  shard_offset=shard_offset,
650
724
  shard_size=shard_size,
725
+ use_presharded_weights=self.use_presharded_weights,
651
726
  )
652
727
 
653
728
 
@@ -688,6 +763,8 @@ class QKVParallelLinear(ColumnParallelLinear):
688
763
  params_dtype: Optional[torch.dtype] = None,
689
764
  quant_config: Optional[QuantizationConfig] = None,
690
765
  prefix: str = "",
766
+ tp_rank: Optional[int] = None,
767
+ tp_size: Optional[int] = None,
691
768
  ):
692
769
  self.hidden_size = hidden_size
693
770
  self.head_size = head_size
@@ -696,7 +773,11 @@ class QKVParallelLinear(ColumnParallelLinear):
696
773
  total_num_kv_heads = total_num_heads
697
774
  self.total_num_kv_heads = total_num_kv_heads
698
775
  # Divide the weight matrix along the last dimension.
699
- tp_size = get_tensor_model_parallel_world_size()
776
+ if tp_rank is None:
777
+ tp_rank = get_tensor_model_parallel_rank()
778
+ if tp_size is None:
779
+ tp_size = get_tensor_model_parallel_world_size()
780
+ self.tp_rank, self.tp_size = tp_rank, tp_size
700
781
  self.num_heads = divide(self.total_num_heads, tp_size)
701
782
  if tp_size >= self.total_num_kv_heads:
702
783
  self.num_kv_heads = 1
@@ -723,6 +804,8 @@ class QKVParallelLinear(ColumnParallelLinear):
723
804
  params_dtype=params_dtype,
724
805
  quant_config=quant_config,
725
806
  prefix=prefix,
807
+ tp_rank=tp_rank,
808
+ tp_size=tp_size,
726
809
  )
727
810
 
728
811
  def _get_shard_offset_mapping(self, loaded_shard_id: str):
@@ -813,13 +896,24 @@ class QKVParallelLinear(ColumnParallelLinear):
813
896
  shard_offset = (shard_offset + block_n - 1) // block_n
814
897
  shard_size = (shard_size + block_n - 1) // block_n
815
898
 
816
- param.load_qkv_weight(
817
- loaded_weight=loaded_weight,
818
- num_heads=self.num_kv_head_replicas,
819
- shard_id=loaded_shard_id,
820
- shard_offset=shard_offset,
821
- shard_size=shard_size,
822
- )
899
+ if isinstance(param, _ColumnvLLMParameter):
900
+ load_column_qkv_weight(
901
+ param,
902
+ loaded_weight,
903
+ num_heads=self.num_kv_head_replicas,
904
+ shard_id=loaded_shard_id,
905
+ shard_offset=shard_offset,
906
+ shard_size=shard_size,
907
+ tp_rank=self.tp_rank,
908
+ )
909
+ else:
910
+ param.load_qkv_weight(
911
+ loaded_weight=loaded_weight,
912
+ num_heads=self.num_kv_head_replicas,
913
+ shard_id=loaded_shard_id,
914
+ shard_offset=shard_offset,
915
+ shard_size=shard_size,
916
+ )
823
917
 
824
918
  def weight_loader(
825
919
  self,
@@ -839,12 +933,9 @@ class QKVParallelLinear(ColumnParallelLinear):
839
933
  return
840
934
 
841
935
  if is_gguf_weight:
842
- tp_size = get_tensor_model_parallel_world_size()
843
- tp_rank = get_tensor_model_parallel_rank()
844
-
845
936
  output_dim = getattr(param, "output_dim", None)
846
- shard_size = loaded_weight.size(output_dim) // tp_size
847
- start_idx = tp_rank * shard_size
937
+ shard_size = loaded_weight.size(output_dim) // self.tp_size
938
+ start_idx = self.tp_rank * shard_size
848
939
 
849
940
  loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
850
941
 
@@ -871,7 +962,9 @@ class QKVParallelLinear(ColumnParallelLinear):
871
962
  param_data, loaded_weight, 0
872
963
  )
873
964
 
874
- assert param_data.shape == loaded_weight.shape
965
+ assert (
966
+ param_data.shape == loaded_weight.shape
967
+ ), f"{param_data.shape=}, {loaded_weight.shape=}"
875
968
  param_data.copy_(loaded_weight)
876
969
  return
877
970
  shard_offsets = [
@@ -933,7 +1026,6 @@ class QKVParallelLinear(ColumnParallelLinear):
933
1026
  self.weight_loader(param, loaded_weight_shard, shard_id)
934
1027
  return
935
1028
 
936
- tp_rank = get_tensor_model_parallel_rank()
937
1029
  assert loaded_shard_id in ["q", "k", "v"]
938
1030
 
939
1031
  # If output dim is defined, use the default loading process.
@@ -983,9 +1075,9 @@ class QKVParallelLinear(ColumnParallelLinear):
983
1075
 
984
1076
  param_data = param_data.narrow(output_dim, shard_offset, shard_size)
985
1077
  if loaded_shard_id == "q":
986
- shard_id = tp_rank
1078
+ shard_id = self.tp_rank
987
1079
  else:
988
- shard_id = tp_rank // self.num_kv_head_replicas
1080
+ shard_id = self.tp_rank // self.num_kv_head_replicas
989
1081
  start_idx = shard_id * shard_size
990
1082
 
991
1083
  # bitsandbytes loads the weights of the specific portion
@@ -1013,7 +1105,9 @@ class QKVParallelLinear(ColumnParallelLinear):
1013
1105
  "for all partitions."
1014
1106
  )
1015
1107
 
1016
- assert param_data.shape == loaded_weight.shape
1108
+ assert (
1109
+ param_data.shape == loaded_weight.shape
1110
+ ), f"{param_data.shape=}, {loaded_weight.shape=}"
1017
1111
  param_data.copy_(loaded_weight)
1018
1112
 
1019
1113
 
@@ -1054,6 +1148,9 @@ class RowParallelLinear(LinearBase):
1054
1148
  reduce_results: bool = True,
1055
1149
  quant_config: Optional[QuantizationConfig] = None,
1056
1150
  prefix: str = "",
1151
+ tp_rank: Optional[int] = None,
1152
+ tp_size: Optional[int] = None,
1153
+ use_presharded_weights: bool = False,
1057
1154
  ):
1058
1155
  super().__init__(
1059
1156
  input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
@@ -1063,10 +1160,14 @@ class RowParallelLinear(LinearBase):
1063
1160
  self.reduce_results = reduce_results
1064
1161
 
1065
1162
  # Divide the weight matrix along the last dimension.
1066
- self.tp_rank = get_tensor_model_parallel_rank()
1067
- self.tp_size = get_tensor_model_parallel_world_size()
1163
+ if tp_rank is None:
1164
+ tp_rank = get_tensor_model_parallel_rank()
1165
+ if tp_size is None:
1166
+ tp_size = get_tensor_model_parallel_world_size()
1167
+ self.tp_rank, self.tp_size = tp_rank, tp_size
1068
1168
  self.input_size_per_partition = divide(input_size, self.tp_size)
1069
1169
  assert self.quant_method is not None
1170
+ self.use_presharded_weights = use_presharded_weights
1070
1171
 
1071
1172
  self.quant_method.create_weights(
1072
1173
  layer=self,
@@ -1100,8 +1201,6 @@ class RowParallelLinear(LinearBase):
1100
1201
  self.register_parameter("bias", None)
1101
1202
 
1102
1203
  def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
1103
- tp_rank = get_tensor_model_parallel_rank()
1104
- tp_size = get_tensor_model_parallel_world_size()
1105
1204
  input_dim = getattr(param, "input_dim", None)
1106
1205
  use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1107
1206
 
@@ -1115,15 +1214,19 @@ class RowParallelLinear(LinearBase):
1115
1214
  if is_gguf_weight and isinstance(param, UninitializedParameter):
1116
1215
  weight_shape = list(loaded_weight.shape)
1117
1216
  if input_dim:
1118
- weight_shape[input_dim] = weight_shape[input_dim] // tp_size
1217
+ weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1119
1218
  param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1120
1219
 
1121
1220
  param_data = param.data
1122
1221
  # bitsandbytes loads the weights of the specific portion
1123
1222
  # no need to narrow here
1124
- if input_dim is not None and not use_bitsandbytes_4bit:
1223
+ if (
1224
+ input_dim is not None
1225
+ and not use_bitsandbytes_4bit
1226
+ and not self.use_presharded_weights
1227
+ ):
1125
1228
  shard_size = param_data.shape[input_dim]
1126
- start_idx = tp_rank * shard_size
1229
+ start_idx = self.tp_rank * shard_size
1127
1230
  loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1128
1231
 
1129
1232
  # Special case for loading scales off disk, which often do not
@@ -1131,7 +1234,9 @@ class RowParallelLinear(LinearBase):
1131
1234
  if len(loaded_weight.shape) == 0:
1132
1235
  loaded_weight = loaded_weight.reshape(1)
1133
1236
 
1134
- assert param_data.shape == loaded_weight.shape
1237
+ assert (
1238
+ param_data.shape == loaded_weight.shape
1239
+ ), f"{param_data.shape=}, {loaded_weight.shape=}"
1135
1240
  param_data.copy_(loaded_weight)
1136
1241
 
1137
1242
  def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
@@ -1148,11 +1253,10 @@ class RowParallelLinear(LinearBase):
1148
1253
  if self.input_is_parallel:
1149
1254
  input_parallel = input_
1150
1255
  else:
1151
- tp_rank = get_tensor_model_parallel_rank()
1152
1256
  splitted_input = split_tensor_along_last_dim(
1153
1257
  input_, num_partitions=self.tp_size
1154
1258
  )
1155
- input_parallel = splitted_input[tp_rank].contiguous()
1259
+ input_parallel = splitted_input[self.tp_rank].contiguous()
1156
1260
 
1157
1261
  # Matrix multiply.
1158
1262
  assert self.quant_method is not None