sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/api.py +13 -1
  2. sglang/bench_latency.py +10 -5
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/global_config.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +60 -49
  7. sglang/lang/chat_template.py +10 -5
  8. sglang/lang/compiler.py +4 -0
  9. sglang/lang/interpreter.py +5 -2
  10. sglang/lang/ir.py +22 -4
  11. sglang/launch_server.py +8 -1
  12. sglang/srt/constrained/jump_forward.py +13 -2
  13. sglang/srt/conversation.py +50 -1
  14. sglang/srt/hf_transformers_utils.py +22 -23
  15. sglang/srt/layers/activation.py +24 -2
  16. sglang/srt/layers/decode_attention.py +338 -50
  17. sglang/srt/layers/extend_attention.py +3 -1
  18. sglang/srt/layers/fused_moe/__init__.py +1 -0
  19. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  20. sglang/srt/layers/fused_moe/layer.py +587 -0
  21. sglang/srt/layers/layernorm.py +3 -0
  22. sglang/srt/layers/logits_processor.py +64 -27
  23. sglang/srt/layers/radix_attention.py +41 -18
  24. sglang/srt/layers/sampler.py +154 -0
  25. sglang/srt/managers/controller_multi.py +2 -8
  26. sglang/srt/managers/controller_single.py +7 -10
  27. sglang/srt/managers/detokenizer_manager.py +20 -9
  28. sglang/srt/managers/io_struct.py +44 -11
  29. sglang/srt/managers/policy_scheduler.py +5 -2
  30. sglang/srt/managers/schedule_batch.py +59 -179
  31. sglang/srt/managers/tokenizer_manager.py +193 -84
  32. sglang/srt/managers/tp_worker.py +131 -50
  33. sglang/srt/mem_cache/memory_pool.py +82 -8
  34. sglang/srt/mm_utils.py +79 -7
  35. sglang/srt/model_executor/cuda_graph_runner.py +97 -28
  36. sglang/srt/model_executor/forward_batch_info.py +188 -82
  37. sglang/srt/model_executor/model_runner.py +269 -87
  38. sglang/srt/models/chatglm.py +6 -14
  39. sglang/srt/models/commandr.py +6 -2
  40. sglang/srt/models/dbrx.py +5 -1
  41. sglang/srt/models/deepseek.py +7 -3
  42. sglang/srt/models/deepseek_v2.py +12 -7
  43. sglang/srt/models/gemma.py +6 -2
  44. sglang/srt/models/gemma2.py +22 -8
  45. sglang/srt/models/gpt_bigcode.py +5 -1
  46. sglang/srt/models/grok.py +66 -398
  47. sglang/srt/models/internlm2.py +5 -1
  48. sglang/srt/models/llama2.py +7 -3
  49. sglang/srt/models/llama_classification.py +2 -2
  50. sglang/srt/models/llama_embedding.py +4 -0
  51. sglang/srt/models/llava.py +176 -59
  52. sglang/srt/models/minicpm.py +7 -3
  53. sglang/srt/models/mixtral.py +61 -255
  54. sglang/srt/models/mixtral_quant.py +6 -5
  55. sglang/srt/models/qwen.py +7 -4
  56. sglang/srt/models/qwen2.py +15 -5
  57. sglang/srt/models/qwen2_moe.py +7 -16
  58. sglang/srt/models/stablelm.py +6 -2
  59. sglang/srt/openai_api/adapter.py +149 -58
  60. sglang/srt/sampling/sampling_batch_info.py +209 -0
  61. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
  62. sglang/srt/server.py +107 -71
  63. sglang/srt/server_args.py +49 -15
  64. sglang/srt/utils.py +27 -18
  65. sglang/test/runners.py +38 -38
  66. sglang/test/simple_eval_common.py +9 -10
  67. sglang/test/simple_eval_gpqa.py +2 -1
  68. sglang/test/simple_eval_humaneval.py +2 -2
  69. sglang/test/simple_eval_math.py +2 -1
  70. sglang/test/simple_eval_mmlu.py +2 -1
  71. sglang/test/test_activation.py +55 -0
  72. sglang/test/test_programs.py +32 -5
  73. sglang/test/test_utils.py +37 -50
  74. sglang/version.py +1 -1
  75. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
  76. sglang-0.2.14.dist-info/RECORD +114 -0
  77. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  78. sglang/launch_server_llavavid.py +0 -29
  79. sglang/srt/model_loader/model_loader.py +0 -292
  80. sglang/srt/model_loader/utils.py +0 -275
  81. sglang-0.2.12.dist-info/RECORD +0 -112
  82. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  83. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -22,10 +22,8 @@ import uuid
22
22
  from dataclasses import dataclass
23
23
  from typing import Dict, List, Optional, Union
24
24
 
25
- import torch
26
-
27
25
  from sglang.srt.managers.schedule_batch import BaseFinishReason
28
- from sglang.srt.sampling_params import SamplingParams
26
+ from sglang.srt.sampling.sampling_params import SamplingParams
29
27
 
30
28
 
31
29
  @dataclass
@@ -43,9 +41,9 @@ class GenerateReqInput:
43
41
  rid: Optional[Union[List[str], str]] = None
44
42
  # Whether to return logprobs.
45
43
  return_logprob: Optional[Union[List[bool], bool]] = None
46
- # The start location of the prompt for return_logprob.
44
+ # If return logprobs, the start location in the prompt for returning logprobs.
47
45
  logprob_start_len: Optional[Union[List[int], int]] = None
48
- # The number of top logprobs to return.
46
+ # If return logprobs, the number of top logprobs to return at each position.
49
47
  top_logprobs_num: Optional[Union[List[int], int]] = None
50
48
  # Whether to detokenize tokens in text in the returned logprobs.
51
49
  return_text_in_logprobs: bool = False
@@ -77,7 +75,7 @@ class GenerateReqInput:
77
75
  if self.return_logprob is None:
78
76
  self.return_logprob = False
79
77
  if self.logprob_start_len is None:
80
- self.logprob_start_len = 0
78
+ self.logprob_start_len = -1
81
79
  if self.top_logprobs_num is None:
82
80
  self.top_logprobs_num = 0
83
81
  else:
@@ -143,7 +141,7 @@ class GenerateReqInput:
143
141
  self.return_logprob = [self.return_logprob] * num
144
142
 
145
143
  if self.logprob_start_len is None:
146
- self.logprob_start_len = [0] * num
144
+ self.logprob_start_len = [-1] * num
147
145
  elif not isinstance(self.logprob_start_len, list):
148
146
  self.logprob_start_len = [self.logprob_start_len] * num
149
147
 
@@ -155,16 +153,27 @@ class GenerateReqInput:
155
153
 
156
154
  @dataclass
157
155
  class TokenizedGenerateReqInput:
156
+ # The request id
158
157
  rid: str
158
+ # The input text
159
159
  input_text: str
160
+ # The input token ids
160
161
  input_ids: List[int]
162
+ # The pixel values for input images
161
163
  pixel_values: List[float]
164
+ # The hash of input images
162
165
  image_hash: int
166
+ # The image size
163
167
  image_size: List[int]
168
+ # The sampling parameters
164
169
  sampling_params: SamplingParams
170
+ # Whether to return the logprobs
165
171
  return_logprob: bool
172
+ # If return logprobs, the start location in the prompt for returning logprobs.
166
173
  logprob_start_len: int
174
+ # If return logprobs, the number of top logprobs to return at each position.
167
175
  top_logprobs_num: int
176
+ # Whether to stream output
168
177
  stream: bool
169
178
 
170
179
 
@@ -215,15 +224,21 @@ class EmbeddingReqInput:
215
224
 
216
225
  @dataclass
217
226
  class TokenizedEmbeddingReqInput:
227
+ # The request id
218
228
  rid: str
229
+ # The input text
219
230
  input_text: str
231
+ # The input token ids
220
232
  input_ids: List[int]
233
+ # Dummy sampling params for compatibility
221
234
  sampling_params: SamplingParams
222
235
 
223
236
 
224
237
  @dataclass
225
238
  class BatchTokenIDOut:
239
+ # The request id
226
240
  rids: List[str]
241
+ # The version id to sync decode status with in detokenizer_manager
227
242
  vids: List[int]
228
243
  decoded_texts: List[str]
229
244
  decode_ids: List[int]
@@ -236,17 +251,25 @@ class BatchTokenIDOut:
236
251
 
237
252
  @dataclass
238
253
  class BatchStrOut:
254
+ # The request id
239
255
  rids: List[str]
256
+ # The output decoded strings
240
257
  output_strs: List[str]
258
+ # The meta info
241
259
  meta_info: List[Dict]
260
+ # The finish reason
242
261
  finished_reason: List[BaseFinishReason]
243
262
 
244
263
 
245
264
  @dataclass
246
265
  class BatchEmbeddingOut:
266
+ # The request id
247
267
  rids: List[str]
268
+ # The output embedding
248
269
  embeddings: List[List[float]]
270
+ # The meta info
249
271
  meta_info: List[Dict]
272
+ # The finish reason
250
273
  finished_reason: List[BaseFinishReason]
251
274
 
252
275
 
@@ -256,10 +279,20 @@ class FlushCacheReq:
256
279
 
257
280
 
258
281
  @dataclass
259
- class AbortReq:
260
- rid: str
282
+ class UpdateWeightReqInput:
283
+ # The model path with the new weights
284
+ model_path: str
285
+ # The format to load the weights
286
+ load_format: Optional[str] = None
261
287
 
262
288
 
263
289
  @dataclass
264
- class DetokenizeReqInput:
265
- input_ids: List[int]
290
+ class UpdateWeightReqOutput:
291
+ success: bool
292
+ message: str
293
+
294
+
295
+ @dataclass
296
+ class AbortReq:
297
+ # The request id
298
+ rid: str
@@ -111,11 +111,14 @@ class PrefillAdder:
111
111
  rem_total_tokens: int,
112
112
  rem_input_tokens: int,
113
113
  rem_chunk_tokens: Optional[int],
114
+ mixed_with_decode_tokens: int = 0,
114
115
  ):
115
116
  self.tree_cache = tree_cache
116
- self.rem_total_tokens = rem_total_tokens
117
- self.rem_input_tokens = rem_input_tokens
117
+ self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
118
+ self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
118
119
  self.rem_chunk_tokens = rem_chunk_tokens
120
+ if self.rem_chunk_tokens is not None:
121
+ self.rem_chunk_tokens -= mixed_with_decode_tokens
119
122
 
120
123
  self.can_run_list = []
121
124
  self.new_inflight_req = None
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2023-2024 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,20 +18,22 @@ limitations under the License.
16
18
  """Meta data for requests and batches"""
17
19
 
18
20
  import logging
19
- import warnings
20
21
  from dataclasses import dataclass
21
- from typing import List, Optional, Union
22
+ from typing import TYPE_CHECKING, List, Optional, Union
22
23
 
23
24
  import torch
24
- from flashinfer.sampling import top_k_top_p_sampling_from_probs
25
25
 
26
- import sglang.srt.sampling.penaltylib as penaltylib
27
26
  from sglang.global_config import global_config
28
27
  from sglang.srt.constrained import RegexGuide
29
28
  from sglang.srt.constrained.jump_forward import JumpForwardMap
30
29
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
31
30
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
32
31
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
32
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
33
+
34
+ if TYPE_CHECKING:
35
+ from sglang.srt.layers.sampler import SampleOutput
36
+
33
37
 
