sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +9 -6
  3. sglang/bench_serving.py +46 -22
  4. sglang/global_config.py +1 -1
  5. sglang/lang/backend/runtime_endpoint.py +60 -49
  6. sglang/lang/compiler.py +2 -2
  7. sglang/lang/interpreter.py +4 -2
  8. sglang/lang/ir.py +16 -7
  9. sglang/srt/constrained/base_tool_cache.py +1 -1
  10. sglang/srt/constrained/fsm_cache.py +12 -2
  11. sglang/srt/constrained/jump_forward.py +13 -2
  12. sglang/srt/layers/activation.py +32 -0
  13. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  14. sglang/srt/layers/extend_attention.py +9 -2
  15. sglang/srt/layers/fused_moe/__init__.py +1 -0
  16. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  17. sglang/srt/layers/fused_moe/layer.py +587 -0
  18. sglang/srt/layers/layernorm.py +65 -0
  19. sglang/srt/layers/logits_processor.py +7 -2
  20. sglang/srt/layers/pooler.py +50 -0
  21. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  22. sglang/srt/layers/radix_attention.py +40 -16
  23. sglang/srt/managers/detokenizer_manager.py +31 -9
  24. sglang/srt/managers/io_struct.py +63 -0
  25. sglang/srt/managers/policy_scheduler.py +173 -25
  26. sglang/srt/managers/schedule_batch.py +115 -97
  27. sglang/srt/managers/tokenizer_manager.py +194 -112
  28. sglang/srt/managers/tp_worker.py +290 -359
  29. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  30. sglang/srt/mem_cache/chunk_cache.py +43 -20
  31. sglang/srt/mem_cache/memory_pool.py +2 -2
  32. sglang/srt/mem_cache/radix_cache.py +74 -40
  33. sglang/srt/model_executor/cuda_graph_runner.py +71 -25
  34. sglang/srt/model_executor/forward_batch_info.py +293 -156
  35. sglang/srt/model_executor/model_runner.py +77 -57
  36. sglang/srt/models/chatglm.py +2 -2
  37. sglang/srt/models/commandr.py +1 -1
  38. sglang/srt/models/deepseek.py +2 -2
  39. sglang/srt/models/deepseek_v2.py +7 -6
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +11 -6
  42. sglang/srt/models/grok.py +50 -396
  43. sglang/srt/models/internlm2.py +2 -7
  44. sglang/srt/models/llama2.py +4 -4
  45. sglang/srt/models/llama_embedding.py +88 -0
  46. sglang/srt/models/minicpm.py +2 -2
  47. sglang/srt/models/mixtral.py +56 -254
  48. sglang/srt/models/mixtral_quant.py +1 -4
  49. sglang/srt/models/qwen.py +2 -2
  50. sglang/srt/models/qwen2.py +2 -2
  51. sglang/srt/models/qwen2_moe.py +2 -13
  52. sglang/srt/models/stablelm.py +1 -1
  53. sglang/srt/openai_api/adapter.py +187 -48
  54. sglang/srt/openai_api/protocol.py +37 -1
  55. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  56. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  57. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  58. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  59. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  60. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  61. sglang/srt/sampling_params.py +31 -8
  62. sglang/srt/server.py +91 -29
  63. sglang/srt/server_args.py +32 -19
  64. sglang/srt/utils.py +32 -15
  65. sglang/test/run_eval.py +10 -1
  66. sglang/test/runners.py +81 -73
  67. sglang/test/simple_eval_humaneval.py +2 -8
  68. sglang/test/simple_eval_mgsm.py +203 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  70. sglang/test/test_layernorm.py +60 -0
  71. sglang/test/test_programs.py +36 -7
  72. sglang/test/test_utils.py +24 -2
  73. sglang/utils.py +0 -1
  74. sglang/version.py +1 -1
  75. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
  76. sglang-0.2.13.dist-info/RECORD +112 -0
  77. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  78. sglang/srt/layers/linear.py +0 -884
  79. sglang/srt/layers/quantization/__init__.py +0 -64
  80. sglang/srt/layers/quantization/fp8.py +0 -677
  81. sglang/srt/model_loader/model_loader.py +0 -292
  82. sglang/srt/model_loader/utils.py +0 -275
  83. sglang-0.2.11.dist-info/RECORD +0 -102
  84. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  85. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -18,18 +18,18 @@ limitations under the License.
