sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +9 -6
  3. sglang/bench_serving.py +46 -22
  4. sglang/global_config.py +1 -1
  5. sglang/lang/backend/runtime_endpoint.py +60 -49
  6. sglang/lang/compiler.py +2 -2
  7. sglang/lang/interpreter.py +4 -2
  8. sglang/lang/ir.py +16 -7
  9. sglang/srt/constrained/base_tool_cache.py +1 -1
  10. sglang/srt/constrained/fsm_cache.py +12 -2
  11. sglang/srt/constrained/jump_forward.py +13 -2
  12. sglang/srt/layers/activation.py +32 -0
  13. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  14. sglang/srt/layers/extend_attention.py +9 -2
  15. sglang/srt/layers/fused_moe/__init__.py +1 -0
  16. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  17. sglang/srt/layers/fused_moe/layer.py +587 -0
  18. sglang/srt/layers/layernorm.py +65 -0
  19. sglang/srt/layers/logits_processor.py +7 -2
  20. sglang/srt/layers/pooler.py +50 -0
  21. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  22. sglang/srt/layers/radix_attention.py +40 -16
  23. sglang/srt/managers/detokenizer_manager.py +31 -9
  24. sglang/srt/managers/io_struct.py +63 -0
  25. sglang/srt/managers/policy_scheduler.py +173 -25
  26. sglang/srt/managers/schedule_batch.py +115 -97
  27. sglang/srt/managers/tokenizer_manager.py +194 -112
  28. sglang/srt/managers/tp_worker.py +290 -359
  29. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  30. sglang/srt/mem_cache/chunk_cache.py +43 -20
  31. sglang/srt/mem_cache/memory_pool.py +2 -2
  32. sglang/srt/mem_cache/radix_cache.py +74 -40
  33. sglang/srt/model_executor/cuda_graph_runner.py +71 -25
  34. sglang/srt/model_executor/forward_batch_info.py +293 -156
  35. sglang/srt/model_executor/model_runner.py +77 -57
  36. sglang/srt/models/chatglm.py +2 -2
  37. sglang/srt/models/commandr.py +1 -1
  38. sglang/srt/models/deepseek.py +2 -2
  39. sglang/srt/models/deepseek_v2.py +7 -6
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +11 -6
  42. sglang/srt/models/grok.py +50 -396
  43. sglang/srt/models/internlm2.py +2 -7
  44. sglang/srt/models/llama2.py +4 -4
  45. sglang/srt/models/llama_embedding.py +88 -0
  46. sglang/srt/models/minicpm.py +2 -2
  47. sglang/srt/models/mixtral.py +56 -254
  48. sglang/srt/models/mixtral_quant.py +1 -4
  49. sglang/srt/models/qwen.py +2 -2
  50. sglang/srt/models/qwen2.py +2 -2
  51. sglang/srt/models/qwen2_moe.py +2 -13
  52. sglang/srt/models/stablelm.py +1 -1
  53. sglang/srt/openai_api/adapter.py +187 -48
  54. sglang/srt/openai_api/protocol.py +37 -1
  55. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  56. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  57. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  58. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  59. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  60. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  61. sglang/srt/sampling_params.py +31 -8
  62. sglang/srt/server.py +91 -29
  63. sglang/srt/server_args.py +32 -19
  64. sglang/srt/utils.py +32 -15
  65. sglang/test/run_eval.py +10 -1
  66. sglang/test/runners.py +81 -73
  67. sglang/test/simple_eval_humaneval.py +2 -8
  68. sglang/test/simple_eval_mgsm.py +203 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  70. sglang/test/test_layernorm.py +60 -0
  71. sglang/test/test_programs.py +36 -7
  72. sglang/test/test_utils.py +24 -2
  73. sglang/utils.py +0 -1
  74. sglang/version.py +1 -1
  75. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
  76. sglang-0.2.13.dist-info/RECORD +112 -0
  77. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  78. sglang/srt/layers/linear.py +0 -884
  79. sglang/srt/layers/quantization/__init__.py +0 -64
  80. sglang/srt/layers/quantization/fp8.py +0 -677
  81. sglang/srt/model_loader/model_loader.py +0 -292
  82. sglang/srt/model_loader/utils.py +0 -275
  83. sglang-0.2.11.dist-info/RECORD +0 -102
  84. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  85. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -28,9 +28,7 @@ from vllm.distributed import (
28
28
  get_tensor_model_parallel_world_size,
29
29
  tensor_model_parallel_all_reduce,
30
30
  )
