sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +18 -1
 - sglang/lang/interpreter.py +71 -1
 - sglang/lang/ir.py +2 -0
 - sglang/srt/configs/__init__.py +4 -0
 - sglang/srt/configs/chatglm.py +78 -0
 - sglang/srt/configs/dbrx.py +279 -0
 - sglang/srt/configs/model_config.py +1 -1
 - sglang/srt/hf_transformers_utils.py +9 -14
 - sglang/srt/layers/attention/__init__.py +8 -1
 - sglang/srt/layers/attention/flashinfer_backend.py +4 -2
 - sglang/srt/layers/linear.py +159 -55
 - sglang/srt/layers/logits_processor.py +6 -6
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -3
 - sglang/srt/layers/parameter.py +431 -0
 - sglang/srt/layers/quantization/__init__.py +3 -2
 - sglang/srt/layers/quantization/fp8.py +1 -1
 - sglang/srt/layers/quantization/modelopt_quant.py +174 -0
 - sglang/srt/layers/vocab_parallel_embedding.py +1 -1
 - sglang/srt/managers/cache_controller.py +307 -0
 - sglang/srt/managers/data_parallel_controller.py +2 -0
 - sglang/srt/managers/schedule_batch.py +7 -1
 - sglang/srt/managers/scheduler.py +10 -6
 - sglang/srt/managers/session_controller.py +1 -1
 - sglang/srt/managers/tokenizer_manager.py +6 -2
 - sglang/srt/mem_cache/memory_pool.py +206 -1
 - sglang/srt/metrics/collector.py +22 -30
 - sglang/srt/model_executor/cuda_graph_runner.py +14 -7
 - sglang/srt/model_executor/forward_batch_info.py +20 -15
 - sglang/srt/model_executor/model_runner.py +10 -4
 - sglang/srt/models/chatglm.py +1 -1
 - sglang/srt/models/dbrx.py +1 -1
 - sglang/srt/models/grok.py +25 -16
 - sglang/srt/models/llama.py +9 -2
 - sglang/srt/sampling/sampling_batch_info.py +1 -0
 - sglang/srt/server.py +11 -8
 - sglang/srt/server_args.py +12 -1
 - sglang/srt/speculative/eagle_utils.py +93 -85
 - sglang/srt/speculative/eagle_worker.py +47 -33
 - sglang/srt/utils.py +32 -5
 - sglang/test/test_programs.py +23 -1
 - sglang/test/test_utils.py +36 -7
 - sglang/version.py +1 -1
 - {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +6 -7
 - {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +48 -43
 - {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
 - {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
 - {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
 
    
        sglang/srt/layers/linear.py
    CHANGED
    
    | 
         @@ -18,14 +18,15 @@ from vllm.distributed import ( 
     | 
|
| 
       18 
18 
     | 
    
         | 
| 
       19 
19 
     | 
    
         
             
            # workaround
         
     | 
| 
       20 
20 
     | 
    
         
             
            from vllm.model_executor.layers.linear import LinearBase
         
     | 
| 
       21 
     | 
    
         
            -
             
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
            from sglang.srt.layers.parameter import (
         
     | 
| 
       22 
23 
     | 
    
         
             
                BasevLLMParameter,
         
     | 
| 
       23 
24 
     | 
    
         
             
                PackedColumnParameter,
         
     | 
| 
       24 
25 
     | 
    
         
             
                PackedvLLMParameter,
         
     | 
| 
       25 
26 
     | 
    
         
             
                PerTensorScaleParameter,
         
     | 
| 
       26 
27 
     | 
    
         
             
                RowvLLMParameter,
         
     | 
| 
      
 28 
     | 
    
         
            +
                _ColumnvLLMParameter,
         
     | 
| 
       27 
29 
     | 
    
         
             
            )
         
     | 
| 
       28 
     | 
    
         
            -
             
     | 
| 
       29 
30 
     | 
    
         
             
            from sglang.srt.layers.quantization.base_config import (
         
     | 
| 
       30 
31 
     | 
    
         
             
                QuantizationConfig,
         
     | 
| 
       31 
32 
     | 
    
         
             
                QuantizeMethodBase,
         
     | 
| 
         @@ -44,6 +45,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ 
     | 
|
| 
       44 
45 
     | 
    
         
             
                "MarlinLinearMethod",
         
     | 
| 
       45 
46 
     | 
    
         
             
                "GPTQLinearMethod",
         
     | 
| 
       46 
47 
     | 
    
         
             
                "QQQLinearMethod",
         
     | 
| 
      
 48 
     | 
    
         
            +
                "ModelOptFp8LinearMethod",
         
     | 
| 
       47 
49 
     | 
    
         
             
            ]
         
     | 
| 
       48 
50 
     | 
    
         | 
| 
       49 
51 
     | 
    
         | 
| 
         @@ -93,6 +95,62 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): 
     | 
|
| 
       93 
95 
     | 
    
         
             
                return param[shard_id], loaded_weight
         
     | 
| 
       94 
96 
     | 
    
         | 
| 
       95 
97 
     | 
    
         | 
| 
      
 98 
     | 
    
         
            +
            def load_column_qkv_weight(
         
     | 
| 
      
 99 
     | 
    
         
            +
                self, loaded_weight, num_heads, shard_id, shard_offset, shard_size, tp_rank
         
     | 
| 
      
 100 
     | 
    
         
            +
            ):
         
     | 
| 
      
 101 
     | 
    
         
            +
                if (
         
     | 
| 
      
 102 
     | 
    
         
            +
                    isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
         
     | 
| 
      
 103 
     | 
    
         
            +
                    and self.output_dim == self.packed_dim
         
     | 
| 
      
 104 
     | 
    
         
            +
                ):
         
     | 
| 
      
 105 
     | 
    
         
            +
                    shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
         
     | 
| 
      
 106 
     | 
    
         
            +
                        shard_offset=shard_offset, shard_size=shard_size
         
     | 
| 
      
 107 
     | 
    
         
            +
                    )
         
     | 
| 
      
 108 
     | 
    
         
            +
             
     | 
| 
      
 109 
     | 
    
         
            +
                param_data = self.data
         
     | 
| 
      
 110 
     | 
    
         
            +
                shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
         
     | 
| 
      
 111 
     | 
    
         
            +
                param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
         
     | 
| 
      
 112 
     | 
    
         
            +
                loaded_weight = loaded_weight.narrow(
         
     | 
| 
      
 113 
     | 
    
         
            +
                    self.output_dim, shard_id * shard_size, shard_size
         
     | 
| 
      
 114 
     | 
    
         
            +
                )
         
     | 
| 
      
 115 
     | 
    
         
            +
             
     | 
| 
      
 116 
     | 
    
         
            +
                assert param_data.shape == loaded_weight.shape
         
     | 
| 
      
 117 
     | 
    
         
            +
                param_data.copy_(loaded_weight)
         
     | 
| 
      
 118 
     | 
    
         
            +
             
     | 
| 
      
 119 
     | 
    
         
            +
             
     | 
| 
      
 120 
     | 
    
         
            +
            def load_column_parallel_weight(
         
     | 
| 
      
 121 
     | 
    
         
            +
                self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
         
     | 
| 
      
 122 
     | 
    
         
            +
            ):
         
     | 
| 
      
 123 
     | 
    
         
            +
                if isinstance(self, _ColumnvLLMParameter):
         
     | 
| 
      
 124 
     | 
    
         
            +
                    if not use_presharded_weights:
         
     | 
| 
      
 125 
     | 
    
         
            +
                        shard_size = self.data.shape[self.output_dim]
         
     | 
| 
      
 126 
     | 
    
         
            +
                        loaded_weight = loaded_weight.narrow(
         
     | 
| 
      
 127 
     | 
    
         
            +
                            self.output_dim, tp_rank * shard_size, shard_size
         
     | 
| 
      
 128 
     | 
    
         
            +
                        )
         
     | 
| 
      
 129 
     | 
    
         
            +
                    assert self.data.shape == loaded_weight.shape
         
     | 
| 
      
 130 
     | 
    
         
            +
                    self.data.copy_(loaded_weight)
         
     | 
| 
      
 131 
     | 
    
         
            +
                else:
         
     | 
| 
      
 132 
     | 
    
         
            +
                    self.data.copy_(loaded_weight)
         
     | 
| 
      
 133 
     | 
    
         
            +
             
     | 
| 
      
 134 
     | 
    
         
            +
             
     | 
| 
      
 135 
     | 
    
         
            +
            def load_row_parallel_weight(
         
     | 
| 
      
 136 
     | 
    
         
            +
                self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
         
     | 
| 
      
 137 
     | 
    
         
            +
            ):
         
     | 
| 
      
 138 
     | 
    
         
            +
                if isinstance(self, RowvLLMParameter):
         
     | 
| 
      
 139 
     | 
    
         
            +
                    if not use_presharded_weights:
         
     | 
| 
      
 140 
     | 
    
         
            +
                        shard_size = self.data.shape[self.input_dim]
         
     | 
| 
      
 141 
     | 
    
         
            +
                        loaded_weight = loaded_weight.narrow(
         
     | 
| 
      
 142 
     | 
    
         
            +
                            self.input_dim, tp_rank * shard_size, shard_size
         
     | 
| 
      
 143 
     | 
    
         
            +
                        )
         
     | 
| 
      
 144 
     | 
    
         
            +
             
     | 
| 
      
 145 
     | 
    
         
            +
                    if len(loaded_weight.shape) == 0:
         
     | 
| 
      
 146 
     | 
    
         
            +
                        loaded_weight = loaded_weight.reshape(1)
         
     | 
| 
      
 147 
     | 
    
         
            +
             
     | 
| 
      
 148 
     | 
    
         
            +
                    assert self.data.shape == loaded_weight.shape
         
     | 
| 
      
 149 
     | 
    
         
            +
                    self.data.copy_(loaded_weight)
         
     | 
| 
      
 150 
     | 
    
         
            +
                else:
         
     | 
| 
      
 151 
     | 
    
         
            +
                    self.data.copy_(loaded_weight)
         
     | 
| 
      
 152 
     | 
    
         
            +
             
     | 
| 
      
 153 
     | 
    
         
            +
             
     | 
| 
       96 
154 
     | 
    
         
             
            class LinearMethodBase(QuantizeMethodBase):
         
     | 
| 
       97 
155 
     | 
    
         
             
                """Base class for different (maybe quantized) linear methods."""
         
     | 
| 
       98 
156 
     | 
    
         | 
| 
         @@ -286,6 +344,8 @@ class ColumnParallelLinear(LinearBase): 
     | 
|
| 
       286 
344 
     | 
    
         
             
                    quant_config: Optional[QuantizationConfig] = None,
         
     | 
| 
       287 
345 
     | 
    
         
             
                    output_sizes: Optional[List[int]] = None,
         
     | 
| 
       288 
346 
     | 
    
         
             
                    prefix: str = "",
         
     | 
| 
      
 347 
     | 
    
         
            +
                    tp_rank: Optional[int] = None,
         
     | 
| 
      
 348 
     | 
    
         
            +
                    tp_size: Optional[int] = None,
         
     | 
| 
       289 
349 
     | 
    
         
             
                ):
         
     | 
| 
       290 
350 
     | 
    
         
             
                    super().__init__(
         
     | 
| 
       291 
351 
     | 
    
         
             
                        input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
         
     | 
| 
         @@ -294,7 +354,11 @@ class ColumnParallelLinear(LinearBase): 
     | 
|
| 
       294 
354 
     | 
    
         
             
                    self.gather_output = gather_output
         
     | 
| 
       295 
355 
     | 
    
         | 
| 
       296 
356 
     | 
    
         
             
                    # Divide the weight matrix along the last dimension.
         
     | 
| 
       297 
     | 
    
         
            -
                     
     | 
| 
      
 357 
     | 
    
         
            +
                    if tp_rank is None:
         
     | 
| 
      
 358 
     | 
    
         
            +
                        tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
      
 359 
     | 
    
         
            +
                    if tp_size is None:
         
     | 
| 
      
 360 
     | 
    
         
            +
                        tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
      
 361 
     | 
    
         
            +
                    self.tp_rank, self.tp_size = tp_rank, tp_size
         
     | 
| 
       298 
362 
     | 
    
         
             
                    assert self.quant_method is not None
         
     | 
| 
       299 
363 
     | 
    
         
             
                    self.output_size_per_partition = divide(self.output_size, tp_size)
         
     | 
| 
       300 
364 
     | 
    
         
             
                    self.output_partition_sizes = [self.output_size_per_partition]
         
     | 
| 
         @@ -335,7 +399,6 @@ class ColumnParallelLinear(LinearBase): 
     | 
|
| 
       335 
399 
     | 
    
         
             
                        self.register_parameter("bias", None)
         
     | 
| 
       336 
400 
     | 
    
         | 
| 
       337 
401 
     | 
    
         
             
                def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
         
     | 
| 
       338 
     | 
    
         
            -
                    tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
       339 
402 
     | 
    
         
             
                    output_dim = getattr(param, "output_dim", None)
         
     | 
| 
       340 
403 
     | 
    
         | 
| 
       341 
404 
     | 
    
         
             
                    # Special case for GGUF
         
     | 
| 
         @@ -355,7 +418,7 @@ class ColumnParallelLinear(LinearBase): 
     | 
|
| 
       355 
418 
     | 
    
         
             
                    # no need to narrow here
         
     | 
| 
       356 
419 
     | 
    
         
             
                    if output_dim is not None and not use_bitsandbytes_4bit:
         
     | 
| 
       357 
420 
     | 
    
         
             
                        shard_size = param_data.shape[output_dim]
         
     | 
| 
       358 
     | 
    
         
            -
                        start_idx = tp_rank * shard_size
         
     | 
| 
      
 421 
     | 
    
         
            +
                        start_idx = self.tp_rank * shard_size
         
     | 
| 
       359 
422 
     | 
    
         
             
                        loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
         
     | 
| 
       360 
423 
     | 
    
         | 
| 
       361 
424 
     | 
    
         
             
                    # Special case for loading scales off disk, which often do not
         
     | 
| 
         @@ -363,7 +426,9 @@ class ColumnParallelLinear(LinearBase): 
     | 
|
| 
       363 
426 
     | 
    
         
             
                    if len(loaded_weight.shape) == 0:
         
     | 
| 
       364 
427 
     | 
    
         
             
                        loaded_weight = loaded_weight.reshape(1)
         
     | 
| 
       365 
428 
     | 
    
         | 
| 
       366 
     | 
    
         
            -
                    assert  
     | 
| 
      
 429 
     | 
    
         
            +
                    assert (
         
     | 
| 
      
 430 
     | 
    
         
            +
                        param_data.shape == loaded_weight.shape
         
     | 
| 
      
 431 
     | 
    
         
            +
                    ), f"{param_data.shape=}, {loaded_weight.shape=}"
         
     | 
| 
       367 
432 
     | 
    
         
             
                    param_data.copy_(loaded_weight)
         
     | 
| 
       368 
433 
     | 
    
         | 
| 
       369 
434 
     | 
    
         
             
                def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
         
     | 
| 
         @@ -392,7 +457,7 @@ class ColumnParallelLinear(LinearBase): 
     | 
|
| 
       392 
457 
     | 
    
         
             
                    s = f"in_features={self.input_size}"
         
     | 
| 
       393 
458 
     | 
    
         
             
                    s += f", output_features={self.output_size_per_partition}"
         
     | 
| 
       394 
459 
     | 
    
         
             
                    s += f", bias={self.bias is not None}"
         
     | 
| 
       395 
     | 
    
         
            -
                    s += f", tp_size={ 
     | 
| 
      
 460 
     | 
    
         
            +
                    s += f", tp_size={self.tp_size}"
         
     | 
| 
       396 
461 
     | 
    
         
             
                    s += f", gather_output={self.gather_output}"
         
     | 
| 
       397 
462 
     | 
    
         
             
                    return s
         
     | 
| 
       398 
463 
     | 
    
         | 
| 
         @@ -430,10 +495,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       430 
495 
     | 
    
         
             
                    params_dtype: Optional[torch.dtype] = None,
         
     | 
| 
       431 
496 
     | 
    
         
             
                    quant_config: Optional[QuantizationConfig] = None,
         
     | 
| 
       432 
497 
     | 
    
         
             
                    prefix: str = "",
         
     | 
| 
      
 498 
     | 
    
         
            +
                    tp_rank: Optional[int] = None,
         
     | 
| 
      
 499 
     | 
    
         
            +
                    tp_size: Optional[int] = None,
         
     | 
| 
      
 500 
     | 
    
         
            +
                    use_presharded_weights: bool = False,
         
     | 
| 
       433 
501 
     | 
    
         
             
                ):
         
     | 
| 
       434 
502 
     | 
    
         
             
                    self.output_sizes = output_sizes
         
     | 
| 
       435 
     | 
    
         
            -
                     
     | 
| 
      
 503 
     | 
    
         
            +
                    if tp_rank is None:
         
     | 
| 
      
 504 
     | 
    
         
            +
                        tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
      
 505 
     | 
    
         
            +
                    if tp_size is None:
         
     | 
| 
      
 506 
     | 
    
         
            +
                        tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
      
 507 
     | 
    
         
            +
                    self.tp_rank, self.tp_size = tp_rank, tp_size
         
     | 
| 
       436 
508 
     | 
    
         
             
                    assert all(output_size % tp_size == 0 for output_size in output_sizes)
         
     | 
| 
      
 509 
     | 
    
         
            +
                    self.use_presharded_weights = use_presharded_weights
         
     | 
| 
       437 
510 
     | 
    
         
             
                    super().__init__(
         
     | 
| 
       438 
511 
     | 
    
         
             
                        input_size=input_size,
         
     | 
| 
       439 
512 
     | 
    
         
             
                        output_size=sum(output_sizes),
         
     | 
| 
         @@ -443,6 +516,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       443 
516 
     | 
    
         
             
                        params_dtype=params_dtype,
         
     | 
| 
       444 
517 
     | 
    
         
             
                        quant_config=quant_config,
         
     | 
| 
       445 
518 
     | 
    
         
             
                        prefix=prefix,
         
     | 
| 
      
 519 
     | 
    
         
            +
                        tp_rank=tp_rank,
         
     | 
| 
      
 520 
     | 
    
         
            +
                        tp_size=tp_size,
         
     | 
| 
       446 
521 
     | 
    
         
             
                    )
         
     | 
| 
       447 
522 
     | 
    
         | 
| 
       448 
523 
     | 
    
         
             
                def weight_loader(
         
     | 
| 
         @@ -462,12 +537,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       462 
537 
     | 
    
         
             
                        return
         
     | 
| 
       463 
538 
     | 
    
         | 
| 
       464 
539 
     | 
    
         
             
                    if is_gguf_weight:
         
     | 
| 
       465 
     | 
    
         
            -
                        tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
       466 
     | 
    
         
            -
                        tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
       467 
     | 
    
         
            -
             
     | 
| 
       468 
540 
     | 
    
         
             
                        output_dim = getattr(param, "output_dim", None)
         
     | 
| 
       469 
     | 
    
         
            -
                        shard_size = loaded_weight.size(output_dim) // tp_size
         
     | 
| 
       470 
     | 
    
         
            -
                        start_idx = tp_rank * shard_size
         
     | 
| 
      
 541 
     | 
    
         
            +
                        shard_size = loaded_weight.size(output_dim) // self.tp_size
         
     | 
| 
      
 542 
     | 
    
         
            +
                        start_idx = self.tp_rank * shard_size
         
     | 
| 
       471 
543 
     | 
    
         | 
| 
       472 
544 
     | 
    
         
             
                        loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
         
     | 
| 
       473 
545 
     | 
    
         | 
| 
         @@ -493,7 +565,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       493 
565 
     | 
    
         
             
                                    param_data, loaded_weight, 0
         
     | 
| 
       494 
566 
     | 
    
         
             
                                )
         
     | 
| 
       495 
567 
     | 
    
         | 
| 
       496 
     | 
    
         
            -
                            assert  
     | 
| 
      
 568 
     | 
    
         
            +
                            assert (
         
     | 
| 
      
 569 
     | 
    
         
            +
                                param_data.shape == loaded_weight.shape
         
     | 
| 
      
 570 
     | 
    
         
            +
                            ), f"{param_data.shape=}, {loaded_weight.shape=}"
         
     | 
| 
       497 
571 
     | 
    
         
             
                            param_data.copy_(loaded_weight)
         
     | 
| 
       498 
572 
     | 
    
         
             
                            return
         
     | 
| 
       499 
573 
     | 
    
         
             
                        current_shard_offset = 0
         
     | 
| 
         @@ -521,11 +595,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       521 
595 
     | 
    
         
             
                        return
         
     | 
| 
       522 
596 
     | 
    
         | 
| 
       523 
597 
     | 
    
         
             
                    assert loaded_shard_id < len(self.output_sizes)
         
     | 
| 
       524 
     | 
    
         
            -
                    tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
       525 
     | 
    
         
            -
                    tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
       526 
598 
     | 
    
         
             
                    if output_dim is not None:
         
     | 
| 
       527 
     | 
    
         
            -
                        shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
         
     | 
| 
       528 
     | 
    
         
            -
                        shard_size = self.output_sizes[loaded_shard_id] // tp_size
         
     | 
| 
      
 599 
     | 
    
         
            +
                        shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
         
     | 
| 
      
 600 
     | 
    
         
            +
                        shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
         
     | 
| 
       529 
601 
     | 
    
         
             
                        # Special case for quantization.
         
     | 
| 
       530 
602 
     | 
    
         
             
                        # If quantized, we need to adjust the offset and size to account
         
     | 
| 
       531 
603 
     | 
    
         
             
                        # for the packing.
         
     | 
| 
         @@ -544,10 +616,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       544 
616 
     | 
    
         
             
                            shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
         
     | 
| 
       545 
617 
     | 
    
         | 
| 
       546 
618 
     | 
    
         
             
                        param_data = param_data.narrow(output_dim, shard_offset, shard_size)
         
     | 
| 
       547 
     | 
    
         
            -
                        start_idx = tp_rank * shard_size
         
     | 
| 
      
 619 
     | 
    
         
            +
                        start_idx = self.tp_rank * shard_size
         
     | 
| 
       548 
620 
     | 
    
         
             
                        # bitsandbytes loads the weights of the specific portion
         
     | 
| 
       549 
621 
     | 
    
         
             
                        # no need to narrow here
         
     | 
| 
       550 
     | 
    
         
            -
                        if not use_bitsandbytes_4bit:
         
     | 
| 
      
 622 
     | 
    
         
            +
                        if not use_bitsandbytes_4bit and not self.use_presharded_weights:
         
     | 
| 
       551 
623 
     | 
    
         
             
                            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
         
     | 
| 
       552 
624 
     | 
    
         
             
                    # Special case for AQLM codebooks.
         
     | 
| 
       553 
625 
     | 
    
         
             
                    elif is_metadata:
         
     | 
| 
         @@ -571,7 +643,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       571 
643 
     | 
    
         
             
                                "the same for all partitions."
         
     | 
| 
       572 
644 
     | 
    
         
             
                            )
         
     | 
| 
       573 
645 
     | 
    
         | 
| 
       574 
     | 
    
         
            -
                    assert  
     | 
| 
      
 646 
     | 
    
         
            +
                    assert (
         
     | 
| 
      
 647 
     | 
    
         
            +
                        param_data.shape == loaded_weight.shape
         
     | 
| 
      
 648 
     | 
    
         
            +
                    ), f"{param_data.shape=}, {loaded_weight.shape=}"
         
     | 
| 
       575 
649 
     | 
    
         
             
                    param_data.copy_(loaded_weight)
         
     | 
| 
       576 
650 
     | 
    
         | 
| 
       577 
651 
     | 
    
         
             
                def _load_fused_module_from_checkpoint(
         
     | 
| 
         @@ -628,26 +702,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       628 
702 
     | 
    
         | 
| 
       629 
703 
     | 
    
         
             
                    assert loaded_shard_id < len(self.output_sizes)
         
     | 
| 
       630 
704 
     | 
    
         | 
| 
       631 
     | 
    
         
            -
                    tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
       632 
     | 
    
         
            -
             
     | 
| 
       633 
705 
     | 
    
         
             
                    if isinstance(param, BlockQuantScaleParameter):
         
     | 
| 
       634 
706 
     | 
    
         
             
                        weight_block_size = self.quant_method.quant_config.weight_block_size
         
     | 
| 
       635 
707 
     | 
    
         
             
                        block_n, _ = weight_block_size[0], weight_block_size[1]
         
     | 
| 
       636 
708 
     | 
    
         
             
                        shard_offset = (
         
     | 
| 
       637 
709 
     | 
    
         
             
                            (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
         
     | 
| 
       638 
     | 
    
         
            -
                        ) // tp_size
         
     | 
| 
      
 710 
     | 
    
         
            +
                        ) // self.tp_size
         
     | 
| 
       639 
711 
     | 
    
         
             
                        shard_size = (
         
     | 
| 
       640 
     | 
    
         
            -
                            (self.output_sizes[loaded_shard_id] + block_n - 1) 
     | 
| 
      
 712 
     | 
    
         
            +
                            (self.output_sizes[loaded_shard_id] + block_n - 1)
         
     | 
| 
      
 713 
     | 
    
         
            +
                            // block_n
         
     | 
| 
      
 714 
     | 
    
         
            +
                            // self.tp_size
         
     | 
| 
       641 
715 
     | 
    
         
             
                        )
         
     | 
| 
       642 
716 
     | 
    
         
             
                    else:
         
     | 
| 
       643 
     | 
    
         
            -
                        shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
         
     | 
| 
       644 
     | 
    
         
            -
                        shard_size = self.output_sizes[loaded_shard_id] // tp_size
         
     | 
| 
      
 717 
     | 
    
         
            +
                        shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
         
     | 
| 
      
 718 
     | 
    
         
            +
                        shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
         
     | 
| 
       645 
719 
     | 
    
         | 
| 
       646 
720 
     | 
    
         
             
                    param.load_merged_column_weight(
         
     | 
| 
       647 
721 
     | 
    
         
             
                        loaded_weight=loaded_weight,
         
     | 
| 
       648 
722 
     | 
    
         
             
                        shard_id=loaded_shard_id,
         
     | 
| 
       649 
723 
     | 
    
         
             
                        shard_offset=shard_offset,
         
     | 
| 
       650 
724 
     | 
    
         
             
                        shard_size=shard_size,
         
     | 
| 
      
 725 
     | 
    
         
            +
                        use_presharded_weights=self.use_presharded_weights,
         
     | 
| 
       651 
726 
     | 
    
         
             
                    )
         
     | 
| 
       652 
727 
     | 
    
         | 
| 
       653 
728 
     | 
    
         | 
| 
         @@ -688,6 +763,8 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       688 
763 
     | 
    
         
             
                    params_dtype: Optional[torch.dtype] = None,
         
     | 
| 
       689 
764 
     | 
    
         
             
                    quant_config: Optional[QuantizationConfig] = None,
         
     | 
| 
       690 
765 
     | 
    
         
             
                    prefix: str = "",
         
     | 
| 
      
 766 
     | 
    
         
            +
                    tp_rank: Optional[int] = None,
         
     | 
| 
      
 767 
     | 
    
         
            +
                    tp_size: Optional[int] = None,
         
     | 
| 
       691 
768 
     | 
    
         
             
                ):
         
     | 
| 
       692 
769 
     | 
    
         
             
                    self.hidden_size = hidden_size
         
     | 
| 
       693 
770 
     | 
    
         
             
                    self.head_size = head_size
         
     | 
| 
         @@ -696,7 +773,11 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       696 
773 
     | 
    
         
             
                        total_num_kv_heads = total_num_heads
         
     | 
| 
       697 
774 
     | 
    
         
             
                    self.total_num_kv_heads = total_num_kv_heads
         
     | 
| 
       698 
775 
     | 
    
         
             
                    # Divide the weight matrix along the last dimension.
         
     | 
| 
       699 
     | 
    
         
            -
                     
     | 
| 
      
 776 
     | 
    
         
            +
                    if tp_rank is None:
         
     | 
| 
      
 777 
     | 
    
         
            +
                        tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
      
 778 
     | 
    
         
            +
                    if tp_size is None:
         
     | 
| 
      
 779 
     | 
    
         
            +
                        tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
      
 780 
     | 
    
         
            +
                    self.tp_rank, self.tp_size = tp_rank, tp_size
         
     | 
| 
       700 
781 
     | 
    
         
             
                    self.num_heads = divide(self.total_num_heads, tp_size)
         
     | 
| 
       701 
782 
     | 
    
         
             
                    if tp_size >= self.total_num_kv_heads:
         
     | 
| 
       702 
783 
     | 
    
         
             
                        self.num_kv_heads = 1
         
     | 
| 
         @@ -723,6 +804,8 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       723 
804 
     | 
    
         
             
                        params_dtype=params_dtype,
         
     | 
| 
       724 
805 
     | 
    
         
             
                        quant_config=quant_config,
         
     | 
| 
       725 
806 
     | 
    
         
             
                        prefix=prefix,
         
     | 
| 
      
 807 
     | 
    
         
            +
                        tp_rank=tp_rank,
         
     | 
| 
      
 808 
     | 
    
         
            +
                        tp_size=tp_size,
         
     | 
| 
       726 
809 
     | 
    
         
             
                    )
         
     | 
| 
       727 
810 
     | 
    
         | 
| 
       728 
811 
     | 
    
         
             
                def _get_shard_offset_mapping(self, loaded_shard_id: str):
         
     | 
| 
         @@ -813,13 +896,24 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       813 
896 
     | 
    
         
             
                        shard_offset = (shard_offset + block_n - 1) // block_n
         
     | 
| 
       814 
897 
     | 
    
         
             
                        shard_size = (shard_size + block_n - 1) // block_n
         
     | 
| 
       815 
898 
     | 
    
         | 
| 
       816 
     | 
    
         
            -
                    param 
     | 
| 
       817 
     | 
    
         
            -
                         
     | 
| 
       818 
     | 
    
         
            -
             
     | 
| 
       819 
     | 
    
         
            -
             
     | 
| 
       820 
     | 
    
         
            -
             
     | 
| 
       821 
     | 
    
         
            -
             
     | 
| 
       822 
     | 
    
         
            -
             
     | 
| 
      
 899 
     | 
    
         
            +
                    if isinstance(param, _ColumnvLLMParameter):
         
     | 
| 
      
 900 
     | 
    
         
            +
                        load_column_qkv_weight(
         
     | 
| 
      
 901 
     | 
    
         
            +
                            param,
         
     | 
| 
      
 902 
     | 
    
         
            +
                            loaded_weight,
         
     | 
| 
      
 903 
     | 
    
         
            +
                            num_heads=self.num_kv_head_replicas,
         
     | 
| 
      
 904 
     | 
    
         
            +
                            shard_id=loaded_shard_id,
         
     | 
| 
      
 905 
     | 
    
         
            +
                            shard_offset=shard_offset,
         
     | 
| 
      
 906 
     | 
    
         
            +
                            shard_size=shard_size,
         
     | 
| 
      
 907 
     | 
    
         
            +
                            tp_rank=self.tp_rank,
         
     | 
| 
      
 908 
     | 
    
         
            +
                        )
         
     | 
| 
      
 909 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 910 
     | 
    
         
            +
                        param.load_qkv_weight(
         
     | 
| 
      
 911 
     | 
    
         
            +
                            loaded_weight=loaded_weight,
         
     | 
| 
      
 912 
     | 
    
         
            +
                            num_heads=self.num_kv_head_replicas,
         
     | 
| 
      
 913 
     | 
    
         
            +
                            shard_id=loaded_shard_id,
         
     | 
| 
      
 914 
     | 
    
         
            +
                            shard_offset=shard_offset,
         
     | 
| 
      
 915 
     | 
    
         
            +
                            shard_size=shard_size,
         
     | 
| 
      
 916 
     | 
    
         
            +
                        )
         
     | 
| 
       823 
917 
     | 
    
         | 
| 
       824 
918 
     | 
    
         
             
                def weight_loader(
         
     | 
| 
       825 
919 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -839,12 +933,9 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       839 
933 
     | 
    
         
             
                        return
         
     | 
| 
       840 
934 
     | 
    
         | 
| 
       841 
935 
     | 
    
         
             
                    if is_gguf_weight:
         
     | 
| 
       842 
     | 
    
         
            -
                        tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
       843 
     | 
    
         
            -
                        tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
       844 
     | 
    
         
            -
             
     | 
| 
       845 
936 
     | 
    
         
             
                        output_dim = getattr(param, "output_dim", None)
         
     | 
| 
       846 
     | 
    
         
            -
                        shard_size = loaded_weight.size(output_dim) // tp_size
         
     | 
| 
       847 
     | 
    
         
            -
                        start_idx = tp_rank * shard_size
         
     | 
| 
      
 937 
     | 
    
         
            +
                        shard_size = loaded_weight.size(output_dim) // self.tp_size
         
     | 
| 
      
 938 
     | 
    
         
            +
                        start_idx = self.tp_rank * shard_size
         
     | 
| 
       848 
939 
     | 
    
         | 
| 
       849 
940 
     | 
    
         
             
                        loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
         
     | 
| 
       850 
941 
     | 
    
         | 
| 
         @@ -871,7 +962,9 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       871 
962 
     | 
    
         
             
                                    param_data, loaded_weight, 0
         
     | 
| 
       872 
963 
     | 
    
         
             
                                )
         
     | 
| 
       873 
964 
     | 
    
         | 
| 
       874 
     | 
    
         
            -
                            assert  
     | 
| 
      
 965 
     | 
    
         
            +
                            assert (
         
     | 
| 
      
 966 
     | 
    
         
            +
                                param_data.shape == loaded_weight.shape
         
     | 
| 
      
 967 
     | 
    
         
            +
                            ), f"{param_data.shape=}, {loaded_weight.shape=}"
         
     | 
| 
       875 
968 
     | 
    
         
             
                            param_data.copy_(loaded_weight)
         
     | 
| 
       876 
969 
     | 
    
         
             
                            return
         
     | 
| 
       877 
970 
     | 
    
         
             
                        shard_offsets = [
         
     | 
| 
         @@ -933,7 +1026,6 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       933 
1026 
     | 
    
         
             
                            self.weight_loader(param, loaded_weight_shard, shard_id)
         
     | 
| 
       934 
1027 
     | 
    
         
             
                        return
         
     | 
| 
       935 
1028 
     | 
    
         | 
| 
       936 
     | 
    
         
            -
                    tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
       937 
1029 
     | 
    
         
             
                    assert loaded_shard_id in ["q", "k", "v"]
         
     | 
| 
       938 
1030 
     | 
    
         | 
| 
       939 
1031 
     | 
    
         
             
                    # If output dim is defined, use the default loading process.
         
     | 
| 
         @@ -983,9 +1075,9 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       983 
1075 
     | 
    
         | 
| 
       984 
1076 
     | 
    
         
             
                        param_data = param_data.narrow(output_dim, shard_offset, shard_size)
         
     | 
| 
       985 
1077 
     | 
    
         
             
                        if loaded_shard_id == "q":
         
     | 
| 
       986 
     | 
    
         
            -
                            shard_id = tp_rank
         
     | 
| 
      
 1078 
     | 
    
         
            +
                            shard_id = self.tp_rank
         
     | 
| 
       987 
1079 
     | 
    
         
             
                        else:
         
     | 
| 
       988 
     | 
    
         
            -
                            shard_id = tp_rank // self.num_kv_head_replicas
         
     | 
| 
      
 1080 
     | 
    
         
            +
                            shard_id = self.tp_rank // self.num_kv_head_replicas
         
     | 
| 
       989 
1081 
     | 
    
         
             
                        start_idx = shard_id * shard_size
         
     | 
| 
       990 
1082 
     | 
    
         | 
| 
       991 
1083 
     | 
    
         
             
                        # bitsandbytes loads the weights of the specific portion
         
     | 
| 
         @@ -1013,7 +1105,9 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       1013 
1105 
     | 
    
         
             
                                "for all partitions."
         
     | 
| 
       1014 
1106 
     | 
    
         
             
                            )
         
     | 
| 
       1015 
1107 
     | 
    
         | 
| 
       1016 
     | 
    
         
            -
                    assert  
     | 
| 
      
 1108 
     | 
    
         
            +
                    assert (
         
     | 
| 
      
 1109 
     | 
    
         
            +
                        param_data.shape == loaded_weight.shape
         
     | 
| 
      
 1110 
     | 
    
         
            +
                    ), f"{param_data.shape=}, {loaded_weight.shape=}"
         
     | 
| 
       1017 
1111 
     | 
    
         
             
                    param_data.copy_(loaded_weight)
         
     | 
| 
       1018 
1112 
     | 
    
         | 
| 
       1019 
1113 
     | 
    
         | 
| 
         @@ -1054,6 +1148,9 @@ class RowParallelLinear(LinearBase): 
     | 
|
| 
       1054 
1148 
     | 
    
         
             
                    reduce_results: bool = True,
         
     | 
| 
       1055 
1149 
     | 
    
         
             
                    quant_config: Optional[QuantizationConfig] = None,
         
     | 
| 
       1056 
1150 
     | 
    
         
             
                    prefix: str = "",
         
     | 
| 
      
 1151 
     | 
    
         
            +
                    tp_rank: Optional[int] = None,
         
     | 
| 
      
 1152 
     | 
    
         
            +
                    tp_size: Optional[int] = None,
         
     | 
| 
      
 1153 
     | 
    
         
            +
                    use_presharded_weights: bool = False,
         
     | 
| 
       1057 
1154 
     | 
    
         
             
                ):
         
     | 
