sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +2 -1
 - sglang/lang/chat_template.py +17 -0
 - sglang/launch_server_llavavid.py +1 -1
 - sglang/srt/configs/__init__.py +3 -0
 - sglang/srt/configs/model_config.py +27 -2
 - sglang/srt/configs/qwen2vl.py +133 -0
 - sglang/srt/constrained/fsm_cache.py +10 -3
 - sglang/srt/conversation.py +27 -0
 - sglang/srt/hf_transformers_utils.py +16 -1
 - sglang/srt/layers/attention/__init__.py +16 -5
 - sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
 - sglang/srt/layers/attention/flashinfer_backend.py +174 -54
 - sglang/srt/layers/attention/triton_backend.py +22 -6
 - sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
 - sglang/srt/layers/linear.py +89 -63
 - sglang/srt/layers/logits_processor.py +5 -5
 - sglang/srt/layers/rotary_embedding.py +112 -0
 - sglang/srt/layers/sampler.py +51 -39
 - sglang/srt/lora/lora.py +3 -1
 - sglang/srt/managers/data_parallel_controller.py +1 -1
 - sglang/srt/managers/detokenizer_manager.py +4 -0
 - sglang/srt/managers/image_processor.py +186 -13
 - sglang/srt/managers/io_struct.py +10 -0
 - sglang/srt/managers/schedule_batch.py +238 -68
 - sglang/srt/managers/scheduler.py +69 -50
 - sglang/srt/managers/tokenizer_manager.py +24 -4
 - sglang/srt/managers/tp_worker.py +26 -111
 - sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
 - sglang/srt/mem_cache/memory_pool.py +56 -10
 - sglang/srt/mem_cache/radix_cache.py +4 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +87 -28
 - sglang/srt/model_executor/forward_batch_info.py +83 -3
 - sglang/srt/model_executor/model_runner.py +32 -11
 - sglang/srt/models/chatglm.py +3 -3
 - sglang/srt/models/deepseek_v2.py +2 -2
 - sglang/srt/models/mllama.py +1004 -0
 - sglang/srt/models/qwen2_vl.py +724 -0
 - sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
 - sglang/srt/sampling/sampling_batch_info.py +13 -3
 - sglang/srt/sampling/sampling_params.py +5 -7
 - sglang/srt/server.py +12 -0
 - sglang/srt/server_args.py +10 -0
 - sglang/srt/utils.py +22 -0
 - sglang/test/run_eval.py +2 -0
 - sglang/test/runners.py +20 -1
 - sglang/test/srt/sampling/penaltylib/utils.py +1 -0
 - sglang/test/test_utils.py +100 -3
 - sglang/version.py +1 -1
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
 
    
        sglang/srt/layers/linear.py
    CHANGED
    
    | 
         @@ -20,8 +20,10 @@ from vllm.distributed import ( 
     | 
|
| 
       20 
20 
     | 
    
         
             
            from vllm.model_executor.layers.linear import LinearBase
         
     | 
| 
       21 
21 
     | 
    
         
             
            from vllm.model_executor.parameter import (
         
     | 
| 
       22 
22 
     | 
    
         
             
                BasevLLMParameter,
         
     | 
| 
      
 23 
     | 
    
         
            +
                PackedColumnParameter,
         
     | 
| 
       23 
24 
     | 
    
         
             
                PackedvLLMParameter,
         
     | 
| 
       24 
25 
     | 
    
         
             
                PerTensorScaleParameter,
         
     | 
| 
      
 26 
     | 
    
         
            +
                RowvLLMParameter,
         
     | 
| 
       25 
27 
     | 
    
         
             
            )
         
     | 
| 
       26 
28 
     | 
    
         | 
| 
       27 
29 
     | 
    
         
             
            from sglang.srt.layers.quantization.base_config import (
         
     | 
| 
         @@ -39,6 +41,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ 
     | 
|
| 
       39 
41 
     | 
    
         
             
                "GPTQMarlinLinearMethod",
         
     | 
| 
       40 
42 
     | 
    
         
             
                "Fp8LinearMethod",
         
     | 
| 
       41 
43 
     | 
    
         
             
                "MarlinLinearMethod",
         
     | 
| 
      
 44 
     | 
    
         
            +
                "GPTQLinearMethod",
         
     | 
| 
       42 
45 
     | 
    
         
             
            ]
         
     | 
| 
       43 
46 
     | 
    
         | 
| 
       44 
47 
     | 
    
         | 
| 
         @@ -50,7 +53,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset): 
     | 
|
| 
       50 
53 
     | 
    
         
             
                return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
         
     | 
| 
       51 
54 
     | 
    
         | 
| 
       52 
55 
     | 
    
         | 
| 
       53 
     | 
    
         
            -
            def  
     | 
| 
      
 56 
     | 
    
         
            +
            def adjust_bitsandbytes_4bit_shard(
         
     | 
| 
       54 
57 
     | 
    
         
             
                param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
         
     | 
| 
       55 
58 
     | 
    
         
             
            ) -> Tuple[int, int]:
         
     | 
| 
       56 