18
18
  import logging
19
19
  import warnings
20
20
  from dataclasses import dataclass
21
- from typing import List, Union
21
+ from typing import List, Optional, Union
22
22
 
23
- import numpy as np
24
23
  import torch
25
24
  from flashinfer.sampling import top_k_top_p_sampling_from_probs
26
25
 
26
+ import sglang.srt.sampling.penaltylib as penaltylib
27
27
  from sglang.global_config import global_config
28
28
  from sglang.srt.constrained import RegexGuide
29
29
  from sglang.srt.constrained.jump_forward import JumpForwardMap
30
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
30
31
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
31
32
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
32
- from sglang.srt.mem_cache.radix_cache import RadixCache
33
33
 
34
34
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
35
35
 
@@ -98,7 +98,7 @@ class Req:
98
98
  self.origin_input_ids_unpadded = origin_input_ids # Before image padding
99
99
  self.origin_input_ids = origin_input_ids
100
100
  self.output_ids = [] # Each decode stage's output ids
101
- self.input_ids = None # input_ids = origin_input_ids + output_ids
101
+ self.fill_ids = None # fill_ids = origin_input_ids + output_ids
102
102
 
103
103
  # Memory info
104
104
  self.req_pool_idx = None
@@ -124,7 +124,7 @@ class Req:
124
124
  # For vision input
125
125
  self.pixel_values = None
126
126
  self.image_size = None
127
- self.image_offset = 0
127
+ self.image_offset = None
128
128
  self.pad_value = None
129
129
 
130
130
  # Prefix info
@@ -142,6 +142,7 @@ class Req:
142
142
 
143
143
  # Logprobs
144
144
  self.return_logprob = False
145
+ self.embedding = None
145
146
  self.logprob_start_len = 0
146
147
  self.top_logprobs_num = 0
147
148
  self.normalized_prompt_logprob = None
@@ -162,6 +163,32 @@ class Req:
162
163
  def finished(self) -> bool:
163
164
  return self.finished_reason is not None
164
165
 
166
+ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
167
+ self.fill_ids = self.origin_input_ids + self.output_ids
168
+ if tree_cache is not None:
169
+ self.prefix_indices, self.last_node = tree_cache.match_prefix(
170
+ rid=self.rid, key=self.adjust_max_prefix_ids()
171
+ )
172
+ self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
173
+
174
+ def adjust_max_prefix_ids(self):
175
+ self.fill_ids = self.origin_input_ids + self.output_ids
176
+ input_len = len(self.fill_ids)
177
+ max_prefix_len = input_len
178
+
179
+ if self.sampling_params.max_new_tokens > 0:
180
+ # Need at least one token to compute logits
181
+ max_prefix_len = min(max_prefix_len, input_len - 1)
182
+
183
+ if self.return_logprob:
184
+ max_prefix_len = min(max_prefix_len, self.logprob_start_len)
185
+
186
+ if self.normalized_prompt_logprob is None:
187
+ # Need at least two tokens to compute normalized logprob
188
+ max_prefix_len = min(max_prefix_len, input_len - 2)
189
+
190
+ return self.fill_ids[:max_prefix_len]
191
+
165
192
  # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
166
193
  def init_incremental_detokenize(self):
167
194
  first_iter = self.surr_offset is None or self.read_offset is None
@@ -176,6 +203,8 @@ class Req:
176
203
  return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
177
204
 
178
205
  def get_next_inc_detokenization(self):
206
+ if self.tokenizer is None:
207
+ return False, ""
179
208
  read_ids, read_offset = self.init_incremental_detokenize()
180
209
  surr_ids = read_ids[:read_offset]
181
210
 
@@ -200,16 +229,20 @@ class Req:
200
229
  return
201
230
 
202
231
  if len(self.output_ids) >= self.sampling_params.max_new_tokens:
203
- self.finished_reason = FINISH_LENGTH(len(self.output_ids))
232
+ self.finished_reason = FINISH_LENGTH(
233
+ length=self.sampling_params.max_new_tokens
234
+ )
204
235
  return
205
236
 