34
38
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
35
39
 
@@ -37,7 +41,7 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
37
41
  global_server_args_dict = {
38
42
  "disable_flashinfer": False,
39
43
  "disable_flashinfer_sampling": False,
40
- "attention_reduce_in_fp32": False,
44
+ "triton_attention_reduce_in_fp32": False,
41
45
  "enable_mla": False,
42
46
  }
43
47
 
@@ -235,10 +239,12 @@ class Req:
235
239
  return
236
240
 
237
241
  last_token_id = self.output_ids[-1]
238
- if self.tokenizer is None:
239
- matched_eos = last_token_id in self.sampling_params.stop_token_ids
240
- else:
241
- matched_eos = last_token_id == self.tokenizer.eos_token_id
242
+
243
+ matched_eos = last_token_id in self.sampling_params.stop_token_ids
244
+
245
+ if self.tokenizer is not None:
246
+ matched_eos |= last_token_id == self.tokenizer.eos_token_id
247
+
242
248
  if matched_eos and not self.sampling_params.ignore_eos:
243
249
  self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
244
250
  return
@@ -266,7 +272,7 @@ class Req:
266
272
 
267
273
  if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
268
274
  # TODO(lsyin): fix token fusion
269
- warnings.warn(
275
+ logger.warning(
270
276
  "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
271
277
  )
272
278
  return False
@@ -325,17 +331,13 @@ class ScheduleBatch:
325
331
  out_cache_loc: torch.Tensor = None
326
332
  extend_num_tokens: int = None
327
333
 
334
+ # For mixed chunekd prefill
335
+ prefix_lens_cpu: List[int] = None
336
+
328
337
  # For processing logprobs
329
338
  return_logprob: bool = False
330
339
  top_logprobs_nums: List[int] = None
331
340
 
332
- # Batched sampling params
333
- temperatures: torch.Tensor = None
334
- top_ps: torch.Tensor = None
335
- top_ks: torch.Tensor = None
336
- penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
337
- logit_bias: torch.Tensor = None
338
-
339
341
  @classmethod
340
342
  def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
341
343
  return_logprob = any(req.return_logprob for req in reqs)
@@ -383,51 +385,7 @@ class ScheduleBatch:
383
385
 
384
386
  return out_cache_loc
385
387
 
386
- def batch_sampling_params(self, vocab_size, int_token_logit_bias):
387
- device = "cuda"
388
- bs, reqs = self.batch_size(), self.reqs
389
- self.temperatures = torch.tensor(
390
- [r.sampling_params.temperature for r in reqs],
391
- dtype=torch.float,
392
- device=device,
393
- ).view(-1, 1)
394
- self.top_ps = torch.tensor(
395
- [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
396
- )
397
- self.top_ks = torch.tensor(
398
- [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
399
- )
400
-
401
- # Each penalizers will do nothing if they evaluate themselves as not required by looking at
402
- # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
403
- # should not add hefty computation overhead other than simple checks.
404
- #
405
- # While we choose not to even create the class instances if they are not required, this
406
- # could add additional complexity to the {ScheduleBatch} class, especially we need to
407
- # handle {filter_batch()} and {merge()} cases as well.
408
- self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
409
- vocab_size=vocab_size,
410
- batch=self,
411
- device=device,
412
- Penalizers={
413
- penaltylib.BatchedFrequencyPenalizer,
414
- penaltylib.BatchedMinNewTokensPenalizer,
415
- penaltylib.BatchedPresencePenalizer,
416
- penaltylib.BatchedRepetitionPenalizer,
417
- },
418
- )
419
-
420
- # Handle logit bias but only allocate when needed
421
- self.logit_bias = None
422
- for i in range(bs):
423
- if reqs[i].sampling_params.dtype == "int":
424
- if self.logit_bias is None:
425
- self.logit_bias = torch.zeros(
426
- (bs, vocab_size), dtype=torch.float32, device=device
427
- )
428
- self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
429
-
430
- def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
388
+ def prepare_for_extend(self, vocab_size: int):
431
389
  bs = self.batch_size()
432
390
  reqs = self.reqs
433
391
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
@@ -465,8 +423,32 @@ class ScheduleBatch:
465
423
  self.extend_num_tokens = extend_num_tokens
466
424
  self.out_cache_loc = out_cache_loc
467
425
  self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
426
+ self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
427
+
428
+ self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
429
+
430
+ def mix_with_running(self, running_batch: "ScheduleBatch"):
431
+ # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
432
+ prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
433
+ prefix_lens_cpu.extend(
434
+ [
435
+ len(r.origin_input_ids) + len(r.output_ids) - 1
436
+ for r in running_batch.reqs
437
+ ]
438
+ )
439
+
440
+ for req in running_batch.reqs:
441
+ req.fill_ids = req.origin_input_ids + req.output_ids
442
+ req.extend_input_len = 1
468
443
 
469
- self.batch_sampling_params(vocab_size, int_token_logit_bias)
444
+ input_ids = torch.cat([self.input_ids, running_batch.input_ids])
445
+ out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
446
+ extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
447
+ self.merge(running_batch)
448
+ self.input_ids = input_ids
449
+ self.out_cache_loc = out_cache_loc
450
+ self.extend_num_tokens = extend_num_tokens
451
+ self.prefix_lens_cpu = prefix_lens_cpu
470
452
 
471
453
  def check_decode_mem(self):
472
454
  bs = self.batch_size()
@@ -639,7 +621,7 @@ class ScheduleBatch:
639
621
  for r in self.reqs
640
622
  ]
641
623
  else:
642
- self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
624
+ self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)
643
625
 