59 
     | 
    
         
             
                """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
         
     | 
| 
         @@ -207,7 +210,6 @@ class ReplicatedLinear(LinearBase): 
     | 
|
| 
       207 
210 
     | 
    
         
             
                        self.output_size,
         
     | 
| 
       208 
211 
     | 
    
         
             
                        self.params_dtype,
         
     | 
| 
       209 
212 
     | 
    
         
             
                        weight_loader=self.weight_loader,
         
     | 
| 
       210 
     | 
    
         
            -
                        prefix=prefix,
         
     | 
| 
       211 
213 
     | 
    
         
             
                    )
         
     | 
| 
       212 
214 
     | 
    
         | 
| 
       213 
215 
     | 
    
         
             
                    if bias:
         
     | 
| 
         @@ -315,7 +317,6 @@ class ColumnParallelLinear(LinearBase): 
     | 
|
| 
       315 
317 
     | 
    
         
             
                            if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
         
     | 
| 
       316 
318 
     | 
    
         
             
                            else self.weight_loader
         
     | 
| 
       317 
319 
     | 
    
         
             
                        ),
         
     | 
| 
       318 
     | 
    
         
            -
                        prefix=prefix,
         
     | 
| 
       319 
320 
     | 
    
         
             
                    )
         
     | 
| 
       320 
321 
     | 
    
         
             
                    if bias:
         
     | 
| 
       321 
322 
     | 
    
         
             
                        self.bias = Parameter(
         
     | 
| 
         @@ -345,8 +346,12 @@ class ColumnParallelLinear(LinearBase): 
     | 
|
| 
       345 
346 
     | 
    
         
             
                    if is_gguf_weight and isinstance(param, UninitializedParameter):
         
     | 
| 
       346 
347 
     | 
    
         
             
                        param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
         
     | 
| 
       347 
348 
     | 
    
         | 
| 
      
 349 
     | 
    
         
            +
                    use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
         
     | 
| 
      
 350 
     | 
    
         
            +
             
     | 
| 
       348 
351 
     | 
    
         
             
                    param_data = param.data
         
     | 
| 
       349 
     | 
    
         
            -
                     
     | 
| 
      
 352 
     | 
    
         
            +
                    # bitsandbytes loads the weights of the specific portion
         
     | 
| 
      
 353 
     | 
    
         
            +
                    # no need to narrow here
         
     | 
| 
      
 354 
     | 
    
         
            +
                    if output_dim is not None and not use_bitsandbytes_4bit:
         
     | 
| 
       350 
355 
     | 
    
         
             
                        shard_size = param_data.shape[output_dim]
         
     | 
| 
       351 
356 
     | 
    
         
             
                        start_idx = tp_rank * shard_size
         
     | 
| 
       352 
357 
     | 
    
         
             
                        loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
         
     | 
| 
         @@ -454,17 +459,22 @@ class MergedColumnParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       454 
459 
     | 
    
         
             
                        param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
         
     | 
| 
       455 
460 
     | 
    
         
             
                        return
         
     | 
| 
       456 
461 
     | 
    
         | 
| 
       457 
     | 
    
         
            -
                    if is_gguf_weight 
     | 
| 
       458 
     | 
    
         
            -
                         
     | 
| 
      
 462 
     | 
    
         
            +
                    if is_gguf_weight:
         
     | 
| 
      
 463 
     | 
    
         
            +
                        tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
      
 464 
     | 
    
         
            +
                        tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
      
 465 
     | 
    
         
            +
             
     | 
| 
      
 466 
     | 
    
         
            +
                        output_dim = getattr(param, "output_dim", None)
         
     | 
| 
      
 467 
     | 
    
         
            +
                        shard_size = loaded_weight.size(output_dim) // tp_size
         
     | 
| 
      
 468 
     | 
    
         
            +
                        start_idx = tp_rank * shard_size
         
     | 
| 
       459 
469 
     | 
    
         | 
| 
       460 
     | 
    
         
            -
                         
     | 
| 
       461 
     | 
    
         
            -
             
     | 
| 
       462 
     | 
    
         
            -
                         
     | 
| 
       463 
     | 
    
         
            -
                         
     | 
| 
       464 
     | 
    
         
            -
             
     | 
| 
       465 
     | 
    
         
            -
             
     | 
| 
       466 
     | 
    
         
            -
             
     | 
| 
       467 
     | 
    
         
            -
                         
     | 
| 
      
 470 
     | 
    
         
            +
                        loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
         
     | 
| 
      
 471 
     | 
    
         
            +
             
     | 
| 
      
 472 
     | 
    
         
            +
                        param.shard_id.append(loaded_shard_id)
         
     | 
| 
      
 473 
     | 
    
         
            +
                        param.shard_id_map[loaded_shard_id] = len(param.data_container)
         
     | 
| 
      
 474 
     | 
    
         
            +
                        param.data_container.append(loaded_weight)
         
     | 
| 
      
 475 
     | 
    
         
            +
                        if len(param.data_container) == 2:
         
     | 
| 
      
 476 
     | 
    
         
            +
                            self.qweight = param.materialize_nested()
         
     | 
| 
      
 477 
     | 
    
         
            +
                        return
         
     | 
| 
       468 
478 
     | 
    
         | 
| 
       469 
479 
     | 
    
         
             
                    param_data = param.data
         
     | 
| 
       470 
480 
     | 
    
         
             
                    output_dim = getattr(param, "output_dim", None)
         
     | 
| 
         @@ -526,26 +536,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       526 
536 
     | 
    
         
             
                                param, shard_size, shard_offset
         
     | 
| 
       527 
537 
     | 
    
         
             
                            )
         
     | 
| 
       528 
538 
     | 
    
         | 
| 
       529 
     | 
    
         
            -
                         
     | 
| 
       530 
     | 
    
         
            -
                        if  
     | 
| 
      
 539 
     | 
    
         
            +
                        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
         
     | 
| 
      
 540 
     | 
    
         
            +
                        if use_bitsandbytes_4bit:
         
     | 
| 
       531 
541 
     | 
    
         
             
                            shard_size = loaded_weight.shape[output_dim]
         
     | 
| 
       532 
542 
     | 
    
         
             
                            shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
         
     | 
| 
       533 
543 
     | 
    
         | 
| 
       534 
     | 
    
         
            -
                        if is_gguf_weight:
         
     | 
| 
       535 
     | 
    
         
            -
                            tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
       536 
     | 
    
         
            -
                            output_dim = getattr(param, "output_dim", None)
         
     | 
| 
       537 
     | 
    
         
            -
                            shard_shape = list(loaded_weight.shape)
         
     | 
| 
       538 
     | 
    
         
            -
                            shard_shape[output_dim] = shard_shape[output_dim] // tp_size
         
     | 
| 
       539 
     | 
    
         
            -
                            param.shard_id.append(loaded_shard_id)
         
     | 
| 
       540 
     | 
    
         
            -
                            param.shard_size[loaded_shard_id] = shard_shape
         
     | 
| 
       541 
     | 
    
         
            -
             
     | 
| 
       542 
     | 
    
         
            -
                            input_dim = getattr(param, "input_dim", None)
         
     | 
| 
       543 
     | 
    
         
            -
                            input_size = loaded_weight.shape[input_dim]
         
     | 
| 
       544 
     | 
    
         
            -
                            param_data = param_data.narrow(input_dim, 0, input_size)
         
     | 
| 
       545 
     | 
    
         
            -
             
     | 
| 
       546 
544 
     | 
    
         
             
                        param_data = param_data.narrow(output_dim, shard_offset, shard_size)
         
     | 
| 
       547 
545 
     | 
    
         
             
                        start_idx = tp_rank * shard_size
         
     | 
| 
       548 
     | 
    
         
            -
                         
     | 
| 
      
 546 
     | 
    
         
            +
                        # bitsandbytes loads the weights of the specific portion
         
     | 
| 
      
 547 
     | 
    
         
            +
                        # no need to narrow here
         
     | 
| 
      
 548 
     | 
    
         
            +
                        if not use_bitsandbytes_4bit:
         
     | 
| 
      
 549 
     | 
    
         
            +
                            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
         
     | 
| 
       549 
550 
     | 
    
         
             
                    # Special case for AQLM codebooks.
         
     | 
| 
       550 
551 
     | 
    
         
             
                    elif is_metadata:
         
     | 
| 
       551 
552 
     | 
    
         
             
                        # metadata indicates fixed size concatenated along dim 0
         
     | 
| 
         @@ -595,7 +596,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       595 
596 
     | 
    
         
             
                        # If quantized, we need to adjust the offset and size to account
         
     | 
| 
       596 
597 
     | 
    
         
             
                        # for the packing.
         
     | 
| 
       597 
598 
     | 
    
         
             
                        if (
         
     | 
| 
       598 
     | 
    
         
            -
                            isinstance(param, PackedvLLMParameter)
         
     | 
| 
      
 599 
     | 
    
         
            +
                            isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
         
     | 
| 
       599 
600 
     | 
    
         
             
                            and param.packed_dim == param.output_dim
         
     | 
| 
       600 
601 
     | 
    
         
             
                        ):
         
     | 
| 
       601 
602 
     | 
    
         
             
                            shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
         
     | 
| 
         @@ -617,7 +618,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       617 
618 
     | 
    
         
             
                        if isinstance(param, PerTensorScaleParameter):
         
     | 
| 
       618 
619 
     | 
    
         
             
                            param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
         
     | 
| 
       619 
620 
     | 
    
         
             
                            return
         
     | 
| 
       620 
     | 
    
         
            -
                        elif type(param)  
     | 
| 
      
 621 
     | 
    
         
            +
                        elif type(param) in (RowvLLMParameter, BasevLLMParameter):
         
     | 
| 
       621 
622 
     | 
    
         
             
                            param.load_merged_column_weight(loaded_weight=loaded_weight)
         
     | 
| 
       622 
623 
     | 
    
         
             
                            return
         
     | 
| 
       623 
624 
     | 
    
         
             
                        self._load_fused_module_from_checkpoint(param, loaded_weight)
         
     | 
| 
         @@ -760,7 +761,7 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       760 
761 
     | 
    
         
             
                        # If quantized, we need to adjust the offset and size to account
         
     | 
| 
       761 
762 
     | 
    
         
             
                        # for the packing.
         
     | 
| 
       762 
763 
     | 
    
         
             
                        if (
         
     | 
| 
       763 
     | 
    
         
            -
                            isinstance(param, PackedvLLMParameter)
         
     | 
| 
      
 764 
     | 
    
         
            +
                            isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
         
     | 
| 
       764 
765 
     | 
    
         
             
                            and param.packed_dim == param.output_dim
         
     | 
| 
       765 
766 
     | 
    
         
             
                        ):
         
     | 
| 
       766 
767 
     | 
    
         
             
                            shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
         
     | 
| 
         @@ -780,10 +781,10 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       780 
781 
     | 
    
         
             
                ):
         
     | 
| 
       781 
782 
     | 
    
         
             
                    if loaded_shard_id is None:  # special case for certain models
         
     | 
| 
       782 
783 
     | 
    
         
             
                        if isinstance(param, PerTensorScaleParameter):
         
     | 
| 
       783 
     | 
    
         
            -
                            param. 
     | 
| 
      
 784 
     | 
    
         
            +
                            param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
         
     | 
| 
       784 
785 
     | 
    
         
             
                            return
         
     | 
| 
       785 
     | 
    
         
            -
                        elif type(param)  
     | 
| 
       786 
     | 
    
         
            -
                            param. 
     | 
| 
      
 786 
     | 
    
         
            +
                        elif type(param) in (RowvLLMParameter, BasevLLMParameter):
         
     | 
| 
      
 787 
     | 
    
         
            +
                            param.load_qkv_weight(loaded_weight=loaded_weight)
         
     | 
| 
       787 
788 
     | 
    
         
             
                            return
         
     | 
| 
       788 
789 
     | 
    
         
             
                        self._load_fused_module_from_checkpoint(param, loaded_weight)
         
     | 
| 
       789 
790 
     | 
    
         
             
                        return
         
     | 
| 
         @@ -818,17 +819,22 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       818 
819 
     | 
    
         
             
                        param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
         
     | 
| 
       819 
820 
     | 
    
         
             
                        return
         
     | 
| 
       820 
821 
     | 
    
         | 
| 
       821 
     | 
    
         
            -
                    if is_gguf_weight 
     | 
| 
       822 
     | 
    
         
            -
                         
     | 
| 
      
 822 
     | 
    
         
            +
                    if is_gguf_weight:
         
     | 
| 
      
 823 
     | 
    
         
            +
                        tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
      
 824 
     | 
    
         
            +
                        tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
      
 825 
     | 
    
         
            +
             
     | 
| 
      
 826 
     | 
    
         
            +
                        output_dim = getattr(param, "output_dim", None)
         
     | 
| 
      
 827 
     | 
    
         
            +
                        shard_size = loaded_weight.size(output_dim) // tp_size
         
     | 
| 
      
 828 
     | 
    
         
            +
                        start_idx = tp_rank * shard_size
         
     | 
| 
       823 
829 
     | 
    
         | 
| 
       824 
     | 
    
         
            -
                         
     | 
| 
       825 
     | 
    
         
            -
             
     | 
| 
       826 
     | 
    
         
            -
                         
     | 
| 
       827 
     | 
    
         
            -
                         
     | 
| 
       828 
     | 
    
         
            -
             
     | 
| 
       829 
     | 
    
         
            -
             
     | 
| 
       830 
     | 
    
         
            -
             
     | 
| 
       831 
     | 
    
         
            -
                         
     | 
| 
      
 830 
     | 
    
         
            +
                        loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
         
     | 
| 
      
 831 
     | 
    
         
            +
             
     | 
| 
      
 832 
     | 
    
         
            +
                        param.shard_id.append(loaded_shard_id)
         
     | 
| 
      
 833 
     | 
    
         
            +
                        param.shard_id_map[loaded_shard_id] = len(param.data_container)
         
     | 
| 
      
 834 
     | 
    
         
            +
                        param.data_container.append(loaded_weight)
         
     | 
| 
      
 835 
     | 
    
         
            +
                        if len(param.data_container) == 3:
         
     | 
| 
      
 836 
     | 
    
         
            +
                            self.qweight = param.materialize_nested()
         
     | 
| 
      
 837 
     | 
    
         
            +
                        return
         
     | 
| 
       832 
838 
     | 
    
         | 
| 
       833 
839 
     | 
    
         
             
                    param_data = param.data
         
     | 
| 
       834 
840 
     | 
    
         
             
                    output_dim = getattr(param, "output_dim", None)
         
     | 
| 
         @@ -863,6 +869,8 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       863 
869 
     | 
    
         
             
                                self.total_num_kv_heads * self.head_size,
         
     | 
| 
       864 
870 
     | 
    
         
             
                            ),
         
     | 
| 
       865 
871 
     | 
    
         
             
                        ]
         
     | 
| 
      
 872 
     | 
    
         
            +
                        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
         
     | 
| 
      
 873 
     | 
    
         
            +
             
     | 
| 
       866 
874 
     | 
    
         
             
                        packed_dim = getattr(param, "packed_dim", None)
         
     | 
| 
       867 
875 
     | 
    
         
             
                        for shard_id, shard_offset, shard_size in shard_offsets:
         
     | 
| 
       868 
876 
     | 
    
         
             
                            # Special case for Quantized Weights.
         
     | 
| 
         @@ -877,6 +885,29 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       877 
885 
     | 
    
         
             
                                    param, shard_size, shard_offset
         
     | 
| 
       878 
886 
     | 
    
         
             
                                )
         
     | 
| 
       879 
887 
     | 
    
         | 
| 
      
 888 
     | 
    
         
            +
                            if use_bitsandbytes_4bit:
         
     | 
| 
      
 889 
     | 
    
         
            +
                                orig_qkv_offsets = {
         
     | 
| 
      
 890 
     | 
    
         
            +
                                    "q": (0, self.total_num_heads * self.head_size),
         
     | 
| 
      
 891 
     | 
    
         
            +
                                    "k": (
         
     | 
| 
      
 892 
     | 
    
         
            +
                                        self.total_num_heads * self.head_size,
         
     | 
| 
      
 893 
     | 
    
         
            +
                                        self.total_num_kv_heads * self.head_size,
         
     | 
| 
      
 894 
     | 
    
         
            +
                                    ),
         
     | 
| 
      
 895 
     | 
    
         
            +
                                    "v": (
         
     | 
| 
      
 896 
     | 
    
         
            +
                                        (self.total_num_heads + self.total_num_kv_heads)
         
     | 
| 
      
 897 
     | 
    
         
            +
                                        * self.head_size,
         
     | 
| 
      
 898 
     | 
    
         
            +
                                        self.total_num_kv_heads * self.head_size,
         
     | 
| 
      
 899 
     | 
    
         
            +
                                    ),
         
     | 
| 
      
 900 
     | 
    
         
            +
                                    "total": (
         
     | 
| 
      
 901 
     | 
    
         
            +
                                        (self.total_num_heads + 2 * self.total_num_kv_heads)
         
     | 
| 
      
 902 
     | 
    
         
            +
                                        * self.head_size,
         
     | 
| 
      
 903 
     | 
    
         
            +
                                        0,
         
     | 
| 
      
 904 
     | 
    
         
            +
                                    ),
         
     | 
| 
      
 905 
     | 
    
         
            +
                                }
         
     | 
| 
      
 906 
     | 
    
         
            +
             
     | 
| 
      
 907 
     | 
    
         
            +
                                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
         
     | 
| 
      
 908 
     | 
    
         
            +
                                    param, orig_qkv_offsets, shard_id
         
     | 
| 
      
 909 
     | 
    
         
            +
                                )
         
     | 
| 
      
 910 
     | 
    
         
            +
             
     | 
| 
       880 
911 
     | 
    
         
             
                            loaded_weight_shard = loaded_weight.narrow(
         
     | 
| 
       881 
912 
     | 
    
         
             
                                output_dim, shard_offset, shard_size
         
     | 
| 
       882 
913 
     | 
    
         
             
                            )
         
     | 
| 
         @@ -910,8 +941,8 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       910 
941 
     | 
    
         
             
                                param, shard_size, shard_offset
         
     | 
| 
       911 
942 
     | 
    
         
             
                            )
         
     | 
| 
       912 
943 
     | 
    
         | 
| 
       913 
     | 
    
         
            -
                         
     | 
| 
       914 
     | 
    
         
            -
                        if  
     | 
| 
      
 944 
     | 
    
         
            +
                        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
         
     | 
| 
      
 945 
     | 
    
         
            +
                        if use_bitsandbytes_4bit:
         
     | 
| 
       915 
946 
     | 
    
         
             
                            orig_qkv_offsets = {
         
     | 
| 
       916 
947 
     | 
    
         
             
                                "q": (0, self.num_heads * self.head_size),
         
     | 
| 
       917 
948 
     | 
    
         
             
                                "k": (
         
     | 
| 
         @@ -927,29 +958,22 @@ class QKVParallelLinear(ColumnParallelLinear): 
     | 
|
| 
       927 
958 
     | 
    
         
             
                                    0,
         
     | 
| 
       928 
959 
     | 
    
         
             
                                ),
         
     | 
| 
       929 
960 
     | 
    
         
             
                            }
         
     | 
| 
       930 
     | 
    
         
            -
                            shard_size, shard_offset =  
     | 
| 
      
 961 
     | 
    
         
            +
                            shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
         
     | 
| 
       931 
962 
     | 
    
         
             
                                param, orig_qkv_offsets, loaded_shard_id
         
     | 
| 
       932 
963 
     | 
    
         
             
                            )
         
     | 
| 
       933 
964 
     | 
    
         | 
| 
       934 
     | 
    
         
            -
                        if is_gguf_weight:
         
     | 
| 
       935 
     | 
    
         
            -
                            tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
       936 
     | 
    
         
            -
                            output_dim = getattr(param, "output_dim", None)
         
     | 
| 
       937 
     | 
    
         
            -
                            shard_shape = list(loaded_weight.shape)
         
     | 
| 
       938 
     | 
    
         
            -
                            shard_shape[output_dim] = shard_shape[output_dim] // tp_size
         
     | 
| 
       939 
     | 
    
         
            -
                            param.shard_id.append(loaded_shard_id)
         
     | 
| 
       940 
     | 
    
         
            -
                            param.shard_size[loaded_shard_id] = shard_shape
         
     | 
| 
       941 
     | 
    
         
            -
             
     | 
| 
       942 
     | 
    
         
            -
                            input_dim = getattr(param, "input_dim", None)
         
     | 
| 
       943 
     | 
    
         
            -
                            input_size = loaded_weight.shape[input_dim]
         
     | 
| 
       944 
     | 
    
         
            -
                            param_data = param_data.narrow(input_dim, 0, input_size)
         
     | 
| 
       945 
     | 
    
         
            -
             
     | 
| 
       946 
965 
     | 
    
         
             
                        param_data = param_data.narrow(output_dim, shard_offset, shard_size)
         
     | 
| 
       947 
966 
     | 
    
         
             
                        if loaded_shard_id == "q":
         
     | 
| 
       948 
967 
     | 
    
         
             
                            shard_id = tp_rank
         
     | 
| 
       949 
968 
     | 
    
         
             
                        else:
         
     | 
| 
       950 
969 
     | 
    
         
             
                            shard_id = tp_rank // self.num_kv_head_replicas
         
     | 
| 
       951 
970 
     | 
    
         
             
                        start_idx = shard_id * shard_size
         
     | 
| 
       952 
     | 
    
         
            -
             
     | 
| 
      
 971 
     | 
    
         
            +
             
     | 
| 
      
 972 
     | 
    
         
            +
                        # bitsandbytes loads the weights of the specific portion
         
     | 
| 
      
 973 
     | 
    
         
            +
                        # no need to narrow here
         
     | 
| 
      
 974 
     | 
    
         
            +
                        if not use_bitsandbytes_4bit:
         
     | 
| 
      
 975 
     | 
    
         
            +
                            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
         
     | 
| 
      
 976 
     | 
    
         
            +
             
     | 
| 
       953 
977 
     | 
    
         
             
                    # Special case for for AQLM codebooks.
         
     | 
| 
       954 
978 
     | 
    
         
             
                    elif is_metadata:
         
     | 
| 
       955 
979 
     | 
    
         
             
                        # metadata indicates fixed size concatenated along dim 0
         
     | 
| 
         @@ -1037,7 +1061,6 @@ class RowParallelLinear(LinearBase): 
     | 
|
| 
       1037 
1061 
     | 
    
         
             
                            if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
         
     | 
| 
       1038 
1062 
     | 
    
         
             
                            else self.weight_loader
         
     | 
| 
       1039 
1063 
     | 
    
         
             
                        ),
         
     | 
| 
       1040 
     | 
    
         
            -
                        prefix=prefix,
         
     | 
| 
       1041 
1064 
     | 
    
         
             
                    )
         
     | 
| 
       1042 
1065 
     | 
    
         
             
                    if not reduce_results and (bias and not skip_bias_add):
         
     | 
| 
       1043 
1066 
     | 
    
         
             
                        raise ValueError(
         
     | 
| 
         @@ -1061,6 +1084,7 @@ class RowParallelLinear(LinearBase): 
     | 
|
| 
       1061 
1084 
     | 
    
         
             
                    tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
       1062 
1085 
     | 
    
         
             
                    tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
       1063 
1086 
     | 
    
         
             
                    input_dim = getattr(param, "input_dim", None)
         
     | 
| 
      
 1087 
     | 
    
         
            +
                    use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
         
     | 
| 
       1064 
1088 
     | 
    
         | 
| 
       1065 
1089 
     | 
    
         
             
                    # Special case for GGUF
         
     | 
| 
       1066 
1090 
     | 
    
         
             
                    is_gguf_weight = getattr(param, "is_gguf_weight", False)
         
     | 
| 
         @@ -1076,7 +1100,9 @@ class RowParallelLinear(LinearBase): 
     | 
|
| 
       1076 
1100 
     | 
    
         
             
                        param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
         
     | 
| 
       1077 
1101 
     | 
    
         | 
| 
       1078 
1102 
     | 
    
         
             
                    param_data = param.data
         
     | 
| 
       1079 
     | 
    
         
            -
                     
     | 
| 
      
 1103 
     | 
    
         
            +
                    # bitsandbytes loads the weights of the specific portion
         
     | 
| 
      
 1104 
     | 
    
         
            +
                    # no need to narrow here
         
     | 
| 
      
 1105 
     | 
    
         
            +
                    if input_dim is not None and not use_bitsandbytes_4bit:
         
     | 
| 
       1080 
1106 
     | 
    
         
             
                        shard_size = param_data.shape[input_dim]
         
     | 
| 
       1081 
1107 
     | 
    
         
             
                        start_idx = tp_rank * shard_size
         
     | 
| 
       1082 
1108 
     | 
    
         
             
                        loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
         
     | 
| 
         @@ -33,17 +33,17 @@ class LogitsProcessorOutput: 
     | 
|
| 
       33 
33 
     | 
    
         
             
                # The logits of the next tokens.       shape: [#seq, vocab_size]
         
     | 
| 
       34 
34 
     | 
    
         
             
                next_token_logits: torch.Tensor
         
     | 
| 
       35 
35 
     | 
    
         
             
                # The logprobs of the next tokens.     shape: [#seq, vocab_size]
         
     | 
| 
       36 
     | 
    
         
            -
                next_token_logprobs: torch.Tensor
         
     | 
| 
      
 36 
     | 
    
         
            +
                next_token_logprobs: torch.Tensor = None
         
     | 
| 
       37 
37 
     | 
    
         | 
| 
       38 
38 
     | 
    
         
             
                # The normlaized logprobs of prompts.  shape: [#seq]
         
     | 
| 
       39 
     | 
    
         
            -
                normalized_prompt_logprobs: torch.Tensor
         
     | 
| 
      
 39 
     | 
    
         
            +
                normalized_prompt_logprobs: torch.Tensor = None
         
     | 
| 
       40 
40 
     | 
    
         
             
                # The logprobs of input tokens.        shape: [#token, vocab_size]
         
     | 
| 
       41 
     | 
    
         
            -
                input_token_logprobs: torch.Tensor
         
     | 
| 
      
 41 
     | 
    
         
            +
                input_token_logprobs: torch.Tensor = None
         
     | 
| 
       42 
42 
     | 
    
         | 
| 
       43 
43 
     | 
    
         
             
                # The logprob and id of the top-k tokens in input positions.  shape [#seq, #token, k] of Tuple(logprob, token_id)
         
     | 
| 
       44 
     | 
    
         
            -
                input_top_logprobs: List
         
     | 
| 
      
 44 
     | 
    
         
            +
                input_top_logprobs: List = None
         
     | 
| 
       45 
45 
     | 
    
         
             
                # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
         
     | 
| 
       46 
     | 
    
         
            -
                output_top_logprobs: List
         
     | 
| 
      
 46 
     | 
    
         
            +
                output_top_logprobs: List = None
         
     | 
| 
       47 
47 
     | 
    
         | 
| 
       48 
48 
     | 
    
         | 
| 
       49 
49 
     | 
    
         
             
            @dataclasses.dataclass
         
     | 
| 
         @@ -0,0 +1,112 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            """
         
     | 
