sglang 0.3.5.post2__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +474 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +71 -1
- sglang/check_env.py +3 -6
- sglang/srt/constrained/outlines_backend.py +15 -2
- sglang/srt/constrained/xgrammar_backend.py +22 -14
- sglang/srt/layers/activation.py +3 -0
- sglang/srt/layers/attention/flashinfer_backend.py +93 -48
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/custom_op_util.py +26 -0
- sglang/srt/layers/fused_moe/fused_moe.py +11 -4
- sglang/srt/layers/layernorm.py +4 -0
- sglang/srt/layers/logits_processor.py +10 -10
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +74 -9
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/schedule_batch.py +104 -38
- sglang/srt/managers/schedule_policy.py +5 -1
- sglang/srt/managers/scheduler.py +204 -54
- sglang/srt/managers/session_controller.py +62 -0
- sglang/srt/managers/tokenizer_manager.py +38 -0
- sglang/srt/managers/tp_worker.py +12 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
- sglang/srt/model_executor/cuda_graph_runner.py +43 -6
- sglang/srt/model_executor/forward_batch_info.py +109 -15
- sglang/srt/model_executor/model_runner.py +99 -43
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/deepseek_v2.py +147 -44
- sglang/srt/models/gemma2.py +9 -8
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/torch_native_llama.py +94 -78
- sglang/srt/openai_api/adapter.py +6 -2
- sglang/srt/openai_api/protocol.py +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +58 -57
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +27 -1
- sglang/srt/server_args.py +78 -62
- sglang/srt/utils.py +71 -52
- sglang/test/runners.py +25 -6
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +30 -19
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/RECORD +60 -55
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,31 @@ limitations under the License.
|
|
17
17
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
18
18
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
19
19
|
|
20
|
+
# PyTorch Tensor Parallel Available for This Model
|
21
|
+
"""
|
22
|
+
This model supports tensor parallelism (TP) using the PyTorch tensor parallel package.
|
23
|
+
Reference: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
|
24
|
+
|
25
|
+
Here is a quick example to enable TP:
|
26
|
+
```python
|
27
|
+
from sglang.srt.model_parallel import tensor_parallel
|
28
|
+
|
29
|
+
device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
|
30
|
+
tensor_parallel(model, device_mesh)
|
31
|
+
```
|
32
|
+
|
33
|
+
An end-to-end example can be found in `python/sglang/bench_one_batch.py`.
|
34
|
+
You can run it with the following command:
|
35
|
+
```bash
|
36
|
+
$ python3 -m sglang.bench_one_batch --correct \
|
37
|
+
--model meta-llama/Meta-Llama-3-8B \
|
38
|
+
--json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}' \
|
39
|
+
--tensor-parallel-size 2 \
|
40
|
+
--disable-cuda-graph
|
41
|
+
```
|
42
|
+
We will eanble CUDA Graph support soon.
|
43
|
+
"""
|
44
|
+
|
20
45
|
import types
|
21
46
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
22
47
|
|
@@ -24,7 +49,10 @@ import torch
|
|
24
49
|
from torch import nn
|
25
50
|
from torch.nn.parameter import Parameter
|
26
51
|
from transformers import LlamaConfig
|
27
|
-
from vllm.distributed import
|
52
|
+
from vllm.distributed import (
|
53
|
+
get_tensor_model_parallel_rank,
|
54
|
+
get_tensor_model_parallel_world_size,
|
55
|
+
)
|
28
56
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
29
57
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
30
58
|
|
@@ -41,35 +69,45 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
41
69
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
42
70
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
71
|
|
72
|
+
tp_size = get_tensor_model_parallel_world_size()
|
73
|
+
tp_rank = get_tensor_model_parallel_rank()
|
74
|
+
|
44
75
|
|
45
76
|
def gate_up_proj_weight_loader(
|
46
77
|
self,
|
47
78
|
param: Parameter,
|
48
79
|
loaded_weight: torch.Tensor,
|
49
|
-
loaded_shard_id:
|
80
|
+
loaded_shard_id: int,
|
50
81
|
):
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
82
|
+
# shard_id: (shard_offset, shard_size)
|
83
|
+
gate_up_offsets = {}
|
84
|
+
current_shard_offset = 0
|
85
|
+
for i, output_size in enumerate(self.output_sizes):
|
86
|
+
# Everything shrinks by tp_size if TP enabled
|
87
|
+
output_size = output_size // tp_size
|
88
|
+
gate_up_offsets[i] = (current_shard_offset, output_size)
|
89
|
+
current_shard_offset += output_size
|
90
|
+
# Re-size the param to the size after TP
|
91
|
+
if current_shard_offset != param.shape[0]:
|
92
|
+
# The clone will free the original, full tensor
|
93
|
+
param.data = param.data.narrow(0, 0, current_shard_offset).clone()
|
94
|
+
|
95
|
+
# Now load gate or up
|
96
|
+
assert loaded_shard_id < len(self.output_sizes)
|
97
|
+
param_data = param.data
|
98
|
+
shard_offset, shard_size = gate_up_offsets[loaded_shard_id]
|
99
|
+
param_data = param_data.narrow(0, shard_offset, shard_size)
|
100
|
+
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
|
101
|
+
assert param_data.shape == loaded_weight.shape
|
102
|
+
param_data.copy_(loaded_weight)
|
70
103
|
|
71
104
|
|
72
105
|
class LlamaMLP(nn.Module):
|
106
|
+
_tp_plan = {
|
107
|
+
"gate_up_proj": "Colwise_Sharded",
|
108
|
+
"down_proj": "Rowwise",
|
109
|
+
}
|
110
|
+
|
73
111
|
def __init__(
|
74
112
|
self,
|
75
113
|
hidden_size: int,
|
@@ -104,62 +142,44 @@ class LlamaMLP(nn.Module):
|
|
104
142
|
return x
|
105
143
|
|
106
144
|
|
107
|
-
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
108
|
-
shard_offset_mapping = {
|
109
|
-
"q": 0,
|
110
|
-
"k": self.num_heads * self.head_size,
|
111
|
-
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
|
112
|
-
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
|
113
|
-
}
|
114
|
-
return shard_offset_mapping.get(loaded_shard_id)
|
115
|
-
|
116
|
-
|
117
|
-
def _get_shard_size_mapping(self, loaded_shard_id: str):
|
118
|
-
shard_size_mapping = {
|
119
|
-
"q": self.num_heads * self.head_size,
|
120
|
-
"k": self.num_kv_heads * self.head_size,
|
121
|
-
"v": self.num_kv_heads * self.head_size,
|
122
|
-
}
|
123
|
-
return shard_size_mapping.get(loaded_shard_id)
|
124
|
-
|
125
|
-
|
126
145
|
def qkv_proj_weight_loader(
|
127
146
|
self,
|
128
147
|
param: Parameter,
|
129
148
|
loaded_weight: torch.Tensor,
|
130
|
-
loaded_shard_id:
|
149
|
+
loaded_shard_id: str,
|
131
150
|
):
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
param_data = param_data.narrow(0, shard_offset, shard_size)
|
157
|
-
assert param_data.shape == loaded_weight.shape
|
158
|
-
param_data.copy_(loaded_weight)
|
159
|
-
return
|
151
|
+
num_heads = self.num_heads // tp_size
|
152
|
+
num_kv_heads = self.num_kv_heads // tp_size
|
153
|
+
# shard_id: (shard_offset, shard_size)
|
154
|
+
qkv_offsets = {
|
155
|
+
"q": (0, num_heads * self.head_size),
|
156
|
+
"k": (num_heads * self.head_size, num_kv_heads * self.head_size),
|
157
|
+
"v": (
|
158
|
+
(num_heads + num_kv_heads) * self.head_size,
|
159
|
+
num_kv_heads * self.head_size,
|
160
|
+
),
|
161
|
+
}
|
162
|
+
total_size = qkv_offsets["v"][0] + qkv_offsets["v"][1]
|
163
|
+
# Re-size the param to the size after TP
|
164
|
+
if total_size != param.shape[0]:
|
165
|
+
# The clone will free the original, full tensor
|
166
|
+
param.data = param.data.narrow(0, 0, total_size).clone()
|
167
|
+
|
168
|
+
# Now load q, k or v
|
169
|
+
shard_offset, shard_size = qkv_offsets[loaded_shard_id]
|
170
|
+
param_data = param.data
|
171
|
+
param_data = param_data.narrow(0, shard_offset, shard_size)
|
172
|
+
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
|
173
|
+
assert param_data.shape == loaded_weight.shape
|
174
|
+
param_data.copy_(loaded_weight)
|
160
175
|
|
161
176
|
|
162
177
|
class LlamaAttention(nn.Module):
|
178
|
+
_tp_plan = {
|
179
|
+
"qkv_proj": "Colwise_Sharded",
|
180
|
+
"o_proj": "Rowwise",
|
181
|
+
}
|
182
|
+
|
163
183
|
def __init__(
|
164
184
|
self,
|
165
185
|
config: LlamaConfig,
|
@@ -176,7 +196,6 @@ class LlamaAttention(nn.Module):
|
|
176
196
|
) -> None:
|
177
197
|
super().__init__()
|
178
198
|
self.hidden_size = hidden_size
|
179
|
-
tp_size = get_tensor_model_parallel_world_size()
|
180
199
|
self.total_num_heads = num_heads
|
181
200
|
assert self.total_num_heads % tp_size == 0
|
182
201
|
self.num_heads = self.total_num_heads // tp_size
|
@@ -205,20 +224,12 @@ class LlamaAttention(nn.Module):
|
|
205
224
|
(self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
|
206
225
|
bias=False,
|
207
226
|
)
|
208
|
-
self.qkv_proj.total_num_heads = self.total_num_heads
|
209
227
|
self.qkv_proj.head_size = self.head_dim
|
210
|
-
self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads
|
211
228
|
self.qkv_proj.num_heads = self.total_num_heads
|
212
229
|
self.qkv_proj.num_kv_heads = self.total_num_kv_heads
|
213
230
|
self.qkv_proj.weight_loader = types.MethodType(
|
214
231
|
qkv_proj_weight_loader, self.qkv_proj
|
215
232
|
)
|
216
|
-
self.qkv_proj._get_shard_offset_mapping = types.MethodType(
|
217
|
-
_get_shard_offset_mapping, self.qkv_proj
|
218
|
-
)
|
219
|
-
self.qkv_proj._get_shard_size_mapping = types.MethodType(
|
220
|
-
_get_shard_size_mapping, self.qkv_proj
|
221
|
-
)
|
222
233
|
self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader
|
223
234
|
self.qkv_proj.weight.output_dim = 0
|
224
235
|
self.o_proj = torch.nn.Linear(
|
@@ -385,10 +396,15 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
385
396
|
self.config = config
|
386
397
|
self.quant_config = quant_config
|
387
398
|
self.torchao_config = global_server_args_dict["torchao_config"]
|
399
|
+
self.supports_torch_tp = True
|
388
400
|
self.model = LlamaModel(config, quant_config=quant_config)
|
389
401
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
390
402
|
self.logits_processor = LogitsProcessor(config)
|
391
403
|
|
404
|
+
# turning off autotune for fp8dq since it doesn't give speedup and
|
405
|
+
# increases compile time significantly
|
406
|
+
torch._inductor.config.max_autotune_gemm_backends = "ATEN"
|
407
|
+
|
392
408
|
@torch.no_grad()
|
393
409
|
def forward(
|
394
410
|
self,
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -989,11 +989,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|
989
989
|
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
990
990
|
)
|
991
991
|
token_logprobs = []
|
992
|
-
for token, logprob in
|
992
|
+
for token_idx, (token, logprob) in enumerate(
|
993
|
+
zip(logprobs.tokens, logprobs.token_logprobs)
|
994
|
+
):
|
993
995
|
token_bytes = list(token.encode("utf-8"))
|
994
996
|
top_logprobs = []
|
995
997
|
if logprobs.top_logprobs:
|
996
|
-
for top_token, top_logprob in logprobs.top_logprobs[
|
998
|
+
for top_token, top_logprob in logprobs.top_logprobs[
|
999
|
+
token_idx
|
1000
|
+
].items():
|
997
1001
|
top_token_bytes = list(top_token.encode("utf-8"))
|
998
1002
|
top_logprobs.append(
|
999
1003
|
TopLogprob(
|
@@ -236,7 +236,7 @@ ChatCompletionMessageContentPart = Union[
|
|
236
236
|
|
237
237
|
|
238
238
|
class ChatCompletionMessageGenericParam(BaseModel):
|
239
|
-
role: Literal["system", "assistant"]
|
239
|
+
role: Literal["system", "assistant", "tool"]
|
240
240
|
content: Union[str, List[ChatCompletionMessageContentTextPart]]
|
241
241
|
|
242
242
|
|
@@ -1,40 +1,34 @@
|
|
1
1
|
import abc
|
2
2
|
import dataclasses
|
3
|
-
import
|
3
|
+
from typing import List, Set, Type, Union
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
7
7
|
|
8
8
|
@dataclasses.dataclass
|
9
9
|
class _ReqLike:
|
10
|
-
origin_input_ids:
|
10
|
+
origin_input_ids: List[int]
|
11
11
|
|
12
12
|
|
13
13
|
@dataclasses.dataclass
|
14
14
|
class _BatchLike:
|
15
|
-
reqs:
|
15
|
+
reqs: List[_ReqLike]
|
16
16
|
|
17
17
|
def batch_size(self):
|
18
18
|
return len(self.reqs)
|
19
19
|
|
20
20
|
|
21
21
|
class BatchedPenalizerOrchestrator:
|
22
|
-
batch: _BatchLike
|
23
|
-
device: str
|
24
|
-
vocab_size: int
|
25
|
-
penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
|
26
|
-
|
27
22
|
def __init__(
|
28
23
|
self,
|
29
24
|
vocab_size: int,
|
30
25
|
batch: _BatchLike,
|
31
26
|
device: str,
|
32
|
-
Penalizers:
|
27
|
+
Penalizers: Set[Type["_BatchedPenalizer"]],
|
33
28
|
):
|
34
29
|
self.vocab_size = vocab_size
|
35
30
|
self.batch = batch
|
36
31
|
self.device = device
|
37
|
-
|
38
32
|
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
39
33
|
|
40
34
|
is_required = False
|
@@ -43,10 +37,12 @@ class BatchedPenalizerOrchestrator:
|
|
43
37
|
is_required |= pen_is_required
|
44
38
|
self.is_required = is_required
|
45
39
|
|
40
|
+
input_ids = [
|
41
|
+
torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
|
42
|
+
for req in self.reqs()
|
43
|
+
]
|
46
44
|
if self.is_required:
|
47
|
-
self.cumulate_input_tokens(
|
48
|
-
input_ids=[req.origin_input_ids for req in self.reqs()]
|
49
|
-
)
|
45
|
+
self.cumulate_input_tokens(input_ids=input_ids)
|
50
46
|
|
51
47
|
def reqs(self):
|
52
48
|
return self.batch.reqs
|
@@ -54,34 +50,24 @@ class BatchedPenalizerOrchestrator:
|
|
54
50
|
def batch_size(self):
|
55
51
|
return self.batch.batch_size()
|
56
52
|
|
57
|
-
def cumulate_input_tokens(
|
58
|
-
self,
|
59
|
-
input_ids: typing.Union[
|
60
|
-
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
61
|
-
],
|
62
|
-
):
|
53
|
+
def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
|
63
54
|
"""
|
64
55
|
Feed the input tokens to the penalizers.
|
65
56
|
|
66
57
|
Args:
|
67
|
-
input_ids (
|
58
|
+
input_ids (List[torch.Tensor]): The input tokens.
|
68
59
|
"""
|
69
60
|
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
|
70
61
|
|
71
62
|
for penalizer in self.penalizers.values():
|
72
63
|
penalizer.cumulate_input_tokens(input_ids=token_ids)
|
73
64
|
|
74
|
-
def cumulate_output_tokens(
|
75
|
-
self,
|
76
|
-
output_ids: typing.Union[
|
77
|
-
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
78
|
-
],
|
79
|
-
):
|
65
|
+
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
80
66
|
"""
|
81
67
|
Feed the output tokens to the penalizers.
|
82
68
|
|
83
69
|
Args:
|
84
|
-
output_ids (
|
70
|
+
output_ids (torch.Tensor): The output tokens.
|
85
71
|
"""
|
86
72
|
if not self.is_required:
|
87
73
|
return
|
@@ -112,14 +98,14 @@ class BatchedPenalizerOrchestrator:
|
|
112
98
|
|
113
99
|
def filter(
|
114
100
|
self,
|
115
|
-
indices_to_keep:
|
101
|
+
indices_to_keep: List[int],
|
116
102
|
indices_tensor_to_keep: torch.Tensor = None,
|
117
103
|
):
|
118
104
|
"""
|
119
105
|
Filter the penalizers based on the indices to keep in the batch.
|
120
106
|
|
121
107
|
Args:
|
122
|
-
indices_to_keep (
|
108
|
+
indices_to_keep (List[int]): List of indices to keep in the batch.
|
123
109
|
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
124
110
|
"""
|
125
111
|
if not self.is_required:
|
@@ -174,32 +160,18 @@ class _TokenIDs:
|
|
174
160
|
|
175
161
|
Attributes:
|
176
162
|
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
|
177
|
-
token_ids (
|
163
|
+
token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
|
178
164
|
cached_counts (torch.Tensor): The cached occurrence count tensor.
|
179
165
|
"""
|
180
166
|
|
181
|
-
orchestrator: BatchedPenalizerOrchestrator
|
182
|
-
token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
|
183
|
-
cached_counts: torch.Tensor = None
|
184
|
-
|
185
167
|
def __init__(
|
186
168
|
self,
|
187
169
|
orchestrator: BatchedPenalizerOrchestrator,
|
188
|
-
token_ids:
|
189
|
-
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
190
|
-
],
|
170
|
+
token_ids: Union[torch.Tensor, List[torch.Tensor]],
|
191
171
|
):
|
192
172
|
self.orchestrator = orchestrator
|
193
|
-
|
194
|
-
if not isinstance(token_ids[0], torch.Tensor):
|
195
|
-
token_ids = [
|
196
|
-
torch.tensor(
|
197
|
-
data=ids, dtype=torch.int64, device=self.orchestrator.device
|
198
|
-
)
|
199
|
-
for ids in token_ids
|
200
|
-
]
|
201
|
-
|
202
173
|
self.token_ids = token_ids
|
174
|
+
self.cached_counts = None
|
203
175
|
|
204
176
|
def occurrence_count(self) -> torch.Tensor:
|
205
177
|
"""
|
@@ -213,30 +185,34 @@ class _TokenIDs:
|
|
213
185
|
|
214
186
|
token_ids = self.token_ids
|
215
187
|
|
216
|
-
if isinstance(token_ids,
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
188
|
+
if isinstance(token_ids, list):
|
189
|
+
# TODO: optimize this part
|
190
|
+
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
191
|
+
sequences=token_ids,
|
192
|
+
batch_first=True,
|
193
|
+
padding_value=self.orchestrator.vocab_size,
|
194
|
+
)
|
195
|
+
self.cached_counts = torch.zeros(
|
196
|
+
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
197
|
+
dtype=torch.int64,
|
198
|
+
device=self.orchestrator.device,
|
199
|
+
).scatter_add_(
|
200
|
+
dim=1,
|
201
|
+
index=padded_token_ids,
|
202
|
+
src=torch.ones_like(padded_token_ids),
|
203
|
+
)[
|
204
|
+
:, : self.orchestrator.vocab_size
|
205
|
+
]
|
206
|
+
else:
|
207
|
+
# TODO: optimize this part. We do not need to create this big tensor every time.
|
208
|
+
# We can directly apply the results on the logits.
|
209
|
+
self.cached_counts = torch.zeros(
|
210
|
+
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
|
211
|
+
device=self.orchestrator.device,
|
212
|
+
)
|
213
|
+
self.cached_counts[
|
214
|
+
torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
|
215
|
+
] = 1
|
240
216
|
|
241
217
|
return self.cached_counts
|
242
218
|
|
@@ -246,11 +222,9 @@ class _BatchedPenalizer(abc.ABC):
|
|
246
222
|
An abstract class for a batched penalizer.
|
247
223
|
"""
|
248
224
|
|
249
|
-
orchestrator: BatchedPenalizerOrchestrator
|
250
|
-
_is_prepared: bool = False
|
251
|
-
|
252
225
|
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
253
226
|
self.orchestrator = orchestrator
|
227
|
+
self._is_prepared = False
|
254
228
|
|
255
229
|
def is_prepared(self) -> bool:
|
256
230
|
return self._is_prepared
|
@@ -293,9 +267,7 @@ class _BatchedPenalizer(abc.ABC):
|
|
293
267
|
|
294
268
|
return self._apply(logits=logits)
|
295
269
|
|
296
|
-
def filter(
|
297
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
298
|
-
):
|
270
|
+
def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
299
271
|
if not self.is_prepared():
|
300
272
|
return
|
301
273
|
|
@@ -360,9 +332,7 @@ class _BatchedPenalizer(abc.ABC):
|
|
360
332
|
pass
|
361
333
|
|
362
334
|
@abc.abstractmethod
|
363
|
-
def _filter(
|
364
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
365
|
-
):
|
335
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
366
336
|
"""
|
367
337
|
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
368
338
|
"""
|
@@ -1,8 +1,8 @@
|
|
1
|
-
import
|
1
|
+
from typing import List
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from
|
5
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
6
|
|
7
7
|
|
8
8
|
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
@@ -44,9 +44,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|
44
44
|
)
|
45
45
|
|
46
46
|
def _teardown(self):
|
47
|
-
del self.frequency_penalties
|
48
|
-
del self.cumulated_frequency_penalties
|
49
|
-
|
50
47
|
self.frequency_penalties = None
|
51
48
|
self.cumulated_frequency_penalties = None
|
52
49
|
|
@@ -62,9 +59,7 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|
62
59
|
logits -= self.cumulated_frequency_penalties
|
63
60
|
return logits
|
64
61
|
|
65
|
-
def _filter(
|
66
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
67
|
-
):
|
62
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
68
63
|
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
69
64
|
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
70
65
|
indices_tensor_to_keep
|
@@ -1,8 +1,8 @@
|
|
1
|
-
import
|
1
|
+
from typing import List
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from
|
5
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
6
|
|
7
7
|
|
8
8
|
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
@@ -70,10 +70,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|
70
70
|
)
|
71
71
|
|
72
72
|
def _teardown(self):
|
73
|
-
del self.min_new_tokens
|
74
|
-
del self.stop_token_penalties
|
75
|
-
del self.len_output_tokens
|
76
|
-
|
77
73
|
self.min_new_tokens = None
|
78
74
|
self.stop_token_penalties = None
|
79
75
|
self.len_output_tokens = None
|
@@ -89,9 +85,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|
89
85
|
logits[mask] += self.stop_token_penalties[mask]
|
90
86
|
return logits
|
91
87
|
|
92
|
-
def _filter(
|
93
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
94
|
-
):
|
88
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
95
89
|
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
|
96
90
|
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
|
97
91
|
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
|
@@ -1,8 +1,8 @@
|
|
1
|
-
import
|
1
|
+
from typing import List
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from
|
5
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
6
|
|
7
7
|
|
8
8
|
class BatchedPresencePenalizer(_BatchedPenalizer):
|
@@ -44,9 +44,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
|
44
44
|
)
|
45
45
|
|
46
46
|
def _teardown(self):
|
47
|
-
del self.presence_penalties
|
48
|
-
del self.cumulated_presence_penalties
|
49
|
-
|
50
47
|
self.presence_penalties = None
|
51
48
|
self.cumulated_presence_penalties = None
|
52
49
|
|
@@ -61,9 +58,7 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
|
61
58
|
logits -= self.cumulated_presence_penalties
|
62
59
|
return logits
|
63
60
|
|
64
|
-
def _filter(
|
65
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
66
|
-
):
|
61
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
67
62
|
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
68
63
|
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
69
64
|
indices_tensor_to_keep
|
@@ -1,8 +1,8 @@
|
|
1
|
-
import
|
1
|
+
from typing import List
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from
|
5
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
6
|
|
7
7
|
|
8
8
|
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
@@ -44,9 +44,6 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|
44
44
|
)
|
45
45
|
|
46
46
|
def _teardown(self):
|
47
|
-
del self.repetition_penalties
|
48
|
-
del self.cumulated_repetition_penalties
|
49
|
-
|
50
47
|
self.repetition_penalties = None
|
51
48
|
self.cumulated_repetition_penalties = None
|
52
49
|
|
@@ -65,9 +62,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|
65
62
|
logits * self.cumulated_repetition_penalties,
|
66
63
|
)
|
67
64
|
|
68
|
-
def _filter(
|
69
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
70
|
-
):
|
65
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
71
66
|
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
72
67
|
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
73
68
|
indices_tensor_to_keep
|