sglang 0.3.1__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +7 -2
- sglang/global_config.py +5 -13
- sglang/lang/interpreter.py +0 -3
- sglang/srt/constrained/fsm_cache.py +5 -1
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention_backend.py +12 -12
- sglang/srt/layers/fused_moe/layer.py +27 -7
- sglang/srt/layers/layernorm.py +12 -0
- sglang/srt/layers/sampler.py +32 -97
- sglang/srt/lora/lora_manager.py +11 -8
- sglang/srt/managers/schedule_batch.py +1 -0
- sglang/srt/managers/tp_worker.py +8 -7
- sglang/srt/model_executor/cuda_graph_runner.py +12 -1
- sglang/srt/model_executor/model_runner.py +24 -41
- sglang/srt/models/deepseek_v2.py +6 -1
- sglang/srt/models/minicpm3.py +5 -1
- sglang/srt/models/olmoe.py +415 -0
- sglang/srt/sampling/sampling_batch_info.py +3 -50
- sglang/srt/server.py +6 -1
- sglang/srt/server_args.py +34 -1
- sglang/srt/utils.py +7 -51
- sglang/test/test_utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.3.1.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +2 -2
- {sglang-0.3.1.dist-info → sglang-0.3.1.post1.dist-info}/RECORD +28 -27
- {sglang-0.3.1.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
- {sglang-0.3.1.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
    
        sglang/bench_latency.py
    CHANGED
    
    | @@ -63,7 +63,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch | |
| 63 63 | 
             
            from sglang.srt.model_executor.model_runner import ModelRunner
         | 
| 64 64 | 
             
            from sglang.srt.sampling.sampling_params import SamplingParams
         | 
| 65 65 | 
             
            from sglang.srt.server_args import ServerArgs
         | 
| 66 | 
            -
            from sglang.srt.utils import suppress_other_loggers
         | 
| 66 | 
            +
            from sglang.srt.utils import kill_child_process, suppress_other_loggers
         | 
| 67 67 |  | 
| 68 68 |  | 
| 69 69 | 
             
            @dataclasses.dataclass
         | 
| @@ -502,4 +502,9 @@ if __name__ == "__main__": | |
| 502 502 | 
             
                    format="%(message)s",
         | 
| 503 503 | 
             
                )
         | 
| 504 504 |  | 
| 505 | 
            -
                 | 
| 505 | 
            +
                try:
         | 
| 506 | 
            +
                    main(server_args, bench_args)
         | 
| 507 | 
            +
                except Exception as e:
         | 
| 508 | 
            +
                    raise e
         | 
| 509 | 
            +
                finally:
         | 
| 510 | 
            +
                    kill_child_process(os.getpid(), including_parent=False)
         | 
    
        sglang/global_config.py
    CHANGED
    
    | @@ -1,5 +1,7 @@ | |
| 1 1 | 
             
            """Global configurations"""
         | 
| 2 2 |  | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
             | 
| 3 5 |  | 
| 4 6 | 
             
            class GlobalConfig:
         | 
| 5 7 | 
             
                def __init__(self):
         | 
| @@ -16,30 +18,20 @@ class GlobalConfig: | |
| 16 18 | 
             
                    self.base_min_new_token_ratio = 0.1
         | 
| 17 19 | 
             
                    self.new_token_ratio_decay = 0.001
         | 
| 18 20 |  | 
| 19 | 
            -
                    # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
         | 
| 20 | 
            -
                    # This can improve the speed for large batch sizes during prefill.
         | 
| 21 | 
            -
                    self.layer_sync_threshold = 8192
         | 
| 22 | 
            -
             | 
| 23 21 | 
             
                    # Runtime constants: others
         | 
| 24 22 | 
             
                    self.num_continue_decode_steps = 10
         | 
| 25 23 | 
             
                    self.retract_decode_steps = 20
         | 
| 26 | 
            -
                    self.flashinfer_workspace_size =  | 
| 24 | 
            +
                    self.flashinfer_workspace_size = os.environ.get(
         | 
| 25 | 
            +
                        "FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
         | 
| 26 | 
            +
                    )
         | 
| 27 27 |  | 
| 28 28 | 
             
                    # Output tokenization configs
         | 
| 29 29 | 
             
                    self.skip_special_tokens_in_output = True
         | 
| 30 30 | 
             
                    self.spaces_between_special_tokens_in_out = True
         | 
| 31 31 |  | 
| 32 32 | 
             
                    # Interpreter optimization configs
         | 
| 33 | 
            -
                    self.eager_fill_image = False
         | 
| 34 33 | 
             
                    self.enable_precache_with_tracing = True
         | 
| 35 34 | 
             
                    self.enable_parallel_encoding = True
         | 
| 36 | 
            -
                    self.enable_parallel_decoding = True
         | 
| 37 | 
            -
             | 
| 38 | 
            -
                    # Deprecated
         | 
| 39 | 
            -
                    # Choices: ["no_adjust", "adjust_cache"]
         | 
| 40 | 
            -
                    # no_adjust: Do not adjust the position embedding of KV cache.
         | 
| 41 | 
            -
                    # adjust_cache: Adjust the position embedding of KV cache.
         | 
| 42 | 
            -
                    self.concate_and_append_mode = "no_adjust"
         | 
| 43 35 |  | 
| 44 36 |  | 
| 45 37 | 
             
            global_config = GlobalConfig()
         | 
    
        sglang/lang/interpreter.py
    CHANGED
    
    | @@ -434,9 +434,6 @@ class StreamExecutor: | |
| 434 434 | 
             
                    self.cur_images.append((path, base64_data))
         | 
| 435 435 | 
             
                    self.text_ += self.chat_template.image_token
         | 
| 436 436 |  | 
| 437 | 
            -
                    # if global_config.eager_fill_image:
         | 
| 438 | 
            -
                    #     self.backend.fill_image(self)
         | 
| 439 | 
            -
             | 
| 440 437 | 
             
                def _spec_gen(self, sampling_params):
         | 
| 441 438 | 
             
                    stop = sampling_params.stop
         | 
| 442 439 | 
             
                    max_new_tokens = sampling_params.max_new_tokens
         | 
| @@ -29,6 +29,7 @@ class FSMCache(BaseToolCache): | |
| 29 29 | 
             
                    tokenizer_args_dict,
         | 
| 30 30 | 
             
                    enable=True,
         | 
| 31 31 | 
             
                    skip_tokenizer_init=False,
         | 
| 32 | 
            +
                    constrained_json_whitespace_pattern=None,
         | 
| 32 33 | 
             
                ):
         | 
