sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +16 -6
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +27 -12
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +76 -102
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +26 -17
- sglang/srt/layers/quantization/__init__.py +22 -23
- sglang/srt/layers/quantization/fp8.py +112 -55
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +2 -3
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +17 -4
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +46 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -8
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +54 -15
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +319 -181
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +303 -158
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +110 -77
- sglang/srt/metrics/collector.py +25 -11
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +80 -21
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +41 -4
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +52 -4
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +153 -9
- sglang/srt/sampling/sampling_params.py +4 -2
- sglang/srt/server.py +4 -1037
- sglang/srt/server_args.py +84 -32
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +130 -63
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
sglang/srt/models/stablelm.py
CHANGED
@@ -24,9 +24,8 @@ from typing import Iterable, Optional, Tuple
|
|
24
24
|
import torch
|
25
25
|
from torch import nn
|
26
26
|
from transformers import PretrainedConfig
|
27
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
28
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
29
27
|
|
28
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
30
29
|
from sglang.srt.layers.activation import SiluAndMul
|
31
30
|
from sglang.srt.layers.linear import (
|
32
31
|
MergedColumnParallelLinear,
|
@@ -36,6 +35,7 @@ from sglang.srt.layers.linear import (
|
|
36
35
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
39
39
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
40
40
|
ParallelLMHead,
|
41
41
|
VocabParallelEmbedding,
|
@@ -47,17 +47,17 @@ import torch
|
|
47
47
|
from torch import nn
|
48
48
|
from torch.nn.parameter import Parameter
|
49
49
|
from transformers import LlamaConfig
|
50
|
-
|
50
|
+
|
51
|
+
from sglang.srt.distributed import (
|
51
52
|
get_tensor_model_parallel_rank,
|
52
53
|
get_tensor_model_parallel_world_size,
|
53
54
|
)
|
54
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
55
|
-
|
56
55
|
from sglang.srt.layers.activation import SiluAndMul
|
57
56
|
from sglang.srt.layers.layernorm import RMSNorm
|
58
57
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
59
58
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
60
59
|
from sglang.srt.layers.radix_attention import RadixAttention
|
60
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
61
61
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
62
62
|
ParallelLMHead,
|
63
63
|
VocabParallelEmbedding,
|
sglang/srt/models/xverse.py
CHANGED
@@ -21,19 +21,19 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import LlamaConfig
|
24
|
-
|
25
|
-
from
|
26
|
-
from
|
27
|
-
from
|
24
|
+
|
25
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
26
|
+
from sglang.srt.layers.activation import SiluAndMul
|
27
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
28
|
+
from sglang.srt.layers.linear import (
|
28
29
|
MergedColumnParallelLinear,
|
29
30
|
QKVParallelLinear,
|
30
31
|
RowParallelLinear,
|
31
32
|
)
|
32
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
33
|
-
|
34
33
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
35
34
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
35
|
from sglang.srt.layers.radix_attention import RadixAttention
|
36
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
37
37
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
38
38
|
ParallelLMHead,
|
39
39
|
VocabParallelEmbedding,
|
sglang/srt/models/xverse_moe.py
CHANGED
@@ -18,25 +18,25 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
18
18
|
import torch
|
19
19
|
from torch import nn
|
20
20
|
from transformers import PretrainedConfig
|
21
|
-
|
21
|
+
|
22
|
+
from sglang.srt.distributed import (
|
22
23
|
get_tensor_model_parallel_rank,
|
23
24
|
get_tensor_model_parallel_world_size,
|
24
25
|
tensor_model_parallel_all_reduce,
|
25
26
|
)
|
26
|
-
from
|
27
|
-
from
|
28
|
-
from
|
27
|
+
from sglang.srt.layers.activation import SiluAndMul
|
28
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
29
|
+
from sglang.srt.layers.linear import (
|
29
30
|
MergedColumnParallelLinear,
|
30
31
|
QKVParallelLinear,
|
31
32
|
ReplicatedLinear,
|
32
33
|
RowParallelLinear,
|
33
34
|
)
|
34
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
35
|
-
|
36
35
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
36
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
38
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
40
40
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
41
41
|
ParallelLMHead,
|
42
42
|
VocabParallelEmbedding,
|
@@ -180,6 +180,7 @@ class CompletionRequest(BaseModel):
|
|
180
180
|
ignore_eos: bool = False
|
181
181
|
skip_special_tokens: bool = True
|
182
182
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
183
|
+
session_params: Optional[Dict] = None
|
183
184
|
|
184
185
|
|
185
186
|
class CompletionResponseChoice(BaseModel):
|
@@ -322,6 +323,7 @@ class ChatCompletionRequest(BaseModel):
|
|
322
323
|
ignore_eos: bool = False
|
323
324
|
skip_special_tokens: bool = True
|
324
325
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
326
|
+
session_params: Optional[Dict] = None
|
325
327
|
|
326
328
|
|
327
329
|
class FunctionResponse(BaseModel):
|
@@ -0,0 +1,38 @@
|
|
1
|
+
import json
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from functools import lru_cache
|
4
|
+
from typing import Any, Dict, List, Optional
|
5
|
+
|
6
|
+
import dill
|
7
|
+
import torch
|
8
|
+
|
9
|
+
|
10
|
+
@lru_cache(maxsize=None)
|
11
|
+
def _cache_from_str(json_str: str):
|
12
|
+
"""Deserialize a json string to a Callable object.
|
13
|
+
This function is cached to avoid redundant deserialization.
|
14
|
+
"""
|
15
|
+
data = json.loads(json_str)
|
16
|
+
return dill.loads(bytes.fromhex(data["callable"]))
|
17
|
+
|
18
|
+
|
19
|
+
class CustomLogitProcessor(ABC):
|
20
|
+
"""Abstract base class for callable functions."""
|
21
|
+
|
22
|
+
@abstractmethod
|
23
|
+
def __call__(
|
24
|
+
self,
|
25
|
+
logits: torch.Tensor,
|
26
|
+
custom_param_list: Optional[List[Dict[str, Any]]] = None,
|
27
|
+
) -> torch.Tensor:
|
28
|
+
"""Define the callable behavior."""
|
29
|
+
raise NotImplementedError
|
30
|
+
|
31
|
+
def to_str(self) -> str:
|
32
|
+
"""Serialize the callable function to a JSON-compatible string."""
|
33
|
+
return json.dumps({"callable": dill.dumps(self).hex()})
|
34
|
+
|
35
|
+
@classmethod
|
36
|
+
def from_str(cls, json_str: str):
|
37
|
+
"""Deserialize a callable function from a JSON string."""
|
38
|
+
return _cache_from_str(json_str)
|
@@ -3,6 +3,11 @@ from typing import List
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
+
from sglang.srt.utils import is_cuda_available
|
7
|
+
|
8
|
+
is_cuda = is_cuda_available()
|
9
|
+
if is_cuda:
|
10
|
+
from sgl_kernel import sampling_scaling_penalties
|
6
11
|
|
7
12
|
|
8
13
|
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
@@ -56,11 +61,16 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|
56
61
|
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
57
62
|
|
58
63
|
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
+
if is_cuda:
|
65
|
+
return sampling_scaling_penalties(
|
66
|
+
logits, self.cumulated_repetition_penalties
|
67
|
+
)
|
68
|
+
else:
|
69
|
+
return torch.where(
|
70
|
+
logits > 0,
|
71
|
+
logits / self.cumulated_repetition_penalties,
|
72
|
+
logits * self.cumulated_repetition_penalties,
|
73
|
+
)
|
64
74
|
|
65
75
|
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
66
76
|
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
@@ -3,11 +3,18 @@ from __future__ import annotations
|
|
3
3
|
import dataclasses
|
4
4
|
import logging
|
5
5
|
import threading
|
6
|
-
from typing import TYPE_CHECKING, Callable, List, Optional
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
10
|
+
from sglang.srt.utils import is_cuda_available
|
11
|
+
|
12
|
+
is_cuda = is_cuda_available()
|
13
|
+
if is_cuda:
|
14
|
+
from sgl_kernel import sampling_scaling_penalties
|
15
|
+
|
10
16
|
import sglang.srt.sampling.penaltylib as penaltylib
|
17
|
+
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
11
18
|
|
12
19
|
logger = logging.getLogger(__name__)
|
13
20
|
|
@@ -30,6 +37,9 @@ class SamplingBatchInfo:
|
|
30
37
|
# Dispatch in CUDA graph
|
31
38
|
need_min_p_sampling: bool
|
32
39
|
|
40
|
+
# Whether any request has custom logit processor
|
41
|
+
has_custom_logit_processor: bool
|
42
|
+
|
33
43
|
# Bias Tensors
|
34
44
|
vocab_size: int
|
35
45
|
grammars: Optional[List] = None
|
@@ -46,6 +56,14 @@ class SamplingBatchInfo:
|
|
46
56
|
# Device
|
47
57
|
device: str = "cuda"
|
48
58
|
|
59
|
+
# Custom Parameters
|
60
|
+
custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
|
61
|
+
|
62
|
+
# Custom Logit Processor
|
63
|
+
custom_logit_processor: Optional[
|
64
|
+
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
|
65
|
+
] = None
|
66
|
+
|
49
67
|
@classmethod
|
50
68
|
def from_schedule_batch(
|
51
69
|
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
|
@@ -70,6 +88,39 @@ class SamplingBatchInfo:
|
|
70
88
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
71
89
|
).to(device, non_blocking=True)
|
72
90
|
|
91
|
+
# Check if any request has custom logit processor
|
92
|
+
has_custom_logit_processor = (
|
93
|
+
batch.enable_custom_logit_processor # check the flag first.
|
94
|
+
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
95
|
+
)
|
96
|
+
|
97
|
+
if has_custom_logit_processor:
|
98
|
+
# Merge the same type of custom logit processors together
|
99
|
+
processor_dict = {}
|
100
|
+
for i, r in enumerate(reqs):
|
101
|
+
if r.custom_logit_processor is None:
|
102
|
+
continue
|
103
|
+
processor_str = r.custom_logit_processor
|
104
|
+
if processor_str not in processor_dict:
|
105
|
+
processor_dict[processor_str] = []
|
106
|
+
processor_dict[processor_str].append(i)
|
107
|
+
|
108
|
+
merged_custom_logit_processor = {
|
109
|
+
hash(processor_str): (
|
110
|
+
# The deserialized custom logit processor object
|
111
|
+
CustomLogitProcessor.from_str(processor_str),
|
112
|
+
# The mask tensor for the requests that use this custom logit processor
|
113
|
+
torch.zeros(len(reqs), dtype=torch.bool)
|
114
|
+
.scatter_(0, torch.tensor(true_indices), True)
|
115
|
+
.to(device, non_blocking=True),
|
116
|
+
)
|
117
|
+
for processor_str, true_indices in processor_dict.items()
|
118
|
+
}
|
119
|
+
custom_params = [r.sampling_params.custom_params for r in reqs]
|
120
|
+
else:
|
121
|
+
merged_custom_logit_processor = None
|
122
|
+
custom_params = None
|
123
|
+
|
73
124
|
ret = cls(
|
74
125
|
temperatures=temperatures,
|
75
126
|
top_ps=top_ps,
|
@@ -77,8 +128,11 @@ class SamplingBatchInfo:
|
|
77
128
|
min_ps=min_ps,
|
78
129
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
79
130
|
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
131
|
+
has_custom_logit_processor=has_custom_logit_processor,
|
80
132
|
vocab_size=vocab_size,
|
81
133
|
device=device,
|
134
|
+
custom_params=custom_params,
|
135
|
+
custom_logit_processor=merged_custom_logit_processor,
|
82
136
|
)
|
83
137
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
84
138
|
|
@@ -178,6 +232,8 @@ class SamplingBatchInfo:
|
|
178
232
|
|
179
233
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
180
234
|
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
235
|
+
if self.has_custom_logit_processor:
|
236
|
+
self._filter_batch_custom_logit_processor(unfinished_indices, new_indices)
|
181
237
|
|
182
238
|
for item in [
|
183
239
|
"temperatures",
|
@@ -190,6 +246,27 @@ class SamplingBatchInfo:
|
|
190
246
|
if value is not None: # logit_bias can be None
|
191
247
|
setattr(self, item, value[new_indices])
|
192
248
|
|
249
|
+
def _filter_batch_custom_logit_processor(
|
250
|
+
self, unfinished_indices: List[int], new_indices: torch.Tensor
|
251
|
+
):
|
252
|
+
"""Filter the custom logit processor and custom params"""
|
253
|
+
|
254
|
+
self.custom_logit_processor = {
|
255
|
+
k: (p, mask[new_indices])
|
256
|
+
for k, (p, mask) in self.custom_logit_processor.items()
|
257
|
+
if any(
|
258
|
+
mask[new_indices]
|
259
|
+
) # ignore the custom logit processor whose mask is all False
|
260
|
+
}
|
261
|
+
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
|
262
|
+
|
263
|
+
# If the custom logit processor is an empty dict, set the flag to False,
|
264
|
+
# and set the custom logit processor and custom params to None.
|
265
|
+
if len(self.custom_logit_processor) == 0:
|
266
|
+
self.custom_logit_processor = None
|
267
|
+
self.custom_params = None
|
268
|
+
self.has_custom_logit_processor = False
|
269
|
+
|
193
270
|
@staticmethod
|
194
271
|
def merge_bias_tensor(
|
195
272
|
lhs: torch.Tensor,
|
@@ -215,9 +292,76 @@ class SamplingBatchInfo:
|
|
215
292
|
|
216
293
|
return None
|
217
294
|
|
295
|
+
@staticmethod
|
296
|
+
def merge_custom_logit_processor(
|
297
|
+
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
298
|
+
rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
299
|
+
bs1: int,
|
300
|
+
bs2: int,
|
301
|
+
device: str,
|
302
|
+
):
|
303
|
+
if lhs is None and rhs is None:
|
304
|
+
return None
|
305
|
+
lhs, rhs = lhs or {}, rhs or {}
|
306
|
+
|
307
|
+
keys = set(lhs.keys()).union(set(rhs.keys()))
|
308
|
+
merged_dict = {}
|
309
|
+
|
310
|
+
for k in keys:
|
311
|
+
# Get the logit processor object
|
312
|
+
processor = lhs[k][0] if k in lhs else rhs[k][0]
|
313
|
+
# Get and merge the mask tensors from the two dicts
|
314
|
+
left_mask = (
|
315
|
+
lhs[k][1]
|
316
|
+
if k in lhs
|
317
|
+
else torch.zeros(bs1, dtype=torch.bool, device=device)
|
318
|
+
)
|
319
|
+
right_mask = (
|
320
|
+
rhs[k][1]
|
321
|
+
if k in rhs
|
322
|
+
else torch.zeros(bs2, dtype=torch.bool, device=device)
|
323
|
+
)
|
324
|
+
merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
|
325
|
+
|
326
|
+
assert merged_dict[k][1].shape[0] == bs1 + bs2, (
|
327
|
+
f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
|
328
|
+
f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
|
329
|
+
f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
|
330
|
+
f"\n{lhs=}\n{rhs=}"
|
331
|
+
)
|
332
|
+
|
333
|
+
return merged_dict
|
334
|
+
|
218
335
|
def merge_batch(self, other: "SamplingBatchInfo"):
|
219
336
|
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
220
337
|
|
338
|
+
# Merge the logit bias tensor
|
339
|
+
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
340
|
+
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
341
|
+
)
|
342
|
+
# Merge the custom logit processors and custom params lists
|
343
|
+
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
344
|
+
# Merge the custom logit processors
|
345
|
+
self.custom_logit_processor = (
|
346
|
+
SamplingBatchInfo.merge_custom_logit_processor(
|
347
|
+
self.custom_logit_processor,
|
348
|
+
other.custom_logit_processor,
|
349
|
+
len(self),
|
350
|
+
len(other),
|
351
|
+
self.device,
|
352
|
+
)
|
353
|
+
)
|
354
|
+
# Merge the custom params lists
|
355
|
+
self.custom_params = self.custom_params or [None] * len(self)
|
356
|
+
other.custom_params = other.custom_params or [None] * len(other)
|
357
|
+
self.custom_params.extend(other.custom_params)
|
358
|
+
|
359
|
+
# Set the flag to True if any of the two has custom logit processor
|
360
|
+
self.has_custom_logit_processor = True
|
361
|
+
|
362
|
+
# Note: becasue the __len()__ operator is defined on the temperatures tensor,
|
363
|
+
# please make sure any merge operation with len(self) or len(other) is done before
|
364
|
+
# the merge operation of the temperatures tensor below.
|
221
365
|
for item in [
|
222
366
|
"temperatures",
|
223
367
|
"top_ps",
|
@@ -229,9 +373,6 @@ class SamplingBatchInfo:
|
|
229
373
|
setattr(self, item, torch.concat([self_val, other_val]))
|
230
374
|
|
231
375
|
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
232
|
-
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
233
|
-
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
234
|
-
)
|
235
376
|
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
236
377
|
|
237
378
|
def apply_logits_bias(self, logits: torch.Tensor):
|
@@ -245,11 +386,14 @@ class SamplingBatchInfo:
|
|
245
386
|
|
246
387
|
# repetition
|
247
388
|
if self.scaling_penalties is not None:
|
248
|
-
|
249
|
-
logits
|
250
|
-
|
251
|
-
logits
|
252
|
-
|
389
|
+
if is_cuda:
|
390
|
+
logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
|
391
|
+
else:
|
392
|
+
logits[:] = torch.where(
|
393
|
+
logits > 0,
|
394
|
+
logits / self.scaling_penalties,
|
395
|
+
logits * self.scaling_penalties,
|
396
|
+
)
|
253
397
|
|
254
398
|
# Apply regex vocab_mask
|
255
399
|
if self.vocab_mask is not None:
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Sampling parameters for text generation."""
|
15
15
|
|
16
|
-
from typing import List, Optional, Union
|
16
|
+
from typing import Any, Dict, List, Optional, Union
|
17
17
|
|
18
18
|
_SAMPLING_EPS = 1e-6
|
19
19
|
|
@@ -23,7 +23,7 @@ class SamplingParams:
|
|
23
23
|
The sampling parameters.
|
24
24
|
|
25
25
|
See docs/references/sampling_params.md or
|
26
|
-
https://
|
26
|
+
https://docs.sglang.ai/references/sampling_params.html
|
27
27
|
for the documentation.
|
28
28
|
"""
|
29
29
|
|
@@ -48,6 +48,7 @@ class SamplingParams:
|
|
48
48
|
no_stop_trim: bool = False,
|
49
49
|
ignore_eos: bool = False,
|
50
50
|
skip_special_tokens: bool = True,
|
51
|
+
custom_params: Optional[Dict[str, Any]] = None,
|
51
52
|
) -> None:
|
52
53
|
self.temperature = temperature
|
53
54
|
self.top_p = top_p
|
@@ -71,6 +72,7 @@ class SamplingParams:
|
|
71
72
|
self.json_schema = json_schema
|
72
73
|
self.ebnf = ebnf
|
73
74
|
self.no_stop_trim = no_stop_trim
|
75
|
+
self.custom_params = custom_params
|
74
76
|
|
75
77
|
# Process some special cases
|
76
78
|
if self.temperature < _SAMPLING_EPS:
|