206
- if (
207
- self.output_ids[-1] == self.tokenizer.eos_token_id
208
- and not self.sampling_params.ignore_eos
209
- ):
210
- self.finished_reason = FINISH_MATCHED_TOKEN(
211
- matched=self.tokenizer.eos_token_id
212
- )
237
+ last_token_id = self.output_ids[-1]
238
+
239
+ matched_eos = last_token_id in self.sampling_params.stop_token_ids
240
+
241
+ if self.tokenizer is not None:
242
+ matched_eos |= last_token_id == self.tokenizer.eos_token_id
243
+
244
+ if matched_eos and not self.sampling_params.ignore_eos:
245
+ self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
213
246
  return
214
247
 
215
248
  if len(self.sampling_params.stop_strs) > 0:
@@ -284,13 +317,12 @@ class ScheduleBatch:
284
317
  reqs: List[Req]
285
318
  req_to_token_pool: ReqToTokenPool
286
319
  token_to_kv_pool: BaseTokenToKVPool
287
- tree_cache: RadixCache
320
+ tree_cache: BasePrefixCache
288
321
 
289
322
  # Batched arguments to model runner
290
323
  input_ids: torch.Tensor = None
291
324
  req_pool_indices: torch.Tensor = None
292
325
  seq_lens: torch.Tensor = None
293
- prefix_lens: torch.Tensor = None
294
326
  position_ids_offsets: torch.Tensor = None
295
327
  out_cache_loc: torch.Tensor = None
296
328
  extend_num_tokens: int = None
@@ -299,17 +331,11 @@ class ScheduleBatch:
299
331
  return_logprob: bool = False
300
332
  top_logprobs_nums: List[int] = None
301
333
 
302
- # For multimodal
303
- pixel_values: List[torch.Tensor] = None
304
- image_sizes: List[List[int]] = None
305
- image_offsets: List[int] = None
306
-
307
334
  # Batched sampling params
308
335
  temperatures: torch.Tensor = None
309
336
  top_ps: torch.Tensor = None
310
337
  top_ks: torch.Tensor = None
311
- frequency_penalties: torch.Tensor = None
312
- presence_penalties: torch.Tensor = None
338
+ penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
313
339
  logit_bias: torch.Tensor = None
314
340
 
315
341
  @classmethod
@@ -359,7 +385,7 @@ class ScheduleBatch:
359
385
 
360
386
  return out_cache_loc
361
387
 
362
- def batch_sampling_params(self, vocab_size, int_token_logit_bias):
388
+ def batch_sampling_params(self, vocab_size):
363
389
  device = "cuda"
364
390
  bs, reqs = self.batch_size(), self.reqs
