sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. sglang/bench_latency.py +31 -13
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/conversation.py +11 -2
  6. sglang/srt/layers/attention/__init__.py +27 -5
  7. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  9. sglang/srt/layers/attention/triton_backend.py +6 -4
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  12. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  13. sglang/srt/layers/sampler.py +6 -2
  14. sglang/srt/managers/data_parallel_controller.py +177 -0
  15. sglang/srt/managers/detokenizer_manager.py +31 -10
  16. sglang/srt/managers/io_struct.py +11 -2
  17. sglang/srt/managers/schedule_batch.py +126 -43
  18. sglang/srt/managers/schedule_policy.py +2 -1
  19. sglang/srt/managers/scheduler.py +245 -142
  20. sglang/srt/managers/tokenizer_manager.py +14 -1
  21. sglang/srt/managers/tp_worker.py +111 -1
  22. sglang/srt/mem_cache/chunk_cache.py +8 -4
  23. sglang/srt/mem_cache/memory_pool.py +77 -4
  24. sglang/srt/mem_cache/radix_cache.py +15 -7
  25. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  26. sglang/srt/model_executor/forward_batch_info.py +16 -21
  27. sglang/srt/model_executor/model_runner.py +100 -36
  28. sglang/srt/models/baichuan.py +2 -3
  29. sglang/srt/models/chatglm.py +5 -6
  30. sglang/srt/models/commandr.py +1 -2
  31. sglang/srt/models/dbrx.py +1 -2
  32. sglang/srt/models/deepseek.py +4 -5
  33. sglang/srt/models/deepseek_v2.py +5 -6
  34. sglang/srt/models/exaone.py +1 -2
  35. sglang/srt/models/gemma.py +2 -2
  36. sglang/srt/models/gemma2.py +5 -5
  37. sglang/srt/models/gpt_bigcode.py +5 -5
  38. sglang/srt/models/grok.py +1 -2
  39. sglang/srt/models/internlm2.py +1 -2
  40. sglang/srt/models/llama.py +1 -2
  41. sglang/srt/models/llama_classification.py +1 -2
  42. sglang/srt/models/llama_reward.py +2 -3
  43. sglang/srt/models/llava.py +4 -8
  44. sglang/srt/models/llavavid.py +1 -2
  45. sglang/srt/models/minicpm.py +1 -2
  46. sglang/srt/models/minicpm3.py +5 -6
  47. sglang/srt/models/mixtral.py +1 -2
  48. sglang/srt/models/mixtral_quant.py +1 -2
  49. sglang/srt/models/olmo.py +352 -0
  50. sglang/srt/models/olmoe.py +1 -2
  51. sglang/srt/models/qwen.py +1 -2
  52. sglang/srt/models/qwen2.py +1 -2
  53. sglang/srt/models/qwen2_moe.py +4 -5
  54. sglang/srt/models/stablelm.py +1 -2
  55. sglang/srt/models/torch_native_llama.py +1 -2
  56. sglang/srt/models/xverse.py +1 -2
  57. sglang/srt/models/xverse_moe.py +4 -5
  58. sglang/srt/models/yivl.py +1 -2
  59. sglang/srt/openai_api/adapter.py +97 -52
  60. sglang/srt/openai_api/protocol.py +10 -2
  61. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  62. sglang/srt/sampling/sampling_batch_info.py +105 -59
  63. sglang/srt/sampling/sampling_params.py +2 -0
  64. sglang/srt/server.py +171 -37
  65. sglang/srt/server_args.py +127 -48
  66. sglang/srt/utils.py +37 -14
  67. sglang/test/few_shot_gsm8k.py +4 -1
  68. sglang/test/few_shot_gsm8k_engine.py +144 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  70. sglang/version.py +1 -1
  71. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
  72. sglang-0.3.4.dist-info/RECORD +143 -0
  73. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  74. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  75. sglang-0.3.3.dist-info/RECORD +0 -139
  76. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  77. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,6 @@ from typing import Iterable, Optional, Tuple
