sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +6 -4
- sglang/bench_serving.py +46 -22
- sglang/lang/compiler.py +2 -2
- sglang/lang/ir.py +3 -3
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +5 -0
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +110 -87
- sglang/srt/managers/tokenizer_manager.py +193 -111
- sglang/srt/managers/tp_worker.py +289 -352
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +168 -105
- sglang/srt/model_executor/model_runner.py +24 -37
- sglang/srt/models/gemma2.py +0 -1
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/qwen2_moe.py +0 -11
- sglang/srt/openai_api/adapter.py +155 -27
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +69 -15
- sglang/srt/server_args.py +26 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +20 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -18,18 +18,18 @@ limitations under the License.
|
|
18
18
|
import logging
|
19
19
|
import warnings
|
20
20
|
from dataclasses import dataclass
|
21
|
-
from typing import List, Union
|
21
|
+
from typing import List, Optional, Union
|
22
22
|
|
23
|
-
import numpy as np
|
24
23
|
import torch
|
25
24
|
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
26
25
|
|
26
|
+
import sglang.srt.sampling.penaltylib as penaltylib
|
27
27
|
from sglang.global_config import global_config
|
28
28
|
from sglang.srt.constrained import RegexGuide
|
29
29
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
30
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
30
31
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
31
32
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
32
|
-
from sglang.srt.mem_cache.radix_cache import RadixCache
|
33
33
|
|
34
34
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
35
35
|
|
@@ -98,7 +98,7 @@ class Req:
|
|
98
98
|
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
99
99
|
self.origin_input_ids = origin_input_ids
|
100
100
|
self.output_ids = [] # Each decode stage's output ids
|
101
|
-
self.
|
101
|
+
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
102
102
|
|
103
103
|
# Memory info
|
104
104
|
self.req_pool_idx = None
|
@@ -124,7 +124,7 @@ class Req:
|
|
124
124
|
# For vision input
|
125
125
|
self.pixel_values = None
|
126
126
|
self.image_size = None
|
127
|
-
self.image_offset =
|
127
|
+
self.image_offset = None
|
128
128
|
self.pad_value = None
|
129
129
|
|
130
130
|
# Prefix info
|
@@ -142,6 +142,7 @@ class Req:
|
|
142
142
|
|
143
143
|
# Logprobs
|
144
144
|
self.return_logprob = False
|
145
|
+
self.embedding = None
|
145
146
|
self.logprob_start_len = 0
|
146
147
|
self.top_logprobs_num = 0
|
147
148
|
self.normalized_prompt_logprob = None
|
@@ -162,6 +163,32 @@ class Req:
|
|
162
163
|
def finished(self) -> bool:
|
163
164
|
return self.finished_reason is not None
|
164
165
|
|
166
|
+
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
167
|
+
self.fill_ids = self.origin_input_ids + self.output_ids
|
168
|
+
if tree_cache is not None:
|
169
|
+
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
170
|
+
rid=self.rid, key=self.adjust_max_prefix_ids()
|
171
|
+
)
|
172
|
+
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
173
|
+
|
174
|
+
def adjust_max_prefix_ids(self):
|
175
|
+
self.fill_ids = self.origin_input_ids + self.output_ids
|
176
|
+
input_len = len(self.fill_ids)
|
177
|
+
max_prefix_len = input_len
|
178
|
+
|
179
|
+
if self.sampling_params.max_new_tokens > 0:
|
180
|
+
# Need at least one token to compute logits
|
181
|
+
max_prefix_len = min(max_prefix_len, input_len - 1)
|
182
|
+
|
183
|
+
if self.return_logprob:
|
184
|
+
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
185
|
+
|
186
|
+
if self.normalized_prompt_logprob is None:
|
187
|
+
# Need at least two tokens to compute normalized logprob
|
188
|
+
max_prefix_len = min(max_prefix_len, input_len - 2)
|
189
|
+
|
190
|
+
return self.fill_ids[:max_prefix_len]
|
191
|
+
|
165
192
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
166
193
|
def init_incremental_detokenize(self):
|
167
194
|
first_iter = self.surr_offset is None or self.read_offset is None
|
@@ -176,6 +203,8 @@ class Req:
|
|
176
203
|
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
177
204
|
|
178
205
|
def get_next_inc_detokenization(self):
|
206
|
+
if self.tokenizer is None:
|
207
|
+
return False, ""
|
179
208
|
read_ids, read_offset = self.init_incremental_detokenize()
|
180
209
|
surr_ids = read_ids[:read_offset]
|
181
210
|
|
@@ -200,16 +229,18 @@ class Req:
|
|
200
229
|
return
|
201
230
|
|
202
231
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
203
|
-
self.finished_reason = FINISH_LENGTH(
|
232
|
+
self.finished_reason = FINISH_LENGTH(
|
233
|
+
length=self.sampling_params.max_new_tokens
|
234
|
+
)
|
204
235
|
return
|
205
236
|
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
)
|
237
|
+
last_token_id = self.output_ids[-1]
|
238
|
+
if self.tokenizer is None:
|
239
|
+
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
240
|
+
else:
|
241
|
+
matched_eos = last_token_id == self.tokenizer.eos_token_id
|
242
|
+
if matched_eos and not self.sampling_params.ignore_eos:
|
243
|
+
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
213
244
|
return
|
214
245
|
|
215
246
|
if len(self.sampling_params.stop_strs) > 0:
|
@@ -284,13 +315,12 @@ class ScheduleBatch:
|
|
284
315
|
reqs: List[Req]
|
285
316
|
req_to_token_pool: ReqToTokenPool
|
286
317
|
token_to_kv_pool: BaseTokenToKVPool
|
287
|
-
tree_cache:
|
318
|
+
tree_cache: BasePrefixCache
|
288
319
|
|
289
320
|
# Batched arguments to model runner
|
290
321
|
input_ids: torch.Tensor = None
|
291
322
|
req_pool_indices: torch.Tensor = None
|
292
323
|
seq_lens: torch.Tensor = None
|
293
|
-
prefix_lens: torch.Tensor = None
|
294
324
|
position_ids_offsets: torch.Tensor = None
|
295
325
|
out_cache_loc: torch.Tensor = None
|
296
326
|
extend_num_tokens: int = None
|
@@ -299,17 +329,11 @@ class ScheduleBatch:
|
|
299
329
|
return_logprob: bool = False
|
300
330
|
top_logprobs_nums: List[int] = None
|
301
331
|
|
302
|
-
# For multimodal
|
303
|
-
pixel_values: List[torch.Tensor] = None
|
304
|
-
image_sizes: List[List[int]] = None
|
305
|
-
image_offsets: List[int] = None
|
306
|
-
|
307
332
|
# Batched sampling params
|
308
333
|
temperatures: torch.Tensor = None
|
309
334
|
top_ps: torch.Tensor = None
|
310
335
|
top_ks: torch.Tensor = None
|
311
|
-
|
312
|
-
presence_penalties: torch.Tensor = None
|
336
|
+
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
313
337
|
logit_bias: torch.Tensor = None
|
314
338
|
|
315
339
|
@classmethod
|
@@ -373,15 +397,24 @@ class ScheduleBatch:
|
|
373
397
|
self.top_ks = torch.tensor(
|
374
398
|
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
375
399
|
)
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
400
|
+
|
401
|
+
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
402
|
+
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
403
|
+
# should not add hefty computation overhead other than simple checks.
|
404
|
+
#
|
405
|
+
# While we choose not to even create the class instances if they are not required, this
|
406
|
+
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
407
|
+
# handle {filter_batch()} and {merge()} cases as well.
|
408
|
+
self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
409
|
+
vocab_size=vocab_size,
|
410
|
+
batch=self,
|
384
411
|
device=device,
|
412
|
+
Penalizers={
|
413
|
+
penaltylib.BatchedFrequencyPenalizer,
|
414
|
+
penaltylib.BatchedMinNewTokensPenalizer,
|
415
|
+
penaltylib.BatchedPresencePenalizer,
|
416
|
+
penaltylib.BatchedRepetitionPenalizer,
|
417
|
+
},
|
385
418
|
)
|
386
419
|
|
387
420
|
# Handle logit bias but only allocate when needed
|
@@ -395,58 +428,40 @@ class ScheduleBatch:
|
|
395
428
|
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
396
429
|
|
397
430
|
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
398
|
-
device = "cuda"
|
399
431
|
bs = self.batch_size()
|
400
432
|
reqs = self.reqs
|
401
|
-
input_ids = [r.
|
402
|
-
|
403
|
-
|
404
|
-
# Handle prefix
|
405
|
-
extend_lens = []
|
406
|
-
prefix_lens = []
|
433
|
+
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
434
|
+
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
407
435
|
seq_lens = []
|
408
436
|
|
437
|
+
# Allocate memory
|
409
438
|
req_pool_indices_cpu = self.alloc_req_slots(bs)
|
439
|
+
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
410
440
|
|
441
|
+
pt = 0
|
411
442
|
for i, req in enumerate(reqs):
|
412
443
|
req.req_pool_idx = req_pool_indices_cpu[i]
|
413
|
-
|
444
|
+
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
445
|
+
ext_len = seq_len - pre_len
|
446
|
+
seq_lens.append(seq_len)
|
414
447
|
|
415
|
-
if
|
416
|
-
prefix_lens.append(0)
|
417
|
-
else:
|
418
|
-
prefix_lens.append(len(prefix_indices[i]))
|
448
|
+
if pre_len > 0:
|
419
449
|
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
420
|
-
:
|
421
|
-
] = prefix_indices
|
422
|
-
|
423
|
-
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
424
|
-
|
425
|
-
# Allocate memory
|
426
|
-
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
427
|
-
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
428
|
-
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
450
|
+
:pre_len
|
451
|
+
] = req.prefix_indices
|
429
452
|
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
] = out_cache_loc[pt : pt + extend_lens[i]]
|
435
|
-
pt += extend_lens[i]
|
453
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
454
|
+
out_cache_loc[pt : pt + ext_len]
|
455
|
+
)
|
456
|
+
pt += ext_len
|
436
457
|
|
437
458
|
# Set fields
|
438
459
|
with torch.device("cuda"):
|
439
460
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
440
461
|
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
|
441
462
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
442
|
-
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.
|
443
|
-
|
444
|
-
self.pixel_values = [r.pixel_values for r in reqs]
|
445
|
-
self.image_sizes = [r.image_size for r in reqs]
|
446
|
-
self.image_offsets = [
|
447
|
-
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
448
|
-
]
|
449
|
-
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
463
|
+
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
|
464
|
+
|
450
465
|
self.extend_num_tokens = extend_num_tokens
|
451
466
|
self.out_cache_loc = out_cache_loc
|
452
467
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
@@ -522,7 +537,7 @@ class ScheduleBatch:
|
|
522
537
|
residual_size = max(0, residual_size)
|
523
538
|
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
524
539
|
|
525
|
-
req.prefix_indices =
|
540
|
+
req.prefix_indices = []
|
526
541
|
req.last_node = None
|
527
542
|
req.extend_input_len = 0
|
528
543
|
|
@@ -596,15 +611,7 @@ class ScheduleBatch:
|
|
596
611
|
req.vid += 1
|
597
612
|
|
598
613
|
# insert the old request into tree_cache
|
599
|
-
self.tree_cache.
|
600
|
-
rid=req.rid,
|
601
|
-
token_ids=cur_all_ids,
|
602
|
-
last_uncached_pos=len(req.prefix_indices),
|
603
|
-
req_pool_idx=req.req_pool_idx,
|
604
|
-
)
|
605
|
-
|
606
|
-
# unlock the last node
|
607
|
-
self.tree_cache.dec_lock_ref(req.last_node)
|
614
|
+
self.tree_cache.cache_finished_req(req, cur_all_ids)
|
608
615
|
|
609
616
|
# re-applying image padding
|
610
617
|
if req.pixel_values is not None:
|
@@ -621,19 +628,21 @@ class ScheduleBatch:
|
|
621
628
|
jump_forward_reqs.append(req)
|
622
629
|
filter_indices.remove(i)
|
623
630
|
|
624
|
-
|
625
|
-
self.filter_batch(filter_indices)
|
631
|
+
self.filter_batch(filter_indices)
|
626
632
|
|
627
633
|
return jump_forward_reqs
|
628
634
|
|
629
635
|
def prepare_for_decode(self, input_ids=None):
|
630
636
|
if input_ids is None:
|
631
637
|
input_ids = [
|
632
|
-
r.output_ids[-1] if r.output_ids else r.
|
638
|
+
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
|
639
|
+
for r in self.reqs
|
633
640
|
]
|
641
|
+
else:
|
642
|
+
self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
643
|
+
|
634
644
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
635
645
|
self.seq_lens.add_(1)
|
636
|
-
self.prefix_lens = None
|
637
646
|
|
638
647
|
# Alloc mem
|
639
648
|
bs = self.batch_size()
|
@@ -644,23 +653,31 @@ class ScheduleBatch:
|
|
644
653
|
] = self.out_cache_loc
|
645
654
|
|
646
655
|
def filter_batch(self, unfinished_indices: List[int]):
|
656
|
+
if unfinished_indices is None or len(unfinished_indices) == 0:
|
657
|
+
# Filter out all requests
|
658
|
+
self.reqs = []
|
659
|
+
return
|
660
|
+
|
661
|
+
if len(unfinished_indices) == len(self.reqs):
|
662
|
+
# No need to filter
|
663
|
+
return
|
664
|
+
|
647
665
|
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
648
666
|
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
649
667
|
self.seq_lens = self.seq_lens[new_indices]
|
650
668
|
self.input_ids = None
|
651
669
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
652
|
-
self.prefix_lens = None
|
653
670
|
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
654
671
|
self.out_cache_loc = None
|
655
672
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
656
673
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
657
674
|
|
675
|
+
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
676
|
+
|
658
677
|
for item in [
|
659
678
|
"temperatures",
|
660
679
|
"top_ps",
|
661
680
|
"top_ks",
|
662
|
-
"frequency_penalties",
|
663
|
-
"presence_penalties",
|
664
681
|
"logit_bias",
|
665
682
|
]:
|
666
683
|
self_val = getattr(self, item, None)
|
@@ -668,13 +685,17 @@ class ScheduleBatch:
|
|
668
685
|
setattr(self, item, self_val[new_indices])
|
669
686
|
|
670
687
|
def merge(self, other: "ScheduleBatch"):
|
688
|
+
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
689
|
+
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
690
|
+
# needs to be called with pre-merged Batch.reqs.
|
691
|
+
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
692
|
+
|
671
693
|
self.reqs.extend(other.reqs)
|
672
694
|
|
673
695
|
self.req_pool_indices = torch.concat(
|
674
696
|
[self.req_pool_indices, other.req_pool_indices]
|
675
697
|
)
|
676
698
|
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
677
|
-
self.prefix_lens = None
|
678
699
|
self.position_ids_offsets = torch.concat(
|
679
700
|
[self.position_ids_offsets, other.position_ids_offsets]
|
680
701
|
)
|
@@ -686,8 +707,6 @@ class ScheduleBatch:
|
|
686
707
|
"temperatures",
|
687
708
|
"top_ps",
|
688
709
|
"top_ks",
|
689
|
-
"frequency_penalties",
|
690
|
-
"presence_penalties",
|
691
710
|
]:
|
692
711
|
self_val = getattr(self, item, None)
|
693
712
|
other_val = getattr(other, item, None)
|
@@ -711,6 +730,7 @@ class ScheduleBatch:
|
|
711
730
|
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
712
731
|
|
713
732
|
def sample(self, logits: torch.Tensor):
|
733
|
+
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
|
714
734
|
# Post process logits
|
715
735
|
logits = logits.contiguous()
|
716
736
|
logits.div_(self.temperatures)
|
@@ -728,7 +748,8 @@ class ScheduleBatch:
|
|
728
748
|
] = 1
|
729
749
|
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
730
750
|
|
731
|
-
|
751
|
+
logits = self.penalizer_orchestrator.apply(logits)
|
752
|
+
|
732
753
|
probs = torch.softmax(logits, dim=-1)
|
733
754
|
|
734
755
|
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
@@ -761,6 +782,8 @@ class ScheduleBatch:
|
|
761
782
|
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
762
783
|
)
|
763
784
|
|
785
|
+
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
|
786
|
+
|
764
787
|
return batch_next_token_ids
|
765
788
|
|
766
789
|
|
@@ -780,7 +803,7 @@ def top_k_top_p_sampling_from_probs_torch(
|
|
780
803
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
781
804
|
except RuntimeError:
|
782
805
|
batch_next_token_ids = torch.zeros(
|
783
|
-
(probs_sort.shape[0],), dtype=torch.
|
806
|
+
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
784
807
|
)
|
785
808
|
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
786
809
|
return batch_next_token_ids, success
|