31
- from vllm.model_executor.layers.activation import SiluAndMul
32
31
  from vllm.model_executor.layers.fused_moe import FusedMoE
33
- from vllm.model_executor.layers.layernorm import RMSNorm
34
32
  from vllm.model_executor.layers.linear import (
35
33
  MergedColumnParallelLinear,
36
34
  QKVParallelLinear,
@@ -46,9 +44,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
46
44
  VocabParallelEmbedding,
47
45
  )
48
46
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
49
- from vllm.model_executor.sampling_metadata import SamplingMetadata
50
- from vllm.sequence import IntermediateTensors, SamplerOutput
51
47
 
48
+ from sglang.srt.layers.activation import SiluAndMul
49
+ from sglang.srt.layers.layernorm import RMSNorm
52
50
  from sglang.srt.layers.logits_processor import LogitsProcessor
53
51
  from sglang.srt.layers.radix_attention import RadixAttention
54
52
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -368,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module):
368
366
  config.vocab_size, config.hidden_size, quant_config=quant_config
369
367
  )
370
368
  self.logits_processor = LogitsProcessor(config)
371
- self.sampler = Sampler()
372
369
 
373
370
  @torch.no_grad()
374
371
  def forward(
@@ -394,14 +391,6 @@ class Qwen2MoeForCausalLM(nn.Module):
394
391
  )
395
392
  return logits
396
393
 
