sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_latency.py +28 -10
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/layers/attention/__init__.py +27 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  7. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  8. sglang/srt/layers/attention/triton_backend.py +6 -4
  9. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  10. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  11. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  12. sglang/srt/layers/sampler.py +6 -2
  13. sglang/srt/managers/detokenizer_manager.py +31 -10
  14. sglang/srt/managers/io_struct.py +4 -0
  15. sglang/srt/managers/schedule_batch.py +120 -43
  16. sglang/srt/managers/schedule_policy.py +2 -1
  17. sglang/srt/managers/scheduler.py +202 -140
  18. sglang/srt/managers/tokenizer_manager.py +5 -1
  19. sglang/srt/managers/tp_worker.py +111 -1
  20. sglang/srt/mem_cache/chunk_cache.py +8 -4
  21. sglang/srt/mem_cache/memory_pool.py +77 -4
  22. sglang/srt/mem_cache/radix_cache.py +15 -7
  23. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  24. sglang/srt/model_executor/forward_batch_info.py +16 -21
  25. sglang/srt/model_executor/model_runner.py +60 -1
  26. sglang/srt/models/baichuan.py +2 -3
  27. sglang/srt/models/chatglm.py +5 -6
  28. sglang/srt/models/commandr.py +1 -2
  29. sglang/srt/models/dbrx.py +1 -2
  30. sglang/srt/models/deepseek.py +4 -5
  31. sglang/srt/models/deepseek_v2.py +5 -6
  32. sglang/srt/models/exaone.py +1 -2
  33. sglang/srt/models/gemma.py +2 -2
  34. sglang/srt/models/gemma2.py +5 -5
  35. sglang/srt/models/gpt_bigcode.py +5 -5
  36. sglang/srt/models/grok.py +1 -2
  37. sglang/srt/models/internlm2.py +1 -2
  38. sglang/srt/models/llama.py +1 -2
  39. sglang/srt/models/llama_classification.py +1 -2
  40. sglang/srt/models/llama_reward.py +2 -3
  41. sglang/srt/models/llava.py +4 -8
  42. sglang/srt/models/llavavid.py +1 -2
  43. sglang/srt/models/minicpm.py +1 -2
  44. sglang/srt/models/minicpm3.py +5 -6
  45. sglang/srt/models/mixtral.py +1 -2
  46. sglang/srt/models/mixtral_quant.py +1 -2
  47. sglang/srt/models/olmo.py +352 -0
  48. sglang/srt/models/olmoe.py +1 -2
  49. sglang/srt/models/qwen.py +1 -2
  50. sglang/srt/models/qwen2.py +1 -2
  51. sglang/srt/models/qwen2_moe.py +4 -5
  52. sglang/srt/models/stablelm.py +1 -2
  53. sglang/srt/models/torch_native_llama.py +1 -2
  54. sglang/srt/models/xverse.py +1 -2
  55. sglang/srt/models/xverse_moe.py +4 -5
  56. sglang/srt/models/yivl.py +1 -2
  57. sglang/srt/openai_api/adapter.py +92 -49
  58. sglang/srt/openai_api/protocol.py +10 -2
  59. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  60. sglang/srt/sampling/sampling_batch_info.py +92 -58
  61. sglang/srt/sampling/sampling_params.py +2 -0
  62. sglang/srt/server.py +116 -17
  63. sglang/srt/server_args.py +121 -45
  64. sglang/srt/utils.py +11 -3
  65. sglang/test/few_shot_gsm8k.py +4 -1
  66. sglang/test/few_shot_gsm8k_engine.py +144 -0
  67. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  68. sglang/version.py +1 -1
  69. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
  70. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
  71. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  72. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  73. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  74. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,7 @@ from http import HTTPStatus
25
25
  from typing import Dict, List
26
26
 
27
27
  from fastapi import HTTPException, Request, UploadFile
28
- from fastapi.responses import JSONResponse, StreamingResponse
28
+ from fastapi.responses import ORJSONResponse, StreamingResponse
29
29
  from pydantic import ValidationError
30
30
 
31
31
  try:
@@ -101,7 +101,7 @@ def create_error_response(
101
101
  status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
102
102
  ):
103
103
  error = ErrorResponse(message=message, type=err_type, code=status_code.value)
104
- return JSONResponse(content=error.model_dump(), status_code=error.code)
104
+ return ORJSONResponse(content=error.model_dump(), status_code=error.code)
105
105
 