365
391
  self.temperatures = torch.tensor(
@@ -373,85 +399,69 @@ class ScheduleBatch:
373
399
  self.top_ks = torch.tensor(
374
400
  [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
375
401
  )
376
- self.frequency_penalties = torch.tensor(
377
- [r.sampling_params.frequency_penalty for r in reqs],
378
- dtype=torch.float,
379
- device=device,
380
- )
381
- self.presence_penalties = torch.tensor(
382
- [r.sampling_params.presence_penalty for r in reqs],
383
- dtype=torch.float,
402
+
403
+ # Each penalizers will do nothing if they evaluate themselves as not required by looking at
404
+ # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
405
+ # should not add hefty computation overhead other than simple checks.
406
+ #
407
+ # While we choose not to even create the class instances if they are not required, this
408
+ # could add additional complexity to the {ScheduleBatch} class, especially we need to
409
+ # handle {filter_batch()} and {merge()} cases as well.
410
+ self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
411
+ vocab_size=vocab_size,
412
+ batch=self,
384
413
  device=device,
414
+ Penalizers={
415
+ penaltylib.BatchedFrequencyPenalizer,
416
+ penaltylib.BatchedMinNewTokensPenalizer,
417
+ penaltylib.BatchedPresencePenalizer,
418
+ penaltylib.BatchedRepetitionPenalizer,
419
+ },
385
420
  )
386
421
 
387
422
  # Handle logit bias but only allocate when needed
388
423
  self.logit_bias = None
389
- for i in range(bs):
390
- if reqs[i].sampling_params.dtype == "int":
391
- if self.logit_bias is None:
392
- self.logit_bias = torch.zeros(
393
- (bs, vocab_size), dtype=torch.float32, device=device
394
- )
395
- self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
396
424
 
397
- def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
398
- device = "cuda"
425
+ def prepare_for_extend(self, vocab_size: int):
399
426
  bs = self.batch_size()
400
427
  reqs = self.reqs
401
- input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
402
- prefix_indices = [r.prefix_indices for r in reqs]
403
-
404
- # Handle prefix
405
- extend_lens = []
406
- prefix_lens = []
428
+ input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
429
+ extend_num_tokens = sum(len(ids) for ids in input_ids)
407
430
  seq_lens = []
408
431
 
432
+ # Allocate memory
409
433
  req_pool_indices_cpu = self.alloc_req_slots(bs)
434
+ out_cache_loc = self.alloc_token_slots(extend_num_tokens)
410
435
 
436
+ pt = 0
411
437
  for i, req in enumerate(reqs):
412
438
  req.req_pool_idx = req_pool_indices_cpu[i]
413
- extend_lens.append(len(input_ids[i]))
439
+ pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
440
+ ext_len = seq_len - pre_len
441
+ seq_lens.append(seq_len)
414
442
 
415
- if len(prefix_indices[i]) == 0:
416
- prefix_lens.append(0)
417
- else:
418
- prefix_lens.append(len(prefix_indices[i]))
443
+ if pre_len > 0:
419
444
  self.req_to_token_pool.req_to_token[req.req_pool_idx][
420
- : len(prefix_indices[i])
421
- ] = prefix_indices[i]
422
-
423
- seq_lens.append(prefix_lens[-1] + extend_lens[-1])
424
-
425
- # Allocate memory
426
- seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
427
- extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
428
- out_cache_loc = self.alloc_token_slots(extend_num_tokens)
445
+ :pre_len
446
+ ] = req.prefix_indices
429
447
 
430
- pt = 0
431
- for i, req in enumerate(reqs):
432
- self.req_to_token_pool.req_to_token[req.req_pool_idx][
433
- prefix_lens[i] : prefix_lens[i] + extend_lens[i]
434
- ] = out_cache_loc[pt : pt + extend_lens[i]]
435
- pt += extend_lens[i]
448
+ self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
449
+ out_cache_loc[pt : pt + ext_len]
450
+ )
451
+ pt += ext_len
436
452
 
437
453
  # Set fields
438
454
  with torch.device("cuda"):
439
455
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
440
456
  self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
441
457
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
442
- self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32)
443
-
444
- self.pixel_values = [r.pixel_values for r in reqs]
445
- self.image_sizes = [r.image_size for r in reqs]
446
- self.image_offsets = [
447
- r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
448
- ]
449
- self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
458
+ self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
459
+
450
460
  self.extend_num_tokens = extend_num_tokens
451
461
  self.out_cache_loc = out_cache_loc
452
462
  self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
453
463
 
454
- self.batch_sampling_params(vocab_size, int_token_logit_bias)
464
+ self.batch_sampling_params(vocab_size)
455
465
 
456
466
  def check_decode_mem(self):
457
467
  bs = self.batch_size()
@@ -522,7 +532,7 @@ class ScheduleBatch:
522
532
  residual_size = max(0, residual_size)
523
533
  self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
524
534
 
525
- req.prefix_indices = None
535
+ req.prefix_indices = []
526
536
  req.last_node = None
527
537
  req.extend_input_len = 0
528
538
 
@@ -596,15 +606,7 @@ class ScheduleBatch:
596
606
  req.vid += 1
597
607
 
598
608
  # insert the old request into tree_cache
599
- self.tree_cache.cache_req(
600
- rid=req.rid,
601
- token_ids=cur_all_ids,
602
- last_uncached_pos=len(req.prefix_indices),
603
- req_pool_idx=req.req_pool_idx,
604
- )
605
-
606
- # unlock the last node
607
- self.tree_cache.dec_lock_ref(req.last_node)
609
+ self.tree_cache.cache_finished_req(req, cur_all_ids)
608
610
 
609
611
  # re-applying image padding
610
612
  if req.pixel_values is not None:
@@ -621,19 +623,21 @@ class ScheduleBatch:
621
623
  jump_forward_reqs.append(req)
622
624
  filter_indices.remove(i)
623
625
 
624
- if len(filter_indices) < len(self.reqs):
625
- self.filter_batch(filter_indices)
626
+ self.filter_batch(filter_indices)
626
627
 
627
628
  return jump_forward_reqs
628
629
 
629
630
  def prepare_for_decode(self, input_ids=None):
630
631
  if input_ids is None:
631
632
  input_ids = [
632
- r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
633
+ r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
634
+ for r in self.reqs
633
635
  ]