644
626
  self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
645
627
  self.seq_lens.add_(1)
@@ -652,6 +634,8 @@ class ScheduleBatch:
652
634
  self.req_pool_indices, self.seq_lens - 1
653
635
  ] = self.out_cache_loc
654
636
 
637
+ self.sampling_info.update_regex_vocab_mask(self)
638
+
655
639
  def filter_batch(self, unfinished_indices: List[int]):
656
640
  if unfinished_indices is None or len(unfinished_indices) == 0:
657
641
  # Filter out all requests
@@ -672,23 +656,13 @@ class ScheduleBatch:
672
656
  self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
673
657
  self.return_logprob = any(req.return_logprob for req in self.reqs)
674
658
 
675
- self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
676
-
677
- for item in [
678
- "temperatures",
679
- "top_ps",
680
- "top_ks",
681
- "logit_bias",
682
- ]:
683
- self_val = getattr(self, item, None)
684
- if self_val is not None: # logit_bias can be None
685
- setattr(self, item, self_val[new_indices])
659
+ self.sampling_info.filter(unfinished_indices, new_indices)
686
660
 
687
661
  def merge(self, other: "ScheduleBatch"):
688
662
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
689
663
  # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
690
664
  # needs to be called with pre-merged Batch.reqs.
691
- self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
665
+ self.sampling_info.merge(other.sampling_info)
692
666
 
693
667
  self.reqs.extend(other.reqs)
694
668
 
@@ -703,111 +677,17 @@ class ScheduleBatch:
703
677
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
704
678
  self.return_logprob = any(req.return_logprob for req in self.reqs)
705
679
 