106
106
 
107
107
  def create_streaming_error_response(
@@ -302,7 +302,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
302
302
  if not isinstance(ret, list):
303
303
  ret = [ret]
304
304
  if end_point == "/v1/chat/completions":
305
- responses = v1_chat_generate_response(request, ret, to_file=True)
305
+ responses = v1_chat_generate_response(
306
+ request,
307
+ ret,
308
+ to_file=True,
309
+ cache_report=tokenizer_manager.server_args.enable_cache_report,
310
+ )
306
311
  else:
307
312
  responses = v1_generate_response(
308
313
  request, ret, tokenizer_manager, to_file=True
@@ -493,23 +498,38 @@ def v1_generate_request(
493
498
  top_logprobs_nums.append(
494
499
  request.logprobs if request.logprobs is not None else 0
495
500
  )
496
- sampling_params_list.append(
497
- {
498
- "temperature": request.temperature,
499
- "max_new_tokens": request.max_tokens,
500
- "min_new_tokens": request.min_tokens,
501
- "stop": request.stop,
502
- "stop_token_ids": request.stop_token_ids,
503
- "top_p": request.top_p,
504
- "presence_penalty": request.presence_penalty,
505
- "frequency_penalty": request.frequency_penalty,
506
- "repetition_penalty": request.repetition_penalty,
507
- "regex": request.regex,
508
- "json_schema": request.json_schema,
509
- "n": request.n,
510
- "ignore_eos": request.ignore_eos,
511
- }
512
- )
501
+ sampling_params = []
502
+ if isinstance(request.no_stop_trim, list):
503
+ num_reqs = len(request.prompt)
504
+ else:
505
+ num_reqs = 1
506
+ for i in range(num_reqs):
507
+ sampling_params.append(
508
+ {
509
+ "temperature": request.temperature,
510
+ "max_new_tokens": request.max_tokens,
511
+ "min_new_tokens": request.min_tokens,
512
+ "stop": request.stop,
513
+ "stop_token_ids": request.stop_token_ids,
514
+ "top_p": request.top_p,
515
+ "presence_penalty": request.presence_penalty,
516
+ "frequency_penalty": request.frequency_penalty,
517
+ "repetition_penalty": request.repetition_penalty,
518
+ "regex": request.regex,
519
+ "json_schema": request.json_schema,
520
+ "n": request.n,
521
+ "ignore_eos": request.ignore_eos,
522
+ "no_stop_trim": (
523
+ request.no_stop_trim
524
+ if not isinstance(request.no_stop_trim, list)
525
+ else request.no_stop_trim[i]
526
+ ),
527
+ }
528
+ )
529
+ if num_reqs == 1:
530
+ sampling_params_list.append(sampling_params[0])
531
+ else:
532
+ sampling_params_list.append(sampling_params)
513
533
 
514
534
  if len(all_requests) == 1:
515
535
  prompt = prompts[0]
@@ -601,16 +621,19 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
601
621
  else:
602
622
  logprobs = None
603
623
 
624
+ finish_reason = ret_item["meta_info"]["finish_reason"]
625
+
604
626
  if to_file:
605
627
  # to make the choise data json serializable
606
628
  choice_data = {
607
629
  "index": 0,
608
630
  "text": text,
609
631
  "logprobs": logprobs,
610
- "finish_reason": (
611
- ret_item["meta_info"]["finish_reason"]["type"]
612
- if ret_item["meta_info"]["finish_reason"]
613
- else ""
632
+ "finish_reason": (finish_reason["type"] if finish_reason else ""),
633
+ "matched_stop": (
634
+ finish_reason["matched"]
635
+ if finish_reason and "matched" in finish_reason
636
+ else None
614
637
  ),
615
638
  }
616
639
  else:
@@ -618,10 +641,11 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
618
641
  index=idx,
619
642
  text=text,
620
643
  logprobs=logprobs,
621
- finish_reason=(
622
- ret_item["meta_info"]["finish_reason"]["type"]
623
- if ret_item["meta_info"]["finish_reason"]
624
- else ""
644
+ finish_reason=(finish_reason["type"] if finish_reason else ""),
645
+ matched_stop=(
646
+ finish_reason["matched"]
647
+ if finish_reason and "matched" in finish_reason
648
+ else None
625
649
  ),
626
650
  )
627
651
 
@@ -751,14 +775,16 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
751
775
 
752
776
  delta = text[len(stream_buffer) :]
753
777
  stream_buffer = stream_buffer + delta
778
+ finish_reason = content["meta_info"]["finish_reason"]
754
779
  choice_data = CompletionResponseStreamChoice(
755
780
  index=index,
756
781
  text=delta,
757
782
  logprobs=logprobs,
758
- finish_reason=(
759
- content["meta_info"]["finish_reason"]["type"]
760
- if content["meta_info"]["finish_reason"]
761
- else ""
783
+ finish_reason=(finish_reason["type"] if finish_reason else ""),
784
+ matched_stop=(
785
+ finish_reason["matched"]
786
+ if finish_reason and "matched" in finish_reason
787
+ else None
762
788
  ),
763
789
  )
764
790
  chunk = CompletionStreamResponse(
@@ -910,6 +936,7 @@ def v1_chat_generate_request(
910
936
  "repetition_penalty": request.repetition_penalty,
911
937
  "regex": request.regex,
912
938
  "n": request.n,
939
+ "ignore_eos": request.ignore_eos,
913
940
  }
914
941
  if request.response_format and request.response_format.type == "json_schema":
915
942
  sampling_params["json_schema"] = convert_json_schema_to_str(
@@ -954,7 +981,7 @@ def v1_chat_generate_request(
954
981
  return adapted_request, all_requests
955
982
 
956
983
 
957
- def v1_chat_generate_response(request, ret, to_file=False):
984
+ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
958
985
  choices = []
959
986
 
960
987
  for idx, ret_item in enumerate(ret):
@@ -995,16 +1022,19 @@ def v1_chat_generate_response(request, ret, to_file=False):
995
1022
  else:
996
1023
  choice_logprobs = None
997
1024
 
1025
+ finish_reason = ret_item["meta_info"]["finish_reason"]
1026
+
998
1027
  if to_file:
999
1028
  # to make the choice data json serializable
1000
1029
  choice_data = {
1001
1030
  "index": 0,
1002
1031
  "message": {"role": "assistant", "content": ret_item["text"]},
1003
1032
  "logprobs": choice_logprobs,
1004
- "finish_reason": (
1005
- ret_item["meta_info"]["finish_reason"]["type"]
1006
- if ret_item["meta_info"]["finish_reason"]
1007
- else ""
1033
+ "finish_reason": (finish_reason["type"] if finish_reason else ""),
1034
+ "matched_stop": (
1035
+ finish_reason["matched"]
1036
+ if finish_reason and "matched" in finish_reason
1037
+ else None
1008
1038
  ),
1009
1039
  }
1010
1040
  else:
@@ -1012,10 +1042,11 @@ def v1_chat_generate_response(request, ret, to_file=False):
1012
1042
  index=idx,
1013
1043
  message=ChatMessage(role="assistant", content=ret_item["text"]),
1014
1044
  logprobs=choice_logprobs,
1015
- finish_reason=(
1016
- ret_item["meta_info"]["finish_reason"]["type"]
1017
- if ret_item["meta_info"]["finish_reason"]
1018
- else ""
1045
+ finish_reason=(finish_reason["type"] if finish_reason else ""),
1046
+ matched_stop=(
1047
+ finish_reason["matched"]
1048
+ if finish_reason and "matched" in finish_reason
1049
+ else None
1019
1050
  ),
1020
1051
  )
1021
1052
 
@@ -1051,6 +1082,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
1051
1082
  ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
1052
1083
  )
1053
1084
  completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
1085
+ cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret)
1054
1086
  response = ChatCompletionResponse(
1055
1087
  id=ret[0]["meta_info"]["id"],
1056
1088
  model=request.model,
@@ -1059,6 +1091,9 @@ def v1_chat_generate_response(request, ret, to_file=False):
1059
1091
  prompt_tokens=prompt_tokens,
1060
1092
  completion_tokens=completion_tokens,
1061
1093
  total_tokens=prompt_tokens + completion_tokens,
1094
+ prompt_tokens_details=(
1095
+ {"cached_tokens": cached_tokens} if cache_report else None
1096
+ ),
1062
1097
  ),
1063
1098
  )
1064
1099
  return response
@@ -1134,6 +1169,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1134
1169
  else:
1135
1170
  choice_logprobs = None
1136
1171
 
1172
+ finish_reason = content["meta_info"]["finish_reason"]
1173
+
1137
1174
  if is_first:
1138
1175
  # First chunk with role
1139
1176
  is_first = False
@@ -1141,9 +1178,12 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1141
1178
  index=index,
1142
1179
  delta=DeltaMessage(role="assistant"),
1143
1180
  finish_reason=(
1144
- content["meta_info"]["finish_reason"]["type"]
1145
- if content["meta_info"]["finish_reason"]
1146
- else ""
1181
+ finish_reason["type"] if finish_reason else ""
1182
+ ),
1183
+ matched_stop=(
1184
+ finish_reason["matched"]
1185
+ if finish_reason and "matched" in finish_reason
1186
+ else None
1147
1187
  ),
1148
1188
  logprobs=choice_logprobs,
1149
1189
  )
@@ -1160,10 +1200,11 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1160
1200
  choice_data = ChatCompletionResponseStreamChoice(
1161
1201
  index=index,
1162
1202
  delta=DeltaMessage(content=delta),
1163
- finish_reason=(
1164
- content["meta_info"]["finish_reason"]["type"]
1165
- if content["meta_info"]["finish_reason"]
1166
- else ""
1203
+ finish_reason=(finish_reason["type"] if finish_reason else ""),
1204
+ matched_stop=(
1205
+ finish_reason["matched"]
1206
+ if finish_reason and "matched" in finish_reason
1207
+ else None
1167
1208
  ),
1168
1209
  logprobs=choice_logprobs,
1169
1210
  )
@@ -1224,7 +1265,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1224
1265
  if not isinstance(ret, list):
1225
1266
  ret = [ret]
1226
1267
 
1227
- response = v1_chat_generate_response(request, ret)
1268
+ response = v1_chat_generate_response(
1269
+ request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report
1270
+ )
1228
1271
 
1229
1272
  return response
1230
1273
 
@@ -76,6 +76,8 @@ class UsageInfo(BaseModel):
76
76
  prompt_tokens: int = 0
77
77
  total_tokens: int = 0
78
78
  completion_tokens: Optional[int] = 0
79
+ # only used to return cached tokens when --enable-cache-report is set
80
+ prompt_tokens_details: Optional[Dict[str, int]] = None
79
81
 
80
82
 
81
83
  class StreamOptions(BaseModel):
@@ -170,10 +172,11 @@ class CompletionRequest(BaseModel):
170
172
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
171
173
  regex: Optional[str] = None
172
174
  json_schema: Optional[str] = None
173
- ignore_eos: Optional[bool] = False
174
- min_tokens: Optional[int] = 0
175
+ ignore_eos: bool = False
176
+ min_tokens: int = 0
175
177
  repetition_penalty: Optional[float] = 1.0
176
178
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
179
+ no_stop_trim: Union[bool, List[bool]] = False
177
180
 
178
181
 
179
182
  class CompletionResponseChoice(BaseModel):
@@ -181,6 +184,7 @@ class CompletionResponseChoice(BaseModel):
181
184
  text: str
182
185
  logprobs: Optional[LogProbs] = None
183
186
  finish_reason: Optional[str] = None
187
+ matched_stop: Union[None, int, str] = None
184
188
 
185
189
 
186
190
  class CompletionResponse(BaseModel):
@@ -197,6 +201,7 @@ class CompletionResponseStreamChoice(BaseModel):
197
201
  text: str
198
202
  logprobs: Optional[LogProbs] = None
199
203
  finish_reason: Optional[str] = None
204
+ matched_stop: Union[None, int, str] = None
200
205
 
201
206
 
202
207
  class CompletionStreamResponse(BaseModel):
@@ -275,6 +280,7 @@ class ChatCompletionRequest(BaseModel):
275
280
  min_tokens: Optional[int] = 0
276
281
  repetition_penalty: Optional[float] = 1.0
277
282
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
283
+ ignore_eos: bool = False
278
284
 
279
285
 
280
286
  class ChatMessage(BaseModel):
@@ -287,6 +293,7 @@ class ChatCompletionResponseChoice(BaseModel):
287
293
  message: ChatMessage
288
294
  logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
289
295
  finish_reason: str
296
+ matched_stop: Union[None, int, str] = None
290
297
 
291
298
 
292
299
  class ChatCompletionResponse(BaseModel):
@@ -308,6 +315,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
308
315
  delta: DeltaMessage
309
316
  logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
310
317
  finish_reason: Optional[str] = None
318
+ matched_stop: Union[None, int, str] = None
311
319
 
312
320
 
313
321
  class ChatCompletionStreamResponse(BaseModel):
@@ -37,12 +37,16 @@ class BatchedPenalizerOrchestrator:
37
37
 
38
38
  self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
39
39
 
40
+ is_required = False
40
41
  for penalizer in self.penalizers.values():
41
- penalizer.prepare_if_required()
42
+ pen_is_required = penalizer.prepare_if_required()
43
+ is_required |= pen_is_required
44
+ self.is_required = is_required
42
45
 
43
- self.cumulate_input_tokens(
44
- input_ids=[req.origin_input_ids for req in self.reqs()]
45
- )
46
+ if self.is_required:
47
+ self.cumulate_input_tokens(
48
+ input_ids=[req.origin_input_ids for req in self.reqs()]
49
+ )
46
50
 
47
51
  def reqs(self):
48
52
  return self.batch.reqs
@@ -79,6 +83,9 @@ class BatchedPenalizerOrchestrator:
79
83
  Args:
80
84
  output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
81
85
  """
86
+ if not self.is_required:
87
+ return
88
+
82
89
  token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
83
90
 
84
91
  for penalizer in self.penalizers.values():
@@ -95,6 +102,9 @@ class BatchedPenalizerOrchestrator:
95
102
  Returns:
96
103
  torch.Tensor: The logits after applying the penalizers.
97
104
  """
105
+ if not self.is_required:
106
+ return
107
+
98
108
  for penalizer in self.penalizers.values():
99
109
  logits = penalizer.apply(logits)
100
110
 
@@ -112,10 +122,16 @@ class BatchedPenalizerOrchestrator:
112
122
  indices_to_keep (typing.List[int]): List of indices to keep in the batch.
113
123
  indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
114
124
  """
125
+ if not self.is_required:
126
+ return
127
+
115
128
  empty_indices = len(indices_to_keep) == 0
116
129
 
130
+ is_required = False
117
131
  for penalizer in self.penalizers.values():
118
- if not penalizer.is_required() or empty_indices:
132
+ tmp_is_required = penalizer.is_required()
133
+ is_required = is_required or tmp_is_required
134
+ if not tmp_is_required or empty_indices:
119
135
  penalizer.teardown()
120
136
  else:
121
137
  # create tensor index only when it's needed
@@ -128,6 +144,7 @@ class BatchedPenalizerOrchestrator:
128
144
  indices_to_keep=indices_to_keep,
129
145
  indices_tensor_to_keep=indices_tensor_to_keep,
130
146
  )
147
+ self.is_required = is_required
131
148
 
132
149
  def merge(self, their: "BatchedPenalizerOrchestrator"):
133
150
  """
@@ -140,11 +157,10 @@ class BatchedPenalizerOrchestrator:
140
157
  Args:
141
158
  their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
142
159
  """
143
- if self.vocab_size != their.vocab_size:
144
- raise ValueError(
145
- f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}"
146
- )
160
+ if not self.is_required and not their.is_required:
161
+ return
147
162
 
163
+ self.is_required |= their.is_required
148
164
  for Penalizer, their_penalizer in their.penalizers.items():
149
165
  if Penalizer not in self.penalizers:
150
166
  raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
@@ -250,6 +266,9 @@ class _BatchedPenalizer(abc.ABC):
250
266
  def prepare_if_required(self):
251
267
  if self.is_required():
252
268
  self.prepare()
269
+ return True
270
+ else:
271
+ return False
253
272
 
254
273
  def teardown(self):
255
274
  if self.is_prepared():
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
- from typing import TYPE_CHECKING, List
4
+ from typing import TYPE_CHECKING, List, Optional
5
5
 
6
6
  import torch
7
7
 
@@ -20,6 +20,9 @@ class SamplingBatchInfo:
20
20
  top_ks: torch.Tensor
21
21
  min_ps: torch.Tensor
22
22
 
23
+ # All requests use greedy sampling
24
+ is_all_greedy: bool
25
+
23
26
  # Dispatch in CUDA graph
24
27
  need_min_p_sampling: bool
25
28
 
@@ -33,30 +36,39 @@ class SamplingBatchInfo:
33
36
  regex_fsm_states: List[int] = None
34
37
 
35
38
  # Penalizer
36
- penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
37
- linear_penalties: torch.Tensor = None
38
- scaling_penalties: torch.Tensor = None
39
+ penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
40
+ linear_penalties: Optional[torch.Tensor] = None
41
+ scaling_penalties: Optional[torch.Tensor] = None
39
42
 
40
43
  # Device
41
44
  device: str = "cuda"
42
45
 
43
46
  @classmethod
44
- def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
47
+ def from_schedule_batch(
48
+ cls,
49
+ batch: ScheduleBatch,
50
+ vocab_size: int,
51
+ disable_penalizer: bool,
52
+ ):
45
53
  reqs = batch.reqs
46
- with batch.input_ids.device:
47
- temperatures = torch.tensor(
54
+ device = batch.input_ids.device
55
+ temperatures = (
56
+ torch.tensor(
48
57
  [r.sampling_params.temperature for r in reqs],
49
58
  dtype=torch.float,
50
- ).view(-1, 1)
51
- top_ps = torch.tensor(
52
- [r.sampling_params.top_p for r in reqs], dtype=torch.float
53
- )
54
- top_ks = torch.tensor(
55
- [r.sampling_params.top_k for r in reqs], dtype=torch.int
56
- )
57
- min_ps = torch.tensor(
58
- [r.sampling_params.min_p for r in reqs], dtype=torch.float
59
59
  )
60
+ .view(-1, 1)
61
+ .to(device, non_blocking=True)
62
+ )
63
+ top_ps = torch.tensor(
64
+ [r.sampling_params.top_p for r in reqs], dtype=torch.float
65
+ ).to(device, non_blocking=True)
66
+ top_ks = torch.tensor(
67
+ [r.sampling_params.top_k for r in reqs], dtype=torch.int32
68
+ ).to(device, non_blocking=True)
69
+ min_ps = torch.tensor(
70
+ [r.sampling_params.min_p for r in reqs], dtype=torch.float
71
+ ).to(device, non_blocking=True)
60
72
 
61
73
  ret = cls(
62
74
  temperatures=temperatures,
@@ -64,6 +76,7 @@ class SamplingBatchInfo:
64
76
  top_ks=top_ks,
65
77
  min_ps=min_ps,
66
78
  need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
79
+ is_all_greedy=top_ks.max().item() <= 1,
67
80
  vocab_size=vocab_size,
68
81
  device=batch.input_ids.device,
69
82
  )
@@ -75,18 +88,21 @@ class SamplingBatchInfo:
75
88
  #
76
89
  # While we choose not to even create the class instances if they are not required, this
77
90
  # could add additional complexity to the {ScheduleBatch} class, especially we need to
78
- # handle {filter_batch()} and {merge()} cases as well.
79
- ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
80
- vocab_size=vocab_size,
81
- batch=batch,
82
- device=batch.input_ids.device,
83
- Penalizers={
84
- penaltylib.BatchedFrequencyPenalizer,
85
- penaltylib.BatchedMinNewTokensPenalizer,
86
- penaltylib.BatchedPresencePenalizer,
87
- penaltylib.BatchedRepetitionPenalizer,
88
- },
89
- )
91
+ # handle {filter_batch()} and {merge_batch()} cases as well.
92
+ if disable_penalizer:
93
+ ret.penalizer_orchestrator = None
94
+ else:
95
+ ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
96
+ vocab_size=vocab_size,
97
+ batch=batch,
98
+ device=batch.input_ids.device,
99
+ Penalizers={
100
+ penaltylib.BatchedFrequencyPenalizer,
101
+ penaltylib.BatchedMinNewTokensPenalizer,
102
+ penaltylib.BatchedPresencePenalizer,
103
+ penaltylib.BatchedRepetitionPenalizer,
104
+ },
105
+ )
90
106
 
91
107
  # Handle logit bias but only allocate when needed
92
108
  ret.logit_bias = None
@@ -97,46 +113,50 @@ class SamplingBatchInfo:
97
113
  return len(self.temperatures)
98
114
 
99
115
  def update_penalties(self):
116
+ if not self.penalizer_orchestrator:
117
+ return
118
+
100
119
  self.scaling_penalties = None
101
120
  self.linear_penalties = None
102
121
 
103
122
  for penalizer in self.penalizer_orchestrator.penalizers.values():
123
+ if not penalizer.is_prepared():
124
+ continue
125
+
104
126
  if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
105
- if penalizer.is_prepared():
106
- self.scaling_penalties = penalizer.cumulated_repetition_penalties
127
+ self.scaling_penalties = penalizer.cumulated_repetition_penalties
107
128
  else:
108
- if penalizer.is_prepared():
109
- if self.linear_penalties is None:
110
- bs = self.penalizer_orchestrator.batch.batch_size()
111
- self.linear_penalties = torch.zeros(
112
- (bs, self.vocab_size),
113
- dtype=torch.float32,
114
- device=self.device,
115
- )
116
- self.linear_penalties = penalizer.apply(self.linear_penalties)
129
+ if self.linear_penalties is None:
130
+ bs = self.penalizer_orchestrator.batch.batch_size()
131
+ self.linear_penalties = torch.zeros(
132
+ (bs, self.vocab_size),
133
+ dtype=torch.float32,
134
+ device=self.device,
135
+ )
136
+ self.linear_penalties = penalizer.apply(self.linear_penalties)
117
137
 
118
138
  def update_regex_vocab_mask(self):
119
139
  has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
120
-
121
- # Reset the vocab mask
122
- self.vocab_mask = None
123
-
124
- if has_regex:
125
- self.vocab_mask = torch.zeros(
126
- len(self.temperatures),
127
- self.vocab_size,
128
- dtype=torch.bool,
129
- device=self.device,
130
- )
131
- for i, regex_fsm in enumerate(self.regex_fsms):
132
- if regex_fsm is not None:
133
- self.vocab_mask[i].fill_(1)
134
- self.vocab_mask[i][
135
- regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
136
- ] = 0
140
+ if not has_regex:
141
+ self.vocab_mask = None
142
+ return
143
+
144
+ self.vocab_mask = torch.zeros(
145
+ len(self.temperatures),
146
+ self.vocab_size,
147
+ dtype=torch.bool,
148
+ device=self.device,
149
+ )
150
+ for i, regex_fsm in enumerate(self.regex_fsms):
151
+ if regex_fsm is not None:
152
+ self.vocab_mask[i].fill_(1)
153
+ self.vocab_mask[i][
154
+ regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
155
+ ] = 0
137
156
 
138
157
  def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
139
- self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
158
+ if self.penalizer_orchestrator:
159
+ self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
140
160
 
141
161
  for item in [
142
162
  "temperatures",
@@ -175,7 +195,8 @@ class SamplingBatchInfo:
175
195
  return None
176
196
 
177
197
  def merge_batch(self, other: "SamplingBatchInfo"):
178
- self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
198
+ if self.penalizer_orchestrator:
199
+ self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
179
200
 
180
201
  for item in [
181
202
  "temperatures",
@@ -187,6 +208,19 @@ class SamplingBatchInfo:
187
208
  other_val = getattr(other, item, None)
188
209
  setattr(self, item, torch.concat([self_val, other_val]))
189
210
 
211
+ self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
190
212
  self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
191
213
  self.logit_bias, other.logit_bias, len(self), len(other), self.device
192
214
  )
215
+
216
+ def copy(self):
217
+ return SamplingBatchInfo(
218
+ temperatures=self.temperatures,
219
+ top_ps=self.top_ps,
220
+ top_ks=self.top_ks,
221
+ min_ps=self.min_ps,
222
+ is_all_greedy=self.is_all_greedy,
223
+ need_min_p_sampling=self.need_min_p_sampling,
224
+ vocab_size=self.vocab_size,
225
+ device=self.device,
226
+ )
@@ -40,6 +40,7 @@ class SamplingParams:
40
40
  regex: Optional[str] = None,
41
41
  n: int = 1,
42
42
  json_schema: Optional[str] = None,
43
+ no_stop_trim: bool = False,
43
44
  ) -> None:
44
45
  self.temperature = temperature
45
46
  self.top_p = top_p
@@ -60,6 +61,7 @@ class SamplingParams:
60
61
  self.regex = regex
61
62
  self.n = n
62
63
  self.json_schema = json_schema
64
+ self.no_stop_trim = no_stop_trim
63
65
 
64
66
  # Process some special cases
65
67
  if self.temperature < _SAMPLING_EPS: