sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +7 -1
- sglang/bench_latency.py +9 -6
- sglang/bench_serving.py +46 -22
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +4 -2
- sglang/lang/ir.py +16 -7
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/layers/activation.py +32 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +9 -2
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +7 -2
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +40 -16
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +115 -97
- sglang/srt/managers/tokenizer_manager.py +194 -112
- sglang/srt/managers/tp_worker.py +290 -359
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +71 -25
- sglang/srt/model_executor/forward_batch_info.py +293 -156
- sglang/srt/model_executor/model_runner.py +77 -57
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/deepseek.py +2 -2
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +11 -6
- sglang/srt/models/grok.py +50 -396
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/mixtral.py +56 -254
- sglang/srt/models/mixtral_quant.py +1 -4
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +2 -13
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +187 -48
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -8
- sglang/srt/server.py +91 -29
- sglang/srt/server_args.py +32 -19
- sglang/srt/utils.py +32 -15
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +81 -73
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +36 -7
- sglang/test/test_utils.py +24 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
- sglang-0.2.13.dist-info/RECORD +112 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -28,9 +28,7 @@ from vllm.distributed import (
|
|
28
28
|
get_tensor_model_parallel_world_size,
|
29
29
|
tensor_model_parallel_all_reduce,
|
30
30
|
)
|
31
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
32
31
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
33
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
34
32
|
from vllm.model_executor.layers.linear import (
|
35
33
|
MergedColumnParallelLinear,
|
36
34
|
QKVParallelLinear,
|
@@ -46,9 +44,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
46
44
|
VocabParallelEmbedding,
|
47
45
|
)
|
48
46
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
49
|
-
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
50
|
-
from vllm.sequence import IntermediateTensors, SamplerOutput
|
51
47
|
|
48
|
+
from sglang.srt.layers.activation import SiluAndMul
|
49
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
52
50
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
53
51
|
from sglang.srt.layers.radix_attention import RadixAttention
|
54
52
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
@@ -368,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
368
366
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
369
367
|
)
|
370
368
|
self.logits_processor = LogitsProcessor(config)
|
371
|
-
self.sampler = Sampler()
|
372
369
|
|
373
370
|
@torch.no_grad()
|
374
371
|
def forward(
|
@@ -394,14 +391,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
394
391
|
)
|
395
392
|
return logits
|
396
393
|
|
397
|
-
def sample(
|
398
|
-
self,
|
399
|
-
logits: Optional[torch.Tensor],
|
400
|
-
sampling_metadata: SamplingMetadata,
|
401
|
-
) -> Optional[SamplerOutput]:
|
402
|
-
next_tokens = self.sampler(logits, sampling_metadata)
|
403
|
-
return next_tokens
|
404
|
-
|
405
394
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
406
395
|
stacked_params_mapping = [
|
407
396
|
# (param_name, shard_name, shard_id)
|
sglang/srt/models/stablelm.py
CHANGED
@@ -24,7 +24,6 @@ from torch import nn
|
|
24
24
|
from transformers import PretrainedConfig
|
25
25
|
from vllm.config import CacheConfig
|
26
26
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
27
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
28
27
|
from vllm.model_executor.layers.linear import (
|
29
28
|
MergedColumnParallelLinear,
|
30
29
|
QKVParallelLinear,
|
@@ -38,6 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
38
37
|
)
|
39
38
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
39
|
|
40
|
+
from sglang.srt.layers.activation import SiluAndMul
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
43
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -34,7 +34,7 @@ from sglang.srt.conversation import (
|
|
34
34
|
generate_chat_conv,
|
35
35
|
register_conv_template,
|
36
36
|
)
|
37
|
-
from sglang.srt.managers.io_struct import GenerateReqInput
|
37
|
+
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
38
38
|
from sglang.srt.openai_api.protocol import (
|
39
39
|
BatchRequest,
|
40
40
|
BatchResponse,
|
@@ -52,6 +52,9 @@ from sglang.srt.openai_api.protocol import (
|
|
52
52
|
CompletionResponseStreamChoice,
|
53
53
|
CompletionStreamResponse,
|
54
54
|
DeltaMessage,
|
55
|
+
EmbeddingObject,
|
56
|
+
EmbeddingRequest,
|
57
|
+
EmbeddingResponse,
|
55
58
|
ErrorResponse,
|
56
59
|
FileDeleteResponse,
|
57
60
|
FileRequest,
|
@@ -74,7 +77,7 @@ class FileMetadata:
|
|
74
77
|
batch_storage: Dict[str, BatchResponse] = {}
|
75
78
|
file_id_request: Dict[str, FileMetadata] = {}
|
76
79
|
file_id_response: Dict[str, FileResponse] = {}
|
77
|
-
# map file id to file path in
|
80
|
+
# map file id to file path in SGLang backend
|
78
81
|
file_id_storage: Dict[str, str] = {}
|
79
82
|
|
80
83
|
|
@@ -82,6 +85,19 @@ file_id_storage: Dict[str, str] = {}
|
|
82
85
|
storage_dir = None
|
83
86
|
|
84
87
|
|
88
|
+
def format_finish_reason(finish_reason) -> Optional[str]:
|
89
|
+
if finish_reason.startswith("None"):
|
90
|
+
return None
|
91
|
+
elif finish_reason.startswith("FINISH_MATCHED"):
|
92
|
+
return "stop"
|
93
|
+
elif finish_reason.startswith("FINISH_LENGTH"):
|
94
|
+
return "length"
|
95
|
+
elif finish_reason.startswith("FINISH_ABORT"):
|
96
|
+
return "abort"
|
97
|
+
else:
|
98
|
+
return "unknown"
|
99
|
+
|
100
|
+
|
85
101
|
def create_error_response(
|
86
102
|
message: str,
|
87
103
|
err_type: str = "BadRequestError",
|
@@ -101,7 +117,7 @@ def create_streaming_error_response(
|
|
101
117
|
return json_str
|
102
118
|
|
103
119
|
|
104
|
-
def load_chat_template_for_openai_api(chat_template_arg):
|
120
|
+
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
|
105
121
|
global chat_template_name
|
106
122
|
|
107
123
|
print(f"Use chat template: {chat_template_arg}")
|
@@ -111,27 +127,38 @@ def load_chat_template_for_openai_api(chat_template_arg):
|
|
111
127
|
f"Chat template {chat_template_arg} is not a built-in template name "
|
112
128
|
"or a valid chat template file path."
|
113
129
|
)
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
raise ValueError(
|
120
|
-
f"Unknown separator style: {template['sep_style']}"
|
121
|
-
) from None
|
122
|
-
register_conv_template(
|
123
|
-
Conversation(
|
124
|
-
name=template["name"],
|
125
|
-
system_template=template["system"] + "\n{system_message}",
|
126
|
-
system_message=template.get("system_message", ""),
|
127
|
-
roles=(template["user"], template["assistant"]),
|
128
|
-
sep_style=sep_style,
|
129
|
-
sep=template.get("sep", "\n"),
|
130
|
-
stop_str=template["stop_str"],
|
131
|
-
),
|
132
|
-
override=True,
|
130
|
+
if chat_template_arg.endswith(".jinja"):
|
131
|
+
with open(chat_template_arg, "r") as f:
|
132
|
+
chat_template = "".join(f.readlines()).strip("\n")
|
133
|
+
tokenizer_manager.tokenizer.chat_template = chat_template.replace(
|
134
|
+
"\\n", "\n"
|
133
135
|
)
|
134
|
-
|
136
|
+
chat_template_name = None
|
137
|
+
else:
|
138
|
+
assert chat_template_arg.endswith(
|
139
|
+
".json"
|
140
|
+
), "unrecognized format of chat template file"
|
141
|
+
with open(chat_template_arg, "r") as filep:
|
142
|
+
template = json.load(filep)
|
143
|
+
try:
|
144
|
+
sep_style = SeparatorStyle[template["sep_style"]]
|
145
|
+
except KeyError:
|
146
|
+
raise ValueError(
|
147
|
+
f"Unknown separator style: {template['sep_style']}"
|
148
|
+
) from None
|
149
|
+
register_conv_template(
|
150
|
+
Conversation(
|
151
|
+
name=template["name"],
|
152
|
+
system_template=template["system"] + "\n{system_message}",
|
153
|
+
system_message=template.get("system_message", ""),
|
154
|
+
roles=(template["user"], template["assistant"]),
|
155
|
+
sep_style=sep_style,
|
156
|
+
sep=template.get("sep", "\n"),
|
157
|
+
stop_str=template["stop_str"],
|
158
|
+
),
|
159
|
+
override=True,
|
160
|
+
)
|
161
|
+
chat_template_name = template["name"]
|
135
162
|
else:
|
136
163
|
chat_template_name = chat_template_arg
|
137
164
|
|
@@ -319,7 +346,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
319
346
|
}
|
320
347
|
|
321
348
|
except Exception as e:
|
322
|
-
print("error in
|
349
|
+
print("error in SGLang:", e)
|
323
350
|
# Update batch status to "failed"
|
324
351
|
retrieve_batch = batch_storage[batch_id]
|
325
352
|
retrieve_batch.status = "failed"
|
@@ -357,7 +384,6 @@ async def v1_retrieve_file_content(file_id: str):
|
|
357
384
|
|
358
385
|
|
359
386
|
def v1_generate_request(all_requests):
|
360
|
-
|
361
387
|
prompts = []
|
362
388
|
sampling_params_list = []
|
363
389
|
return_logprobs = []
|
@@ -378,10 +404,13 @@ def v1_generate_request(all_requests):
|
|
378
404
|
{
|
379
405
|
"temperature": request.temperature,
|
380
406
|
"max_new_tokens": request.max_tokens,
|
407
|
+
"min_new_tokens": request.min_tokens,
|
381
408
|
"stop": request.stop,
|
409
|
+
"stop_token_ids": request.stop_token_ids,
|
382
410
|
"top_p": request.top_p,
|
383
411
|
"presence_penalty": request.presence_penalty,
|
384
412
|
"frequency_penalty": request.frequency_penalty,
|
413
|
+
"repetition_penalty": request.repetition_penalty,
|
385
414
|
"regex": request.regex,
|
386
415
|
"n": request.n,
|
387
416
|
"ignore_eos": request.ignore_eos,
|
@@ -485,14 +514,18 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|
485
514
|
"index": 0,
|
486
515
|
"text": text,
|
487
516
|
"logprobs": logprobs,
|
488
|
-
"finish_reason":
|
517
|
+
"finish_reason": format_finish_reason(
|
518
|
+
ret_item["meta_info"]["finish_reason"]
|
519
|
+
),
|
489
520
|
}
|
490
521
|
else:
|
491
522
|
choice_data = CompletionResponseChoice(
|
492
523
|
index=idx,
|
493
524
|
text=text,
|
494
525
|
logprobs=logprobs,
|
495
|
-
finish_reason=
|
526
|
+
finish_reason=format_finish_reason(
|
527
|
+
ret_item["meta_info"]["finish_reason"]
|
528
|
+
),
|
496
529
|
)
|
497
530
|
|
498
531
|
choices.append(choice_data)
|
@@ -607,20 +640,34 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
607
640
|
index=0,
|
608
641
|
text=delta,
|
609
642
|
logprobs=logprobs,
|
610
|
-
finish_reason=
|
643
|
+
finish_reason=format_finish_reason(
|
644
|
+
content["meta_info"]["finish_reason"]
|
645
|
+
),
|
611
646
|
)
|
612
647
|
chunk = CompletionStreamResponse(
|
613
648
|
id=content["meta_info"]["id"],
|
614
649
|
object="text_completion",
|
615
650
|
choices=[choice_data],
|
616
651
|
model=request.model,
|
617
|
-
usage=UsageInfo(
|
618
|
-
prompt_tokens=prompt_tokens,
|
619
|
-
completion_tokens=completion_tokens,
|
620
|
-
total_tokens=prompt_tokens + completion_tokens,
|
621
|
-
),
|
622
652
|
)
|
623
653
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
654
|
+
if request.stream_options and request.stream_options.include_usage:
|
655
|
+
usage = UsageInfo(
|
656
|
+
prompt_tokens=prompt_tokens,
|
657
|
+
completion_tokens=completion_tokens,
|
658
|
+
total_tokens=prompt_tokens + completion_tokens,
|
659
|
+
)
|
660
|
+
|
661
|
+
final_usage_chunk = CompletionStreamResponse(
|
662
|
+
id=str(uuid.uuid4().hex),
|
663
|
+
choices=[],
|
664
|
+
model=request.model,
|
665
|
+
usage=usage,
|
666
|
+
)
|
667
|
+
final_usage_data = final_usage_chunk.model_dump_json(
|
668
|
+
exclude_unset=True, exclude_none=True
|
669
|
+
)
|
670
|
+
yield f"data: {final_usage_data}\n\n"
|
624
671
|
except ValueError as e:
|
625
672
|
error = create_streaming_error_response(str(e))
|
626
673
|
yield f"data: {error}\n\n"
|
@@ -648,7 +695,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
648
695
|
|
649
696
|
|
650
697
|
def v1_chat_generate_request(all_requests, tokenizer_manager):
|
651
|
-
|
652
698
|
input_ids = []
|
653
699
|
sampling_params_list = []
|
654
700
|
image_data_list = []
|
@@ -691,10 +737,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
691
737
|
{
|
692
738
|
"temperature": request.temperature,
|
693
739
|
"max_new_tokens": request.max_tokens,
|
740
|
+
"min_new_tokens": request.min_tokens,
|
694
741
|
"stop": stop,
|
742
|
+
"stop_token_ids": request.stop_token_ids,
|
695
743
|
"top_p": request.top_p,
|
696
744
|
"presence_penalty": request.presence_penalty,
|
697
745
|
"frequency_penalty": request.frequency_penalty,
|
746
|
+
"repetition_penalty": request.repetition_penalty,
|
698
747
|
"regex": request.regex,
|
699
748
|
"n": request.n,
|
700
749
|
}
|
@@ -776,14 +825,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
776
825
|
"index": 0,
|
777
826
|
"message": {"role": "assistant", "content": ret_item["text"]},
|
778
827
|
"logprobs": choice_logprobs,
|
779
|
-
"finish_reason":
|
828
|
+
"finish_reason": format_finish_reason(
|
829
|
+
ret_item["meta_info"]["finish_reason"]
|
830
|
+
),
|
780
831
|
}
|
781
832
|
else:
|
782
833
|
choice_data = ChatCompletionResponseChoice(
|
783
834
|
index=idx,
|
784
835
|
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
785
836
|
logprobs=choice_logprobs,
|
786
|
-
finish_reason=
|
837
|
+
finish_reason=format_finish_reason(
|
838
|
+
ret_item["meta_info"]["finish_reason"]
|
839
|
+
),
|
787
840
|
)
|
788
841
|
|
789
842
|
choices.append(choice_data)
|
@@ -900,18 +953,15 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
900
953
|
choice_data = ChatCompletionResponseStreamChoice(
|
901
954
|
index=0,
|
902
955
|
delta=DeltaMessage(role="assistant"),
|
903
|
-
finish_reason=
|
956
|
+
finish_reason=format_finish_reason(
|
957
|
+
content["meta_info"]["finish_reason"]
|
958
|
+
),
|
904
959
|
logprobs=choice_logprobs,
|
905
960
|
)
|
906
961
|
chunk = ChatCompletionStreamResponse(
|
907
962
|
id=content["meta_info"]["id"],
|
908
963
|
choices=[choice_data],
|
909
964
|
model=request.model,
|
910
|
-
usage=UsageInfo(
|
911
|
-
prompt_tokens=prompt_tokens,
|
912
|
-
completion_tokens=completion_tokens,
|
913
|
-
total_tokens=prompt_tokens + completion_tokens,
|
914
|
-
),
|
915
965
|
)
|
916
966
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
917
967
|
|
@@ -921,20 +971,34 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
921
971
|
choice_data = ChatCompletionResponseStreamChoice(
|
922
972
|
index=0,
|
923
973
|
delta=DeltaMessage(content=delta),
|
924
|
-
finish_reason=
|
974
|
+
finish_reason=format_finish_reason(
|
975
|
+
content["meta_info"]["finish_reason"]
|
976
|
+
),
|
925
977
|
logprobs=choice_logprobs,
|
926
978
|
)
|
927
979
|
chunk = ChatCompletionStreamResponse(
|
928
980
|
id=content["meta_info"]["id"],
|
929
981
|
choices=[choice_data],
|
930
982
|
model=request.model,
|
931
|
-
usage=UsageInfo(
|
932
|
-
prompt_tokens=prompt_tokens,
|
933
|
-
completion_tokens=completion_tokens,
|
934
|
-
total_tokens=prompt_tokens + completion_tokens,
|
935
|
-
),
|
936
983
|
)
|
937
984
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
985
|
+
if request.stream_options and request.stream_options.include_usage:
|
986
|
+
usage = UsageInfo(
|
987
|
+
prompt_tokens=prompt_tokens,
|
988
|
+
completion_tokens=completion_tokens,
|
989
|
+
total_tokens=prompt_tokens + completion_tokens,
|
990
|
+
)
|
991
|
+
|
992
|
+
final_usage_chunk = ChatCompletionStreamResponse(
|
993
|
+
id=str(uuid.uuid4().hex),
|
994
|
+
choices=[],
|
995
|
+
model=request.model,
|
996
|
+
usage=usage,
|
997
|
+
)
|
998
|
+
final_usage_data = final_usage_chunk.model_dump_json(
|
999
|
+
exclude_unset=True, exclude_none=True
|
1000
|
+
)
|
1001
|
+
yield f"data: {final_usage_data}\n\n"
|
938
1002
|
except ValueError as e:
|
939
1003
|
error = create_streaming_error_response(str(e))
|
940
1004
|
yield f"data: {error}\n\n"
|
@@ -961,6 +1025,81 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
961
1025
|
return response
|
962
1026
|
|
963
1027
|
|
1028
|
+
def v1_embedding_request(all_requests, tokenizer_manager):
|
1029
|
+
prompts = []
|
1030
|
+
sampling_params_list = []
|
1031
|
+
first_prompt_type = type(all_requests[0].input)
|
1032
|
+
|
1033
|
+
for request in all_requests:
|
1034
|
+
prompt = request.input
|
1035
|
+
assert (
|
1036
|
+
type(prompt) == first_prompt_type
|
1037
|
+
), "All prompts must be of the same type in file input settings"
|
1038
|
+
prompts.append(prompt)
|
1039
|
+
|
1040
|
+
if len(all_requests) == 1:
|
1041
|
+
prompt = prompts[0]
|
1042
|
+
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
1043
|
+
prompt_kwargs = {"text": prompt}
|
1044
|
+
else:
|
1045
|
+
prompt_kwargs = {"input_ids": prompt}
|
1046
|
+
else:
|
1047
|
+
if isinstance(prompts[0], str) or isinstance(propmt[0][0], str):
|
1048
|
+
prompt_kwargs = {"text": prompts}
|
1049
|
+
else:
|
1050
|
+
prompt_kwargs = {"input_ids": prompts}
|
1051
|
+
|
1052
|
+
adapted_request = EmbeddingReqInput(
|
1053
|
+
**prompt_kwargs,
|
1054
|
+
)
|
1055
|
+
|
1056
|
+
if len(all_requests) == 1:
|
1057
|
+
return adapted_request, all_requests[0]
|
1058
|
+
return adapted_request, all_requests
|
1059
|
+
|
1060
|
+
|
1061
|
+
def v1_embedding_response(ret, model_path, to_file=False):
|
1062
|
+
embedding_objects = []
|
1063
|
+
prompt_tokens = 0
|
1064
|
+
for idx, ret_item in enumerate(ret):
|
1065
|
+
embedding_objects.append(
|
1066
|
+
EmbeddingObject(
|
1067
|
+
embedding=ret[idx]["embedding"],
|
1068
|
+
index=idx,
|
1069
|
+
)
|
1070
|
+
)
|
1071
|
+
prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"]
|
1072
|
+
|
1073
|
+
return EmbeddingResponse(
|
1074
|
+
data=embedding_objects,
|
1075
|
+
model=model_path,
|
1076
|
+
usage=UsageInfo(
|
1077
|
+
prompt_tokens=prompt_tokens,
|
1078
|
+
total_tokens=prompt_tokens,
|
1079
|
+
),
|
1080
|
+
)
|
1081
|
+
|
1082
|
+
|
1083
|
+
async def v1_embeddings(tokenizer_manager, raw_request: Request):
|
1084
|
+
request_json = await raw_request.json()
|
1085
|
+
all_requests = [EmbeddingRequest(**request_json)]
|
1086
|
+
adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager)
|
1087
|
+
|
1088
|
+
try:
|
1089
|
+
ret = await tokenizer_manager.generate_request(
|
1090
|
+
adapted_request, raw_request
|
1091
|
+
).__anext__()
|
1092
|
+
except ValueError as e:
|
1093
|
+
return create_error_response(str(e))
|
1094
|
+
|
1095
|
+
if not isinstance(ret, list):
|
1096
|
+
ret = [ret]
|
1097
|
+
|
1098
|
+
response = v1_embedding_response(ret, tokenizer_manager.model_path)
|
1099
|
+
|
1100
|
+
return response
|
1101
|
+
|
1102
|
+
|
964
1103
|
def to_openai_style_logprobs(
|
965
1104
|
input_token_logprobs=None,
|
966
1105
|
output_token_logprobs=None,
|
@@ -78,6 +78,10 @@ class UsageInfo(BaseModel):
|
|
78
78
|
completion_tokens: Optional[int] = 0
|
79
79
|
|
80
80
|
|
81
|
+
class StreamOptions(BaseModel):
|
82
|
+
include_usage: Optional[bool] = False
|
83
|
+
|
84
|
+
|
81
85
|
class FileRequest(BaseModel):
|
82
86
|
# https://platform.openai.com/docs/api-reference/files/create
|
83
87
|
file: bytes # The File object (not file name) to be uploaded
|
@@ -149,6 +153,7 @@ class CompletionRequest(BaseModel):
|
|
149
153
|
seed: Optional[int] = None
|
150
154
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
151
155
|
stream: Optional[bool] = False
|
156
|
+
stream_options: Optional[StreamOptions] = None
|
152
157
|
suffix: Optional[str] = None
|
153
158
|
temperature: Optional[float] = 1.0
|
154
159
|
top_p: Optional[float] = 1.0
|
@@ -157,6 +162,9 @@ class CompletionRequest(BaseModel):
|
|
157
162
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
158
163
|
regex: Optional[str] = None
|
159
164
|
ignore_eos: Optional[bool] = False
|
165
|
+
min_tokens: Optional[int] = 0
|
166
|
+
repetition_penalty: Optional[float] = 1.0
|
167
|
+
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
160
168
|
|
161
169
|
|
162
170
|
class CompletionResponseChoice(BaseModel):
|
@@ -188,7 +196,7 @@ class CompletionStreamResponse(BaseModel):
|
|
188
196
|
created: int = Field(default_factory=lambda: int(time.time()))
|
189
197
|
model: str
|
190
198
|
choices: List[CompletionResponseStreamChoice]
|
191
|
-
usage: UsageInfo
|
199
|
+
usage: Optional[UsageInfo] = None
|
192
200
|
|
193
201
|
|
194
202
|
class ChatCompletionMessageGenericParam(BaseModel):
|
@@ -247,12 +255,16 @@ class ChatCompletionRequest(BaseModel):
|
|
247
255
|
seed: Optional[int] = None
|
248
256
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
249
257
|
stream: Optional[bool] = False
|
258
|
+
stream_options: Optional[StreamOptions] = None
|
250
259
|
temperature: Optional[float] = 0.7
|
251
260
|
top_p: Optional[float] = 1.0
|
252
261
|
user: Optional[str] = None
|
253
262
|
|
254
263
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
255
264
|
regex: Optional[str] = None
|
265
|
+
min_tokens: Optional[int] = 0
|
266
|
+
repetition_penalty: Optional[float] = 1.0
|
267
|
+
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
256
268
|
|
257
269
|
|
258
270
|
class ChatMessage(BaseModel):
|
@@ -294,3 +306,27 @@ class ChatCompletionStreamResponse(BaseModel):
|
|
294
306
|
created: int = Field(default_factory=lambda: int(time.time()))
|
295
307
|
model: str
|
296
308
|
choices: List[ChatCompletionResponseStreamChoice]
|
309
|
+
usage: Optional[UsageInfo] = None
|
310
|
+
|
311
|
+
|
312
|
+
class EmbeddingRequest(BaseModel):
|
313
|
+
# Ordered by official OpenAI API documentation
|
314
|
+
# https://platform.openai.com/docs/api-reference/embeddings/create
|
315
|
+
input: Union[List[int], List[List[int]], str, List[str]]
|
316
|
+
model: str
|
317
|
+
encoding_format: str = "float"
|
318
|
+
dimensions: int = None
|
319
|
+
user: Optional[str] = None
|
320
|
+
|
321
|
+
|
322
|
+
class EmbeddingObject(BaseModel):
|
323
|
+
embedding: List[float]
|
324
|
+
index: int
|
325
|
+
object: str = "embedding"
|
326
|
+
|
327
|
+
|
328
|
+
class EmbeddingResponse(BaseModel):
|
329
|
+
data: List[EmbeddingObject]
|
330
|
+
model: str
|
331
|
+
object: str = "list"
|
332
|
+
usage: Optional[UsageInfo] = None
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from .orchestrator import BatchedPenalizerOrchestrator
|
2
|
+
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer
|
3
|
+
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer
|
4
|
+
from .penalizers.presence_penalty import BatchedPresencePenalizer
|
5
|
+
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
|
6
|
+
|
7
|
+
__all__ = [
|
8
|
+
"BatchedFrequencyPenalizer",
|
9
|
+
"BatchedMinNewTokensPenalizer",
|
10
|
+
"BatchedPresencePenalizer",
|
11
|
+
"BatchedRepetitionPenalizer",
|
12
|
+
"BatchedPenalizerOrchestrator",
|
13
|
+
]
|