sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +337 -0
- sglang/bench_one_batch.py +474 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +115 -31
- sglang/check_env.py +3 -6
- sglang/srt/constrained/base_grammar_backend.py +4 -3
- sglang/srt/constrained/outlines_backend.py +39 -26
- sglang/srt/constrained/xgrammar_backend.py +58 -14
- sglang/srt/layers/activation.py +3 -0
- sglang/srt/layers/attention/flashinfer_backend.py +93 -48
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/custom_op_util.py +26 -0
- sglang/srt/layers/fused_moe/fused_moe.py +11 -4
- sglang/srt/layers/fused_moe/patch.py +4 -2
- sglang/srt/layers/layernorm.py +4 -0
- sglang/srt/layers/logits_processor.py +10 -10
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +74 -9
- sglang/srt/managers/detokenizer_manager.py +1 -14
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/schedule_batch.py +104 -38
- sglang/srt/managers/schedule_policy.py +5 -1
- sglang/srt/managers/scheduler.py +210 -56
- sglang/srt/managers/session_controller.py +62 -0
- sglang/srt/managers/tokenizer_manager.py +38 -0
- sglang/srt/managers/tp_worker.py +12 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
- sglang/srt/model_executor/cuda_graph_runner.py +43 -6
- sglang/srt/model_executor/forward_batch_info.py +109 -15
- sglang/srt/model_executor/model_runner.py +102 -43
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/deepseek_v2.py +147 -44
- sglang/srt/models/gemma2.py +9 -8
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/torch_native_llama.py +94 -78
- sglang/srt/openai_api/adapter.py +11 -4
- sglang/srt/openai_api/protocol.py +30 -27
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +58 -57
- sglang/srt/sampling/sampling_params.py +3 -3
- sglang/srt/server.py +29 -2
- sglang/srt/server_args.py +97 -60
- sglang/srt/utils.py +103 -51
- sglang/test/runners.py +25 -6
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +33 -22
- sglang/version.py +1 -1
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,31 @@ limitations under the License.
|
|
17
17
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
18
18
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
19
19
|
|
20
|
+
# PyTorch Tensor Parallel Available for This Model
|
21
|
+
"""
|
22
|
+
This model supports tensor parallelism (TP) using the PyTorch tensor parallel package.
|
23
|
+
Reference: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
|
24
|
+
|
25
|
+
Here is a quick example to enable TP:
|
26
|
+
```python
|
27
|
+
from sglang.srt.model_parallel import tensor_parallel
|
28
|
+
|
29
|
+
device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
|
30
|
+
tensor_parallel(model, device_mesh)
|
31
|
+
```
|
32
|
+
|
33
|
+
An end-to-end example can be found in `python/sglang/bench_one_batch.py`.
|
34
|
+
You can run it with the following command:
|
35
|
+
```bash
|
36
|
+
$ python3 -m sglang.bench_one_batch --correct \
|
37
|
+
--model meta-llama/Meta-Llama-3-8B \
|
38
|
+
--json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}' \
|
39
|
+
--tensor-parallel-size 2 \
|
40
|
+
--disable-cuda-graph
|
41
|
+
```
|
42
|
+
We will eanble CUDA Graph support soon.
|
43
|
+
"""
|
44
|
+
|
20
45
|
import types
|
21
46
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
22
47
|
|
@@ -24,7 +49,10 @@ import torch
|
|
24
49
|
from torch import nn
|
25
50
|
from torch.nn.parameter import Parameter
|
26
51
|
from transformers import LlamaConfig
|
27
|
-
from vllm.distributed import
|
52
|
+
from vllm.distributed import (
|
53
|
+
get_tensor_model_parallel_rank,
|
54
|
+
get_tensor_model_parallel_world_size,
|
55
|
+
)
|
28
56
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
29
57
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
30
58
|
|
@@ -41,35 +69,45 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
41
69
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
42
70
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
71
|
|
72
|
+
tp_size = get_tensor_model_parallel_world_size()
|
73
|
+
tp_rank = get_tensor_model_parallel_rank()
|
74
|
+
|
44
75
|
|
45
76
|
def gate_up_proj_weight_loader(
|
46
77
|
self,
|
47
78
|
param: Parameter,
|
48
79
|
loaded_weight: torch.Tensor,
|
49
|
-
loaded_shard_id:
|
80
|
+
loaded_shard_id: int,
|
50
81
|
):
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
82
|
+
# shard_id: (shard_offset, shard_size)
|
83
|
+
gate_up_offsets = {}
|
84
|
+
current_shard_offset = 0
|
85
|
+
for i, output_size in enumerate(self.output_sizes):
|
86
|
+
# Everything shrinks by tp_size if TP enabled
|
87
|
+
output_size = output_size // tp_size
|
88
|
+
gate_up_offsets[i] = (current_shard_offset, output_size)
|
89
|
+
current_shard_offset += output_size
|
90
|
+
# Re-size the param to the size after TP
|
91
|
+
if current_shard_offset != param.shape[0]:
|
92
|
+
# The clone will free the original, full tensor
|
93
|
+
param.data = param.data.narrow(0, 0, current_shard_offset).clone()
|
94
|
+
|
95
|
+
# Now load gate or up
|
96
|
+
assert loaded_shard_id < len(self.output_sizes)
|
97
|
+
param_data = param.data
|
98
|
+
shard_offset, shard_size = gate_up_offsets[loaded_shard_id]
|
99
|
+
param_data = param_data.narrow(0, shard_offset, shard_size)
|
100
|
+
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
|
101
|
+
assert param_data.shape == loaded_weight.shape
|
102
|
+
param_data.copy_(loaded_weight)
|
70
103
|
|
71
104
|
|
72
105
|
class LlamaMLP(nn.Module):
|
106
|
+
_tp_plan = {
|
107
|
+
"gate_up_proj": "Colwise_Sharded",
|
108
|
+
"down_proj": "Rowwise",
|
109
|
+
}
|
110
|
+
|
73
111
|
def __init__(
|
74
112
|
self,
|
75
113
|
hidden_size: int,
|
@@ -104,62 +142,44 @@ class LlamaMLP(nn.Module):
|
|
104
142
|
return x
|
105
143
|
|
106
144
|
|
107
|
-
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
108
|
-
shard_offset_mapping = {
|
109
|
-
"q": 0,
|
110
|
-
"k": self.num_heads * self.head_size,
|
111
|
-
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
|
112
|
-
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
|
113
|
-
}
|
114
|
-
return shard_offset_mapping.get(loaded_shard_id)
|
115
|
-
|
116
|
-
|
117
|
-
def _get_shard_size_mapping(self, loaded_shard_id: str):
|
118
|
-
shard_size_mapping = {
|
119
|
-
"q": self.num_heads * self.head_size,
|
120
|
-
"k": self.num_kv_heads * self.head_size,
|
121
|
-
"v": self.num_kv_heads * self.head_size,
|
122
|
-
}
|
123
|
-
return shard_size_mapping.get(loaded_shard_id)
|
124
|
-
|
125
|
-
|
126
145
|
def qkv_proj_weight_loader(
|
127
146
|
self,
|
128
147
|
param: Parameter,
|
129
148
|
loaded_weight: torch.Tensor,
|
130
|
-
loaded_shard_id:
|
149
|
+
loaded_shard_id: str,
|
131
150
|
):
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
param_data = param_data.narrow(0, shard_offset, shard_size)
|
157
|
-
assert param_data.shape == loaded_weight.shape
|
158
|
-
param_data.copy_(loaded_weight)
|
159
|
-
return
|
151
|
+
num_heads = self.num_heads // tp_size
|
152
|
+
num_kv_heads = self.num_kv_heads // tp_size
|
153
|
+
# shard_id: (shard_offset, shard_size)
|
154
|
+
qkv_offsets = {
|
155
|
+
"q": (0, num_heads * self.head_size),
|
156
|
+
"k": (num_heads * self.head_size, num_kv_heads * self.head_size),
|
157
|
+
"v": (
|
158
|
+
(num_heads + num_kv_heads) * self.head_size,
|
159
|
+
num_kv_heads * self.head_size,
|
160
|
+
),
|
161
|
+
}
|
162
|
+
total_size = qkv_offsets["v"][0] + qkv_offsets["v"][1]
|
163
|
+
# Re-size the param to the size after TP
|
164
|
+
if total_size != param.shape[0]:
|
165
|
+
# The clone will free the original, full tensor
|
166
|
+
param.data = param.data.narrow(0, 0, total_size).clone()
|
167
|
+
|
168
|
+
# Now load q, k or v
|
169
|
+
shard_offset, shard_size = qkv_offsets[loaded_shard_id]
|
170
|
+
param_data = param.data
|
171
|
+
param_data = param_data.narrow(0, shard_offset, shard_size)
|
172
|
+
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
|
173
|
+
assert param_data.shape == loaded_weight.shape
|
174
|
+
param_data.copy_(loaded_weight)
|
160
175
|
|
161
176
|
|
162
177
|
class LlamaAttention(nn.Module):
|
178
|
+
_tp_plan = {
|
179
|
+
"qkv_proj": "Colwise_Sharded",
|
180
|
+
"o_proj": "Rowwise",
|
181
|
+
}
|
182
|
+
|
163
183
|
def __init__(
|
164
184
|
self,
|
165
185
|
config: LlamaConfig,
|
@@ -176,7 +196,6 @@ class LlamaAttention(nn.Module):
|
|
176
196
|
) -> None:
|
177
197
|
super().__init__()
|
178
198
|
self.hidden_size = hidden_size
|
179
|
-
tp_size = get_tensor_model_parallel_world_size()
|
180
199
|
self.total_num_heads = num_heads
|
181
200
|
assert self.total_num_heads % tp_size == 0
|
182
201
|
self.num_heads = self.total_num_heads // tp_size
|
@@ -205,20 +224,12 @@ class LlamaAttention(nn.Module):
|
|
205
224
|
(self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
|
206
225
|
bias=False,
|
207
226
|
)
|
208
|
-
self.qkv_proj.total_num_heads = self.total_num_heads
|
209
227
|
self.qkv_proj.head_size = self.head_dim
|
210
|
-
self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads
|
211
228
|
self.qkv_proj.num_heads = self.total_num_heads
|
212
229
|
self.qkv_proj.num_kv_heads = self.total_num_kv_heads
|
213
230
|
self.qkv_proj.weight_loader = types.MethodType(
|
214
231
|
qkv_proj_weight_loader, self.qkv_proj
|
215
232
|
)
|
216
|
-
self.qkv_proj._get_shard_offset_mapping = types.MethodType(
|
217
|
-
_get_shard_offset_mapping, self.qkv_proj
|
218
|
-
)
|
219
|
-
self.qkv_proj._get_shard_size_mapping = types.MethodType(
|
220
|
-
_get_shard_size_mapping, self.qkv_proj
|
221
|
-
)
|
222
233
|
self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader
|
223
234
|
self.qkv_proj.weight.output_dim = 0
|
224
235
|
self.o_proj = torch.nn.Linear(
|
@@ -385,10 +396,15 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
385
396
|
self.config = config
|
386
397
|
self.quant_config = quant_config
|
387
398
|
self.torchao_config = global_server_args_dict["torchao_config"]
|
399
|
+
self.supports_torch_tp = True
|
388
400
|
self.model = LlamaModel(config, quant_config=quant_config)
|
389
401
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
390
402
|
self.logits_processor = LogitsProcessor(config)
|
391
403
|
|
404
|
+
# turning off autotune for fp8dq since it doesn't give speedup and
|
405
|
+
# increases compile time significantly
|
406
|
+
torch._inductor.config.max_autotune_gemm_backends = "ATEN"
|
407
|
+
|
392
408
|
@torch.no_grad()
|
393
409
|
def forward(
|
394
410
|
self,
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -516,8 +516,9 @@ def v1_generate_request(
|
|
516
516
|
"regex": request.regex,
|
517
517
|
"json_schema": request.json_schema,
|
518
518
|
"n": request.n,
|
519
|
-
"ignore_eos": request.ignore_eos,
|
520
519
|
"no_stop_trim": request.no_stop_trim,
|
520
|
+
"ignore_eos": request.ignore_eos,
|
521
|
+
"skip_special_tokens": request.skip_special_tokens,
|
521
522
|
}
|
522
523
|
)
|
523
524
|
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
|
@@ -928,7 +929,9 @@ def v1_chat_generate_request(
|
|
928
929
|
"repetition_penalty": request.repetition_penalty,
|
929
930
|
"regex": request.regex,
|
930
931
|
"n": request.n,
|
932
|
+
"no_stop_trim": request.no_stop_trim,
|
931
933
|
"ignore_eos": request.ignore_eos,
|
934
|
+
"skip_special_tokens": request.skip_special_tokens,
|
932
935
|
}
|
933
936
|
if request.response_format and request.response_format.type == "json_schema":
|
934
937
|
sampling_params["json_schema"] = convert_json_schema_to_str(
|
@@ -986,11 +989,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|
986
989
|
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
987
990
|
)
|
988
991
|
token_logprobs = []
|
989
|
-
for token, logprob in
|
992
|
+
for token_idx, (token, logprob) in enumerate(
|
993
|
+
zip(logprobs.tokens, logprobs.token_logprobs)
|
994
|
+
):
|
990
995
|
token_bytes = list(token.encode("utf-8"))
|
991
996
|
top_logprobs = []
|
992
997
|
if logprobs.top_logprobs:
|
993
|
-
for top_token, top_logprob in logprobs.top_logprobs[
|
998
|
+
for top_token, top_logprob in logprobs.top_logprobs[
|
999
|
+
token_idx
|
1000
|
+
].items():
|
994
1001
|
top_token_bytes = list(top_token.encode("utf-8"))
|
995
1002
|
top_logprobs.append(
|
996
1003
|
TopLogprob(
|
@@ -1166,7 +1173,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1166
1173
|
is_first = False
|
1167
1174
|
choice_data = ChatCompletionResponseStreamChoice(
|
1168
1175
|
index=index,
|
1169
|
-
delta=DeltaMessage(role="assistant"),
|
1176
|
+
delta=DeltaMessage(role="assistant", content=""),
|
1170
1177
|
finish_reason=(
|
1171
1178
|
finish_reason["type"] if finish_reason else ""
|
1172
1179
|
),
|
@@ -36,7 +36,7 @@ class ModelList(BaseModel):
|
|
36
36
|
"""Model list consists of model cards."""
|
37
37
|
|
38
38
|
object: str = "list"
|
39
|
-
data: List[ModelCard] =
|
39
|
+
data: List[ModelCard] = Field(default_factory=list)
|
40
40
|
|
41
41
|
|
42
42
|
class ErrorResponse(BaseModel):
|
@@ -143,7 +143,7 @@ class BatchResponse(BaseModel):
|
|
143
143
|
expired_at: Optional[int] = None
|
144
144
|
cancelling_at: Optional[int] = None
|
145
145
|
cancelled_at: Optional[int] = None
|
146
|
-
request_counts: dict =
|
146
|
+
request_counts: Optional[dict] = None
|
147
147
|
metadata: Optional[dict] = None
|
148
148
|
|
149
149
|
|
@@ -153,30 +153,31 @@ class CompletionRequest(BaseModel):
|
|
153
153
|
model: str
|
154
154
|
prompt: Union[List[int], List[List[int]], str, List[str]]
|
155
155
|
best_of: Optional[int] = None
|
156
|
-
echo:
|
157
|
-
frequency_penalty:
|
156
|
+
echo: bool = False
|
157
|
+
frequency_penalty: float = 0.0
|
158
158
|
logit_bias: Optional[Dict[str, float]] = None
|
159
159
|
logprobs: Optional[int] = None
|
160
|
-
max_tokens:
|
160
|
+
max_tokens: int = 16
|
161
161
|
n: int = 1
|
162
|
-
presence_penalty:
|
162
|
+
presence_penalty: float = 0.0
|
163
163
|
seed: Optional[int] = None
|
164
|
-
stop: Optional[Union[str, List[str]]] =
|
165
|
-
stream:
|
164
|
+
stop: Optional[Union[str, List[str]]] = None
|
165
|
+
stream: bool = False
|
166
166
|
stream_options: Optional[StreamOptions] = None
|
167
167
|
suffix: Optional[str] = None
|
168
|
-
temperature:
|
169
|
-
top_p:
|
168
|
+
temperature: float = 1.0
|
169
|
+
top_p: float = 1.0
|
170
170
|
user: Optional[str] = None
|
171
171
|
|
172
172
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
173
|
-
regex: Optional[str] = None
|
174
173
|
json_schema: Optional[str] = None
|
175
|
-
|
174
|
+
regex: Optional[str] = None
|
176
175
|
min_tokens: int = 0
|
177
|
-
repetition_penalty:
|
178
|
-
stop_token_ids: Optional[List[int]] =
|
179
|
-
no_stop_trim:
|
176
|
+
repetition_penalty: float = 1.0
|
177
|
+
stop_token_ids: Optional[List[int]] = None
|
178
|
+
no_stop_trim: bool = False
|
179
|
+
ignore_eos: bool = False
|
180
|
+
skip_special_tokens: bool = True
|
180
181
|
|
181
182
|
|
182
183
|
class CompletionResponseChoice(BaseModel):
|
@@ -235,7 +236,7 @@ ChatCompletionMessageContentPart = Union[
|
|
235
236
|
|
236
237
|
|
237
238
|
class ChatCompletionMessageGenericParam(BaseModel):
|
238
|
-
role: Literal["system", "assistant"]
|
239
|
+
role: Literal["system", "assistant", "tool"]
|
239
240
|
content: Union[str, List[ChatCompletionMessageContentTextPart]]
|
240
241
|
|
241
242
|
|
@@ -259,28 +260,30 @@ class ChatCompletionRequest(BaseModel):
|
|
259
260
|
# https://platform.openai.com/docs/api-reference/chat/create
|
260
261
|
messages: List[ChatCompletionMessageParam]
|
261
262
|
model: str
|
262
|
-
frequency_penalty:
|
263
|
+
frequency_penalty: float = 0.0
|
263
264
|
logit_bias: Optional[Dict[str, float]] = None
|
264
|
-
logprobs:
|
265
|
+
logprobs: bool = False
|
265
266
|
top_logprobs: Optional[int] = None
|
266
267
|
max_tokens: Optional[int] = None
|
267
|
-
n:
|
268
|
-
presence_penalty:
|
268
|
+
n: int = 1
|
269
|
+
presence_penalty: float = 0.0
|
269
270
|
response_format: Optional[ResponseFormat] = None
|
270
271
|
seed: Optional[int] = None
|
271
|
-
stop: Optional[Union[str, List[str]]] =
|
272
|
-
stream:
|
272
|
+
stop: Optional[Union[str, List[str]]] = None
|
273
|
+
stream: bool = False
|
273
274
|
stream_options: Optional[StreamOptions] = None
|
274
|
-
temperature:
|
275
|
-
top_p:
|
275
|
+
temperature: float = 0.7
|
276
|
+
top_p: float = 1.0
|
276
277
|
user: Optional[str] = None
|
277
278
|
|
278
279
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
279
280
|
regex: Optional[str] = None
|
280
|
-
min_tokens:
|
281
|
-
repetition_penalty:
|
282
|
-
stop_token_ids: Optional[List[int]] =
|
281
|
+
min_tokens: int = 0
|
282
|
+
repetition_penalty: float = 1.0
|
283
|
+
stop_token_ids: Optional[List[int]] = None
|
284
|
+
no_stop_trim: bool = False
|
283
285
|
ignore_eos: bool = False
|
286
|
+
skip_special_tokens: bool = True
|
284
287
|
|
285
288
|
|
286
289
|
class ChatMessage(BaseModel):
|
@@ -1,40 +1,34 @@
|
|
1
1
|
import abc
|
2
2
|
import dataclasses
|
3
|
-
import
|
3
|
+
from typing import List, Set, Type, Union
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
7
7
|
|
8
8
|
@dataclasses.dataclass
|
9
9
|
class _ReqLike:
|
10
|
-
origin_input_ids:
|
10
|
+
origin_input_ids: List[int]
|
11
11
|
|
12
12
|
|
13
13
|
@dataclasses.dataclass
|
14
14
|
class _BatchLike:
|
15
|
-
reqs:
|
15
|
+
reqs: List[_ReqLike]
|
16
16
|
|
17
17
|
def batch_size(self):
|
18
18
|
return len(self.reqs)
|
19
19
|
|
20
20
|
|
21
21
|
class BatchedPenalizerOrchestrator:
|
22
|
-
batch: _BatchLike
|
23
|
-
device: str
|
24
|
-
vocab_size: int
|
25
|
-
penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
|
26
|
-
|
27
22
|
def __init__(
|
28
23
|
self,
|
29
24
|
vocab_size: int,
|
30
25
|
batch: _BatchLike,
|
31
26
|
device: str,
|
32
|
-
Penalizers:
|
27
|
+
Penalizers: Set[Type["_BatchedPenalizer"]],
|
33
28
|
):
|
34
29
|
self.vocab_size = vocab_size
|
35
30
|
self.batch = batch
|
36
31
|
self.device = device
|
37
|
-
|
38
32
|
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
39
33
|
|
40
34
|
is_required = False
|
@@ -43,10 +37,12 @@ class BatchedPenalizerOrchestrator:
|
|
43
37
|
is_required |= pen_is_required
|
44
38
|
self.is_required = is_required
|
45
39
|
|
40
|
+
input_ids = [
|
41
|
+
torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
|
42
|
+
for req in self.reqs()
|
43
|
+
]
|
46
44
|
if self.is_required:
|
47
|
-
self.cumulate_input_tokens(
|
48
|
-
input_ids=[req.origin_input_ids for req in self.reqs()]
|
49
|
-
)
|
45
|
+
self.cumulate_input_tokens(input_ids=input_ids)
|
50
46
|
|
51
47
|
def reqs(self):
|
52
48
|
return self.batch.reqs
|
@@ -54,34 +50,24 @@ class BatchedPenalizerOrchestrator:
|
|
54
50
|
def batch_size(self):
|
55
51
|
return self.batch.batch_size()
|
56
52
|
|
57
|
-
def cumulate_input_tokens(
|
58
|
-
self,
|
59
|
-
input_ids: typing.Union[
|
60
|
-
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
61
|
-
],
|
62
|
-
):
|
53
|
+
def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
|
63
54
|
"""
|
64
55
|
Feed the input tokens to the penalizers.
|
65
56
|
|
66
57
|
Args:
|
67
|
-
input_ids (
|
58
|
+
input_ids (List[torch.Tensor]): The input tokens.
|
68
59
|
"""
|
69
60
|
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
|
70
61
|
|
71
62
|
for penalizer in self.penalizers.values():
|
72
63
|
penalizer.cumulate_input_tokens(input_ids=token_ids)
|
73
64
|
|
74
|
-
def cumulate_output_tokens(
|
75
|
-
self,
|
76
|
-
output_ids: typing.Union[
|
77
|
-
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
78
|
-
],
|
79
|
-
):
|
65
|
+
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
80
66
|
"""
|
81
67
|
Feed the output tokens to the penalizers.
|
82
68
|
|
83
69
|
Args:
|
84
|
-
output_ids (
|
70
|
+
output_ids (torch.Tensor): The output tokens.
|
85
71
|
"""
|
86
72
|
if not self.is_required:
|
87
73
|
return
|
@@ -112,14 +98,14 @@ class BatchedPenalizerOrchestrator:
|
|
112
98
|
|
113
99
|
def filter(
|
114
100
|
self,
|
115
|
-
indices_to_keep:
|
101
|
+
indices_to_keep: List[int],
|
116
102
|
indices_tensor_to_keep: torch.Tensor = None,
|
117
103
|
):
|
118
104
|
"""
|
119
105
|
Filter the penalizers based on the indices to keep in the batch.
|
120
106
|
|
121
107
|
Args:
|
122
|
-
indices_to_keep (
|
108
|
+
indices_to_keep (List[int]): List of indices to keep in the batch.
|
123
109
|
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
124
110
|
"""
|
125
111
|
if not self.is_required:
|
@@ -174,32 +160,18 @@ class _TokenIDs:
|
|
174
160
|
|
175
161
|
Attributes:
|
176
162
|
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
|
177
|
-
token_ids (
|
163
|
+
token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
|
178
164
|
cached_counts (torch.Tensor): The cached occurrence count tensor.
|
179
165
|
"""
|
180
166
|
|
181
|
-
orchestrator: BatchedPenalizerOrchestrator
|
182
|
-
token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
|
183
|
-
cached_counts: torch.Tensor = None
|
184
|
-
|
185
167
|
def __init__(
|
186
168
|
self,
|
187
169
|
orchestrator: BatchedPenalizerOrchestrator,
|
188
|
-
token_ids:
|
189
|
-
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
190
|
-
],
|
170
|
+
token_ids: Union[torch.Tensor, List[torch.Tensor]],
|
191
171
|
):
|
192
172
|
self.orchestrator = orchestrator
|
193
|
-
|
194
|
-
if not isinstance(token_ids[0], torch.Tensor):
|
195
|
-
token_ids = [
|
196
|
-
torch.tensor(
|
197
|
-
data=ids, dtype=torch.int64, device=self.orchestrator.device
|
198
|
-
)
|
199
|
-
for ids in token_ids
|
200
|
-
]
|
201
|
-
|
202
173
|
self.token_ids = token_ids
|
174
|
+
self.cached_counts = None
|
203
175
|
|
204
176
|
def occurrence_count(self) -> torch.Tensor:
|
205
177
|
"""
|
@@ -213,30 +185,34 @@ class _TokenIDs:
|
|
213
185
|
|
214
186
|
token_ids = self.token_ids
|
215
187
|
|
216
|
-
if isinstance(token_ids,
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
188
|
+
if isinstance(token_ids, list):
|
189
|
+
# TODO: optimize this part
|
190
|
+
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
191
|
+
sequences=token_ids,
|
192
|
+
batch_first=True,
|
193
|
+
padding_value=self.orchestrator.vocab_size,
|
194
|
+
)
|
195
|
+
self.cached_counts = torch.zeros(
|
196
|
+
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
197
|
+
dtype=torch.int64,
|
198
|
+
device=self.orchestrator.device,
|
199
|
+
).scatter_add_(
|
200
|
+
dim=1,
|
201
|
+
index=padded_token_ids,
|
202
|
+
src=torch.ones_like(padded_token_ids),
|
203
|
+
)[
|
204
|
+
:, : self.orchestrator.vocab_size
|
205
|
+
]
|
206
|
+
else:
|
207
|
+
# TODO: optimize this part. We do not need to create this big tensor every time.
|
208
|
+
# We can directly apply the results on the logits.
|
209
|
+
self.cached_counts = torch.zeros(
|
210
|
+
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
|
211
|
+
device=self.orchestrator.device,
|
212
|
+
)
|
213
|
+
self.cached_counts[
|
214
|
+
torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
|
215
|
+
] = 1
|
240
216
|
|
241
217
|
return self.cached_counts
|
242
218
|
|
@@ -246,11 +222,9 @@ class _BatchedPenalizer(abc.ABC):
|
|
246
222
|
An abstract class for a batched penalizer.
|
247
223
|
"""
|
248
224
|
|
249
|
-
orchestrator: BatchedPenalizerOrchestrator
|
250
|
-
_is_prepared: bool = False
|
251
|
-
|
252
225
|
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
253
226
|
self.orchestrator = orchestrator
|
227
|
+
self._is_prepared = False
|
254
228
|
|
255
229
|
def is_prepared(self) -> bool:
|
256
230
|
return self._is_prepared
|
@@ -293,9 +267,7 @@ class _BatchedPenalizer(abc.ABC):
|
|
293
267
|
|
294
268
|
return self._apply(logits=logits)
|
295
269
|
|
296
|
-
def filter(
|
297
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
298
|
-
):
|
270
|
+
def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
299
271
|
if not self.is_prepared():
|
300
272
|
return
|
301
273
|
|
@@ -360,9 +332,7 @@ class _BatchedPenalizer(abc.ABC):
|
|
360
332
|
pass
|
361
333
|
|
362
334
|
@abc.abstractmethod
|
363
|
-
def _filter(
|
364
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
365
|
-
):
|
335
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
366
336
|
"""
|
367
337
|
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
368
338
|
"""
|