sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +6 -4
- sglang/bench_serving.py +46 -22
- sglang/lang/compiler.py +2 -2
- sglang/lang/ir.py +3 -3
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +5 -0
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +110 -87
- sglang/srt/managers/tokenizer_manager.py +193 -111
- sglang/srt/managers/tp_worker.py +289 -352
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +168 -105
- sglang/srt/model_executor/model_runner.py +24 -37
- sglang/srt/models/gemma2.py +0 -1
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/qwen2_moe.py +0 -11
- sglang/srt/openai_api/adapter.py +155 -27
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +69 -15
- sglang/srt/server_args.py +26 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +20 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -46,8 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
46
46
|
VocabParallelEmbedding,
|
47
47
|
)
|
48
48
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
49
|
-
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
50
|
-
from vllm.sequence import IntermediateTensors, SamplerOutput
|
51
49
|
|
52
50
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
53
51
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -368,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
368
366
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
369
367
|
)
|
370
368
|
self.logits_processor = LogitsProcessor(config)
|
371
|
-
self.sampler = Sampler()
|
372
369
|
|
373
370
|
@torch.no_grad()
|
374
371
|
def forward(
|
@@ -394,14 +391,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
394
391
|
)
|
395
392
|
return logits
|
396
393
|
|
397
|
-
def sample(
|
398
|
-
self,
|
399
|
-
logits: Optional[torch.Tensor],
|
400
|
-
sampling_metadata: SamplingMetadata,
|
401
|
-
) -> Optional[SamplerOutput]:
|
402
|
-
next_tokens = self.sampler(logits, sampling_metadata)
|
403
|
-
return next_tokens
|
404
|
-
|
405
394
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
406
395
|
stacked_params_mapping = [
|
407
396
|
# (param_name, shard_name, shard_id)
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -34,7 +34,7 @@ from sglang.srt.conversation import (
|
|
34
34
|
generate_chat_conv,
|
35
35
|
register_conv_template,
|
36
36
|
)
|
37
|
-
from sglang.srt.managers.io_struct import GenerateReqInput
|
37
|
+
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
38
38
|
from sglang.srt.openai_api.protocol import (
|
39
39
|
BatchRequest,
|
40
40
|
BatchResponse,
|
@@ -52,6 +52,9 @@ from sglang.srt.openai_api.protocol import (
|
|
52
52
|
CompletionResponseStreamChoice,
|
53
53
|
CompletionStreamResponse,
|
54
54
|
DeltaMessage,
|
55
|
+
EmbeddingObject,
|
56
|
+
EmbeddingRequest,
|
57
|
+
EmbeddingResponse,
|
55
58
|
ErrorResponse,
|
56
59
|
FileDeleteResponse,
|
57
60
|
FileRequest,
|
@@ -74,7 +77,7 @@ class FileMetadata:
|
|
74
77
|
batch_storage: Dict[str, BatchResponse] = {}
|
75
78
|
file_id_request: Dict[str, FileMetadata] = {}
|
76
79
|
file_id_response: Dict[str, FileResponse] = {}
|
77
|
-
# map file id to file path in
|
80
|
+
# map file id to file path in SGLang backend
|
78
81
|
file_id_storage: Dict[str, str] = {}
|
79
82
|
|
80
83
|
|
@@ -82,6 +85,19 @@ file_id_storage: Dict[str, str] = {}
|
|
82
85
|
storage_dir = None
|
83
86
|
|
84
87
|
|
88
|
+
def format_finish_reason(finish_reason) -> Optional[str]:
|
89
|
+
if finish_reason.startswith("None"):
|
90
|
+
return None
|
91
|
+
elif finish_reason.startswith("FINISH_MATCHED"):
|
92
|
+
return "stop"
|
93
|
+
elif finish_reason.startswith("FINISH_LENGTH"):
|
94
|
+
return "length"
|
95
|
+
elif finish_reason.startswith("FINISH_ABORT"):
|
96
|
+
return "abort"
|
97
|
+
else:
|
98
|
+
return "unknown"
|
99
|
+
|
100
|
+
|
85
101
|
def create_error_response(
|
86
102
|
message: str,
|
87
103
|
err_type: str = "BadRequestError",
|
@@ -319,7 +335,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
319
335
|
}
|
320
336
|
|
321
337
|
except Exception as e:
|
322
|
-
print("error in
|
338
|
+
print("error in SGLang:", e)
|
323
339
|
# Update batch status to "failed"
|
324
340
|
retrieve_batch = batch_storage[batch_id]
|
325
341
|
retrieve_batch.status = "failed"
|
@@ -357,7 +373,6 @@ async def v1_retrieve_file_content(file_id: str):
|
|
357
373
|
|
358
374
|
|
359
375
|
def v1_generate_request(all_requests):
|
360
|
-
|
361
376
|
prompts = []
|
362
377
|
sampling_params_list = []
|
363
378
|
return_logprobs = []
|
@@ -378,10 +393,13 @@ def v1_generate_request(all_requests):
|
|
378
393
|
{
|
379
394
|
"temperature": request.temperature,
|
380
395
|
"max_new_tokens": request.max_tokens,
|
396
|
+
"min_new_tokens": request.min_tokens,
|
381
397
|
"stop": request.stop,
|
398
|
+
"stop_token_ids": request.stop_token_ids,
|
382
399
|
"top_p": request.top_p,
|
383
400
|
"presence_penalty": request.presence_penalty,
|
384
401
|
"frequency_penalty": request.frequency_penalty,
|
402
|
+
"repetition_penalty": request.repetition_penalty,
|
385
403
|
"regex": request.regex,
|
386
404
|
"n": request.n,
|
387
405
|
"ignore_eos": request.ignore_eos,
|
@@ -485,14 +503,18 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|
485
503
|
"index": 0,
|
486
504
|
"text": text,
|
487
505
|
"logprobs": logprobs,
|
488
|
-
"finish_reason":
|
506
|
+
"finish_reason": format_finish_reason(
|
507
|
+
ret_item["meta_info"]["finish_reason"]
|
508
|
+
),
|
489
509
|
}
|
490
510
|
else:
|
491
511
|
choice_data = CompletionResponseChoice(
|
492
512
|
index=idx,
|
493
513
|
text=text,
|
494
514
|
logprobs=logprobs,
|
495
|
-
finish_reason=
|
515
|
+
finish_reason=format_finish_reason(
|
516
|
+
ret_item["meta_info"]["finish_reason"]
|
517
|
+
),
|
496
518
|
)
|
497
519
|
|
498
520
|
choices.append(choice_data)
|
@@ -607,20 +629,34 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
607
629
|
index=0,
|
608
630
|
text=delta,
|
609
631
|
logprobs=logprobs,
|
610
|
-
finish_reason=
|
632
|
+
finish_reason=format_finish_reason(
|
633
|
+
content["meta_info"]["finish_reason"]
|
634
|
+
),
|
611
635
|
)
|
612
636
|
chunk = CompletionStreamResponse(
|
613
637
|
id=content["meta_info"]["id"],
|
614
638
|
object="text_completion",
|
615
639
|
choices=[choice_data],
|
616
640
|
model=request.model,
|
617
|
-
usage=UsageInfo(
|
618
|
-
prompt_tokens=prompt_tokens,
|
619
|
-
completion_tokens=completion_tokens,
|
620
|
-
total_tokens=prompt_tokens + completion_tokens,
|
621
|
-
),
|
622
641
|
)
|
623
642
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
643
|
+
if request.stream_options and request.stream_options.include_usage:
|
644
|
+
usage = UsageInfo(
|
645
|
+
prompt_tokens=prompt_tokens,
|
646
|
+
completion_tokens=completion_tokens,
|
647
|
+
total_tokens=prompt_tokens + completion_tokens,
|
648
|
+
)
|
649
|
+
|
650
|
+
final_usage_chunk = CompletionStreamResponse(
|
651
|
+
id=str(uuid.uuid4().hex),
|
652
|
+
choices=[],
|
653
|
+
model=request.model,
|
654
|
+
usage=usage,
|
655
|
+
)
|
656
|
+
final_usage_data = final_usage_chunk.model_dump_json(
|
657
|
+
exclude_unset=True, exclude_none=True
|
658
|
+
)
|
659
|
+
yield f"data: {final_usage_data}\n\n"
|
624
660
|
except ValueError as e:
|
625
661
|
error = create_streaming_error_response(str(e))
|
626
662
|
yield f"data: {error}\n\n"
|
@@ -648,7 +684,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
648
684
|
|
649
685
|
|
650
686
|
def v1_chat_generate_request(all_requests, tokenizer_manager):
|
651
|
-
|
652
687
|
input_ids = []
|
653
688
|
sampling_params_list = []
|
654
689
|
image_data_list = []
|
@@ -691,10 +726,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
691
726
|
{
|
692
727
|
"temperature": request.temperature,
|
693
728
|
"max_new_tokens": request.max_tokens,
|
729
|
+
"min_new_tokens": request.min_tokens,
|
694
730
|
"stop": stop,
|
731
|
+
"stop_token_ids": request.stop_token_ids,
|
695
732
|
"top_p": request.top_p,
|
696
733
|
"presence_penalty": request.presence_penalty,
|
697
734
|
"frequency_penalty": request.frequency_penalty,
|
735
|
+
"repetition_penalty": request.repetition_penalty,
|
698
736
|
"regex": request.regex,
|
699
737
|
"n": request.n,
|
700
738
|
}
|
@@ -776,14 +814,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
776
814
|
"index": 0,
|
777
815
|
"message": {"role": "assistant", "content": ret_item["text"]},
|
778
816
|
"logprobs": choice_logprobs,
|
779
|
-
"finish_reason":
|
817
|
+
"finish_reason": format_finish_reason(
|
818
|
+
ret_item["meta_info"]["finish_reason"]
|
819
|
+
),
|
780
820
|
}
|
781
821
|
else:
|
782
822
|
choice_data = ChatCompletionResponseChoice(
|
783
823
|
index=idx,
|
784
824
|
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
785
825
|
logprobs=choice_logprobs,
|
786
|
-
finish_reason=
|
826
|
+
finish_reason=format_finish_reason(
|
827
|
+
ret_item["meta_info"]["finish_reason"]
|
828
|
+
),
|
787
829
|
)
|
788
830
|
|
789
831
|
choices.append(choice_data)
|
@@ -900,18 +942,15 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
900
942
|
choice_data = ChatCompletionResponseStreamChoice(
|
901
943
|
index=0,
|
902
944
|
delta=DeltaMessage(role="assistant"),
|
903
|
-
finish_reason=
|
945
|
+
finish_reason=format_finish_reason(
|
946
|
+
content["meta_info"]["finish_reason"]
|
947
|
+
),
|
904
948
|
logprobs=choice_logprobs,
|
905
949
|
)
|
906
950
|
chunk = ChatCompletionStreamResponse(
|
907
951
|
id=content["meta_info"]["id"],
|
908
952
|
choices=[choice_data],
|
909
953
|
model=request.model,
|
910
|
-
usage=UsageInfo(
|
911
|
-
prompt_tokens=prompt_tokens,
|
912
|
-
completion_tokens=completion_tokens,
|
913
|
-
total_tokens=prompt_tokens + completion_tokens,
|
914
|
-
),
|
915
954
|
)
|
916
955
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
917
956
|
|
@@ -921,20 +960,34 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
921
960
|
choice_data = ChatCompletionResponseStreamChoice(
|
922
961
|
index=0,
|
923
962
|
delta=DeltaMessage(content=delta),
|
924
|
-
finish_reason=
|
963
|
+
finish_reason=format_finish_reason(
|
964
|
+
content["meta_info"]["finish_reason"]
|
965
|
+
),
|
925
966
|
logprobs=choice_logprobs,
|
926
967
|
)
|
927
968
|
chunk = ChatCompletionStreamResponse(
|
928
969
|
id=content["meta_info"]["id"],
|
929
970
|
choices=[choice_data],
|
930
971
|
model=request.model,
|
931
|
-
usage=UsageInfo(
|
932
|
-
prompt_tokens=prompt_tokens,
|
933
|
-
completion_tokens=completion_tokens,
|
934
|
-
total_tokens=prompt_tokens + completion_tokens,
|
935
|
-
),
|
936
972
|
)
|
937
973
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
974
|
+
if request.stream_options and request.stream_options.include_usage:
|
975
|
+
usage = UsageInfo(
|
976
|
+
prompt_tokens=prompt_tokens,
|
977
|
+
completion_tokens=completion_tokens,
|
978
|
+
total_tokens=prompt_tokens + completion_tokens,
|
979
|
+
)
|
980
|
+
|
981
|
+
final_usage_chunk = ChatCompletionStreamResponse(
|
982
|
+
id=str(uuid.uuid4().hex),
|
983
|
+
choices=[],
|
984
|
+
model=request.model,
|
985
|
+
usage=usage,
|
986
|
+
)
|
987
|
+
final_usage_data = final_usage_chunk.model_dump_json(
|
988
|
+
exclude_unset=True, exclude_none=True
|
989
|
+
)
|
990
|
+
yield f"data: {final_usage_data}\n\n"
|
938
991
|
except ValueError as e:
|
939
992
|
error = create_streaming_error_response(str(e))
|
940
993
|
yield f"data: {error}\n\n"
|
@@ -961,6 +1014,81 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
961
1014
|
return response
|
962
1015
|
|
963
1016
|
|
1017
|
+
def v1_embedding_request(all_requests, tokenizer_manager):
|
1018
|
+
prompts = []
|
1019
|
+
sampling_params_list = []
|
1020
|
+
first_prompt_type = type(all_requests[0].input)
|
1021
|
+
|
1022
|
+
for request in all_requests:
|
1023
|
+
prompt = request.input
|
1024
|
+
assert (
|
1025
|
+
type(prompt) == first_prompt_type
|
1026
|
+
), "All prompts must be of the same type in file input settings"
|
1027
|
+
prompts.append(prompt)
|
1028
|
+
|
1029
|
+
if len(all_requests) == 1:
|
1030
|
+
prompt = prompts[0]
|
1031
|
+
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
1032
|
+
prompt_kwargs = {"text": prompt}
|
1033
|
+
else:
|
1034
|
+
prompt_kwargs = {"input_ids": prompt}
|
1035
|
+
else:
|
1036
|
+
if isinstance(prompts[0], str) or isinstance(propmt[0][0], str):
|
1037
|
+
prompt_kwargs = {"text": prompts}
|
1038
|
+
else:
|
1039
|
+
prompt_kwargs = {"input_ids": prompts}
|
1040
|
+
|
1041
|
+
adapted_request = EmbeddingReqInput(
|
1042
|
+
**prompt_kwargs,
|
1043
|
+
)
|
1044
|
+
|
1045
|
+
if len(all_requests) == 1:
|
1046
|
+
return adapted_request, all_requests[0]
|
1047
|
+
return adapted_request, all_requests
|
1048
|
+
|
1049
|
+
|
1050
|
+
def v1_embedding_response(ret, model_path, to_file=False):
|
1051
|
+
embedding_objects = []
|
1052
|
+
prompt_tokens = 0
|
1053
|
+
for idx, ret_item in enumerate(ret):
|
1054
|
+
embedding_objects.append(
|
1055
|
+
EmbeddingObject(
|
1056
|
+
embedding=ret[idx]["embedding"],
|
1057
|
+
index=idx,
|
1058
|
+
)
|
1059
|
+
)
|
1060
|
+
prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"]
|
1061
|
+
|
1062
|
+
return EmbeddingResponse(
|
1063
|
+
data=embedding_objects,
|
1064
|
+
model=model_path,
|
1065
|
+
usage=UsageInfo(
|
1066
|
+
prompt_tokens=prompt_tokens,
|
1067
|
+
total_tokens=prompt_tokens,
|
1068
|
+
),
|
1069
|
+
)
|
1070
|
+
|
1071
|
+
|
1072
|
+
async def v1_embeddings(tokenizer_manager, raw_request: Request):
|
1073
|
+
request_json = await raw_request.json()
|
1074
|
+
all_requests = [EmbeddingRequest(**request_json)]
|
1075
|
+
adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager)
|
1076
|
+
|
1077
|
+
try:
|
1078
|
+
ret = await tokenizer_manager.generate_request(
|
1079
|
+
adapted_request, raw_request
|
1080
|
+
).__anext__()
|
1081
|
+
except ValueError as e:
|
1082
|
+
return create_error_response(str(e))
|
1083
|
+
|
1084
|
+
if not isinstance(ret, list):
|
1085
|
+
ret = [ret]
|
1086
|
+
|
1087
|
+
response = v1_embedding_response(ret, tokenizer_manager.model_path)
|
1088
|
+
|
1089
|
+
return response
|
1090
|
+
|
1091
|
+
|
964
1092
|
def to_openai_style_logprobs(
|
965
1093
|
input_token_logprobs=None,
|
966
1094
|
output_token_logprobs=None,
|
@@ -78,6 +78,10 @@ class UsageInfo(BaseModel):
|
|
78
78
|
completion_tokens: Optional[int] = 0
|
79
79
|
|
80
80
|
|
81
|
+
class StreamOptions(BaseModel):
|
82
|
+
include_usage: Optional[bool] = False
|
83
|
+
|
84
|
+
|
81
85
|
class FileRequest(BaseModel):
|
82
86
|
# https://platform.openai.com/docs/api-reference/files/create
|
83
87
|
file: bytes # The File object (not file name) to be uploaded
|
@@ -149,6 +153,7 @@ class CompletionRequest(BaseModel):
|
|
149
153
|
seed: Optional[int] = None
|
150
154
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
151
155
|
stream: Optional[bool] = False
|
156
|
+
stream_options: Optional[StreamOptions] = None
|
152
157
|
suffix: Optional[str] = None
|
153
158
|
temperature: Optional[float] = 1.0
|
154
159
|
top_p: Optional[float] = 1.0
|
@@ -157,6 +162,9 @@ class CompletionRequest(BaseModel):
|
|
157
162
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
158
163
|
regex: Optional[str] = None
|
159
164
|
ignore_eos: Optional[bool] = False
|
165
|
+
min_tokens: Optional[int] = 0
|
166
|
+
repetition_penalty: Optional[float] = 1.0
|
167
|
+
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
160
168
|
|
161
169
|
|
162
170
|
class CompletionResponseChoice(BaseModel):
|
@@ -188,7 +196,7 @@ class CompletionStreamResponse(BaseModel):
|
|
188
196
|
created: int = Field(default_factory=lambda: int(time.time()))
|
189
197
|
model: str
|
190
198
|
choices: List[CompletionResponseStreamChoice]
|
191
|
-
usage: UsageInfo
|
199
|
+
usage: Optional[UsageInfo] = None
|
192
200
|
|
193
201
|
|
194
202
|
class ChatCompletionMessageGenericParam(BaseModel):
|
@@ -247,12 +255,16 @@ class ChatCompletionRequest(BaseModel):
|
|
247
255
|
seed: Optional[int] = None
|
248
256
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
249
257
|
stream: Optional[bool] = False
|
258
|
+
stream_options: Optional[StreamOptions] = None
|
250
259
|
temperature: Optional[float] = 0.7
|
251
260
|
top_p: Optional[float] = 1.0
|
252
261
|
user: Optional[str] = None
|
253
262
|
|
254
263
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
255
264
|
regex: Optional[str] = None
|
265
|
+
min_tokens: Optional[int] = 0
|
266
|
+
repetition_penalty: Optional[float] = 1.0
|
267
|
+
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
256
268
|
|
257
269
|
|
258
270
|
class ChatMessage(BaseModel):
|
@@ -294,3 +306,27 @@ class ChatCompletionStreamResponse(BaseModel):
|
|
294
306
|
created: int = Field(default_factory=lambda: int(time.time()))
|
295
307
|
model: str
|
296
308
|
choices: List[ChatCompletionResponseStreamChoice]
|
309
|
+
usage: Optional[UsageInfo] = None
|
310
|
+
|
311
|
+
|
312
|
+
class EmbeddingRequest(BaseModel):
|
313
|
+
# Ordered by official OpenAI API documentation
|
314
|
+
# https://platform.openai.com/docs/api-reference/embeddings/create
|
315
|
+
input: Union[List[int], List[List[int]], str, List[str]]
|
316
|
+
model: str
|
317
|
+
encoding_format: str = "float"
|
318
|
+
dimensions: int = None
|
319
|
+
user: Optional[str] = None
|
320
|
+
|
321
|
+
|
322
|
+
class EmbeddingObject(BaseModel):
|
323
|
+
embedding: List[float]
|
324
|
+
index: int
|
325
|
+
object: str = "embedding"
|
326
|
+
|
327
|
+
|
328
|
+
class EmbeddingResponse(BaseModel):
|
329
|
+
data: List[EmbeddingObject]
|
330
|
+
model: str
|
331
|
+
object: str = "list"
|
332
|
+
usage: Optional[UsageInfo] = None
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from .orchestrator import BatchedPenalizerOrchestrator
|
2
|
+
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer
|
3
|
+
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer
|
4
|
+
from .penalizers.presence_penalty import BatchedPresencePenalizer
|
5
|
+
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
|
6
|
+
|
7
|
+
__all__ = [
|
8
|
+
"BatchedFrequencyPenalizer",
|
9
|
+
"BatchedMinNewTokensPenalizer",
|
10
|
+
"BatchedPresencePenalizer",
|
11
|
+
"BatchedRepetitionPenalizer",
|
12
|
+
"BatchedPenalizerOrchestrator",
|
13
|
+
]
|