397
- def sample(
398
- self,
399
- logits: Optional[torch.Tensor],
400
- sampling_metadata: SamplingMetadata,
401
- ) -> Optional[SamplerOutput]:
402
- next_tokens = self.sampler(logits, sampling_metadata)
403
- return next_tokens
404
-
405
394
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
406
395
  stacked_params_mapping = [
407
396
  # (param_name, shard_name, shard_id)
@@ -24,7 +24,6 @@ from torch import nn
24
24
  from transformers import PretrainedConfig
25
25
  from vllm.config import CacheConfig
26
26
  from vllm.distributed import get_tensor_model_parallel_world_size
27
- from vllm.model_executor.layers.activation import SiluAndMul
28
27
  from vllm.model_executor.layers.linear import (
29
28
  MergedColumnParallelLinear,
30
29
  QKVParallelLinear,
@@ -38,6 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
38
37
  )
39
38
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
39
 
40
+ from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -34,7 +34,7 @@ from sglang.srt.conversation import (
34
34
  generate_chat_conv,
35
35
  register_conv_template,
36
36
  )
37
- from sglang.srt.managers.io_struct import GenerateReqInput
37
+ from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
38
38
  from sglang.srt.openai_api.protocol import (
39
39
  BatchRequest,
40
40
  BatchResponse,
@@ -52,6 +52,9 @@ from sglang.srt.openai_api.protocol import (
52
52
  CompletionResponseStreamChoice,
53
53
  CompletionStreamResponse,
54
54
  DeltaMessage,
55
+ EmbeddingObject,
56
+ EmbeddingRequest,
57
+ EmbeddingResponse,
55
58
  ErrorResponse,
56
59
  FileDeleteResponse,
57
60
  FileRequest,
@@ -74,7 +77,7 @@ class FileMetadata:
74
77
  batch_storage: Dict[str, BatchResponse] = {}
75
78
  file_id_request: Dict[str, FileMetadata] = {}
76
79
  file_id_response: Dict[str, FileResponse] = {}
77
- # map file id to file path in SGlang backend
80
+ # map file id to file path in SGLang backend
78
81
  file_id_storage: Dict[str, str] = {}
79
82
 
80
83
 
@@ -82,6 +85,19 @@ file_id_storage: Dict[str, str] = {}
82
85
  storage_dir = None
83
86
 
84
87
 
88
+ def format_finish_reason(finish_reason) -> Optional[str]:
89
+ if finish_reason.startswith("None"):
90
+ return None
91
+ elif finish_reason.startswith("FINISH_MATCHED"):
92
+ return "stop"
93
+ elif finish_reason.startswith("FINISH_LENGTH"):
94
+ return "length"
95
+ elif finish_reason.startswith("FINISH_ABORT"):
96
+ return "abort"
97
+ else:
98
+ return "unknown"
99
+
100
+
85
101
  def create_error_response(
86
102
  message: str,
87
103
  err_type: str = "BadRequestError",
@@ -101,7 +117,7 @@ def create_streaming_error_response(
101
117
  return json_str
102
118
 
103
119
 
104
- def load_chat_template_for_openai_api(chat_template_arg):
120
+ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
105
121
  global chat_template_name
106
122
 
107
123
  print(f"Use chat template: {chat_template_arg}")
@@ -111,27 +127,38 @@ def load_chat_template_for_openai_api(chat_template_arg):
111
127
  f"Chat template {chat_template_arg} is not a built-in template name "
112
128
  "or a valid chat template file path."
113
129
  )
114
- with open(chat_template_arg, "r") as filep:
115
- template = json.load(filep)
116
- try:
117
- sep_style = SeparatorStyle[template["sep_style"]]
118
- except KeyError:
119
- raise ValueError(
120
- f"Unknown separator style: {template['sep_style']}"
121
- ) from None
122
- register_conv_template(
123
- Conversation(
124
- name=template["name"],
125
- system_template=template["system"] + "\n{system_message}",
126
- system_message=template.get("system_message", ""),
127
- roles=(template["user"], template["assistant"]),
128
- sep_style=sep_style,
129
- sep=template.get("sep", "\n"),
130
- stop_str=template["stop_str"],
131
- ),
132
- override=True,
130
+ if chat_template_arg.endswith(".jinja"):
131
+ with open(chat_template_arg, "r") as f:
132
+ chat_template = "".join(f.readlines()).strip("\n")
133
+ tokenizer_manager.tokenizer.chat_template = chat_template.replace(
134
+ "\\n", "\n"
133
135
  )
134
- chat_template_name = template["name"]
136
+ chat_template_name = None
137
+ else:
138
+ assert chat_template_arg.endswith(
139
+ ".json"
140
+ ), "unrecognized format of chat template file"
141
+ with open(chat_template_arg, "r") as filep:
142
+ template = json.load(filep)
143
+ try:
144
+ sep_style = SeparatorStyle[template["sep_style"]]
145
+ except KeyError:
146
+ raise ValueError(
147
+ f"Unknown separator style: {template['sep_style']}"
148
+ ) from None
149
+ register_conv_template(
150
+ Conversation(
151
+ name=template["name"],
152
+ system_template=template["system"] + "\n{system_message}",
153
+ system_message=template.get("system_message", ""),
154
+ roles=(template["user"], template["assistant"]),
155
+ sep_style=sep_style,
156
+ sep=template.get("sep", "\n"),
157
+ stop_str=template["stop_str"],
158
+ ),
159
+ override=True,
160
+ )
161
+ chat_template_name = template["name"]
135
162
  else:
136
163
  chat_template_name = chat_template_arg
137
164
 
@@ -319,7 +346,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
319
346
  }
320
347
 
321
348
  except Exception as e:
322
- print("error in SGlang:", e)
349
+ print("error in SGLang:", e)
323
350
  # Update batch status to "failed"
324
351
  retrieve_batch = batch_storage[batch_id]
325
352
  retrieve_batch.status = "failed"
@@ -357,7 +384,6 @@ async def v1_retrieve_file_content(file_id: str):
357
384
 
358
385
 
359
386
  def v1_generate_request(all_requests):
360
-
361
387
  prompts = []
362
388
  sampling_params_list = []
363
389
  return_logprobs = []
@@ -378,10 +404,13 @@ def v1_generate_request(all_requests):
378
404
  {
379
405
  "temperature": request.temperature,
380
406
  "max_new_tokens": request.max_tokens,
407
+ "min_new_tokens": request.min_tokens,
381
408
  "stop": request.stop,
409
+ "stop_token_ids": request.stop_token_ids,
382
410
  "top_p": request.top_p,
383
411
  "presence_penalty": request.presence_penalty,
384
412
  "frequency_penalty": request.frequency_penalty,
413
+ "repetition_penalty": request.repetition_penalty,
385
414
  "regex": request.regex,
386
415
  "n": request.n,
387
416
  "ignore_eos": request.ignore_eos,
@@ -485,14 +514,18 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
485
514
  "index": 0,
486
515
  "text": text,
487
516
  "logprobs": logprobs,
488
- "finish_reason": ret_item["meta_info"]["finish_reason"],
517
+ "finish_reason": format_finish_reason(
518
+ ret_item["meta_info"]["finish_reason"]
519
+ ),
489
520
  }
490
521
  else:
491
522
  choice_data = CompletionResponseChoice(
492
523
  index=idx,
493
524
  text=text,
494
525
  logprobs=logprobs,
495
- finish_reason=ret_item["meta_info"]["finish_reason"],
526
+ finish_reason=format_finish_reason(
527
+ ret_item["meta_info"]["finish_reason"]
528
+ ),
496
529
  )
497
530
 
498
531
  choices.append(choice_data)
@@ -607,20 +640,34 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
607
640
  index=0,
608
641
  text=delta,
609
642
  logprobs=logprobs,
610
- finish_reason=content["meta_info"]["finish_reason"],
643
+ finish_reason=format_finish_reason(
644
+ content["meta_info"]["finish_reason"]
645
+ ),
611
646
  )
612
647
  chunk = CompletionStreamResponse(
613
648
  id=content["meta_info"]["id"],
614
649
  object="text_completion",
615
650
  choices=[choice_data],
616
651
  model=request.model,
617
- usage=UsageInfo(
618
- prompt_tokens=prompt_tokens,
619
- completion_tokens=completion_tokens,
620
- total_tokens=prompt_tokens + completion_tokens,
621
- ),
622
652
  )
623
653
  yield f"data: {chunk.model_dump_json()}\n\n"
654
+ if request.stream_options and request.stream_options.include_usage:
655
+ usage = UsageInfo(
656
+ prompt_tokens=prompt_tokens,
657
+ completion_tokens=completion_tokens,
658
+ total_tokens=prompt_tokens + completion_tokens,
659
+ )
660
+
661
+ final_usage_chunk = CompletionStreamResponse(
662
+ id=str(uuid.uuid4().hex),
663
+ choices=[],
664
+ model=request.model,
665
+ usage=usage,
666
+ )
667
+ final_usage_data = final_usage_chunk.model_dump_json(
668
+ exclude_unset=True, exclude_none=True
669
+ )
670
+ yield f"data: {final_usage_data}\n\n"
624
671
  except ValueError as e:
625
672
  error = create_streaming_error_response(str(e))
626
673
  yield f"data: {error}\n\n"
@@ -648,7 +695,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
648
695
 
649
696
 
650
697
  def v1_chat_generate_request(all_requests, tokenizer_manager):
651
-
652
698
  input_ids = []
653
699
  sampling_params_list = []
654
700
  image_data_list = []
@@ -691,10 +737,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
691
737
  {
692
738
  "temperature": request.temperature,
693
739
  "max_new_tokens": request.max_tokens,
740
+ "min_new_tokens": request.min_tokens,
694
741
  "stop": stop,
742
+ "stop_token_ids": request.stop_token_ids,
695
743
  "top_p": request.top_p,
696
744
  "presence_penalty": request.presence_penalty,
697
745
  "frequency_penalty": request.frequency_penalty,
746
+ "repetition_penalty": request.repetition_penalty,
698
747
  "regex": request.regex,
699
748
  "n": request.n,
700
749
  }
@@ -776,14 +825,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
776
825
  "index": 0,
777
826
  "message": {"role": "assistant", "content": ret_item["text"]},
778
827
  "logprobs": choice_logprobs,
779
- "finish_reason": ret_item["meta_info"]["finish_reason"],
828
+ "finish_reason": format_finish_reason(
829
+ ret_item["meta_info"]["finish_reason"]
830
+ ),
780
831
  }
781
832
  else:
782
833
  choice_data = ChatCompletionResponseChoice(
783
834
  index=idx,
784
835
  message=ChatMessage(role="assistant", content=ret_item["text"]),
785
836
  logprobs=choice_logprobs,
786
- finish_reason=ret_item["meta_info"]["finish_reason"],
837
+ finish_reason=format_finish_reason(
838
+ ret_item["meta_info"]["finish_reason"]
839
+ ),
787
840
  )
788
841
 
789
842
  choices.append(choice_data)
@@ -900,18 +953,15 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
900
953
  choice_data = ChatCompletionResponseStreamChoice(
901
954
  index=0,
902
955
  delta=DeltaMessage(role="assistant"),
903
- finish_reason=content["meta_info"]["finish_reason"],
956
+ finish_reason=format_finish_reason(
957
+ content["meta_info"]["finish_reason"]
958
+ ),
904
959
  logprobs=choice_logprobs,
905
960
  )
906
961
  chunk = ChatCompletionStreamResponse(
907
962
  id=content["meta_info"]["id"],
908
963
  choices=[choice_data],
909
964
  model=request.model,
910
- usage=UsageInfo(
911
- prompt_tokens=prompt_tokens,
912
- completion_tokens=completion_tokens,
913
- total_tokens=prompt_tokens + completion_tokens,
914
- ),
915
965
  )
916
966
  yield f"data: {chunk.model_dump_json()}\n\n"
917
967
 
@@ -921,20 +971,34 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
921
971
  choice_data = ChatCompletionResponseStreamChoice(
922
972
  index=0,
923
973
  delta=DeltaMessage(content=delta),
924
- finish_reason=content["meta_info"]["finish_reason"],
974
+ finish_reason=format_finish_reason(
975
+ content["meta_info"]["finish_reason"]
976
+ ),
925
977
  logprobs=choice_logprobs,
926
978
  )
927
979
  chunk = ChatCompletionStreamResponse(
928
980
  id=content["meta_info"]["id"],
929
981
  choices=[choice_data],
930
982
  model=request.model,
931
- usage=UsageInfo(
932
- prompt_tokens=prompt_tokens,
933
- completion_tokens=completion_tokens,
934
- total_tokens=prompt_tokens + completion_tokens,
935
- ),
936
983
  )
937
984
  yield f"data: {chunk.model_dump_json()}\n\n"
985
+ if request.stream_options and request.stream_options.include_usage:
986
+ usage = UsageInfo(
987
+ prompt_tokens=prompt_tokens,
988
+ completion_tokens=completion_tokens,
989
+ total_tokens=prompt_tokens + completion_tokens,
990
+ )
991
+
992
+ final_usage_chunk = ChatCompletionStreamResponse(
993
+ id=str(uuid.uuid4().hex),
994
+ choices=[],
995
+ model=request.model,
996
+ usage=usage,
997
+ )
998
+ final_usage_data = final_usage_chunk.model_dump_json(
999
+ exclude_unset=True, exclude_none=True
1000
+ )
1001
+ yield f"data: {final_usage_data}\n\n"
938
1002
  except ValueError as e:
939
1003
  error = create_streaming_error_response(str(e))
940
1004
  yield f"data: {error}\n\n"
@@ -961,6 +1025,81 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
961
1025
  return response
962
1026
 
963
1027
 
1028
+ def v1_embedding_request(all_requests, tokenizer_manager):
1029
+ prompts = []
1030
+ sampling_params_list = []
1031
+ first_prompt_type = type(all_requests[0].input)
1032
+
1033
+ for request in all_requests:
1034
+ prompt = request.input
1035
+ assert (
1036
+ type(prompt) == first_prompt_type
1037
+ ), "All prompts must be of the same type in file input settings"
1038
+ prompts.append(prompt)
1039
+
1040
+ if len(all_requests) == 1:
1041
+ prompt = prompts[0]
1042
+ if isinstance(prompt, str) or isinstance(prompt[0], str):
1043
+ prompt_kwargs = {"text": prompt}
1044
+ else:
1045
+ prompt_kwargs = {"input_ids": prompt}
1046
+ else:
1047
+ if isinstance(prompts[0], str) or isinstance(propmt[0][0], str):
1048
+ prompt_kwargs = {"text": prompts}
1049
+ else:
1050
+ prompt_kwargs = {"input_ids": prompts}
1051
+
1052
+ adapted_request = EmbeddingReqInput(
1053
+ **prompt_kwargs,
1054
+ )
1055
+
1056
+ if len(all_requests) == 1:
1057
+ return adapted_request, all_requests[0]
1058
+ return adapted_request, all_requests
1059
+
1060
+
1061
+ def v1_embedding_response(ret, model_path, to_file=False):
1062
+ embedding_objects = []
1063
+ prompt_tokens = 0
1064
+ for idx, ret_item in enumerate(ret):
1065
+ embedding_objects.append(
1066
+ EmbeddingObject(
1067
+ embedding=ret[idx]["embedding"],
1068
+ index=idx,
1069
+ )
1070
+ )
1071
+ prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"]
1072
+
1073
+ return EmbeddingResponse(
1074
+ data=embedding_objects,
1075
+ model=model_path,
1076
+ usage=UsageInfo(
1077
+ prompt_tokens=prompt_tokens,
1078
+ total_tokens=prompt_tokens,
1079
+ ),
1080
+ )
1081
+
1082
+
1083
+ async def v1_embeddings(tokenizer_manager, raw_request: Request):
1084
+ request_json = await raw_request.json()
1085
+ all_requests = [EmbeddingRequest(**request_json)]
1086
+ adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager)
1087
+
1088
+ try:
1089
+ ret = await tokenizer_manager.generate_request(
1090
+ adapted_request, raw_request
1091
+ ).__anext__()
1092
+ except ValueError as e:
1093
+ return create_error_response(str(e))
1094
+
1095
+ if not isinstance(ret, list):
1096
+ ret = [ret]
1097
+
1098
+ response = v1_embedding_response(ret, tokenizer_manager.model_path)
1099
+
1100
+ return response
1101
+
1102
+
964
1103
  def to_openai_style_logprobs(
965
1104
  input_token_logprobs=None,
966
1105
  output_token_logprobs=None,
@@ -78,6 +78,10 @@ class UsageInfo(BaseModel):
78
78
  completion_tokens: Optional[int] = 0
79
79
 
80
80
 
81
+ class StreamOptions(BaseModel):
82
+ include_usage: Optional[bool] = False
83
+
84
+
81
85
  class FileRequest(BaseModel):
82
86
  # https://platform.openai.com/docs/api-reference/files/create
83
87
  file: bytes # The File object (not file name) to be uploaded
@@ -149,6 +153,7 @@ class CompletionRequest(BaseModel):
149
153
  seed: Optional[int] = None
150
154
  stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
151
155
  stream: Optional[bool] = False
156
+ stream_options: Optional[StreamOptions] = None
152
157
  suffix: Optional[str] = None
153
158
  temperature: Optional[float] = 1.0
154
159
  top_p: Optional[float] = 1.0
@@ -157,6 +162,9 @@ class CompletionRequest(BaseModel):
157
162
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
158
163
  regex: Optional[str] = None
159
164
  ignore_eos: Optional[bool] = False
165
+ min_tokens: Optional[int] = 0
166
+ repetition_penalty: Optional[float] = 1.0
167
+ stop_token_ids: Optional[List[int]] = Field(default_factory=list)
160
168
 
161
169
 
162
170
  class CompletionResponseChoice(BaseModel):
@@ -188,7 +196,7 @@ class CompletionStreamResponse(BaseModel):
188
196
  created: int = Field(default_factory=lambda: int(time.time()))
189
197
  model: str
190
198
  choices: List[CompletionResponseStreamChoice]
191
- usage: UsageInfo
199
+ usage: Optional[UsageInfo] = None
192
200
 
193
201
 
194
202
  class ChatCompletionMessageGenericParam(BaseModel):
@@ -247,12 +255,16 @@ class ChatCompletionRequest(BaseModel):
247
255
  seed: Optional[int] = None
248
256
  stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
249
257
  stream: Optional[bool] = False
258
+ stream_options: Optional[StreamOptions] = None
250
259
  temperature: Optional[float] = 0.7
251
260
  top_p: Optional[float] = 1.0
252
261
  user: Optional[str] = None
253
262
 
254
263
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
255
264
  regex: Optional[str] = None
265
+ min_tokens: Optional[int] = 0
266
+ repetition_penalty: Optional[float] = 1.0
267
+ stop_token_ids: Optional[List[int]] = Field(default_factory=list)
256
268
 
257
269
 
258
270
  class ChatMessage(BaseModel):
@@ -294,3 +306,27 @@ class ChatCompletionStreamResponse(BaseModel):
294
306
  created: int = Field(default_factory=lambda: int(time.time()))
295
307
  model: str
296
308
  choices: List[ChatCompletionResponseStreamChoice]
309
+ usage: Optional[UsageInfo] = None
310
+
311
+
312
+ class EmbeddingRequest(BaseModel):
313
+ # Ordered by official OpenAI API documentation
314
+ # https://platform.openai.com/docs/api-reference/embeddings/create
315
+ input: Union[List[int], List[List[int]], str, List[str]]
316
+ model: str
317
+ encoding_format: str = "float"
318
+ dimensions: int = None
319
+ user: Optional[str] = None
320
+
321
+
322
+ class EmbeddingObject(BaseModel):
323
+ embedding: List[float]
324
+ index: int
325
+ object: str = "embedding"
326
+
327
+
328
+ class EmbeddingResponse(BaseModel):
329
+ data: List[EmbeddingObject]
330
+ model: str
331
+ object: str = "list"
332
+ usage: Optional[UsageInfo] = None
@@ -0,0 +1,13 @@
1
+ from .orchestrator import BatchedPenalizerOrchestrator
2
+ from .penalizers.frequency_penalty import BatchedFrequencyPenalizer
3
+ from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer
4
+ from .penalizers.presence_penalty import BatchedPresencePenalizer
5
+ from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
6
+
7
+ __all__ = [
8
+ "BatchedFrequencyPenalizer",
9
+ "BatchedMinNewTokensPenalizer",
10
+ "BatchedPresencePenalizer",
11
+ "BatchedRepetitionPenalizer",
12
+ "BatchedPenalizerOrchestrator",
13
+ ]