22
22
  import torch
23
23
  from torch import nn
24
24
  from transformers import PretrainedConfig
25
- from vllm.config import CacheConfig
26
25
  from vllm.distributed import get_tensor_model_parallel_world_size
27
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
28
27
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -241,7 +240,7 @@ class StableLmForCausalLM(nn.Module):
241
240
  self,
242
241
  config: PretrainedConfig,
243
242
  quant_config: Optional[QuantizationConfig] = None,
244
- cache_config: Optional[CacheConfig] = None,
243
+ cache_config=None,
245
244
  ) -> None:
246
245
  super().__init__()
247
246
  self.config = config
@@ -24,7 +24,6 @@ import torch
24
24
  from torch import nn
25
25
  from torch.nn.parameter import Parameter
26
26
  from transformers import LlamaConfig
27
- from vllm.config import CacheConfig
28
27
  from vllm.distributed import get_tensor_model_parallel_world_size
29
28
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
29
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -380,7 +379,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
380
379
  self,
381
380
  config: LlamaConfig,
382
381
  quant_config: Optional[QuantizationConfig] = None,
383
- cache_config: Optional[CacheConfig] = None,
382
+ cache_config=None,
384
383
  ) -> None:
385
384
  super().__init__()
386
385
  self.config = config
@@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
22
22
  import torch
23
23
  from torch import nn
24
24
  from transformers import LlamaConfig
25
- from vllm.config import CacheConfig
26
25
  from vllm.distributed import get_tensor_model_parallel_world_size
27
26
  from vllm.model_executor.layers.activation import SiluAndMul
28
27
  from vllm.model_executor.layers.layernorm import RMSNorm
@@ -297,7 +296,7 @@ class XverseForCausalLM(nn.Module):
297
296
  self,
298
297
  config: LlamaConfig,
299
298
  quant_config: Optional[QuantizationConfig] = None,
300
- cache_config: Optional[CacheConfig] = None,
299
+ cache_config=None,
301
300
  efficient_weight_load=False,
302
301
  ) -> None:
303
302
  super().__init__()
@@ -19,7 +19,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
19
19
  import torch
20
20
  from torch import nn
21
21
  from transformers import PretrainedConfig
22
- from vllm.config import CacheConfig
23
22
  from vllm.distributed import (
24
23
  get_tensor_model_parallel_rank,
25
24
  get_tensor_model_parallel_world_size,
@@ -183,7 +182,7 @@ class XverseAttention(nn.Module):
183
182
  rope_theta: float = 10000,
184
183
  rope_scaling: Optional[Dict[str, Any]] = None,
185
184
  max_position_embeddings: int = 8192,
186
- cache_config: Optional[CacheConfig] = None,
185
+ cache_config=None,
187
186
  quant_config: Optional[QuantizationConfig] = None,
188
187
  ) -> None:
189
188
  super().__init__()
@@ -260,7 +259,7 @@ class XverseDecoderLayer(nn.Module):
260
259
  self,
261
260
  config: PretrainedConfig,
262
261
  layer_id: int,
263
- cache_config: Optional[CacheConfig] = None,
262
+ cache_config=None,
264
263
  quant_config: Optional[QuantizationConfig] = None,
265
264
  ) -> None:
266
265
  super().__init__()
@@ -328,7 +327,7 @@ class XverseModel(nn.Module):
328
327
  def __init__(
329
328
  self,
330
329
  config: PretrainedConfig,
331
- cache_config: Optional[CacheConfig] = None,
330
+ cache_config=None,
332
331
  quant_config: Optional[QuantizationConfig] = None,
333
332
  ) -> None:
334
333
  super().__init__()
@@ -371,7 +370,7 @@ class XverseMoeForCausalLM(nn.Module):
371
370
  def __init__(
372
371
  self,
373
372
  config: PretrainedConfig,
374
- cache_config: Optional[CacheConfig] = None,
373
+ cache_config=None,
375
374
  quant_config: Optional[QuantizationConfig] = None,
376
375
  ) -> None:
377
376
  super().__init__()
sglang/srt/models/yivl.py CHANGED
@@ -20,7 +20,6 @@ from typing import Iterable, Optional, Tuple
20
20
  import torch
21
21
  import torch.nn as nn
22
22
  from transformers import CLIPVisionModel, LlavaConfig
23
- from vllm.config import CacheConfig
24
23
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
24
 
26
25
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -32,7 +31,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
32
31
  self,
33
32
  config: LlavaConfig,
34
33
  quant_config: Optional[QuantizationConfig] = None,
35
- cache_config: Optional[CacheConfig] = None,
34
+ cache_config=None,
36
35
  ) -> None:
37
36
  super().__init__(config, quant_config, cache_config)
38
37
 
@@ -25,7 +25,7 @@ from http import HTTPStatus
25
25
  from typing import Dict, List
26
26
 
27
27
  from fastapi import HTTPException, Request, UploadFile
28
- from fastapi.responses import JSONResponse, StreamingResponse
28
+ from fastapi.responses import ORJSONResponse, StreamingResponse
29
29
  from pydantic import ValidationError
30
30
 
31
31
  try:
@@ -101,7 +101,7 @@ def create_error_response(
101
101
  status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
102
102
  ):
103
103
  error = ErrorResponse(message=message, type=err_type, code=status_code.value)
104
- return JSONResponse(content=error.model_dump(), status_code=error.code)
104
+ return ORJSONResponse(content=error.model_dump(), status_code=error.code)
105
105
 
106
106
 