| 
       1058 
1155 
     | 
    
         
             
                    super().__init__(
         
     | 
| 
       1059 
1156 
     | 
    
         
             
                        input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
         
     | 
| 
         @@ -1063,10 +1160,14 @@ class RowParallelLinear(LinearBase): 
     | 
|
| 
       1063 
1160 
     | 
    
         
             
                    self.reduce_results = reduce_results
         
     | 
| 
       1064 
1161 
     | 
    
         | 
| 
       1065 
1162 
     | 
    
         
             
                    # Divide the weight matrix along the last dimension.
         
     | 
| 
       1066 
     | 
    
         
            -
                     
     | 
| 
       1067 
     | 
    
         
            -
             
     | 
| 
      
 1163 
     | 
    
         
            +
                    if tp_rank is None:
         
     | 
| 
      
 1164 
     | 
    
         
            +
                        tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
      
 1165 
     | 
    
         
            +
                    if tp_size is None:
         
     | 
| 
      
 1166 
     | 
    
         
            +
                        tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
      
 1167 
     | 
    
         
            +
                    self.tp_rank, self.tp_size = tp_rank, tp_size
         
     | 
| 
       1068 
1168 
     | 
    
         
             
                    self.input_size_per_partition = divide(input_size, self.tp_size)
         
     | 
| 
       1069 
1169 
     | 
    
         
             
                    assert self.quant_method is not None
         
     | 
| 
      
 1170 
     | 
    
         
            +
                    self.use_presharded_weights = use_presharded_weights
         
     | 
| 
       1070 
1171 
     | 
    
         | 
| 
       1071 
1172 
     | 
    
         
             
                    self.quant_method.create_weights(
         
     | 
| 
       1072 
1173 
     | 
    
         
             
                        layer=self,
         
     | 
| 
         @@ -1100,8 +1201,6 @@ class RowParallelLinear(LinearBase): 
     | 
|
| 
       1100 
1201 
     | 
    
         
             
                        self.register_parameter("bias", None)
         
     | 
| 
       1101 
1202 
     | 
    
         | 
| 
       1102 
1203 
     | 
    
         
             
                def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
         
     | 
| 
       1103 
     | 
    
         
            -
                    tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
       1104 
     | 
    
         
            -
                    tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
       1105 
1204 
     | 
    
         
             
                    input_dim = getattr(param, "input_dim", None)
         
     | 
| 
       1106 
1205 
     | 
    
         
             
                    use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
         
     | 
| 
       1107 
1206 
     | 
    
         | 
| 
         @@ -1115,15 +1214,19 @@ class RowParallelLinear(LinearBase): 
     | 
|
| 
       1115 
1214 
     | 
    
         
             
                    if is_gguf_weight and isinstance(param, UninitializedParameter):
         
     | 
| 
       1116 
1215 
     | 
    
         
             
                        weight_shape = list(loaded_weight.shape)
         
     | 
| 
       1117 
1216 
     | 
    
         
             
                        if input_dim:
         
     | 
| 
       1118 
     | 
    
         
            -
                            weight_shape[input_dim] = weight_shape[input_dim] // tp_size
         
     | 
| 
      
 1217 
     | 
    
         
            +
                            weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
         
     | 
| 
       1119 
1218 
     | 
    
         
             
                        param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
         
     | 
| 
       1120 
1219 
     | 
    
         | 
| 
       1121 
1220 
     | 
    
         
             
                    param_data = param.data
         
     | 
| 
       1122 
1221 
     | 
    
         
             
                    # bitsandbytes loads the weights of the specific portion
         
     | 
| 
       1123 
1222 
     | 
    
         
             
                    # no need to narrow here
         
     | 
| 
       1124 
     | 
    
         
            -
                    if  
     | 
| 
      
 1223 
     | 
    
         
            +
                    if (
         
     | 
| 
      
 1224 
     | 
    
         
            +
                        input_dim is not None
         
     | 
| 
      
 1225 
     | 
    
         
            +
                        and not use_bitsandbytes_4bit
         
     | 
| 
      
 1226 
     | 
    
         
            +
                        and not self.use_presharded_weights
         
     | 
| 
      
 1227 
     | 
    
         
            +
                    ):
         
     | 
| 
       1125 
1228 
     | 
    
         
             
                        shard_size = param_data.shape[input_dim]
         
     | 
| 
       1126 
     | 
    
         
            -
                        start_idx = tp_rank * shard_size
         
     | 
| 
      
 1229 
     | 
    
         
            +
                        start_idx = self.tp_rank * shard_size
         
     | 
| 
       1127 
1230 
     | 
    
         
             
                        loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
         
     | 
| 
       1128 
1231 
     | 
    
         | 
| 
       1129 
1232 
     | 
    
         
             
                    # Special case for loading scales off disk, which often do not
         
     | 
| 
         @@ -1131,7 +1234,9 @@ class RowParallelLinear(LinearBase): 
     | 
|
| 
       1131 
1234 
     | 
    
         
             
                    if len(loaded_weight.shape) == 0:
         
     | 
| 
       1132 
1235 
     | 
    
         
             
                        loaded_weight = loaded_weight.reshape(1)
         
     | 
| 
       1133 
1236 
     | 
    
         | 
| 
       1134 
     | 
    
         
            -
                    assert  
     | 
| 
      
 1237 
     | 
    
         
            +
                    assert (
         
     | 
| 
      
 1238 
     | 
    
         
            +
                        param_data.shape == loaded_weight.shape
         
     | 
| 
      
 1239 
     | 
    
         
            +
                    ), f"{param_data.shape=}, {loaded_weight.shape=}"
         
     | 
| 
       1135 
1240 
     | 
    
         
             
                    param_data.copy_(loaded_weight)
         
     | 
| 
       1136 
1241 
     | 
    
         | 
| 
       1137 
1242 
     | 
    
         
             
                def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
         
     | 
| 
         @@ -1148,11 +1253,10 @@ class RowParallelLinear(LinearBase): 
     | 
|
| 
       1148 
1253 
     | 
    
         
             
                    if self.input_is_parallel:
         
     | 
| 
       1149 
1254 
     | 
    
         
             
                        input_parallel = input_
         
     | 
| 
       1150 
1255 
     | 
    
         
             
                    else:
         
     | 
| 
       1151 
     | 
    
         
            -
                        tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
       1152 
1256 
     | 
    
         
             
                        splitted_input = split_tensor_along_last_dim(
         
     | 
| 
       1153 
1257 
     | 
    
         
             
                            input_, num_partitions=self.tp_size
         
     | 
| 
       1154 
1258 
     | 
    
         
             
                        )
         
     | 
| 
       1155 
     | 
    
         
            -
                        input_parallel = splitted_input[tp_rank].contiguous()
         
     | 
| 
      
 1259 
     | 
    
         
            +
                        input_parallel = splitted_input[self.tp_rank].contiguous()
         
     | 
| 
       1156 
1260 
     | 
    
         | 
| 
       1157 
1261 
     | 
    
         
             
                    # Matrix multiply.
         
     | 
| 
       1158 
1262 
     | 
    
         
             
                    assert self.quant_method is not None
         
     | 
| 
         @@ -74,11 +74,6 @@ class LogitsMetadata: 
     | 
|
| 
       74 
74 
     | 
    
         | 
| 
       75 
75 
     | 
    
         
             
                @classmethod
         
     | 
| 
       76 
76 
     | 
    
         
             
                def from_forward_batch(cls, forward_batch: ForwardBatch):
         
     | 
| 
       77 
     | 
    
         
            -
                    if forward_batch.spec_info:
         
     | 
| 
       78 
     | 
    
         
            -
                        capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
         
     | 
| 
       79 
     | 
    
         
            -
                    else:
         
     | 
| 
       80 
     | 
    
         
            -
                        capture_hidden_mode = CaptureHiddenMode.NULL
         
     | 
| 
       81 
     | 
    
         
            -
             
     | 
| 
       82 
77 
     | 
    
         
             
                    if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
         
     | 
| 
       83 
78 
     | 
    
         
             
                        extend_return_logprob = True
         
     | 
| 
       84 
79 
     | 
    
         
             
                        extend_return_top_logprob = any(
         
     | 
| 
         @@ -98,7 +93,7 @@ class LogitsMetadata: 
     | 
|
| 
       98 
93 
     | 
    
         | 
| 
       99 
94 
     | 
    
         
             
                    return cls(
         
     | 
| 
       100 
95 
     | 
    
         
             
                        forward_mode=forward_batch.forward_mode,
         
     | 
| 
       101 
     | 
    
         
            -
                        capture_hidden_mode=capture_hidden_mode,
         
     | 
| 
      
 96 
     | 
    
         
            +
                        capture_hidden_mode=forward_batch.capture_hidden_mode,
         
     | 
| 
       102 
97 
     | 
    
         
             
                        extend_return_logprob=extend_return_logprob,
         
     | 
| 
       103 
98 
     | 
    
         
             
                        extend_return_top_logprob=extend_return_top_logprob,
         
     | 
| 
       104 
99 
     | 
    
         
             
                        extend_seq_lens=forward_batch.extend_seq_lens,
         
     | 
| 
         @@ -122,6 +117,11 @@ class LogitsProcessor(nn.Module): 
     | 
|
| 
       122 
117 
     | 
    
         
             
                    self.final_logit_softcapping = getattr(
         
     | 
| 
       123 
118 
     | 
    
         
             
                        self.config, "final_logit_softcapping", None
         
     | 
| 
       124 
119 
     | 
    
         
             
                    )
         
     | 
| 
      
 120 
     | 
    
         
            +
                    if (
         
     | 
| 
      
 121 
     | 
    
         
            +
                        self.final_logit_softcapping is not None
         
     | 
| 
      
 122 
     | 
    
         
            +
                        and self.final_logit_softcapping < 0
         
     | 
| 
      
 123 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 124 
     | 
    
         
            +
                        self.final_logit_softcapping = None
         
     | 
| 
       125 
125 
     | 
    
         | 
| 
       126 
126 
     | 
    
         
             
                def forward(
         
     | 
| 
       127 
127 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -1011,11 +1011,22 @@ def fused_experts_impl( 
     | 
|
| 
       1011 
1011 
     | 
    
         
             
                            out_hidden_states[begin_chunk_idx:end_chunk_idx],
         
     | 
| 
       1012 
1012 
     | 
    
         
             
                        )
         
     | 
| 
       1013 
1013 
     | 
    
         
             
                    else:
         
     | 
| 
       1014 
     | 
    
         
            -
                         
     | 
| 
       1015 
     | 
    
         
            -
                             
     | 
| 
       1016 
     | 
    
         
            -
             
     | 
| 
       1017 
     | 
    
         
            -
                             
     | 
| 
       1018 
     | 
    
         
            -
                         
     | 
| 
      
 1014 
     | 
    
         
            +
                        if topk_ids.shape[1] == 1:
         
     | 
| 
      
 1015 
     | 
    
         
            +
                            out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_(
         
     | 
| 
      
 1016 
     | 
    
         
            +
                                intermediate_cache3[:, 0]
         
     | 
| 
      
 1017 
     | 
    
         
            +
                            )
         
     | 
| 
      
 1018 
     | 
    
         
            +
                        elif topk_ids.shape[1] == 2:
         
     | 
| 
      
 1019 
     | 
    
         
            +
                            torch.add(
         
     | 
| 
      
 1020 
     | 
    
         
            +
                                intermediate_cache3[:, 0],
         
     | 
| 
      
 1021 
     | 
    
         
            +
                                intermediate_cache3[:, 1],
         
     | 
| 
      
 1022 
     | 
    
         
            +
                                out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
         
     | 
| 
      
 1023 
     | 
    
         
            +
                            ).squeeze(dim=1)
         
     | 
| 
      
 1024 
     | 
    
         
            +
                        elif topk_ids.shape[1] > 2:
         
     | 
| 
      
 1025 
     | 
    
         
            +
                            torch.sum(
         
     | 
| 
      
 1026 
     | 
    
         
            +
                                intermediate_cache3.view(*intermediate_cache3.shape),
         
     | 
| 
      
 1027 
     | 
    
         
            +
                                dim=1,
         
     | 
| 
      
 1028 
     | 
    
         
            +
                                out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
         
     | 
| 
      
 1029 
     | 
    
         
            +
                            )
         
     | 
| 
       1019 
1030 
     | 
    
         | 
| 
       1020 
1031 
     | 
    
         
             
                return out_hidden_states
         
     | 
| 
       1021 
1032 
     | 
    
         | 
| 
         @@ -204,6 +204,7 @@ class FusedMoE(torch.nn.Module): 
     | 
|
| 
       204 
204 
     | 
    
         
             
                    prefix: str = "",
         
     | 
| 
       205 
205 
     | 
    
         
             
                    custom_routing_function: Optional[Callable] = None,
         
     | 
| 
       206 
206 
     | 
    
         
             
                    correction_bias: Optional[torch.Tensor] = None,
         
     | 
| 
      
 207 
     | 
    
         
            +
                    use_presharded_weights: bool = False,
         
     | 
| 
       207 
208 
     | 
    
         
             
                ):
         
     | 
| 
       208 
209 
     | 
    
         
             
                    super().__init__()
         
     | 
| 
       209 
210 
     | 
    
         | 
| 
         @@ -243,6 +244,7 @@ class FusedMoE(torch.nn.Module): 
     | 
|
| 
       243 
244 
     | 
    
         
             
                        params_dtype=params_dtype,
         
     | 
| 
       244 
245 
     | 
    
         
             
                        weight_loader=self.weight_loader,
         
     | 
| 
       245 
246 
     | 
    
         
             
                    )
         
     | 
| 
      
 247 
     | 
    
         
            +
                    self.use_presharded_weights = use_presharded_weights
         
     | 
| 
       246 
248 
     | 
    
         | 
| 
       247 
249 
     | 
    
         
             
                def _load_per_tensor_weight_scale(
         
     | 
| 
       248 
250 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -395,10 +397,7 @@ class FusedMoE(torch.nn.Module): 
     | 
|
| 
       395 
397 
     | 
    
         
             
                    weight_name: str,
         
     | 
| 
       396 
398 
     | 
    
         
             
                    shard_id: str,
         
     | 
| 
       397 
399 
     | 
    
         
             
                    expert_id: int,
         
     | 
| 
       398 
     | 
    
         
            -
                    use_presharded_weights: bool = False,
         
     | 
| 
       399 
400 
     | 
    
         
             
                ) -> None:
         
     | 
| 
       400 
     | 
    
         
            -
                    self.use_presharded_weights = use_presharded_weights
         
     | 
| 
       401 
     | 
    
         
            -
             
     | 
| 
       402 
401 
     | 
    
         
             
                    # compressed-tensors checkpoints with packed weights are stored flipped
         
     | 
| 
       403 
402 
     | 
    
         
             
                    # TODO (mgoin): check self.quant_method.quant_config.quant_format
         
     | 
| 
       404 
403 
     | 
    
         
             
                    # against known CompressionFormat enum values that have this quality
         
     |