| 33 34 | 
             
                    super().__init__(enable=enable)
         | 
| 34 35 |  | 
| @@ -63,11 +64,14 @@ class FSMCache(BaseToolCache): | |
| 63 64 | 
             
                        self.outlines_tokenizer.vocabulary = (
         | 
| 64 65 | 
             
                            self.outlines_tokenizer.tokenizer.get_vocab()
         | 
| 65 66 | 
             
                        )
         | 
| 67 | 
            +
                    self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern
         | 
| 66 68 |  | 
| 67 69 | 
             
                def init_value(self, key):
         | 
| 68 70 | 
             
                    key_type, key_string = key
         | 
| 69 71 | 
             
                    if key_type == "json":
         | 
| 70 | 
            -
                        regex = build_regex_from_schema( | 
| 72 | 
            +
                        regex = build_regex_from_schema(
         | 
| 73 | 
            +
                            key_string, whitespace_pattern=self.constrained_json_whitespace_pattern
         | 
| 74 | 
            +
                        )
         | 
| 71 75 | 
             
                    elif key_type == "regex":
         | 
| 72 76 | 
             
                        regex = key_string
         | 
| 73 77 | 
             
                    else:
         | 
    
        sglang/srt/layers/activation.py
    CHANGED
    
    | @@ -13,6 +13,7 @@ limitations under the License. | |
| 13 13 |  | 
| 14 14 | 
             
            """Fused operators for activation layers."""
         | 
| 15 15 |  | 
| 16 | 
            +
            import logging
         | 
| 16 17 | 
             
            from typing import Optional
         | 
| 17 18 |  | 
| 18 19 | 
             
            import torch
         | 
| @@ -28,6 +29,10 @@ from vllm.model_executor.custom_op import CustomOp | |
| 28 29 | 
             
            from vllm.model_executor.layers.quantization import QuantizationConfig
         | 
| 29 30 | 
             
            from vllm.model_executor.utils import set_weight_attrs
         | 
| 30 31 |  | 
| 32 | 
            +
            from sglang.srt.utils import is_hip
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 35 | 
            +
             | 
| 31 36 |  | 
| 32 37 | 
             
            class SiluAndMul(CustomOp):
         | 
| 33 38 | 
             
                def forward_native(self, x: torch.Tensor) -> torch.Tensor:
         | 
| @@ -135,3 +140,10 @@ def get_act_fn( | |
| 135 140 | 
             
                        act_fn, intermediate_size, input_is_parallel, params_dtype
         | 
| 136 141 | 
             
                    )
         | 
| 137 142 | 
             
                return act_fn
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            if is_hip():
         | 
| 146 | 
            +
                logger.info(
         | 
| 147 | 
            +
                    "FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
         | 
| 148 | 
            +
                )
         | 
| 149 | 
            +
                from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
         | 
| @@ -12,22 +12,26 @@ from typing import TYPE_CHECKING | |
| 12 12 |  | 
| 13 13 | 
             
            import torch
         | 
| 14 14 | 
             
            import torch.nn as nn
         | 
| 15 | 
            -
            from flashinfer import (
         | 
| 16 | 
            -
                BatchDecodeWithPagedKVCacheWrapper,
         | 
| 17 | 
            -
                BatchPrefillWithPagedKVCacheWrapper,
         | 
| 18 | 
            -
                BatchPrefillWithRaggedKVCacheWrapper,
         | 
| 19 | 
            -
            )
         | 
| 20 | 
            -
            from flashinfer.cascade import merge_state
         | 
| 21 | 
            -
            from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
         | 
| 22 15 |  | 
| 23 16 | 
             
            from sglang.global_config import global_config
         | 
| 24 17 | 
             
            from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
         | 
| 25 18 | 
             
            from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
         | 
| 26 19 | 
             
            from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
         | 
| 20 | 
            +
            from sglang.srt.utils import is_hip
         | 
| 27 21 |  | 
| 28 22 | 
             
            if TYPE_CHECKING:
         | 
| 29 23 | 
             
                from sglang.srt.model_executor.model_runner import ModelRunner
         | 
| 30 24 |  | 
| 25 | 
            +
            # ROCm: flashinfer available later
         | 
| 26 | 
            +
            if not is_hip():
         | 
| 27 | 
            +
                from flashinfer import (
         | 
| 28 | 
            +
                    BatchDecodeWithPagedKVCacheWrapper,
         | 
| 29 | 
            +
                    BatchPrefillWithPagedKVCacheWrapper,
         | 
| 30 | 
            +
                    BatchPrefillWithRaggedKVCacheWrapper,
         | 
| 31 | 
            +
                )
         | 
| 32 | 
            +
                from flashinfer.cascade import merge_state
         | 
| 33 | 
            +
                from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
         | 
| 34 | 
            +
             | 
| 31 35 |  | 
| 32 36 | 
             
            class AttentionBackend(ABC):
         | 
| 33 37 | 
             
                """The base class of attention backends"""
         | 
| @@ -150,7 +154,7 @@ class FlashInferAttnBackend(AttentionBackend): | |
| 150 154 | 
             
                        # Some heuristics to check whether to use ragged forward
         | 
| 151 155 | 
             
                        use_ragged = False
         | 
| 152 156 | 
             
                        if (
         | 
| 153 | 
            -
                             | 
| 157 | 
            +
                            torch.sum(input_metadata.seq_lens).item() >= 4096
         | 
| 154 158 | 
             
                            and self.model_runner.sliding_window_size is None
         | 
| 155 159 | 
             
                        ):
         | 
| 156 160 | 
             
                            use_ragged = True
         | 
| @@ -301,10 +305,6 @@ class FlashInferAttnBackend(AttentionBackend): | |
| 301 305 | 
             
                            layer.layer_id, input_metadata.out_cache_loc, k, v
         | 
| 302 306 | 
             
                        )
         | 