706
- for item in [
707
- "temperatures",
708
- "top_ps",
709
- "top_ks",
710
- ]:
711
- self_val = getattr(self, item, None)
712
- other_val = getattr(other, item, None)
713
- setattr(self, item, torch.concat([self_val, other_val]))
714
-
715
- # logit_bias can be None
716
- if self.logit_bias is not None or other.logit_bias is not None:
717
- vocab_size = (
718
- self.logit_bias.shape[1]
719
- if self.logit_bias is not None
720
- else other.logit_bias.shape[1]
721
- )
722
- if self.logit_bias is None:
723
- self.logit_bias = torch.zeros(
724
- (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
725
- )
726
- if other.logit_bias is None:
727
- other.logit_bias = torch.zeros(
728
- (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
729
- )
730
- self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
731
-
732
- def sample(self, logits: torch.Tensor):
733
- # TODO(lsyin): move this into a part of layer and run with CUDA Graph
734
- # Post process logits
735
- logits = logits.contiguous()
736
- logits.div_(self.temperatures)
737
- if self.logit_bias is not None:
738
- logits.add_(self.logit_bias)
739
-
740
- has_regex = any(req.regex_fsm is not None for req in self.reqs)
741
- if has_regex:
742
- allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
743
- for i, req in enumerate(self.reqs):
744
- if req.regex_fsm is not None:
745
- allowed_mask.zero_()
746
- allowed_mask[
747
- req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
748
- ] = 1
749
- logits[i].masked_fill_(~allowed_mask, float("-inf"))
750
-
751
- logits = self.penalizer_orchestrator.apply(logits)
752
-
753
- probs = torch.softmax(logits, dim=-1)
754
-
755
- if not global_server_args_dict["disable_flashinfer_sampling"]:
756
- max_top_k_round, batch_size = 32, probs.shape[0]
757
- uniform_samples = torch.rand(
758
- (max_top_k_round, batch_size), device=probs.device
759
- )
760
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
761
- probs, uniform_samples, self.top_ks, self.top_ps
762
- )
763
- else:
764
- # Here we provide a slower fallback implementation.
765
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
766
- probs, self.top_ks, self.top_ps
767
- )
768
-
769
- if not torch.all(success):
770
- warnings.warn("Sampling failed, fallback to top_k=1 strategy")
680
+ def check_sample_results(self, sample_output: SampleOutput):
681
+ if not torch.all(sample_output.success):
682
+ probs = sample_output.probs
683
+ batch_next_token_ids = sample_output.batch_next_token_ids
684
+ logging.warning("Sampling failed, fallback to top_k=1 strategy")
771
685
  probs = probs.masked_fill(torch.isnan(probs), 0.0)
772
686
  argmax_ids = torch.argmax(probs, dim=-1)
773
687
  batch_next_token_ids = torch.where(
774
- success, batch_next_token_ids, argmax_ids
688
+ sample_output.success, batch_next_token_ids, argmax_ids
775
689
  )
690
+ sample_output.probs = probs
691
+ sample_output.batch_next_token_ids = batch_next_token_ids
776
692
 
777
- if has_regex:
778
- batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
779
- for i, req in enumerate(self.reqs):
780
- if req.regex_fsm is not None:
781
- req.regex_fsm_state = req.regex_fsm.get_next_state(
782
- req.regex_fsm_state, batch_next_token_ids_cpu[i]
783
- )
784
-
785
- self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
786
-
787
- return batch_next_token_ids
788
-
789
-
790
- def top_k_top_p_sampling_from_probs_torch(
791
- probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
792
- ):
793
- """A top-k and top-k sampling implementation with native pytorch operations."""
794
- probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
795
- probs_sum = torch.cumsum(probs_sort, dim=-1)
796
- probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
797
- probs_sort[
798
- torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
799
- >= top_ks.view(-1, 1)
800
- ] = 0.0
801
- probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
802
- try:
803
- sampled_index = torch.multinomial(probs_sort, num_samples=1)
804
- except RuntimeError:
805
- batch_next_token_ids = torch.zeros(
806
- (probs_sort.shape[0],), dtype=torch.int32, device=probs.device
807
- )
808
- success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
809
- return batch_next_token_ids, success
810
-
811
- batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
812
- success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
813
- return batch_next_token_ids, success
693
+ return sample_output.batch_next_token_ids