107
107
  def create_streaming_error_response(
@@ -117,7 +117,9 @@ def create_streaming_error_response(
117
117
  def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
118
118
  global chat_template_name
119
119
 
120
- logger.info(f"Use chat template: {chat_template_arg}")
120
+ logger.info(
121
+ f"Use chat template for the OpenAI-compatible API server: {chat_template_arg}"
122
+ )
121
123
  if not chat_template_exists(chat_template_arg):
122
124
  if not os.path.exists(chat_template_arg):
123
125
  raise RuntimeError(
@@ -300,7 +302,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
300
302
  if not isinstance(ret, list):
301
303
  ret = [ret]
302
304
  if end_point == "/v1/chat/completions":
303
- responses = v1_chat_generate_response(request, ret, to_file=True)
305
+ responses = v1_chat_generate_response(
306
+ request,
307
+ ret,
308
+ to_file=True,
309
+ cache_report=tokenizer_manager.server_args.enable_cache_report,
310
+ )
304
311
  else:
305
312
  responses = v1_generate_response(
306
313
  request, ret, tokenizer_manager, to_file=True
@@ -491,23 +498,38 @@ def v1_generate_request(
491
498
  top_logprobs_nums.append(
492
499
  request.logprobs if request.logprobs is not None else 0
493
500
  )
494
- sampling_params_list.append(
495
- {
496
- "temperature": request.temperature,
497
- "max_new_tokens": request.max_tokens,
498
- "min_new_tokens": request.min_tokens,
499
- "stop": request.stop,
500
- "stop_token_ids": request.stop_token_ids,
501
- "top_p": request.top_p,
502
- "presence_penalty": request.presence_penalty,
503
- "frequency_penalty": request.frequency_penalty,
504
- "repetition_penalty": request.repetition_penalty,
505
- "regex": request.regex,
506
- "json_schema": request.json_schema,
507
- "n": request.n,
508
- "ignore_eos": request.ignore_eos,
509
- }
510
- )
501
+ sampling_params = []
502
+ if isinstance(request.no_stop_trim, list):
503
+ num_reqs = len(request.prompt)
504
+ else:
505
+ num_reqs = 1
506
+ for i in range(num_reqs):
507
+ sampling_params.append(
508
+ {
509
+ "temperature": request.temperature,
510
+ "max_new_tokens": request.max_tokens,
511
+ "min_new_tokens": request.min_tokens,
512
+ "stop": request.stop,
513
+ "stop_token_ids": request.stop_token_ids,
514
+ "top_p": request.top_p,
515
+ "presence_penalty": request.presence_penalty,
516
+ "frequency_penalty": request.frequency_penalty,
517
+ "repetition_penalty": request.repetition_penalty,
518
+ "regex": request.regex,
519
+ "json_schema": request.json_schema,
520
+ "n": request.n,
521
+ "ignore_eos": request.ignore_eos,
522
+ "no_stop_trim": (
523
+ request.no_stop_trim
524
+ if not isinstance(request.no_stop_trim, list)
525
+ else request.no_stop_trim[i]
526
+ ),
527
+ }
528
+ )
529
+ if num_reqs == 1:
530
+ sampling_params_list.append(sampling_params[0])
531
+ else:
532
+ sampling_params_list.append(sampling_params)
511
533
 
512
534
  if len(all_requests) == 1:
513
535
  prompt = prompts[0]
@@ -599,16 +621,19 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
599
621
  else:
600
622
  logprobs = None
601
623
 
624
+ finish_reason = ret_item["meta_info"]["finish_reason"]
625
+
602
626
  if to_file:
603
627
  # to make the choise data json serializable
604
628
  choice_data = {
605
629
  "index": 0,
606
630
  "text": text,
607
631
  "logprobs": logprobs,
608
- "finish_reason": (
609
- ret_item["meta_info"]["finish_reason"]["type"]
610
- if ret_item["meta_info"]["finish_reason"]
611
- else ""
632
+ "finish_reason": (finish_reason["type"] if finish_reason else ""),
633
+ "matched_stop": (
634
+ finish_reason["matched"]
635
+ if finish_reason and "matched" in finish_reason
636
+ else None
612
637
  ),
613
638
  }
614
639
  else:
@@ -616,10 +641,11 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
616
641
  index=idx,
617
642
  text=text,
618
643
  logprobs=logprobs,
619
- finish_reason=(
620
- ret_item["meta_info"]["finish_reason"]["type"]
621
- if ret_item["meta_info"]["finish_reason"]
622
- else ""
644
+ finish_reason=(finish_reason["type"] if finish_reason else ""),
645
+ matched_stop=(
646
+ finish_reason["matched"]
647
+ if finish_reason and "matched" in finish_reason
648
+ else None
623
649
  ),
624
650
  )
625
651
 
@@ -749,14 +775,16 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
749
775
 
750
776
  delta = text[len(stream_buffer) :]
751
777
  stream_buffer = stream_buffer + delta
778
+ finish_reason = content["meta_info"]["finish_reason"]
752
779
  choice_data = CompletionResponseStreamChoice(
753
780
  index=index,
754
781
  text=delta,
755
782
  logprobs=logprobs,
756
- finish_reason=(
757
- content["meta_info"]["finish_reason"]["type"]
758
- if content["meta_info"]["finish_reason"]
759
- else ""
783
+ finish_reason=(finish_reason["type"] if finish_reason else ""),
784
+ matched_stop=(
785
+ finish_reason["matched"]
786
+ if finish_reason and "matched" in finish_reason
787
+ else None
760
788
  ),
761
789
  )
762
790
  chunk = CompletionStreamResponse(
@@ -908,6 +936,7 @@ def v1_chat_generate_request(
908
936
  "repetition_penalty": request.repetition_penalty,
909
937
  "regex": request.regex,
910
938
  "n": request.n,
939
+ "ignore_eos": request.ignore_eos,
911
940
  }
912
941
  if request.response_format and request.response_format.type == "json_schema":
913
942
  sampling_params["json_schema"] = convert_json_schema_to_str(
@@ -924,7 +953,7 @@ def v1_chat_generate_request(
924
953
  else:
925
954
  prompt_kwargs = {"input_ids": input_ids}
926
955
  sampling_params_list = sampling_params_list[0]
927
- image_data = image_data_list[0]
956
+ image_data_list = image_data_list[0]
928
957
  return_logprobs = return_logprobs[0]
929
958
  logprob_start_lens = logprob_start_lens[0]
930
959
  top_logprobs_nums = top_logprobs_nums[0]
@@ -937,7 +966,7 @@ def v1_chat_generate_request(
937
966
 
938
967
  adapted_request = GenerateReqInput(
939
968
  **prompt_kwargs,
940
- image_data=image_data,
969
+ image_data=image_data_list,
941
970
  sampling_params=sampling_params_list,
942
971
  return_logprob=return_logprobs,
943
972
  logprob_start_len=logprob_start_lens,
@@ -952,7 +981,7 @@ def v1_chat_generate_request(
952
981
  return adapted_request, all_requests
953
982
 
954
983
 
955
- def v1_chat_generate_response(request, ret, to_file=False):
984
+ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
956
985
  choices = []
957
986
 
958
987
  for idx, ret_item in enumerate(ret):
@@ -993,16 +1022,19 @@ def v1_chat_generate_response(request, ret, to_file=False):
993
1022
  else:
994
1023
  choice_logprobs = None
995
1024
 
1025
+ finish_reason = ret_item["meta_info"]["finish_reason"]
1026
+
996
1027
  if to_file:
997
1028
  # to make the choice data json serializable
998
1029
  choice_data = {
999
1030
  "index": 0,
1000
1031
  "message": {"role": "assistant", "content": ret_item["text"]},
1001
1032
  "logprobs": choice_logprobs,
1002
- "finish_reason": (
1003
- ret_item["meta_info"]["finish_reason"]["type"]
1004
- if ret_item["meta_info"]["finish_reason"]
1005
- else ""
1033
+ "finish_reason": (finish_reason["type"] if finish_reason else ""),
1034
+ "matched_stop": (
1035
+ finish_reason["matched"]
1036
+ if finish_reason and "matched" in finish_reason
1037
+ else None
1006
1038
  ),
1007
1039
  }
1008
1040
  else:
@@ -1010,10 +1042,11 @@ def v1_chat_generate_response(request, ret, to_file=False):
1010
1042
  index=idx,
1011
1043
  message=ChatMessage(role="assistant", content=ret_item["text"]),
1012
1044
  logprobs=choice_logprobs,
1013
- finish_reason=(
1014
- ret_item["meta_info"]["finish_reason"]["type"]
1015
- if ret_item["meta_info"]["finish_reason"]
1016
- else ""
1045
+ finish_reason=(finish_reason["type"] if finish_reason else ""),
1046
+ matched_stop=(
1047
+ finish_reason["matched"]
1048
+ if finish_reason and "matched" in finish_reason
1049
+ else None
1017
1050
  ),
1018
1051
  )
1019
1052
 
@@ -1049,6 +1082,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
1049
1082
  ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
1050
1083
  )
1051
1084
  completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
1085
+ cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret)
1052
1086
  response = ChatCompletionResponse(
1053
1087
  id=ret[0]["meta_info"]["id"],
1054
1088
  model=request.model,
@@ -1057,6 +1091,9 @@ def v1_chat_generate_response(request, ret, to_file=False):
1057
1091
  prompt_tokens=prompt_tokens,
1058
1092
  completion_tokens=completion_tokens,
1059
1093
  total_tokens=prompt_tokens + completion_tokens,
1094
+ prompt_tokens_details=(
1095
+ {"cached_tokens": cached_tokens} if cache_report else None
1096
+ ),
1060
1097
  ),
1061
1098
  )
1062
1099
  return response
@@ -1132,6 +1169,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1132
1169
  else:
1133
1170
  choice_logprobs = None
1134
1171
 
1172
+ finish_reason = content["meta_info"]["finish_reason"]
1173
+
1135
1174
  if is_first:
1136
1175
  # First chunk with role
1137
1176
  is_first = False
@@ -1139,9 +1178,12 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1139
1178
  index=index,
1140
1179
  delta=DeltaMessage(role="assistant"),
1141
1180
  finish_reason=(
1142
- content["meta_info"]["finish_reason"]["type"]
1143
- if content["meta_info"]["finish_reason"]
1144
- else ""
1181
+ finish_reason["type"] if finish_reason else ""
1182
+ ),
1183
+ matched_stop=(
1184
+ finish_reason["matched"]
1185
+ if finish_reason and "matched" in finish_reason
1186
+ else None
1145
1187
  ),
1146
1188
  logprobs=choice_logprobs,
1147
1189
  )
@@ -1158,10 +1200,11 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1158
1200
  choice_data = ChatCompletionResponseStreamChoice(
1159
1201
  index=index,
1160
1202
  delta=DeltaMessage(content=delta),
1161
- finish_reason=(
1162
- content["meta_info"]["finish_reason"]["type"]
1163
- if content["meta_info"]["finish_reason"]
1164
- else ""
1203
+ finish_reason=(finish_reason["type"] if finish_reason else ""),
1204
+ matched_stop=(
1205
+ finish_reason["matched"]
1206
+ if finish_reason and "matched" in finish_reason
1207
+ else None
1165
1208
  ),
1166
1209
  logprobs=choice_logprobs,
1167
1210
  )
@@ -1222,7 +1265,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1222
1265
  if not isinstance(ret, list):
1223
1266
  ret = [ret]
1224
1267
 
1225
- response = v1_chat_generate_response(request, ret)
1268
+ response = v1_chat_generate_response(
1269
+ request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report
1270
+ )
1226
1271
 
1227
1272
  return response
1228
1273
 
@@ -76,6 +76,8 @@ class UsageInfo(BaseModel):
76
76
  prompt_tokens: int = 0
77
77
  total_tokens: int = 0
78
78
  completion_tokens: Optional[int] = 0
79
+ # only used to return cached tokens when --enable-cache-report is set
80
+ prompt_tokens_details: Optional[Dict[str, int]] = None
79
81
 
80
82
 
81
83
  class StreamOptions(BaseModel):
@@ -170,10 +172,11 @@ class CompletionRequest(BaseModel):
170
172
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
171
173
  regex: Optional[str] = None
172
174
  json_schema: Optional[str] = None
173
- ignore_eos: Optional[bool] = False
174
- min_tokens: Optional[int] = 0
175
+ ignore_eos: bool = False
176
+ min_tokens: int = 0
175
177
  repetition_penalty: Optional[float] = 1.0
176
178
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
179
+ no_stop_trim: Union[bool, List[bool]] = False
177
180
 
178
181
 
179
182
  class CompletionResponseChoice(BaseModel):
@@ -181,6 +184,7 @@ class CompletionResponseChoice(BaseModel):
181
184
  text: str
182
185
  logprobs: Optional[LogProbs] = None
183
186
  finish_reason: Optional[str] = None
187
+ matched_stop: Union[None, int, str] = None
184
188
 
185
189
 
186
190
  class CompletionResponse(BaseModel):
@@ -197,6 +201,7 @@ class CompletionResponseStreamChoice(BaseModel):
197
201
  text: str
198
202
  logprobs: Optional[LogProbs] = None
199
203
  finish_reason: Optional[str] = None
204
+ matched_stop: Union[None, int, str] = None
200
205
 
201
206
 
202
207
  class CompletionStreamResponse(BaseModel):
@@ -275,6 +280,7 @@ class ChatCompletionRequest(BaseModel):
275
280
  min_tokens: Optional[int] = 0
276
281
  repetition_penalty: Optional[float] = 1.0
277
282
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
283
+ ignore_eos: bool = False
278
284
 
279
285
 
280
286
  class ChatMessage(BaseModel):
@@ -287,6 +293,7 @@ class ChatCompletionResponseChoice(BaseModel):
287
293
  message: ChatMessage
288
294
  logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
289
295
  finish_reason: str
296
+ matched_stop: Union[None, int, str] = None
290
297
 
291
298
 
292
299
  class ChatCompletionResponse(BaseModel):
@@ -308,6 +315,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
308
315
  delta: DeltaMessage
309
316
  logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
310
317
  finish_reason: Optional[str] = None
318
+ matched_stop: Union[None, int, str] = None
311
319
 
312
320
 
313
321
  class ChatCompletionStreamResponse(BaseModel):
@@ -37,12 +37,16 @@ class BatchedPenalizerOrchestrator:
37
37
 
38
38
  self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
39
39
 
40
+ is_required = False
40
41
  for penalizer in self.penalizers.values():
41
- penalizer.prepare_if_required()
42
+ pen_is_required = penalizer.prepare_if_required()
43
+ is_required |= pen_is_required
44
+ self.is_required = is_required
42
45
 
43
- self.cumulate_input_tokens(
44
- input_ids=[req.origin_input_ids for req in self.reqs()]
45
- )
46
+ if self.is_required:
47
+ self.cumulate_input_tokens(
48
+ input_ids=[req.origin_input_ids for req in self.reqs()]
49
+ )
46
50
 
47
51
  def reqs(self):
48
52
  return self.batch.reqs
@@ -79,6 +83,9 @@ class BatchedPenalizerOrchestrator:
79
83
  Args:
80
84
  output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
81
85
  """
86
+ if not self.is_required:
87
+ return
88
+
82
89
  token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
83
90
 
84
91
  for penalizer in self.penalizers.values():
@@ -95,6 +102,9 @@ class BatchedPenalizerOrchestrator:
95
102
  Returns:
96
103
  torch.Tensor: The logits after applying the penalizers.
97
104
  """
105
+ if not self.is_required:
106
+ return
107
+
98
108
  for penalizer in self.penalizers.values():
99
109
  logits = penalizer.apply(logits)
100
110
 
@@ -112,10 +122,16 @@ class BatchedPenalizerOrchestrator:
112
122
  indices_to_keep (typing.List[int]): List of indices to keep in the batch.
113
123
  indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
114
124
  """
125
+ if not self.is_required:
126
+ return
127
+
115
128
  empty_indices = len(indices_to_keep) == 0
116
129
 
130
+ is_required = False
117
131
  for penalizer in self.penalizers.values():
118
- if not penalizer.is_required() or empty_indices:
132
+ tmp_is_required = penalizer.is_required()
133
+ is_required = is_required or tmp_is_required
134
+ if not tmp_is_required or empty_indices:
119
135
  penalizer.teardown()
120
136
  else:
121
137
  # create tensor index only when it's needed
@@ -128,6 +144,7 @@ class BatchedPenalizerOrchestrator:
128
144
  indices_to_keep=indices_to_keep,
129
145
  indices_tensor_to_keep=indices_tensor_to_keep,
130
146
  )
147
+ self.is_required = is_required
131
148
 
132
149
  def merge(self, their: "BatchedPenalizerOrchestrator"):
133
150
  """
@@ -140,11 +157,10 @@ class BatchedPenalizerOrchestrator:
140
157
  Args:
141
158
  their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
142
159
  """
143
- if self.vocab_size != their.vocab_size:
144
- raise ValueError(
145
- f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}"
146
- )
160
+ if not self.is_required and not their.is_required:
161
+ return
147
162
 
163
+ self.is_required |= their.is_required
148
164
  for Penalizer, their_penalizer in their.penalizers.items():
149
165
  if Penalizer not in self.penalizers:
150
166
  raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
@@ -250,6 +266,9 @@ class _BatchedPenalizer(abc.ABC):
250
266
  def prepare_if_required(self):
251
267
  if self.is_required():
252
268
  self.prepare()
269
+ return True
270
+ else:
271
+ return False
253
272
 
254
273
  def teardown(self):
255
274
  if self.is_prepared():