636
+ else:
637
+ self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
638
+
634
639
  self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
635
640
  self.seq_lens.add_(1)
636
- self.prefix_lens = None
637
641
 
638
642
  # Alloc mem
639
643
  bs = self.batch_size()
@@ -644,23 +648,31 @@ class ScheduleBatch:
644
648
  ] = self.out_cache_loc
645
649
 
646
650
  def filter_batch(self, unfinished_indices: List[int]):
651
+ if unfinished_indices is None or len(unfinished_indices) == 0:
652
+ # Filter out all requests
653
+ self.reqs = []
654
+ return
655
+
656
+ if len(unfinished_indices) == len(self.reqs):
657
+ # No need to filter
658
+ return
659
+
647
660
  self.reqs = [self.reqs[i] for i in unfinished_indices]
648
661
  new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
649
662
  self.seq_lens = self.seq_lens[new_indices]
650
663
  self.input_ids = None
651
664
  self.req_pool_indices = self.req_pool_indices[new_indices]
652
- self.prefix_lens = None
653
665
  self.position_ids_offsets = self.position_ids_offsets[new_indices]
654
666
  self.out_cache_loc = None
655
667
  self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
656
668
  self.return_logprob = any(req.return_logprob for req in self.reqs)
657
669
 
670
+ self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
671
+
658
672
  for item in [
659
673
  "temperatures",
660
674
  "top_ps",
661
675
  "top_ks",
662
- "frequency_penalties",
663
- "presence_penalties",
664
676
  "logit_bias",
665
677
  ]:
666
678
  self_val = getattr(self, item, None)
@@ -668,13 +680,17 @@ class ScheduleBatch:
668
680
  setattr(self, item, self_val[new_indices])
669
681
 
670
682
  def merge(self, other: "ScheduleBatch"):
683
+ # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
684
+ # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
685
+ # needs to be called with pre-merged Batch.reqs.
686
+ self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
687
+
671
688
  self.reqs.extend(other.reqs)
672
689
 
673
690
  self.req_pool_indices = torch.concat(
674
691
  [self.req_pool_indices, other.req_pool_indices]
675
692
  )
676
693
  self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
677
- self.prefix_lens = None
678
694
  self.position_ids_offsets = torch.concat(
679
695
  [self.position_ids_offsets, other.position_ids_offsets]
680
696
  )
@@ -686,8 +702,6 @@ class ScheduleBatch:
686
702
  "temperatures",
687
703
  "top_ps",
688
704
  "top_ks",
689
- "frequency_penalties",
690
- "presence_penalties",
691
705
  ]:
692
706
  self_val = getattr(self, item, None)
693
707
  other_val = getattr(other, item, None)
@@ -711,6 +725,7 @@ class ScheduleBatch:
711
725
  self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
712
726
 
713
727
  def sample(self, logits: torch.Tensor):
728
+ # TODO(lsyin): move this into a part of layer and run with CUDA Graph
714
729
  # Post process logits
715
730
  logits = logits.contiguous()
716
731
  logits.div_(self.temperatures)
@@ -728,7 +743,8 @@ class ScheduleBatch:
728
743
  ] = 1
729
744
  logits[i].masked_fill_(~allowed_mask, float("-inf"))
730
745
 
731
- # TODO(lmzheng): apply penalty
746
+ logits = self.penalizer_orchestrator.apply(logits)
747
+
732
748
  probs = torch.softmax(logits, dim=-1)
733
749
 
734
750
  if not global_server_args_dict["disable_flashinfer_sampling"]:
@@ -761,6 +777,8 @@ class ScheduleBatch:
761
777
  req.regex_fsm_state, batch_next_token_ids_cpu[i]
762
778
  )
763
779
 
780
+ self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
781
+
764
782
  return batch_next_token_ids
765
783
 
766
784
 
@@ -780,7 +798,7 @@ def top_k_top_p_sampling_from_probs_torch(
780
798
  sampled_index = torch.multinomial(probs_sort, num_samples=1)
781
799
  except RuntimeError:
782
800
  batch_next_token_ids = torch.zeros(
783
- (probs_sort.shape[0],), dtype=torch.int64, device=probs.device
801
+ (probs_sort.shape[0],), dtype=torch.int32, device=probs.device
784
802
  )
785
803
  success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
786
804
  return batch_next_token_ids, success