| 303 307 |  | 
| 304 | 
            -
                        if total_num_tokens >= global_config.layer_sync_threshold:
         | 
| 305 | 
            -
                            # TODO: Revisit this. Why is this synchronize needed?
         | 
| 306 | 
            -
                            torch.cuda.synchronize()
         | 
| 307 | 
            -
             | 
| 308 308 | 
             
                    return o.view(-1, layer.tp_q_head_num * layer.head_dim)
         | 
| 309 309 |  | 
| 310 310 | 
             
                def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
         | 
| @@ -18,6 +18,8 @@ from vllm.model_executor.layers.quantization.base_config import ( | |
| 18 18 | 
             
            from vllm.model_executor.layers.quantization.fp8 import Fp8Config
         | 
| 19 19 | 
             
            from vllm.model_executor.utils import set_weight_attrs
         | 
| 20 20 |  | 
| 21 | 
            +
            from sglang.srt.utils import is_hip
         | 
| 22 | 
            +
             | 
| 21 23 | 
             
            logger = init_logger(__name__)
         | 
| 22 24 |  | 
| 23 25 |  | 
| @@ -381,6 +383,7 @@ from torch.nn import Module | |
| 381 383 | 
             
            from vllm import _custom_ops as ops
         | 
| 382 384 | 
             
            from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
         | 
| 383 385 | 
             
                all_close_1d,
         | 
| 386 | 
            +
                normalize_e4m3fn_to_e4m3fnuz,
         | 
| 384 387 | 
             
                per_tensor_dequantize,
         | 
| 385 388 | 
             
            )
         | 
