sglang 0.3.1__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +7 -2
- sglang/global_config.py +5 -13
- sglang/lang/interpreter.py +0 -3
- sglang/srt/constrained/fsm_cache.py +5 -1
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention_backend.py +12 -12
- sglang/srt/layers/fused_moe/layer.py +27 -7
- sglang/srt/layers/layernorm.py +12 -0
- sglang/srt/layers/sampler.py +32 -97
- sglang/srt/lora/lora_manager.py +11 -8
- sglang/srt/managers/schedule_batch.py +1 -0
- sglang/srt/managers/tp_worker.py +8 -7
- sglang/srt/model_executor/cuda_graph_runner.py +12 -1
- sglang/srt/model_executor/model_runner.py +24 -41
- sglang/srt/models/deepseek_v2.py +6 -1
- sglang/srt/models/minicpm3.py +5 -1
- sglang/srt/models/olmoe.py +415 -0
- sglang/srt/sampling/sampling_batch_info.py +3 -50
- sglang/srt/server.py +6 -1
- sglang/srt/server_args.py +34 -1
- sglang/srt/utils.py +7 -51
- sglang/test/test_utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.3.1.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +2 -2
- {sglang-0.3.1.dist-info → sglang-0.3.1.post1.dist-info}/RECORD +28 -27
- {sglang-0.3.1.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
- {sglang-0.3.1.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
CHANGED
@@ -63,7 +63,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
|
63
63
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
64
64
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
65
65
|
from sglang.srt.server_args import ServerArgs
|
66
|
-
from sglang.srt.utils import suppress_other_loggers
|
66
|
+
from sglang.srt.utils import kill_child_process, suppress_other_loggers
|
67
67
|
|
68
68
|
|
69
69
|
@dataclasses.dataclass
|
@@ -502,4 +502,9 @@ if __name__ == "__main__":
|
|
502
502
|
format="%(message)s",
|
503
503
|
)
|
504
504
|
|
505
|
-
|
505
|
+
try:
|
506
|
+
main(server_args, bench_args)
|
507
|
+
except Exception as e:
|
508
|
+
raise e
|
509
|
+
finally:
|
510
|
+
kill_child_process(os.getpid(), including_parent=False)
|
sglang/global_config.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
"""Global configurations"""
|
2
2
|
|
3
|
+
import os
|
4
|
+
|
3
5
|
|
4
6
|
class GlobalConfig:
|
5
7
|
def __init__(self):
|
@@ -16,30 +18,20 @@ class GlobalConfig:
|
|
16
18
|
self.base_min_new_token_ratio = 0.1
|
17
19
|
self.new_token_ratio_decay = 0.001
|
18
20
|
|
19
|
-
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
|
20
|
-
# This can improve the speed for large batch sizes during prefill.
|
21
|
-
self.layer_sync_threshold = 8192
|
22
|
-
|
23
21
|
# Runtime constants: others
|
24
22
|
self.num_continue_decode_steps = 10
|
25
23
|
self.retract_decode_steps = 20
|
26
|
-
self.flashinfer_workspace_size =
|
24
|
+
self.flashinfer_workspace_size = os.environ.get(
|
25
|
+
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
|
26
|
+
)
|
27
27
|
|
28
28
|
# Output tokenization configs
|
29
29
|
self.skip_special_tokens_in_output = True
|
30
30
|
self.spaces_between_special_tokens_in_out = True
|
31
31
|
|
32
32
|
# Interpreter optimization configs
|
33
|
-
self.eager_fill_image = False
|
34
33
|
self.enable_precache_with_tracing = True
|
35
34
|
self.enable_parallel_encoding = True
|
36
|
-
self.enable_parallel_decoding = True
|
37
|
-
|
38
|
-
# Deprecated
|
39
|
-
# Choices: ["no_adjust", "adjust_cache"]
|
40
|
-
# no_adjust: Do not adjust the position embedding of KV cache.
|
41
|
-
# adjust_cache: Adjust the position embedding of KV cache.
|
42
|
-
self.concate_and_append_mode = "no_adjust"
|
43
35
|
|
44
36
|
|
45
37
|
global_config = GlobalConfig()
|
sglang/lang/interpreter.py
CHANGED
@@ -434,9 +434,6 @@ class StreamExecutor:
|
|
434
434
|
self.cur_images.append((path, base64_data))
|
435
435
|
self.text_ += self.chat_template.image_token
|
436
436
|
|
437
|
-
# if global_config.eager_fill_image:
|
438
|
-
# self.backend.fill_image(self)
|
439
|
-
|
440
437
|
def _spec_gen(self, sampling_params):
|
441
438
|
stop = sampling_params.stop
|
442
439
|
max_new_tokens = sampling_params.max_new_tokens
|
@@ -29,6 +29,7 @@ class FSMCache(BaseToolCache):
|
|
29
29
|
tokenizer_args_dict,
|
30
30
|
enable=True,
|
31
31
|
skip_tokenizer_init=False,
|
32
|
+
constrained_json_whitespace_pattern=None,
|
32
33
|
):
|
33
34
|
super().__init__(enable=enable)
|
34
35
|
|
@@ -63,11 +64,14 @@ class FSMCache(BaseToolCache):
|
|
63
64
|
self.outlines_tokenizer.vocabulary = (
|
64
65
|
self.outlines_tokenizer.tokenizer.get_vocab()
|
65
66
|
)
|
67
|
+
self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern
|
66
68
|
|
67
69
|
def init_value(self, key):
|
68
70
|
key_type, key_string = key
|
69
71
|
if key_type == "json":
|
70
|
-
regex = build_regex_from_schema(
|
72
|
+
regex = build_regex_from_schema(
|
73
|
+
key_string, whitespace_pattern=self.constrained_json_whitespace_pattern
|
74
|
+
)
|
71
75
|
elif key_type == "regex":
|
72
76
|
regex = key_string
|
73
77
|
else:
|
sglang/srt/layers/activation.py
CHANGED
@@ -13,6 +13,7 @@ limitations under the License.
|
|
13
13
|
|
14
14
|
"""Fused operators for activation layers."""
|
15
15
|
|
16
|
+
import logging
|
16
17
|
from typing import Optional
|
17
18
|
|
18
19
|
import torch
|
@@ -28,6 +29,10 @@ from vllm.model_executor.custom_op import CustomOp
|
|
28
29
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
29
30
|
from vllm.model_executor.utils import set_weight_attrs
|
30
31
|
|
32
|
+
from sglang.srt.utils import is_hip
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
31
36
|
|
32
37
|
class SiluAndMul(CustomOp):
|
33
38
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -135,3 +140,10 @@ def get_act_fn(
|
|
135
140
|
act_fn, intermediate_size, input_is_parallel, params_dtype
|
136
141
|
)
|
137
142
|
return act_fn
|
143
|
+
|
144
|
+
|
145
|
+
if is_hip():
|
146
|
+
logger.info(
|
147
|
+
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
|
148
|
+
)
|
149
|
+
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
@@ -12,22 +12,26 @@ from typing import TYPE_CHECKING
|
|
12
12
|
|
13
13
|
import torch
|
14
14
|
import torch.nn as nn
|
15
|
-
from flashinfer import (
|
16
|
-
BatchDecodeWithPagedKVCacheWrapper,
|
17
|
-
BatchPrefillWithPagedKVCacheWrapper,
|
18
|
-
BatchPrefillWithRaggedKVCacheWrapper,
|
19
|
-
)
|
20
|
-
from flashinfer.cascade import merge_state
|
21
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
22
15
|
|
23
16
|
from sglang.global_config import global_config
|
24
17
|
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
25
18
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
26
19
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
20
|
+
from sglang.srt.utils import is_hip
|
27
21
|
|
28
22
|
if TYPE_CHECKING:
|
29
23
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
30
24
|
|
25
|
+
# ROCm: flashinfer available later
|
26
|
+
if not is_hip():
|
27
|
+
from flashinfer import (
|
28
|
+
BatchDecodeWithPagedKVCacheWrapper,
|
29
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
30
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
31
|
+
)
|
32
|
+
from flashinfer.cascade import merge_state
|
33
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
34
|
+
|
31
35
|
|
32
36
|
class AttentionBackend(ABC):
|
33
37
|
"""The base class of attention backends"""
|
@@ -150,7 +154,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
150
154
|
# Some heuristics to check whether to use ragged forward
|
151
155
|
use_ragged = False
|
152
156
|
if (
|
153
|
-
|
157
|
+
torch.sum(input_metadata.seq_lens).item() >= 4096
|
154
158
|
and self.model_runner.sliding_window_size is None
|
155
159
|
):
|
156
160
|
use_ragged = True
|
@@ -301,10 +305,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
301
305
|
layer.layer_id, input_metadata.out_cache_loc, k, v
|
302
306
|
)
|
303
307
|
|
304
|
-
if total_num_tokens >= global_config.layer_sync_threshold:
|
305
|
-
# TODO: Revisit this. Why is this synchronize needed?
|
306
|
-
torch.cuda.synchronize()
|
307
|
-
|
308
308
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
309
309
|
|
310
310
|
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|
18
18
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
19
19
|
from vllm.model_executor.utils import set_weight_attrs
|
20
20
|
|
21
|
+
from sglang.srt.utils import is_hip
|
22
|
+
|
21
23
|
logger = init_logger(__name__)
|
22
24
|
|
23
25
|
|
@@ -381,6 +383,7 @@ from torch.nn import Module
|
|
381
383
|
from vllm import _custom_ops as ops
|
382
384
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
383
385
|
all_close_1d,
|
386
|
+
normalize_e4m3fn_to_e4m3fnuz,
|
384
387
|
per_tensor_dequantize,
|
385
388
|
)
|
386
389
|
from vllm.utils import print_warning_once
|
@@ -479,14 +482,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
479
482
|
|
480
483
|
def process_weights_after_loading(self, layer: Module) -> None:
|
481
484
|
|
482
|
-
# If checkpoint is fp16, quantize in place.
|
485
|
+
# If checkpoint is fp16 or bfloat16, quantize in place.
|
483
486
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
484
|
-
|
485
|
-
|
486
|
-
)
|
487
|
-
w2_weight = torch.empty_like(
|
488
|
-
layer.w2_weight.data, dtype=torch.float8_e4m3fn
|
489
|
-
)
|
487
|
+
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
488
|
+
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
489
|
+
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
490
|
+
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
490
491
|
|
491
492
|
# Re-initialize w13_scale because we directly quantize
|
492
493
|
# merged w13 weights and generate a single scaling factor.
|
@@ -534,6 +535,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
534
535
|
layer.a2_scale.max(), requires_grad=False
|
535
536
|
)
|
536
537
|
|
538
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
539
|
+
if is_hip():
|
540
|
+
# Normalize the weights and scales
|
541
|
+
w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz(
|
542
|
+
layer.w13_weight, layer.w13_scale, layer.a13_scale
|
543
|
+
)
|
544
|
+
w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz(
|
545
|
+
layer.w2_weight, layer.w2_scale, layer.a2_scale
|
546
|
+
)
|
547
|
+
# Reset the parameters
|
548
|
+
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
549
|
+
layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
|
550
|
+
if a13_scale is not None:
|
551
|
+
layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False)
|
552
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
553
|
+
layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)
|
554
|
+
if a2_scale is not None:
|
555
|
+
layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False)
|
556
|
+
|
537
557
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
538
558
|
# We take the max then dequant and requant each expert.
|
539
559
|
assert layer.w13_scale is not None
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -15,6 +15,7 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""Fused operators for normalization layers."""
|
17
17
|
|
18
|
+
import logging
|
18
19
|
from typing import Optional, Tuple, Union
|
19
20
|
|
20
21
|
import torch
|
@@ -27,6 +28,10 @@ from flashinfer.norm import (
|
|
27
28
|
)
|
28
29
|
from vllm.model_executor.custom_op import CustomOp
|
29
30
|
|
31
|
+
from sglang.srt.utils import is_hip
|
32
|
+
|
33
|
+
logger = logging.getLogger(__name__)
|
34
|
+
|
30
35
|
|
31
36
|
class RMSNorm(CustomOp):
|
32
37
|
def __init__(
|
@@ -109,3 +114,10 @@ class GemmaRMSNorm(CustomOp):
|
|
109
114
|
return x, residual
|
110
115
|
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
111
116
|
return out
|
117
|
+
|
118
|
+
|
119
|
+
if is_hip():
|
120
|
+
logger.info(
|
121
|
+
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
|
122
|
+
)
|
123
|
+
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,51 +1,28 @@
|
|
1
|
-
import dataclasses
|
2
1
|
import logging
|
3
|
-
from typing import
|
2
|
+
from typing import Union
|
4
3
|
|
5
4
|
import torch
|
6
|
-
from
|
7
|
-
min_p_sampling_from_probs,
|
8
|
-
top_k_renorm_prob,
|
9
|
-
top_k_top_p_sampling_from_probs,
|
10
|
-
top_p_renorm_prob,
|
11
|
-
)
|
12
|
-
from torch.library import custom_op as torch_custom_op
|
13
|
-
from vllm.model_executor.custom_op import CustomOp
|
5
|
+
from torch import nn
|
14
6
|
|
15
7
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
16
|
-
|
17
|
-
# TODO: move this dict to another place
|
18
8
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
19
9
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
10
|
+
from sglang.srt.utils import is_hip
|
11
|
+
|
12
|
+
# ROCm: flashinfer available later
|
13
|
+
if not is_hip():
|
14
|
+
from flashinfer.sampling import (
|
15
|
+
min_p_sampling_from_probs,
|
16
|
+
top_k_renorm_prob,
|
17
|
+
top_k_top_p_sampling_from_probs,
|
18
|
+
top_p_renorm_prob,
|
19
|
+
)
|
20
20
|
|
21
21
|
logger = logging.getLogger(__name__)
|
22
22
|
|
23
23
|
|
24
|
-
|
25
|
-
|
26
|
-
success: torch.Tensor
|
27
|
-
probs: torch.Tensor
|
28
|
-
batch_next_token_ids: torch.Tensor
|
29
|
-
|
30
|
-
|
31
|
-
class Sampler(CustomOp):
|
32
|
-
def __init__(self):
|
33
|
-
super().__init__()
|
34
|
-
# FIXME: torch.multinomial has too many bugs
|
35
|
-
self.forward_native = self.forward_cuda
|
36
|
-
self.is_torch_compile = False
|
37
|
-
|
38
|
-
def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
39
|
-
# Post process logits
|
40
|
-
logits = logits.contiguous()
|
41
|
-
logits.div_(sampling_info.temperatures)
|
42
|
-
if self.is_torch_compile:
|
43
|
-
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
44
|
-
logits.add_(0)
|
45
|
-
|
46
|
-
return torch.softmax(logits, dim=-1)
|
47
|
-
|
48
|
-
def forward_cuda(
|
24
|
+
class Sampler(nn.Module):
|
25
|
+
def forward(
|
49
26
|
self,
|
50
27
|
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
51
28
|
sampling_info: SamplingBatchInfo,
|
@@ -53,7 +30,15 @@ class Sampler(CustomOp):
|
|
53
30
|
if isinstance(logits, LogitsProcessorOutput):
|
54
31
|
logits = logits.next_token_logits
|
55
32
|
|
56
|
-
|
33
|
+
# Post process logits
|
34
|
+
logits.div_(sampling_info.temperatures)
|
35
|
+
probs = logits[:] = torch.softmax(logits, dim=-1)
|
36
|
+
|
37
|
+
if torch.any(torch.isnan(probs)):
|
38
|
+
logger.warning("Detected errors during sampling! NaN in the probability.")
|
39
|
+
probs = torch.where(
|
40
|
+
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
|
41
|
+
)
|
57
42
|
|
58
43
|
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
59
44
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
@@ -67,12 +52,16 @@ class Sampler(CustomOp):
|
|
67
52
|
probs, uniform_samples, sampling_info.min_ps
|
68
53
|
)
|
69
54
|
else:
|
70
|
-
batch_next_token_ids, success =
|
55
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
71
56
|
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
|
72
57
|
)
|
58
|
+
|
59
|
+
if not torch.all(success):
|
60
|
+
logger.warning("Detected errors during sampling!")
|
61
|
+
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
73
62
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
74
63
|
# Here we provide a slower fallback implementation.
|
75
|
-
batch_next_token_ids
|
64
|
+
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
76
65
|
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
77
66
|
)
|
78
67
|
else:
|
@@ -80,48 +69,7 @@ class Sampler(CustomOp):
|
|
80
69
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
81
70
|
)
|
82
71
|
|
83
|
-
return
|
84
|
-
|
85
|
-
def forward_native(
|
86
|
-
self,
|
87
|
-
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
88
|
-
sampling_info: SamplingBatchInfo,
|
89
|
-
):
|
90
|
-
if isinstance(logits, LogitsProcessorOutput):
|
91
|
-
logits = logits.next_token_logits
|
92
|
-
|
93
|
-
probs = self._get_probs(logits, sampling_info)
|
94
|
-
|
95
|
-
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
96
|
-
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
97
|
-
)
|
98
|
-
|
99
|
-
return SampleOutput(success, probs, batch_next_token_ids)
|
100
|
-
|
101
|
-
|
102
|
-
@torch_custom_op("my_lib::flashinfer_top_k_top_p", mutates_args={})
|
103
|
-
def flashinfer_top_k_top_p(
|
104
|
-
probs: torch.Tensor,
|
105
|
-
uniform_samples: torch.Tensor,
|
106
|
-
top_ks: torch.Tensor,
|
107
|
-
top_ps: torch.Tensor,
|
108
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
109
|
-
# NOTE: we do not use min_p neither in CUDA nor in torch.compile
|
110
|
-
return top_k_top_p_sampling_from_probs(probs, uniform_samples, top_ks, top_ps)
|
111
|
-
|
112
|
-
|
113
|
-
@flashinfer_top_k_top_p.register_fake
|
114
|
-
def _(
|
115
|
-
probs: torch.Tensor,
|
116
|
-
uniform_samples: torch.Tensor,
|
117
|
-
top_ks: torch.Tensor,
|
118
|
-
top_ps: torch.Tensor,
|
119
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
120
|
-
bs = probs.shape[0]
|
121
|
-
return (
|
122
|
-
torch.ones(bs, dtype=torch.bool, device=probs.device),
|
123
|
-
torch.zeros(bs, dtype=torch.int32, device=probs.device),
|
124
|
-
)
|
72
|
+
return batch_next_token_ids
|
125
73
|
|
126
74
|
|
127
75
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
@@ -141,19 +89,6 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
141
89
|
] = 0.0
|
142
90
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
143
91
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
144
|
-
|
145
|
-
# FIXME: torch.multiomial does not support num_samples = 1
|
146
|
-
sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
|
147
|
-
:, :1
|
148
|
-
]
|
149
|
-
except RuntimeError as e:
|
150
|
-
logger.warning(f"Sampling error: {e}")
|
151
|
-
batch_next_token_ids = torch.zeros(
|
152
|
-
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
153
|
-
)
|
154
|
-
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
155
|
-
return batch_next_token_ids, success
|
156
|
-
|
92
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
157
93
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
158
|
-
|
159
|
-
return batch_next_token_ids, success
|
94
|
+
return batch_next_token_ids
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -21,12 +21,15 @@ import re
|
|
21
21
|
from dataclasses import dataclass
|
22
22
|
|
23
23
|
import torch
|
24
|
-
from flashinfer import SegmentGEMMWrapper
|
25
24
|
|
26
25
|
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
27
26
|
from sglang.srt.lora.lora_config import LoRAConfig
|
28
27
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
29
|
-
from sglang.srt.utils import replace_submodule
|
28
|
+
from sglang.srt.utils import is_hip, replace_submodule
|
29
|
+
|
30
|
+
# ROCm: flashinfer available later
|
31
|
+
if not is_hip():
|
32
|
+
from flashinfer import SegmentGEMMWrapper
|
30
33
|
|
31
34
|
|
32
35
|
def get_stacked_name(name):
|
@@ -96,10 +99,10 @@ class LoRAManager:
|
|
96
99
|
# get configs and target modules
|
97
100
|
self.configs = {}
|
98
101
|
self.origin_target_modules = set()
|
99
|
-
for path in self.lora_paths:
|
100
|
-
self.configs[
|
102
|
+
for name, path in self.lora_paths.items():
|
103
|
+
self.configs[name] = LoRAConfig(path)
|
101
104
|
self.origin_target_modules = set(self.origin_target_modules) | set(
|
102
|
-
self.configs[
|
105
|
+
self.configs[name].target_modules
|
103
106
|
)
|
104
107
|
self.target_modules = set(
|
105
108
|
[
|
@@ -114,11 +117,11 @@ class LoRAManager:
|
|
114
117
|
# load all weights to cpu
|
115
118
|
self.loras = []
|
116
119
|
self.lora_id = {}
|
117
|
-
for
|
118
|
-
self.lora_id[
|
120
|
+
for name in self.lora_paths.keys():
|
121
|
+
self.lora_id[name] = len(self.loras)
|
119
122
|
self.loras.append(
|
120
123
|
LoRAAdapter(
|
121
|
-
|
124
|
+
name, self.configs[name], self.base_hf_config, self.load_config
|
122
125
|
)
|
123
126
|
)
|
124
127
|
self.loras[-1].initialize_weights()
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -198,6 +198,7 @@ class ModelTpServer:
|
|
198
198
|
"trust_remote_code": server_args.trust_remote_code,
|
199
199
|
},
|
200
200
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
201
|
+
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
201
202
|
)
|
202
203
|
self.jump_forward_cache = JumpForwardCache()
|
203
204
|
|
@@ -414,7 +415,7 @@ class ModelTpServer:
|
|
414
415
|
|
415
416
|
# Truncate prompts that are too long
|
416
417
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
417
|
-
logger.
|
418
|
+
logger.warning(
|
418
419
|
"Request length is longer than the KV cache pool size or "
|
419
420
|
"the max context length. Truncated!!!"
|
420
421
|
)
|
@@ -807,12 +808,10 @@ class ModelTpServer:
|
|
807
808
|
unfinished_indices.append(i)
|
808
809
|
|
809
810
|
if req.finished() or (
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
or len(req.output_ids) == 1
|
815
|
-
)
|
811
|
+
req.stream
|
812
|
+
and (
|
813
|
+
self.decode_forward_ct % self.stream_interval == 0
|
814
|
+
or len(req.output_ids) == 1
|
816
815
|
)
|
817
816
|
):
|
818
817
|
output_rids.append(req.rid)
|
@@ -937,6 +936,8 @@ class ModelTpServer:
|
|
937
936
|
if success:
|
938
937
|
flash_cache_success = self.flush_cache()
|
939
938
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
939
|
+
else:
|
940
|
+
logger.error(message)
|
940
941
|
return success, message
|
941
942
|
|
942
943
|
|
@@ -41,6 +41,9 @@ if TYPE_CHECKING:
|
|
41
41
|
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
42
42
|
for sub in model._modules.values():
|
43
43
|
if isinstance(sub, CustomOp):
|
44
|
+
# NOTE: FusedMoE torch native implementaiton is not efficient
|
45
|
+
if "FusedMoE" in sub.__class__.__name__:
|
46
|
+
continue
|
44
47
|
if reverse:
|
45
48
|
sub._forward_method = sub.forward_cuda
|
46
49
|
setattr(sub, "is_torch_compile", False)
|
@@ -105,7 +108,15 @@ class CudaGraphRunner:
|
|
105
108
|
self.capture_bs = list(range(1, 32)) + [64, 128]
|
106
109
|
else:
|
107
110
|
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
108
|
-
self.compile_bs =
|
111
|
+
self.compile_bs = (
|
112
|
+
[
|
113
|
+
bs
|
114
|
+
for bs in self.capture_bs
|
115
|
+
if bs <= self.model_runner.server_args.max_torch_compile_bs
|
116
|
+
]
|
117
|
+
if self.use_torch_compile
|
118
|
+
else []
|
119
|
+
)
|
109
120
|
|
110
121
|
# Common inputs
|
111
122
|
self.max_bs = max(self.capture_bs)
|