| 
      
 2 
     | 
    
         
            +
            Copyright 2023-2024 SGLang Team
         
     | 
| 
      
 3 
     | 
    
         
            +
            Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
      
 4 
     | 
    
         
            +
            you may not use this file except in compliance with the License.
         
     | 
| 
      
 5 
     | 
    
         
            +
            You may obtain a copy of the License at
         
     | 
| 
      
 6 
     | 
    
         
            +
                http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
      
 7 
     | 
    
         
            +
            Unless required by applicable law or agreed to in writing, software
         
     | 
| 
      
 8 
     | 
    
         
            +
            distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
      
 9 
     | 
    
         
            +
            WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
      
 10 
     | 
    
         
            +
            See the License for the specific language governing permissions and
         
     | 
| 
      
 11 
     | 
    
         
            +
            limitations under the License.
         
     | 
| 
      
 12 
     | 
    
         
            +
            """
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
            """MRotaryEmbedding"""
         
     | 
| 
      
 15 
     | 
    
         
            +
            from typing import Any, Dict, List, Optional, Tuple, Union
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
            class MRotaryEmbedding:
         
     | 
| 
      
 21 
     | 
    
         
            +
                """Rotary Embedding with Multimodal Sections."""
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 24 
     | 
    
         
            +
                def get_input_positions(
         
     | 
| 
      
 25 
     | 
    
         
            +
                    input_tokens: torch.Tensor,
         
     | 
| 
      
 26 
     | 
    
         
            +
                    image_grid_thw: Union[List[List[int]], torch.Tensor],
         
     | 
| 
      
 27 
     | 
    
         
            +
                    vision_start_token_id: int,
         
     | 
| 
      
 28 
     | 
    
         
            +
                    spatial_merge_size: int,
         
     | 
| 
      
 29 
     | 
    
         
            +
                    context_len: int = 0,
         
     | 
| 
      
 30 
     | 
    
         
            +
                ) -> Tuple[List[List[int]], int]:
         
     | 
| 
      
 31 
     | 
    
         
            +
                    """Get mrope input positions and delta value."""
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
                    if isinstance(image_grid_thw, torch.Tensor):
         
     | 
| 
      
 34 
     | 
    
         
            +
                        image_grid_thw = image_grid_thw.tolist()
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
                    vision_start_indices = torch.argwhere(
         
     | 
| 
      
 37 
     | 
    
         
            +
                        input_tokens == vision_start_token_id
         
     | 
| 
      
 38 
     | 
    
         
            +
                    ).squeeze(1)
         
     | 
| 
      
 39 
     | 
    
         
            +
                    image_indices = vision_start_indices + 1
         
     | 
| 
      
 40 
     | 
    
         
            +
                    image_nums = image_indices.shape[0]
         
     | 
| 
      
 41 
     | 
    
         
            +
                    llm_pos_ids_list: list = []
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
                    st = 0
         
     | 
| 
      
 44 
     | 
    
         
            +
                    input_tokens_len = input_tokens.shape[0]
         
     | 
| 
      
 45 
     | 
    
         
            +
                    for image_index in range(image_nums):
         
     | 
| 
      
 46 
     | 
    
         
            +
                        ed = image_indices[image_index].item()
         
     | 
| 
      
 47 
     | 
    
         
            +
                        t, h, w = (
         
     | 
| 
      
 48 
     | 
    
         
            +
                            image_grid_thw[image_index][0],
         
     | 
| 
      
 49 
     | 
    
         
            +
                            image_grid_thw[image_index][1],
         
     | 
| 
      
 50 
     | 
    
         
            +
                            image_grid_thw[image_index][2],
         
     | 
| 
      
 51 
     | 
    
         
            +
                        )
         
     | 
| 
      
 52 
     | 
    
         
            +
                        llm_grid_t, llm_grid_h, llm_grid_w = (
         
     | 
| 
      
 53 
     | 
    
         
            +
                            t,
         
     | 
| 
      
 54 
     | 
    
         
            +
                            h // spatial_merge_size,
         
     | 
| 
      
 55 
     | 
    
         
            +
                            w // spatial_merge_size,
         
     | 
| 
      
 56 
     | 
    
         
            +
                        )
         
     | 
| 
      
 57 
     | 
    
         
            +
                        text_len = ed - st
         
     | 
| 
      
 58 
     | 
    
         
            +
             
     | 
| 
      
 59 
     | 
    
         
            +
                        st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
         
     | 
| 
      
 60 
     | 
    
         
            +
                        llm_pos_ids_list.append(
         
     | 
| 
      
 61 
     | 
    
         
            +
                            torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
         
     | 
| 
      
 62 
     | 
    
         
            +
                        )
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
                        t_index = (
         
     | 
| 
      
 65 
     | 
    
         
            +
                            torch.arange(llm_grid_t)
         
     | 
| 
      
 66 
     | 
    
         
            +
                            .view(-1, 1)
         
     | 
| 
      
 67 
     | 
    
         
            +
                            .expand(-1, llm_grid_h * llm_grid_w)
         
     | 
| 
      
 68 
     | 
    
         
            +
                            .flatten()
         
     | 
| 
      
 69 
     | 
    
         
            +
                        )
         
     | 
| 
      
 70 
     | 
    
         
            +
                        h_index = (
         
     | 
| 
      
 71 
     | 
    
         
            +
                            torch.arange(llm_grid_h)
         
     | 
| 
      
 72 
     | 
    
         
            +
                            .view(1, -1, 1)
         
     | 
| 
      
 73 
     | 
    
         
            +
                            .expand(llm_grid_t, -1, llm_grid_w)
         
     | 
| 
      
 74 
     | 
    
         
            +
                            .flatten()
         
     | 
| 
      
 75 
     | 
    
         
            +
                        )
         
     | 
| 
      
 76 
     | 
    
         
            +
                        w_index = (
         
     | 
| 
      
 77 
     | 
    
         
            +
                            torch.arange(llm_grid_w)
         
     | 
| 
      
 78 
     | 
    
         
            +
                            .view(1, 1, -1)
         
     | 
| 
      
 79 
     | 
    
         
            +
                            .expand(llm_grid_t, llm_grid_h, -1)
         
     | 
| 
      
 80 
     | 
    
         
            +
                            .flatten()
         
     | 
| 
      
 81 
     | 
    
         
            +
                        )
         
     | 
| 
      
 82 
     | 
    
         
            +
                        llm_pos_ids_list.append(
         
     | 
| 
      
 83 
     | 
    
         
            +
                            torch.stack([t_index, h_index, w_index]) + text_len + st_idx
         
     | 
| 
      
 84 
     | 
    
         
            +
                        )
         
     | 
| 
      
 85 
     | 
    
         
            +
                        st = ed + llm_grid_t * llm_grid_h * llm_grid_w
         
     | 
| 
      
 86 
     | 
    
         
            +
             
     | 
| 
      
 87 
     | 
    
         
            +
                    if st < input_tokens_len:
         
     | 
| 
      
 88 
     | 
    
         
            +
                        st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
         
     | 
| 
      
 89 
     | 
    
         
            +
                        text_len = input_tokens_len - st
         
     | 
| 
      
 90 
     | 
    
         
            +
                        llm_pos_ids_list.append(
         
     | 
| 
      
 91 
     | 
    
         
            +
                            torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
         
     | 
| 
      
 92 
     | 
    
         
            +
                        )
         
     | 
| 
      
 93 
     | 
    
         
            +
             
     | 
| 
      
 94 
     | 
    
         
            +
                    llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
         
     | 
| 
      
 95 
     | 
    
         
            +
                    llm_positions = llm_positions[:, context_len:]
         
     | 
| 
      
 96 
     | 
    
         
            +
                    mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item()
         
     | 
| 
      
 97 
     | 
    
         
            +
                    return llm_positions.tolist(), mrope_position_delta
         
     | 
| 
      
 98 
     | 
    
         
            +
             
     | 
| 
      
 99 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 100 
     | 
    
         
            +
                def get_next_input_positions(
         
     | 
| 
      
 101 
     | 
    
         
            +
                    mrope_position_delta: int,
         
     | 
| 
      
 102 
     | 
    
         
            +
                    context_len: int,
         
     | 
| 
      
 103 
     | 
    
         
            +
                    seq_len: int,
         
     | 
| 
      
 104 
     | 
    
         
            +
                ) -> List[List[int]]:
         
     | 
| 
      
 105 
     | 
    
         
            +
                    return [
         
     | 
| 
      
 106 
     | 
    
         
            +
                        list(
         
     | 
| 
      
 107 
     | 
    
         
            +
                            range(
         
     | 
| 
      
 108 
     | 
    
         
            +
                                context_len + mrope_position_delta, seq_len + mrope_position_delta
         
     | 
| 
      
 109 
     | 
    
         
            +
                            )
         
     | 
| 
      
 110 
     | 
    
         
            +
                        )
         
     | 
| 
      
 111 
     | 
    
         
            +
                        for _ in range(3)
         
     | 
| 
      
 112 
     | 
    
         
            +
                    ]
         
     | 
    
        sglang/srt/layers/sampler.py
    CHANGED
    
    | 
         @@ -1,4 +1,5 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            import logging
         
     | 
| 
      
 2 
     | 
    
         
            +
            import os
         
     | 
| 
       2 
3 
     | 
    
         
             
            from typing import Union
         
     | 
| 
       3 
4 
     | 
    
         | 
| 
       4 
5 
     | 
    
         
             
            import torch
         
     | 
| 
         @@ -17,6 +18,11 @@ if is_flashinfer_available(): 
     | 
|
| 
       17 
18 
     | 
    
         
             
                    top_p_renorm_prob,
         
     | 
| 
       18 
19 
     | 
    
         
             
                )
         
     | 
| 
       19 
20 
     | 
    
         | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
            # Crash on warning if we are running CI tests
         
     | 
| 
      
 23 
     | 
    
         
            +
            crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
             
     | 
| 
       20 
26 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
       21 
27 
     | 
    
         | 
| 
       22 
28 
     | 
    
         | 
| 
         @@ -33,56 +39,62 @@ class Sampler(nn.Module): 
     | 
|
| 
       33 
39 
     | 
    
         
             
                    if isinstance(logits, LogitsProcessorOutput):
         
     | 
| 
       34 
40 
     | 
    
         
             
                        logits = logits.next_token_logits
         
     | 
| 
       35 
41 
     | 
    
         | 
| 
       36 
     | 
    
         
            -
                    # Post process logits
         
     | 
| 
       37 
42 
     | 
    
         
             
                    logits = logits.contiguous()
         
     | 
| 
       38 
     | 
    
         
            -
             
     | 
| 
       39 
     | 
    
         
            -
                     
     | 
| 
       40 
     | 
    
         
            -
             
     | 
| 
       41 
     | 
    
         
            -
             
     | 
| 
       42 
     | 
    
         
            -
             
     | 
| 
       43 
     | 
    
         
            -
                    if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
         
     | 
| 
       44 
     | 
    
         
            -
                        logger.warning("Detected errors during sampling! NaN in the probability.")
         
     | 
| 
       45 
     | 
    
         
            -
                        probs = torch.where(
         
     | 
| 
       46 
     | 
    
         
            -
                            torch.isnan(probs), torch.full_like(probs, 1e-10), probs
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
                    if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
         
     | 
| 
      
 45 
     | 
    
         
            +
                        logger.warning("Detected errors during sampling! NaN in the logits.")
         
     | 
| 
      
 46 
     | 
    
         
            +
                        logits = torch.where(
         
     | 
| 
      
 47 
     | 
    
         
            +
                            torch.isnan(logits), torch.full_like(logits, -1e5), logits
         
     | 
| 
       47 
48 
     | 
    
         
             
                        )
         
     | 
| 
      
 49 
     | 
    
         
            +
                        exit(1) if crash_on_warning else None
         
     | 
| 
       48 
50 
     | 
    
         | 
| 
       49 
51 
     | 
    
         
             
                    if sampling_info.is_all_greedy:
         
     | 
| 
       50 
52 
     | 
    
         
             
                        # Use torch.argmax if all requests use greedy sampling
         
     | 
| 
       51 
     | 
    
         
            -
                        batch_next_token_ids = torch.argmax( 
     | 
| 
       52 
     | 
    
         
            -
                     
     | 
| 
       53 
     | 
    
         
            -
                         
     | 
| 
       54 
     | 
    
         
            -
                         
     | 
| 
       55 
     | 
    
         
            -
             
     | 
| 
       56 
     | 
    
         
            -
                         
     | 
| 
       57 
     | 
    
         
            -
                         
     | 
| 
       58 
     | 
    
         
            -
             
     | 
| 
       59 
     | 
    
         
            -
             
     | 
| 
       60 
     | 
    
         
            -
                             
     | 
| 
       61 
     | 
    
         
            -
             
     | 
| 
      
 53 
     | 
    
         
            +
                        batch_next_token_ids = torch.argmax(logits, -1)
         
     | 
| 
      
 54 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 55 
     | 
    
         
            +
                        # Post process logits
         
     | 
| 
      
 56 
     | 
    
         
            +
                        logits.div_(sampling_info.temperatures)
         
     | 
| 
      
 57 
     | 
    
         
            +
                        probs = torch.softmax(logits, dim=-1)
         
     | 
| 
      
 58 
     | 
    
         
            +
                        logits = None
         
     | 
| 
      
 59 
     | 
    
         
            +
                        del logits
         
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
                        if global_server_args_dict["sampling_backend"] == "flashinfer":
         
     | 
| 
      
 62 
     | 
    
         
            +
                            max_top_k_round, batch_size = 32, probs.shape[0]
         
     | 
| 
      
 63 
     | 
    
         
            +
                            uniform_samples = torch.rand(
         
     | 
| 
      
 64 
     | 
    
         
            +
                                (max_top_k_round, batch_size), device=probs.device
         
     | 
| 
       62 
65 
     | 
    
         
             
                            )
         
     | 
| 
       63 
     | 
    
         
            -
             
     | 
| 
       64 
     | 
    
         
            -
             
     | 
| 
      
 66 
     | 
    
         
            +
                            if sampling_info.need_min_p_sampling:
         
     | 
| 
      
 67 
     | 
    
         
            +
                                probs = top_k_renorm_prob(probs, sampling_info.top_ks)
         
     | 
| 
      
 68 
     | 
    
         
            +
                                probs = top_p_renorm_prob(probs, sampling_info.top_ps)
         
     | 
| 
      
 69 
     | 
    
         
            +
                                batch_next_token_ids, success = min_p_sampling_from_probs(
         
     | 
| 
      
 70 
     | 
    
         
            +
                                    probs, uniform_samples, sampling_info.min_ps
         
     | 
| 
      
 71 
     | 
    
         
            +
                                )
         
     | 
| 
      
 72 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 73 
     | 
    
         
            +
                                batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
         
     | 
| 
      
 74 
     | 
    
         
            +
                                    probs,
         
     | 
| 
      
 75 
     | 
    
         
            +
                                    uniform_samples,
         
     | 
| 
      
 76 
     | 
    
         
            +
                                    sampling_info.top_ks,
         
     | 
| 
      
 77 
     | 
    
         
            +
                                    sampling_info.top_ps,
         
     | 
| 
      
 78 
     | 
    
         
            +
                                    filter_apply_order="joint",
         
     | 
| 
      
 79 
     | 
    
         
            +
                                )
         
     | 
| 
      
 80 
     | 
    
         
            +
             
     | 
| 
      
 81 
     | 
    
         
            +
                            if not torch.all(success):
         
     | 
| 
      
 82 
     | 
    
         
            +
                                logger.warning("Detected errors during sampling!")
         
     | 
| 
      
 83 
     | 
    
         
            +
                                batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
         
     | 
| 
      
 84 
     | 
    
         
            +
                        elif global_server_args_dict["sampling_backend"] == "pytorch":
         
     | 
| 
      
 85 
     | 
    
         
            +
                            # A slower fallback implementation with torch native operations.
         
     | 
| 
      
 86 
     | 
    
         
            +
                            batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
         
     | 
| 
       65 
87 
     | 
    
         
             
                                probs,
         
     | 
| 
       66 
     | 
    
         
            -
                                uniform_samples,
         
     | 
| 
       67 
88 
     | 
    
         
             
                                sampling_info.top_ks,
         
     | 
| 
       68 
89 
     | 
    
         
             
                                sampling_info.top_ps,
         
     | 
| 
       69 
     | 
    
         
            -
                                 
     | 
| 
      
 90 
     | 
    
         
            +
                                sampling_info.min_ps,
         
     | 
| 
      
 91 
     | 
    
         
            +
                            )
         
     | 
| 
      
 92 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 93 
     | 
    
         
            +
                            raise ValueError(
         
     | 
| 
      
 94 
     | 
    
         
            +
                                f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
         
     | 
| 
       70 
95 
     | 
    
         
             
                            )
         
     | 
| 
       71 
96 
     | 
    
         | 
| 
       72 
     | 
    
         
            -
             
     | 
| 
       73 
     | 
    
         
            -
                            logger.warning("Detected errors during sampling!")
         
     | 
| 
       74 
     | 
    
         
            -
                            batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
         
     | 
| 
       75 
     | 
    
         
            -
                    elif global_server_args_dict["sampling_backend"] == "pytorch":
         
     | 
| 
       76 
     | 
    
         
            -
                        # Here we provide a slower fallback implementation.
         
     | 
| 
       77 
     | 
    
         
            -
                        batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
         
     | 
| 
       78 
     | 
    
         
            -
                            probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
         
     | 
| 
       79 
     | 
    
         
            -
                        )
         
     | 
| 
       80 
     | 
    
         
            -
                    else:
         
     | 
| 
       81 
     | 
    
         
            -
                        raise ValueError(
         
     | 
| 
       82 
     | 
    
         
            -
                            f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
         
     | 
| 
       83 
     | 
    
         
            -
                        )
         
     | 
| 
       84 
     | 
    
         
            -
             
     | 
| 
       85 
     | 
    
         
            -
                    return batch_next_token_ids
         
     | 
| 
      
 97 
     | 
    
         
            +
                    return batch_next_token_ids.to(torch.int32)
         
     | 
| 
       86 
98 
     | 
    
         | 
| 
       87 
99 
     | 
    
         | 
| 
       88 
100 
     | 
    
         
             
            def top_k_top_p_min_p_sampling_from_probs_torch(
         
     | 
    
        sglang/srt/lora/lora.py
    CHANGED
    
    | 
         @@ -351,7 +351,9 @@ class LoRAAdapter(nn.Module): 
     | 
|
| 
       351 
351 
     | 
    
         
             
                    loader = DefaultModelLoader(self.load_config)
         
     | 
| 
       352 
352 
     | 
    
         
             
                    revision = getattr(self.config.hf_config, "revision", None)
         
     | 
| 
       353 
353 
     | 
    
         
             
                    for name, loaded_weight in loader._get_weights_iterator(
         
     | 
| 
       354 
     | 
    
         
            -
                         
     | 
| 
      
 354 
     | 
    
         
            +
                        DefaultModelLoader.Source(
         
     | 
| 
      
 355 
     | 
    
         
            +
                            model_path, revision=revision, fall_back_to_pt=True
         
     | 
| 
      
 356 
     | 
    
         
            +
                        )
         
     | 
| 
       355 
357 
     | 
    
         
             
                    ):
         
     | 
| 
       356 
358 
     | 
    
         
             
                        match = re.search(r"layers\.(\d+)\.", name)
         
     | 
| 
       357 
359 
     | 
    
         
             
                        if match is not None:
         
     | 
| 
         @@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import ( 
     | 
|
| 
       27 
27 
     | 
    
         
             
                BatchEmbeddingOut,
         
     | 
| 
       28 
28 
     | 
    
         
             
                BatchStrOut,
         
     | 
| 
       29 
29 
     | 
    
         
             
                BatchTokenIDOut,
         
     | 
| 
      
 30 
     | 
    
         
            +
                GetMemPoolSizeReqOutput,
         
     | 
| 
       30 
31 
     | 
    
         
             
                UpdateWeightReqOutput,
         
     | 
| 
       31 
32 
     | 
    
         
             
            )
         
     | 
| 
       32 
33 
     | 
    
         
             
            from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
         
     | 
| 
         @@ -111,6 +112,9 @@ class DetokenizerManager: 
     | 
|
| 
       111 
112 
     | 
    
         
             
                            # If it is a weight update request, no detokenization is needed.
         
     | 
| 
       112 
113 
     | 
    
         
             
                            self.send_to_tokenizer.send_pyobj(recv_obj)
         
     | 
| 
       113 
114 
     | 
    
         
             
                            continue
         
     | 
| 
      
 115 
     | 
    
         
            +
                        elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
         
     | 
| 
      
 116 
     | 
    
         
            +
                            self.send_to_tokenizer.send_pyobj(recv_obj)
         
     | 
| 
      
 117 
     | 
    
         
            +
                            continue
         
     | 
| 
       114 
118 
     | 
    
         
             
                        elif self.tokenizer is None:
         
     | 
| 
       115 
119 
     | 
    
         
             
                            # If the tokenizer is skipped, no detokenization is needed
         
     | 
| 
       116 
120 
     | 
    
         
             
                            self.send_to_tokenizer.send_pyobj(recv_obj)
         
     |