| 386 389 | 
             
            from vllm.utils import print_warning_once
         | 
| @@ -479,14 +482,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): | |
| 479 482 |  | 
| 480 483 | 
             
                def process_weights_after_loading(self, layer: Module) -> None:
         | 
| 481 484 |  | 
| 482 | 
            -
                    # If checkpoint is fp16, quantize in place.
         | 
| 485 | 
            +
                    # If checkpoint is fp16 or bfloat16, quantize in place.
         | 
| 483 486 | 
             
                    if not self.quant_config.is_checkpoint_fp8_serialized:
         | 
| 484 | 
            -
                         | 
| 485 | 
            -
             | 
| 486 | 
            -
                        )
         | 
| 487 | 
            -
                        w2_weight = torch.empty_like(
         | 
| 488 | 
            -
                            layer.w2_weight.data, dtype=torch.float8_e4m3fn
         | 
| 489 | 
            -
                        )
         | 
| 487 | 
            +
                        # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
         | 
| 488 | 
            +
                        fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
         | 
| 489 | 
            +
                        w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
         | 
| 490 | 
            +
                        w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
         | 
| 490 491 |  | 
| 491 492 | 
             
                        # Re-initialize w13_scale because we directly quantize
         | 
| 492 493 | 
             
                        # merged w13 weights and generate a single scaling factor.
         | 
| @@ -534,6 +535,25 @@ class Fp8MoEMethod(FusedMoEMethodBase): | |
| 534 535 | 
             
                                layer.a2_scale.max(), requires_grad=False
         | 
| 535 536 | 
             
                            )
         | 
| 536 537 |  | 
| 538 | 
            +
                        # If ROCm, normalize the weights and scales to e4m3fnuz
         | 
| 539 | 
            +
                        if is_hip():
         | 
| 540 | 
            +
                            # Normalize the weights and scales
         | 
| 541 | 
            +
                            w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz(
         | 
| 542 | 
            +
                                layer.w13_weight, layer.w13_scale, layer.a13_scale
         | 
| 543 | 
            +
                            )
         | 
| 544 | 
            +
                            w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz(
         | 
| 545 | 
            +
                                layer.w2_weight, layer.w2_scale, layer.a2_scale
         | 
| 546 | 
            +
                            )
         | 
| 547 | 
            +
                            # Reset the parameters
         | 
| 548 | 
            +
                            layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
         | 
| 549 | 
            +
                            layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
         | 
| 550 | 
            +
                            if a13_scale is not None:
         | 
| 551 | 
            +
                                layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False)
         | 
| 552 | 
            +
                            layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
         | 
| 553 | 
            +
                            layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)
         | 
| 554 | 
            +
                            if a2_scale is not None:
         | 
| 555 | 
            +
                                layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False)
         | 
| 556 | 
            +
             | 
| 537 557 | 
             
                        # Fp8 moe kernel needs single weight scale for w13 per expert.
         | 
| 538 558 | 
             
                        # We take the max then dequant and requant each expert.
         | 
| 539 559 | 
             
                        assert layer.w13_scale is not None
         | 
    
        sglang/srt/layers/layernorm.py
    CHANGED
    
    | @@ -15,6 +15,7 @@ limitations under the License. | |
| 15 15 |  | 
| 16 16 | 
             
            """Fused operators for normalization layers."""
         | 
| 17 17 |  | 
| 18 | 
            +
            import logging
         | 
| 18 19 | 
             
            from typing import Optional, Tuple, Union
         | 
| 19 20 |  | 
| 20 21 | 
             
            import torch
         | 
| @@ -27,6 +28,10 @@ from flashinfer.norm import ( | |
| 27 28 | 
             
            )
         | 
| 28 29 | 
             
            from vllm.model_executor.custom_op import CustomOp
         | 
| 29 30 |  | 
| 31 | 
            +
            from sglang.srt.utils import is_hip
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 34 | 
            +
             | 
| 30 35 |  | 
| 31 36 | 
             
            class RMSNorm(CustomOp):
         | 
