sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +28 -10
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +4 -0
- sglang/srt/managers/schedule_batch.py +120 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +202 -140
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +60 -1
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +92 -49
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +92 -58
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +116 -17
- sglang/srt/server_args.py +121 -45
- sglang/srt/utils.py +11 -3
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
sglang/srt/openai_api/adapter.py
CHANGED
@@ -25,7 +25,7 @@ from http import HTTPStatus
|
|
25
25
|
from typing import Dict, List
|
26
26
|
|
27
27
|
from fastapi import HTTPException, Request, UploadFile
|
28
|
-
from fastapi.responses import
|
28
|
+
from fastapi.responses import ORJSONResponse, StreamingResponse
|
29
29
|
from pydantic import ValidationError
|
30
30
|
|
31
31
|
try:
|
@@ -101,7 +101,7 @@ def create_error_response(
|
|
101
101
|
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
102
102
|
):
|
103
103
|
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
|
104
|
-
return
|
104
|
+
return ORJSONResponse(content=error.model_dump(), status_code=error.code)
|
105
105
|
|
106
106
|
|
107
107
|
def create_streaming_error_response(
|
@@ -302,7 +302,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
302
302
|
if not isinstance(ret, list):
|
303
303
|
ret = [ret]
|
304
304
|
if end_point == "/v1/chat/completions":
|
305
|
-
responses = v1_chat_generate_response(
|
305
|
+
responses = v1_chat_generate_response(
|
306
|
+
request,
|
307
|
+
ret,
|
308
|
+
to_file=True,
|
309
|
+
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
310
|
+
)
|
306
311
|
else:
|
307
312
|
responses = v1_generate_response(
|
308
313
|
request, ret, tokenizer_manager, to_file=True
|
@@ -493,23 +498,38 @@ def v1_generate_request(
|
|
493
498
|
top_logprobs_nums.append(
|
494
499
|
request.logprobs if request.logprobs is not None else 0
|
495
500
|
)
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
501
|
+
sampling_params = []
|
502
|
+
if isinstance(request.no_stop_trim, list):
|
503
|
+
num_reqs = len(request.prompt)
|
504
|
+
else:
|
505
|
+
num_reqs = 1
|
506
|
+
for i in range(num_reqs):
|
507
|
+
sampling_params.append(
|
508
|
+
{
|
509
|
+
"temperature": request.temperature,
|
510
|
+
"max_new_tokens": request.max_tokens,
|
511
|
+
"min_new_tokens": request.min_tokens,
|
512
|
+
"stop": request.stop,
|
513
|
+
"stop_token_ids": request.stop_token_ids,
|
514
|
+
"top_p": request.top_p,
|
515
|
+
"presence_penalty": request.presence_penalty,
|
516
|
+
"frequency_penalty": request.frequency_penalty,
|
517
|
+
"repetition_penalty": request.repetition_penalty,
|
518
|
+
"regex": request.regex,
|
519
|
+
"json_schema": request.json_schema,
|
520
|
+
"n": request.n,
|
521
|
+
"ignore_eos": request.ignore_eos,
|
522
|
+
"no_stop_trim": (
|
523
|
+
request.no_stop_trim
|
524
|
+
if not isinstance(request.no_stop_trim, list)
|
525
|
+
else request.no_stop_trim[i]
|
526
|
+
),
|
527
|
+
}
|
528
|
+
)
|
529
|
+
if num_reqs == 1:
|
530
|
+
sampling_params_list.append(sampling_params[0])
|
531
|
+
else:
|
532
|
+
sampling_params_list.append(sampling_params)
|
513
533
|
|
514
534
|
if len(all_requests) == 1:
|
515
535
|
prompt = prompts[0]
|
@@ -601,16 +621,19 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|
601
621
|
else:
|
602
622
|
logprobs = None
|
603
623
|
|
624
|
+
finish_reason = ret_item["meta_info"]["finish_reason"]
|
625
|
+
|
604
626
|
if to_file:
|
605
627
|
# to make the choise data json serializable
|
606
628
|
choice_data = {
|
607
629
|
"index": 0,
|
608
630
|
"text": text,
|
609
631
|
"logprobs": logprobs,
|
610
|
-
"finish_reason": (
|
611
|
-
|
612
|
-
|
613
|
-
|
632
|
+
"finish_reason": (finish_reason["type"] if finish_reason else ""),
|
633
|
+
"matched_stop": (
|
634
|
+
finish_reason["matched"]
|
635
|
+
if finish_reason and "matched" in finish_reason
|
636
|
+
else None
|
614
637
|
),
|
615
638
|
}
|
616
639
|
else:
|
@@ -618,10 +641,11 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|
618
641
|
index=idx,
|
619
642
|
text=text,
|
620
643
|
logprobs=logprobs,
|
621
|
-
finish_reason=(
|
622
|
-
|
623
|
-
|
624
|
-
|
644
|
+
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
645
|
+
matched_stop=(
|
646
|
+
finish_reason["matched"]
|
647
|
+
if finish_reason and "matched" in finish_reason
|
648
|
+
else None
|
625
649
|
),
|
626
650
|
)
|
627
651
|
|
@@ -751,14 +775,16 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
751
775
|
|
752
776
|
delta = text[len(stream_buffer) :]
|
753
777
|
stream_buffer = stream_buffer + delta
|
778
|
+
finish_reason = content["meta_info"]["finish_reason"]
|
754
779
|
choice_data = CompletionResponseStreamChoice(
|
755
780
|
index=index,
|
756
781
|
text=delta,
|
757
782
|
logprobs=logprobs,
|
758
|
-
finish_reason=(
|
759
|
-
|
760
|
-
|
761
|
-
|
783
|
+
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
784
|
+
matched_stop=(
|
785
|
+
finish_reason["matched"]
|
786
|
+
if finish_reason and "matched" in finish_reason
|
787
|
+
else None
|
762
788
|
),
|
763
789
|
)
|
764
790
|
chunk = CompletionStreamResponse(
|
@@ -910,6 +936,7 @@ def v1_chat_generate_request(
|
|
910
936
|
"repetition_penalty": request.repetition_penalty,
|
911
937
|
"regex": request.regex,
|
912
938
|
"n": request.n,
|
939
|
+
"ignore_eos": request.ignore_eos,
|
913
940
|
}
|
914
941
|
if request.response_format and request.response_format.type == "json_schema":
|
915
942
|
sampling_params["json_schema"] = convert_json_schema_to_str(
|
@@ -954,7 +981,7 @@ def v1_chat_generate_request(
|
|
954
981
|
return adapted_request, all_requests
|
955
982
|
|
956
983
|
|
957
|
-
def v1_chat_generate_response(request, ret, to_file=False):
|
984
|
+
def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
958
985
|
choices = []
|
959
986
|
|
960
987
|
for idx, ret_item in enumerate(ret):
|
@@ -995,16 +1022,19 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
995
1022
|
else:
|
996
1023
|
choice_logprobs = None
|
997
1024
|
|
1025
|
+
finish_reason = ret_item["meta_info"]["finish_reason"]
|
1026
|
+
|
998
1027
|
if to_file:
|
999
1028
|
# to make the choice data json serializable
|
1000
1029
|
choice_data = {
|
1001
1030
|
"index": 0,
|
1002
1031
|
"message": {"role": "assistant", "content": ret_item["text"]},
|
1003
1032
|
"logprobs": choice_logprobs,
|
1004
|
-
"finish_reason": (
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
1033
|
+
"finish_reason": (finish_reason["type"] if finish_reason else ""),
|
1034
|
+
"matched_stop": (
|
1035
|
+
finish_reason["matched"]
|
1036
|
+
if finish_reason and "matched" in finish_reason
|
1037
|
+
else None
|
1008
1038
|
),
|
1009
1039
|
}
|
1010
1040
|
else:
|
@@ -1012,10 +1042,11 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
1012
1042
|
index=idx,
|
1013
1043
|
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
1014
1044
|
logprobs=choice_logprobs,
|
1015
|
-
finish_reason=(
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1045
|
+
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
1046
|
+
matched_stop=(
|
1047
|
+
finish_reason["matched"]
|
1048
|
+
if finish_reason and "matched" in finish_reason
|
1049
|
+
else None
|
1019
1050
|
),
|
1020
1051
|
)
|
1021
1052
|
|
@@ -1051,6 +1082,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
1051
1082
|
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
|
1052
1083
|
)
|
1053
1084
|
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
1085
|
+
cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret)
|
1054
1086
|
response = ChatCompletionResponse(
|
1055
1087
|
id=ret[0]["meta_info"]["id"],
|
1056
1088
|
model=request.model,
|
@@ -1059,6 +1091,9 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
1059
1091
|
prompt_tokens=prompt_tokens,
|
1060
1092
|
completion_tokens=completion_tokens,
|
1061
1093
|
total_tokens=prompt_tokens + completion_tokens,
|
1094
|
+
prompt_tokens_details=(
|
1095
|
+
{"cached_tokens": cached_tokens} if cache_report else None
|
1096
|
+
),
|
1062
1097
|
),
|
1063
1098
|
)
|
1064
1099
|
return response
|
@@ -1134,6 +1169,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1134
1169
|
else:
|
1135
1170
|
choice_logprobs = None
|
1136
1171
|
|
1172
|
+
finish_reason = content["meta_info"]["finish_reason"]
|
1173
|
+
|
1137
1174
|
if is_first:
|
1138
1175
|
# First chunk with role
|
1139
1176
|
is_first = False
|
@@ -1141,9 +1178,12 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1141
1178
|
index=index,
|
1142
1179
|
delta=DeltaMessage(role="assistant"),
|
1143
1180
|
finish_reason=(
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1181
|
+
finish_reason["type"] if finish_reason else ""
|
1182
|
+
),
|
1183
|
+
matched_stop=(
|
1184
|
+
finish_reason["matched"]
|
1185
|
+
if finish_reason and "matched" in finish_reason
|
1186
|
+
else None
|
1147
1187
|
),
|
1148
1188
|
logprobs=choice_logprobs,
|
1149
1189
|
)
|
@@ -1160,10 +1200,11 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1160
1200
|
choice_data = ChatCompletionResponseStreamChoice(
|
1161
1201
|
index=index,
|
1162
1202
|
delta=DeltaMessage(content=delta),
|
1163
|
-
finish_reason=(
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1203
|
+
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
1204
|
+
matched_stop=(
|
1205
|
+
finish_reason["matched"]
|
1206
|
+
if finish_reason and "matched" in finish_reason
|
1207
|
+
else None
|
1167
1208
|
),
|
1168
1209
|
logprobs=choice_logprobs,
|
1169
1210
|
)
|
@@ -1224,7 +1265,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1224
1265
|
if not isinstance(ret, list):
|
1225
1266
|
ret = [ret]
|
1226
1267
|
|
1227
|
-
response = v1_chat_generate_response(
|
1268
|
+
response = v1_chat_generate_response(
|
1269
|
+
request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report
|
1270
|
+
)
|
1228
1271
|
|
1229
1272
|
return response
|
1230
1273
|
|
@@ -76,6 +76,8 @@ class UsageInfo(BaseModel):
|
|
76
76
|
prompt_tokens: int = 0
|
77
77
|
total_tokens: int = 0
|
78
78
|
completion_tokens: Optional[int] = 0
|
79
|
+
# only used to return cached tokens when --enable-cache-report is set
|
80
|
+
prompt_tokens_details: Optional[Dict[str, int]] = None
|
79
81
|
|
80
82
|
|
81
83
|
class StreamOptions(BaseModel):
|
@@ -170,10 +172,11 @@ class CompletionRequest(BaseModel):
|
|
170
172
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
171
173
|
regex: Optional[str] = None
|
172
174
|
json_schema: Optional[str] = None
|
173
|
-
ignore_eos:
|
174
|
-
min_tokens:
|
175
|
+
ignore_eos: bool = False
|
176
|
+
min_tokens: int = 0
|
175
177
|
repetition_penalty: Optional[float] = 1.0
|
176
178
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
179
|
+
no_stop_trim: Union[bool, List[bool]] = False
|
177
180
|
|
178
181
|
|
179
182
|
class CompletionResponseChoice(BaseModel):
|
@@ -181,6 +184,7 @@ class CompletionResponseChoice(BaseModel):
|
|
181
184
|
text: str
|
182
185
|
logprobs: Optional[LogProbs] = None
|
183
186
|
finish_reason: Optional[str] = None
|
187
|
+
matched_stop: Union[None, int, str] = None
|
184
188
|
|
185
189
|
|
186
190
|
class CompletionResponse(BaseModel):
|
@@ -197,6 +201,7 @@ class CompletionResponseStreamChoice(BaseModel):
|
|
197
201
|
text: str
|
198
202
|
logprobs: Optional[LogProbs] = None
|
199
203
|
finish_reason: Optional[str] = None
|
204
|
+
matched_stop: Union[None, int, str] = None
|
200
205
|
|
201
206
|
|
202
207
|
class CompletionStreamResponse(BaseModel):
|
@@ -275,6 +280,7 @@ class ChatCompletionRequest(BaseModel):
|
|
275
280
|
min_tokens: Optional[int] = 0
|
276
281
|
repetition_penalty: Optional[float] = 1.0
|
277
282
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
283
|
+
ignore_eos: bool = False
|
278
284
|
|
279
285
|
|
280
286
|
class ChatMessage(BaseModel):
|
@@ -287,6 +293,7 @@ class ChatCompletionResponseChoice(BaseModel):
|
|
287
293
|
message: ChatMessage
|
288
294
|
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
289
295
|
finish_reason: str
|
296
|
+
matched_stop: Union[None, int, str] = None
|
290
297
|
|
291
298
|
|
292
299
|
class ChatCompletionResponse(BaseModel):
|
@@ -308,6 +315,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
|
308
315
|
delta: DeltaMessage
|
309
316
|
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
310
317
|
finish_reason: Optional[str] = None
|
318
|
+
matched_stop: Union[None, int, str] = None
|
311
319
|
|
312
320
|
|
313
321
|
class ChatCompletionStreamResponse(BaseModel):
|
@@ -37,12 +37,16 @@ class BatchedPenalizerOrchestrator:
|
|
37
37
|
|
38
38
|
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
39
39
|
|
40
|
+
is_required = False
|
40
41
|
for penalizer in self.penalizers.values():
|
41
|
-
penalizer.prepare_if_required()
|
42
|
+
pen_is_required = penalizer.prepare_if_required()
|
43
|
+
is_required |= pen_is_required
|
44
|
+
self.is_required = is_required
|
42
45
|
|
43
|
-
self.
|
44
|
-
|
45
|
-
|
46
|
+
if self.is_required:
|
47
|
+
self.cumulate_input_tokens(
|
48
|
+
input_ids=[req.origin_input_ids for req in self.reqs()]
|
49
|
+
)
|
46
50
|
|
47
51
|
def reqs(self):
|
48
52
|
return self.batch.reqs
|
@@ -79,6 +83,9 @@ class BatchedPenalizerOrchestrator:
|
|
79
83
|
Args:
|
80
84
|
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
|
81
85
|
"""
|
86
|
+
if not self.is_required:
|
87
|
+
return
|
88
|
+
|
82
89
|
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
|
83
90
|
|
84
91
|
for penalizer in self.penalizers.values():
|
@@ -95,6 +102,9 @@ class BatchedPenalizerOrchestrator:
|
|
95
102
|
Returns:
|
96
103
|
torch.Tensor: The logits after applying the penalizers.
|
97
104
|
"""
|
105
|
+
if not self.is_required:
|
106
|
+
return
|
107
|
+
|
98
108
|
for penalizer in self.penalizers.values():
|
99
109
|
logits = penalizer.apply(logits)
|
100
110
|
|
@@ -112,10 +122,16 @@ class BatchedPenalizerOrchestrator:
|
|
112
122
|
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
|
113
123
|
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
114
124
|
"""
|
125
|
+
if not self.is_required:
|
126
|
+
return
|
127
|
+
|
115
128
|
empty_indices = len(indices_to_keep) == 0
|
116
129
|
|
130
|
+
is_required = False
|
117
131
|
for penalizer in self.penalizers.values():
|
118
|
-
|
132
|
+
tmp_is_required = penalizer.is_required()
|
133
|
+
is_required = is_required or tmp_is_required
|
134
|
+
if not tmp_is_required or empty_indices:
|
119
135
|
penalizer.teardown()
|
120
136
|
else:
|
121
137
|
# create tensor index only when it's needed
|
@@ -128,6 +144,7 @@ class BatchedPenalizerOrchestrator:
|
|
128
144
|
indices_to_keep=indices_to_keep,
|
129
145
|
indices_tensor_to_keep=indices_tensor_to_keep,
|
130
146
|
)
|
147
|
+
self.is_required = is_required
|
131
148
|
|
132
149
|
def merge(self, their: "BatchedPenalizerOrchestrator"):
|
133
150
|
"""
|
@@ -140,11 +157,10 @@ class BatchedPenalizerOrchestrator:
|
|
140
157
|
Args:
|
141
158
|
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
|
142
159
|
"""
|
143
|
-
if self.
|
144
|
-
|
145
|
-
f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}"
|
146
|
-
)
|
160
|
+
if not self.is_required and not their.is_required:
|
161
|
+
return
|
147
162
|
|
163
|
+
self.is_required |= their.is_required
|
148
164
|
for Penalizer, their_penalizer in their.penalizers.items():
|
149
165
|
if Penalizer not in self.penalizers:
|
150
166
|
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
|
@@ -250,6 +266,9 @@ class _BatchedPenalizer(abc.ABC):
|
|
250
266
|
def prepare_if_required(self):
|
251
267
|
if self.is_required():
|
252
268
|
self.prepare()
|
269
|
+
return True
|
270
|
+
else:
|
271
|
+
return False
|
253
272
|
|
254
273
|
def teardown(self):
|
255
274
|
if self.is_prepared():
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
-
from typing import TYPE_CHECKING, List
|
4
|
+
from typing import TYPE_CHECKING, List, Optional
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
@@ -20,6 +20,9 @@ class SamplingBatchInfo:
|
|
20
20
|
top_ks: torch.Tensor
|
21
21
|
min_ps: torch.Tensor
|
22
22
|
|
23
|
+
# All requests use greedy sampling
|
24
|
+
is_all_greedy: bool
|
25
|
+
|
23
26
|
# Dispatch in CUDA graph
|
24
27
|
need_min_p_sampling: bool
|
25
28
|
|
@@ -33,30 +36,39 @@ class SamplingBatchInfo:
|
|
33
36
|
regex_fsm_states: List[int] = None
|
34
37
|
|
35
38
|
# Penalizer
|
36
|
-
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
37
|
-
linear_penalties: torch.Tensor = None
|
38
|
-
scaling_penalties: torch.Tensor = None
|
39
|
+
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
40
|
+
linear_penalties: Optional[torch.Tensor] = None
|
41
|
+
scaling_penalties: Optional[torch.Tensor] = None
|
39
42
|
|
40
43
|
# Device
|
41
44
|
device: str = "cuda"
|
42
45
|
|
43
46
|
@classmethod
|
44
|
-
def from_schedule_batch(
|
47
|
+
def from_schedule_batch(
|
48
|
+
cls,
|
49
|
+
batch: ScheduleBatch,
|
50
|
+
vocab_size: int,
|
51
|
+
disable_penalizer: bool,
|
52
|
+
):
|
45
53
|
reqs = batch.reqs
|
46
|
-
|
47
|
-
|
54
|
+
device = batch.input_ids.device
|
55
|
+
temperatures = (
|
56
|
+
torch.tensor(
|
48
57
|
[r.sampling_params.temperature for r in reqs],
|
49
58
|
dtype=torch.float,
|
50
|
-
).view(-1, 1)
|
51
|
-
top_ps = torch.tensor(
|
52
|
-
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
53
|
-
)
|
54
|
-
top_ks = torch.tensor(
|
55
|
-
[r.sampling_params.top_k for r in reqs], dtype=torch.int
|
56
|
-
)
|
57
|
-
min_ps = torch.tensor(
|
58
|
-
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
59
59
|
)
|
60
|
+
.view(-1, 1)
|
61
|
+
.to(device, non_blocking=True)
|
62
|
+
)
|
63
|
+
top_ps = torch.tensor(
|
64
|
+
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
65
|
+
).to(device, non_blocking=True)
|
66
|
+
top_ks = torch.tensor(
|
67
|
+
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
|
68
|
+
).to(device, non_blocking=True)
|
69
|
+
min_ps = torch.tensor(
|
70
|
+
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
71
|
+
).to(device, non_blocking=True)
|
60
72
|
|
61
73
|
ret = cls(
|
62
74
|
temperatures=temperatures,
|
@@ -64,6 +76,7 @@ class SamplingBatchInfo:
|
|
64
76
|
top_ks=top_ks,
|
65
77
|
min_ps=min_ps,
|
66
78
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
79
|
+
is_all_greedy=top_ks.max().item() <= 1,
|
67
80
|
vocab_size=vocab_size,
|
68
81
|
device=batch.input_ids.device,
|
69
82
|
)
|
@@ -75,18 +88,21 @@ class SamplingBatchInfo:
|
|
75
88
|
#
|
76
89
|
# While we choose not to even create the class instances if they are not required, this
|
77
90
|
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
78
|
-
# handle {filter_batch()} and {
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
91
|
+
# handle {filter_batch()} and {merge_batch()} cases as well.
|
92
|
+
if disable_penalizer:
|
93
|
+
ret.penalizer_orchestrator = None
|
94
|
+
else:
|
95
|
+
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
96
|
+
vocab_size=vocab_size,
|
97
|
+
batch=batch,
|
98
|
+
device=batch.input_ids.device,
|
99
|
+
Penalizers={
|
100
|
+
penaltylib.BatchedFrequencyPenalizer,
|
101
|
+
penaltylib.BatchedMinNewTokensPenalizer,
|
102
|
+
penaltylib.BatchedPresencePenalizer,
|
103
|
+
penaltylib.BatchedRepetitionPenalizer,
|
104
|
+
},
|
105
|
+
)
|
90
106
|
|
91
107
|
# Handle logit bias but only allocate when needed
|
92
108
|
ret.logit_bias = None
|
@@ -97,46 +113,50 @@ class SamplingBatchInfo:
|
|
97
113
|
return len(self.temperatures)
|
98
114
|
|
99
115
|
def update_penalties(self):
|
116
|
+
if not self.penalizer_orchestrator:
|
117
|
+
return
|
118
|
+
|
100
119
|
self.scaling_penalties = None
|
101
120
|
self.linear_penalties = None
|
102
121
|
|
103
122
|
for penalizer in self.penalizer_orchestrator.penalizers.values():
|
123
|
+
if not penalizer.is_prepared():
|
124
|
+
continue
|
125
|
+
|
104
126
|
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
|
105
|
-
|
106
|
-
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
127
|
+
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
107
128
|
else:
|
108
|
-
if
|
109
|
-
|
110
|
-
|
111
|
-
self.
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
129
|
+
if self.linear_penalties is None:
|
130
|
+
bs = self.penalizer_orchestrator.batch.batch_size()
|
131
|
+
self.linear_penalties = torch.zeros(
|
132
|
+
(bs, self.vocab_size),
|
133
|
+
dtype=torch.float32,
|
134
|
+
device=self.device,
|
135
|
+
)
|
136
|
+
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
117
137
|
|
118
138
|
def update_regex_vocab_mask(self):
|
119
139
|
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
self.
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
self.
|
135
|
-
|
136
|
-
] = 0
|
140
|
+
if not has_regex:
|
141
|
+
self.vocab_mask = None
|
142
|
+
return
|
143
|
+
|
144
|
+
self.vocab_mask = torch.zeros(
|
145
|
+
len(self.temperatures),
|
146
|
+
self.vocab_size,
|
147
|
+
dtype=torch.bool,
|
148
|
+
device=self.device,
|
149
|
+
)
|
150
|
+
for i, regex_fsm in enumerate(self.regex_fsms):
|
151
|
+
if regex_fsm is not None:
|
152
|
+
self.vocab_mask[i].fill_(1)
|
153
|
+
self.vocab_mask[i][
|
154
|
+
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
155
|
+
] = 0
|
137
156
|
|
138
157
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
139
|
-
self.penalizer_orchestrator
|
158
|
+
if self.penalizer_orchestrator:
|
159
|
+
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
140
160
|
|
141
161
|
for item in [
|
142
162
|
"temperatures",
|
@@ -175,7 +195,8 @@ class SamplingBatchInfo:
|
|
175
195
|
return None
|
176
196
|
|
177
197
|
def merge_batch(self, other: "SamplingBatchInfo"):
|
178
|
-
self.penalizer_orchestrator
|
198
|
+
if self.penalizer_orchestrator:
|
199
|
+
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
179
200
|
|
180
201
|
for item in [
|
181
202
|
"temperatures",
|
@@ -187,6 +208,19 @@ class SamplingBatchInfo:
|
|
187
208
|
other_val = getattr(other, item, None)
|
188
209
|
setattr(self, item, torch.concat([self_val, other_val]))
|
189
210
|
|
211
|
+
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
190
212
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
191
213
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
192
214
|
)
|
215
|
+
|
216
|
+
def copy(self):
|
217
|
+
return SamplingBatchInfo(
|
218
|
+
temperatures=self.temperatures,
|
219
|
+
top_ps=self.top_ps,
|
220
|
+
top_ks=self.top_ks,
|
221
|
+
min_ps=self.min_ps,
|
222
|
+
is_all_greedy=self.is_all_greedy,
|
223
|
+
need_min_p_sampling=self.need_min_p_sampling,
|
224
|
+
vocab_size=self.vocab_size,
|
225
|
+
device=self.device,
|
226
|
+
)
|
@@ -40,6 +40,7 @@ class SamplingParams:
|
|
40
40
|
regex: Optional[str] = None,
|
41
41
|
n: int = 1,
|
42
42
|
json_schema: Optional[str] = None,
|
43
|
+
no_stop_trim: bool = False,
|
43
44
|
) -> None:
|
44
45
|
self.temperature = temperature
|
45
46
|
self.top_p = top_p
|
@@ -60,6 +61,7 @@ class SamplingParams:
|
|
60
61
|
self.regex = regex
|
61
62
|
self.n = n
|
62
63
|
self.json_schema = json_schema
|
64
|
+
self.no_stop_trim = no_stop_trim
|
63
65
|
|
64
66
|
# Process some special cases
|
65
67
|
if self.temperature < _SAMPLING_EPS:
|