| 32 37 | 
             
                def __init__(
         | 
| @@ -109,3 +114,10 @@ class GemmaRMSNorm(CustomOp): | |
| 109 114 | 
             
                        return x, residual
         | 
| 110 115 | 
             
                    out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
         | 
| 111 116 | 
             
                    return out
         | 
| 117 | 
            +
             | 
| 118 | 
            +
             | 
| 119 | 
            +
            if is_hip():
         | 
| 120 | 
            +
                logger.info(
         | 
| 121 | 
            +
                    "FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
         | 
| 122 | 
            +
                )
         | 
| 123 | 
            +
                from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
         | 
    
        sglang/srt/layers/sampler.py
    CHANGED
    
    | @@ -1,51 +1,28 @@ | |
| 1 | 
            -
            import dataclasses
         | 
| 2 1 | 
             
            import logging
         | 
| 3 | 
            -
            from typing import  | 
| 2 | 
            +
            from typing import Union
         | 
| 4 3 |  | 
| 5 4 | 
             
            import torch
         | 
| 6 | 
            -
            from  | 
| 7 | 
            -
                min_p_sampling_from_probs,
         | 
| 8 | 
            -
                top_k_renorm_prob,
         | 
| 9 | 
            -
                top_k_top_p_sampling_from_probs,
         | 
| 10 | 
            -
                top_p_renorm_prob,
         | 
| 11 | 
            -
            )
         | 
| 12 | 
            -
            from torch.library import custom_op as torch_custom_op
         | 
| 13 | 
            -
            from vllm.model_executor.custom_op import CustomOp
         | 
| 5 | 
            +
            from torch import nn
         | 
| 14 6 |  | 
| 15 7 | 
             
            from sglang.srt.layers.logits_processor import LogitsProcessorOutput
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            # TODO: move this dict to another place
         | 
| 18 8 | 
             
            from sglang.srt.managers.schedule_batch import global_server_args_dict
         | 
| 19 9 | 
             
            from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
         | 
| 10 | 
            +
            from sglang.srt.utils import is_hip
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # ROCm: flashinfer available later
         | 
| 13 | 
            +
            if not is_hip():
         | 
| 14 | 
            +
                from flashinfer.sampling import (
         | 
| 15 | 
            +
                    min_p_sampling_from_probs,
         | 
| 16 | 
            +
                    top_k_renorm_prob,
         | 
| 17 | 
            +
                    top_k_top_p_sampling_from_probs,
         | 
| 18 | 
            +
                    top_p_renorm_prob,
         | 
| 19 | 
            +
                )
         | 
| 20 20 |  | 
| 21 21 | 
             
            logger = logging.getLogger(__name__)
         | 
| 22 22 |  | 
| 23 23 |  | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
                success: torch.Tensor
         | 
| 27 | 
            -
                probs: torch.Tensor
         | 
| 28 | 
            -
                batch_next_token_ids: torch.Tensor
         | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
            class Sampler(CustomOp):
         | 
| 32 | 
            -
                def __init__(self):
         | 
| 33 | 
            -
                    super().__init__()
         | 
| 34 | 
            -
                    # FIXME: torch.multinomial has too many bugs
         | 
| 35 | 
            -
                    self.forward_native = self.forward_cuda
         | 
| 36 | 
            -
                    self.is_torch_compile = False
         | 
| 37 | 
            -
             | 
| 38 | 
            -
                def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
         | 
| 39 | 
            -
                    # Post process logits
         | 
| 40 | 
            -
                    logits = logits.contiguous()
         | 
| 41 | 
            -
                    logits.div_(sampling_info.temperatures)
         | 
| 42 | 
            -
                    if self.is_torch_compile:
         | 
| 43 | 
            -
                        # FIXME: Temporary workaround for unknown bugs in torch.compile
         | 
| 44 | 
            -
                        logits.add_(0)
         | 
| 45 | 
            -
             | 
| 46 | 
            -
                    return torch.softmax(logits, dim=-1)
         | 
| 47 | 
            -
             | 
| 48 | 
            -
                def forward_cuda(
         | 
| 24 | 
            +
            class Sampler(nn.Module):
         | 
| 25 | 
            +
                def forward(
         | 
| 49 26 | 
             
                    self,
         | 
| 50 27 | 
             
                    logits: Union[torch.Tensor, LogitsProcessorOutput],
         | 
| 51 28 | 
             
                    sampling_info: SamplingBatchInfo,
         | 
| @@ -53,7 +30,15 @@ class Sampler(CustomOp): | |
| 53 30 | 
             
                    if isinstance(logits, LogitsProcessorOutput):
         | 
| 54 31 | 
             
                        logits = logits.next_token_logits
         | 
| 55 32 |  | 
| 56 | 
            -
                     | 
| 33 | 
            +
                    # Post process logits
         | 
| 34 | 
            +
                    logits.div_(sampling_info.temperatures)
         | 
| 35 | 
            +
                    probs = logits[:] = torch.softmax(logits, dim=-1)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    if torch.any(torch.isnan(probs)):
         | 
| 38 | 
            +
                        logger.warning("Detected errors during sampling! NaN in the probability.")
         | 
| 39 | 
            +
                        probs = torch.where(
         | 
| 40 | 
            +
                            torch.isnan(probs), torch.full_like(probs, 1e-10), probs
         | 
| 41 | 
            +
                        )
         | 
| 57 42 |  | 
| 58 43 | 
             
                    if global_server_args_dict["sampling_backend"] == "flashinfer":
         | 
| 59 44 | 
             
                        max_top_k_round, batch_size = 32, probs.shape[0]
         | 
| @@ -67,12 +52,16 @@ class Sampler(CustomOp): | |
| 67 52 | 
             
                                probs, uniform_samples, sampling_info.min_ps
         | 
| 68 53 | 
             
                            )
         | 
| 69 54 | 
             
                        else:
         | 
| 70 | 
            -
                            batch_next_token_ids, success =  | 
| 55 | 
            +
                            batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
         | 
| 71 56 | 
             
                                probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
         | 
| 72 57 | 
             
                            )
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                        if not torch.all(success):
         | 
| 60 | 
            +
                            logger.warning("Detected errors during sampling!")
         | 
| 61 | 
            +
                            batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
         | 
| 73 62 | 
             
                    elif global_server_args_dict["sampling_backend"] == "pytorch":
         | 
| 74 63 | 
             
                        # Here we provide a slower fallback implementation.
         | 
| 75 | 
            -
                        batch_next_token_ids | 
| 64 | 
            +
                        batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
         | 
| 76 65 | 
             
                            probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
         | 
| 77 66 | 
             
                        )
         | 
| 78 67 | 
             
                    else:
         | 
| @@ -80,48 +69,7 @@ class Sampler(CustomOp): | |
| 80 69 | 
             
                            f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
         | 
| 81 70 | 
             
                        )
         | 
| 82 71 |  | 
| 83 | 
            -
                    return  | 
| 84 | 
            -
             | 
| 85 | 
            -
                def forward_native(
         | 
| 86 | 
            -
                    self,
         | 
| 87 | 
            -
                    logits: Union[torch.Tensor, LogitsProcessorOutput],
         | 
| 88 | 
            -
                    sampling_info: SamplingBatchInfo,
         | 
| 89 | 
            -
                ):
         | 
| 90 | 
            -
                    if isinstance(logits, LogitsProcessorOutput):
         | 
| 91 | 
            -
                        logits = logits.next_token_logits
         | 
| 92 | 
            -
             | 
| 93 | 
            -
                    probs = self._get_probs(logits, sampling_info)
         | 
| 94 | 
            -
             | 
| 95 | 
            -
                    batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
         | 
| 96 | 
            -
                        probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
         | 
| 97 | 
            -
                    )
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                    return SampleOutput(success, probs, batch_next_token_ids)
         | 
| 100 | 
            -
             | 
| 101 | 
            -
             | 
| 102 | 
            -
            @torch_custom_op("my_lib::flashinfer_top_k_top_p", mutates_args={})
         | 
| 103 | 
            -
            def flashinfer_top_k_top_p(
         | 
| 104 | 
            -
                probs: torch.Tensor,
         | 
| 105 | 
            -
                uniform_samples: torch.Tensor,
         | 
| 106 | 
            -
                top_ks: torch.Tensor,
         | 
| 107 | 
            -
                top_ps: torch.Tensor,
         | 
| 108 | 
            -
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 109 | 
            -
                # NOTE: we do not use min_p neither in CUDA nor in torch.compile
         | 
| 110 | 
            -
                return top_k_top_p_sampling_from_probs(probs, uniform_samples, top_ks, top_ps)
         | 
| 111 | 
            -
             | 
| 112 | 
            -
             | 
| 113 | 
            -
            @flashinfer_top_k_top_p.register_fake
         | 
| 114 | 
            -
            def _(
         | 
| 115 | 
            -
                probs: torch.Tensor,
         | 
| 116 | 
            -
                uniform_samples: torch.Tensor,
         | 
| 117 | 
            -
                top_ks: torch.Tensor,
         | 
| 118 | 
            -
                top_ps: torch.Tensor,
         | 
| 119 | 
            -
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 120 | 
            -
                bs = probs.shape[0]
         | 
| 121 | 
            -
                return (
         | 
| 122 | 
            -
                    torch.ones(bs, dtype=torch.bool, device=probs.device),
         | 
| 123 | 
            -
                    torch.zeros(bs, dtype=torch.int32, device=probs.device),
         | 
| 124 | 
            -
                )
         | 
| 72 | 
            +
                    return batch_next_token_ids
         | 
| 125 73 |  | 
| 126 74 |  | 
| 127 75 | 
             
            def top_k_top_p_min_p_sampling_from_probs_torch(
         | 
| @@ -141,19 +89,6 @@ def top_k_top_p_min_p_sampling_from_probs_torch( | |
| 141 89 | 
             
                ] = 0.0
         | 
| 142 90 | 
             
                probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
         | 
| 143 91 | 
             
                probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
         | 
| 144 | 
            -
                 | 
| 145 | 
            -
                    # FIXME: torch.multiomial does not support num_samples = 1
         | 
| 146 | 
            -
                    sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
         | 
| 147 | 
            -
                        :, :1
         | 
| 148 | 
            -
                    ]
         | 
| 149 | 
            -
                except RuntimeError as e:
         | 
| 150 | 
            -
                    logger.warning(f"Sampling error: {e}")
         | 
| 151 | 
            -
                    batch_next_token_ids = torch.zeros(
         | 
| 152 | 
            -
                        (probs_sort.shape[0],), dtype=torch.int32, device=probs.device
         | 
| 153 | 
            -
                    )
         | 
| 154 | 
            -
                    success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
         | 
| 155 | 
            -
                    return batch_next_token_ids, success
         | 
| 156 | 
            -
             | 
| 92 | 
            +
                sampled_index = torch.multinomial(probs_sort, num_samples=1)
         | 
| 157 93 | 
             
                batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
         | 
| 158 | 
            -
                 | 
| 159 | 
            -
                return batch_next_token_ids, success
         | 
| 94 | 
            +
                return batch_next_token_ids
         | 
    
        sglang/srt/lora/lora_manager.py
    CHANGED
    
    | @@ -21,12 +21,15 @@ import re | |
| 21 21 | 
             
            from dataclasses import dataclass
         | 
| 22 22 |  | 
| 23 23 | 
             
            import torch
         | 
| 24 | 
            -
            from flashinfer import SegmentGEMMWrapper
         | 
| 25 24 |  | 
| 26 25 | 
             
            from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
         | 
| 27 26 | 
             
            from sglang.srt.lora.lora_config import LoRAConfig
         | 
| 28 27 | 
             
            from sglang.srt.model_executor.forward_batch_info import ForwardMode
         | 
| 29 | 
            -
            from sglang.srt.utils import replace_submodule
         | 
| 28 | 
            +
            from sglang.srt.utils import is_hip, replace_submodule
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # ROCm: flashinfer available later
         | 
| 31 | 
            +
            if not is_hip():
         | 
| 32 | 
            +
                from flashinfer import SegmentGEMMWrapper
         | 
| 30 33 |  | 
| 31 34 |  | 
| 32 35 | 
             
            def get_stacked_name(name):
         | 
| @@ -96,10 +99,10 @@ class LoRAManager: | |
| 96 99 | 
             
                    # get configs and target modules
         | 
| 97 100 | 
             
                    self.configs = {}
         | 
| 98 101 | 
             
                    self.origin_target_modules = set()
         | 
| 99 | 
            -
                    for path in self.lora_paths:
         | 
| 100 | 
            -
                        self.configs[ | 
| 102 | 
            +
                    for name, path in self.lora_paths.items():
         | 
| 103 | 
            +
                        self.configs[name] = LoRAConfig(path)
         | 
| 101 104 | 
             
                        self.origin_target_modules = set(self.origin_target_modules) | set(
         | 
| 102 | 
            -
                            self.configs[ | 
| 105 | 
            +
                            self.configs[name].target_modules
         | 
| 103 106 | 
             
                        )
         | 
| 104 107 | 
             
                    self.target_modules = set(
         | 
| 105 108 | 
             
                        [
         | 
| @@ -114,11 +117,11 @@ class LoRAManager: | |
| 114 117 | 
             
                    # load all weights to cpu
         | 
| 115 118 | 
             
                    self.loras = []
         | 
| 116 119 | 
             
                    self.lora_id = {}
         | 
| 117 | 
            -
                    for  | 
| 118 | 
            -
                        self.lora_id[ | 
| 120 | 
            +
                    for name in self.lora_paths.keys():
         | 
| 121 | 
            +
                        self.lora_id[name] = len(self.loras)
         | 
| 119 122 | 
             
                        self.loras.append(
         | 
| 120 123 | 
             
                            LoRAAdapter(
         | 
| 121 | 
            -
                                 | 
| 124 | 
            +
                                name, self.configs[name], self.base_hf_config, self.load_config
         | 
| 122 125 | 
             
                            )
         | 
| 123 126 | 
             
                        )
         | 
| 124 127 | 
             
                        self.loras[-1].initialize_weights()
         | 
    
        sglang/srt/managers/tp_worker.py
    CHANGED
    
    | @@ -198,6 +198,7 @@ class ModelTpServer: | |
| 198 198 | 
             
                                "trust_remote_code": server_args.trust_remote_code,
         | 
| 199 199 | 
             
                            },
         | 
| 200 200 | 
             
                            skip_tokenizer_init=server_args.skip_tokenizer_init,
         | 
| 201 | 
            +
                            constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
         | 
| 201 202 | 
             
                        )
         | 
| 202 203 | 
             
                    self.jump_forward_cache = JumpForwardCache()
         | 
| 203 204 |  | 
| @@ -414,7 +415,7 @@ class ModelTpServer: | |
| 414 415 |  | 
| 415 416 | 
             
                    # Truncate prompts that are too long
         | 
| 416 417 | 
             
                    if len(req.origin_input_ids) >= self.max_req_input_len:
         | 
| 417 | 
            -
                        logger. | 
| 418 | 
            +
                        logger.warning(
         | 
| 418 419 | 
             
                            "Request length is longer than the KV cache pool size or "
         | 
| 419 420 | 
             
                            "the max context length. Truncated!!!"
         | 
| 420 421 | 
             
                        )
         | 
| @@ -807,12 +808,10 @@ class ModelTpServer: | |
| 807 808 | 
             
                            unfinished_indices.append(i)
         | 
| 808 809 |  | 
| 809 810 | 
             
                        if req.finished() or (
         | 
| 810 | 
            -
                             | 
| 811 | 
            -
             | 
| 812 | 
            -
                                 | 
| 813 | 
            -
             | 
| 814 | 
            -
                                    or len(req.output_ids) == 1
         | 
| 815 | 
            -
                                )
         | 
| 811 | 
            +
                            req.stream
         | 
| 812 | 
            +
                            and (
         | 
| 813 | 
            +
                                self.decode_forward_ct % self.stream_interval == 0
         | 
| 814 | 
            +
                                or len(req.output_ids) == 1
         | 
| 816 815 | 
             
                            )
         | 
| 817 816 | 
             
                        ):
         | 
| 818 817 | 
             
                            output_rids.append(req.rid)
         | 
| @@ -937,6 +936,8 @@ class ModelTpServer: | |
| 937 936 | 
             
                    if success:
         | 
| 938 937 | 
             
                        flash_cache_success = self.flush_cache()
         | 
| 939 938 | 
             
                        assert flash_cache_success, "Cache flush failed after updating weights"
         | 
| 939 | 
            +
                    else:
         | 
| 940 | 
            +
                        logger.error(message)
         | 
| 940 941 | 
             
                    return success, message
         | 
| 941 942 |  | 
| 942 943 |  | 
| @@ -41,6 +41,9 @@ if TYPE_CHECKING: | |
| 41 41 | 
             
            def _to_torch(model: torch.nn.Module, reverse: bool = False):
         | 
| 42 42 | 
             
                for sub in model._modules.values():
         | 
| 43 43 | 
             
                    if isinstance(sub, CustomOp):
         | 
| 44 | 
            +
                        # NOTE: FusedMoE torch native implementaiton is not efficient
         | 
| 45 | 
            +
                        if "FusedMoE" in sub.__class__.__name__:
         | 
| 46 | 
            +
                            continue
         | 
| 44 47 | 
             
                        if reverse:
         | 
| 45 48 | 
             
                            sub._forward_method = sub.forward_cuda
         | 
| 46 49 | 
             
                            setattr(sub, "is_torch_compile", False)
         | 
| @@ -105,7 +108,15 @@ class CudaGraphRunner: | |
| 105 108 | 
             
                        self.capture_bs = list(range(1, 32)) + [64, 128]
         | 
| 106 109 | 
             
                    else:
         | 
| 107 110 | 
             
                        self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
         | 
| 108 | 
            -
                    self.compile_bs =  | 
| 111 | 
            +
                    self.compile_bs = (
         | 
| 112 | 
            +
                        [
         | 
| 113 | 
            +
                            bs
         | 
| 114 | 
            +
                            for bs in self.capture_bs
         | 
| 115 | 
            +
                            if bs <= self.model_runner.server_args.max_torch_compile_bs
         | 
| 116 | 
            +
                        ]
         | 
| 117 | 
            +
                        if self.use_torch_compile
         | 
| 118 | 
            +
                        else []
         | 
| 119 | 
            +
                    )
         | 
| 109 120 |  | 
| 110 121 | 
             
                    # Common inputs
         | 
| 111 122 | 
             
                    self.max_bs = max(self.capture_bs)
         |