sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +7 -1
- sglang/bench_latency.py +9 -6
- sglang/bench_serving.py +46 -22
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +4 -2
- sglang/lang/ir.py +16 -7
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/layers/activation.py +32 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +9 -2
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +7 -2
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +40 -16
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +115 -97
- sglang/srt/managers/tokenizer_manager.py +194 -112
- sglang/srt/managers/tp_worker.py +290 -359
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +71 -25
- sglang/srt/model_executor/forward_batch_info.py +293 -156
- sglang/srt/model_executor/model_runner.py +77 -57
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/deepseek.py +2 -2
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +11 -6
- sglang/srt/models/grok.py +50 -396
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/mixtral.py +56 -254
- sglang/srt/models/mixtral_quant.py +1 -4
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +2 -13
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +187 -48
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -8
- sglang/srt/server.py +91 -29
- sglang/srt/server_args.py +32 -19
- sglang/srt/utils.py +32 -15
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +81 -73
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +36 -7
- sglang/test/test_utils.py +24 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
- sglang-0.2.13.dist-info/RECORD +112 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -18,18 +18,18 @@ limitations under the License.
|
|
18
18
|
import logging
|
19
19
|
import warnings
|
20
20
|
from dataclasses import dataclass
|
21
|
-
from typing import List, Union
|
21
|
+
from typing import List, Optional, Union
|
22
22
|
|
23
|
-
import numpy as np
|
24
23
|
import torch
|
25
24
|
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
26
25
|
|
26
|
+
import sglang.srt.sampling.penaltylib as penaltylib
|
27
27
|
from sglang.global_config import global_config
|
28
28
|
from sglang.srt.constrained import RegexGuide
|
29
29
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
30
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
30
31
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
31
32
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
32
|
-
from sglang.srt.mem_cache.radix_cache import RadixCache
|
33
33
|
|
34
34
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
35
35
|
|
@@ -98,7 +98,7 @@ class Req:
|
|
98
98
|
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
99
99
|
self.origin_input_ids = origin_input_ids
|
100
100
|
self.output_ids = [] # Each decode stage's output ids
|
101
|
-
self.
|
101
|
+
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
102
102
|
|
103
103
|
# Memory info
|
104
104
|
self.req_pool_idx = None
|
@@ -124,7 +124,7 @@ class Req:
|
|
124
124
|
# For vision input
|
125
125
|
self.pixel_values = None
|
126
126
|
self.image_size = None
|
127
|
-
self.image_offset =
|
127
|
+
self.image_offset = None
|
128
128
|
self.pad_value = None
|
129
129
|
|
130
130
|
# Prefix info
|
@@ -142,6 +142,7 @@ class Req:
|
|
142
142
|
|
143
143
|
# Logprobs
|
144
144
|
self.return_logprob = False
|
145
|
+
self.embedding = None
|
145
146
|
self.logprob_start_len = 0
|
146
147
|
self.top_logprobs_num = 0
|
147
148
|
self.normalized_prompt_logprob = None
|
@@ -162,6 +163,32 @@ class Req:
|
|
162
163
|
def finished(self) -> bool:
|
163
164
|
return self.finished_reason is not None
|
164
165
|
|
166
|
+
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
167
|
+
self.fill_ids = self.origin_input_ids + self.output_ids
|
168
|
+
if tree_cache is not None:
|
169
|
+
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
170
|
+
rid=self.rid, key=self.adjust_max_prefix_ids()
|
171
|
+
)
|
172
|
+
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
173
|
+
|
174
|
+
def adjust_max_prefix_ids(self):
|
175
|
+
self.fill_ids = self.origin_input_ids + self.output_ids
|
176
|
+
input_len = len(self.fill_ids)
|
177
|
+
max_prefix_len = input_len
|
178
|
+
|
179
|
+
if self.sampling_params.max_new_tokens > 0:
|
180
|
+
# Need at least one token to compute logits
|
181
|
+
max_prefix_len = min(max_prefix_len, input_len - 1)
|
182
|
+
|
183
|
+
if self.return_logprob:
|
184
|
+
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
185
|
+
|
186
|
+
if self.normalized_prompt_logprob is None:
|
187
|
+
# Need at least two tokens to compute normalized logprob
|
188
|
+
max_prefix_len = min(max_prefix_len, input_len - 2)
|
189
|
+
|
190
|
+
return self.fill_ids[:max_prefix_len]
|
191
|
+
|
165
192
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
166
193
|
def init_incremental_detokenize(self):
|
167
194
|
first_iter = self.surr_offset is None or self.read_offset is None
|
@@ -176,6 +203,8 @@ class Req:
|
|
176
203
|
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
177
204
|
|
178
205
|
def get_next_inc_detokenization(self):
|
206
|
+
if self.tokenizer is None:
|
207
|
+
return False, ""
|
179
208
|
read_ids, read_offset = self.init_incremental_detokenize()
|
180
209
|
surr_ids = read_ids[:read_offset]
|
181
210
|
|
@@ -200,16 +229,20 @@ class Req:
|
|
200
229
|
return
|
201
230
|
|
202
231
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
203
|
-
self.finished_reason = FINISH_LENGTH(
|
232
|
+
self.finished_reason = FINISH_LENGTH(
|
233
|
+
length=self.sampling_params.max_new_tokens
|
234
|
+
)
|
204
235
|
return
|
205
236
|
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
237
|
+
last_token_id = self.output_ids[-1]
|
238
|
+
|
239
|
+
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
240
|
+
|
241
|
+
if self.tokenizer is not None:
|
242
|
+
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
243
|
+
|
244
|
+
if matched_eos and not self.sampling_params.ignore_eos:
|
245
|
+
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
213
246
|
return
|
214
247
|
|
215
248
|
if len(self.sampling_params.stop_strs) > 0:
|
@@ -284,13 +317,12 @@ class ScheduleBatch:
|
|
284
317
|
reqs: List[Req]
|
285
318
|
req_to_token_pool: ReqToTokenPool
|
286
319
|
token_to_kv_pool: BaseTokenToKVPool
|
287
|
-
tree_cache:
|
320
|
+
tree_cache: BasePrefixCache
|
288
321
|
|
289
322
|
# Batched arguments to model runner
|
290
323
|
input_ids: torch.Tensor = None
|
291
324
|
req_pool_indices: torch.Tensor = None
|
292
325
|
seq_lens: torch.Tensor = None
|
293
|
-
prefix_lens: torch.Tensor = None
|
294
326
|
position_ids_offsets: torch.Tensor = None
|
295
327
|
out_cache_loc: torch.Tensor = None
|
296
328
|
extend_num_tokens: int = None
|
@@ -299,17 +331,11 @@ class ScheduleBatch:
|
|
299
331
|
return_logprob: bool = False
|
300
332
|
top_logprobs_nums: List[int] = None
|
301
333
|
|
302
|
-
# For multimodal
|
303
|
-
pixel_values: List[torch.Tensor] = None
|
304
|
-
image_sizes: List[List[int]] = None
|
305
|
-
image_offsets: List[int] = None
|
306
|
-
|
307
334
|
# Batched sampling params
|
308
335
|
temperatures: torch.Tensor = None
|
309
336
|
top_ps: torch.Tensor = None
|
310
337
|
top_ks: torch.Tensor = None
|
311
|
-
|
312
|
-
presence_penalties: torch.Tensor = None
|
338
|
+
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
313
339
|
logit_bias: torch.Tensor = None
|
314
340
|
|
315
341
|
@classmethod
|
@@ -359,7 +385,7 @@ class ScheduleBatch:
|
|
359
385
|
|
360
386
|
return out_cache_loc
|
361
387
|
|
362
|
-
def batch_sampling_params(self, vocab_size
|
388
|
+
def batch_sampling_params(self, vocab_size):
|
363
389
|
device = "cuda"
|
364
390
|
bs, reqs = self.batch_size(), self.reqs
|
365
391
|
self.temperatures = torch.tensor(
|
@@ -373,85 +399,69 @@ class ScheduleBatch:
|
|
373
399
|
self.top_ks = torch.tensor(
|
374
400
|
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
375
401
|
)
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
402
|
+
|
403
|
+
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
404
|
+
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
405
|
+
# should not add hefty computation overhead other than simple checks.
|
406
|
+
#
|
407
|
+
# While we choose not to even create the class instances if they are not required, this
|
408
|
+
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
409
|
+
# handle {filter_batch()} and {merge()} cases as well.
|
410
|
+
self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
411
|
+
vocab_size=vocab_size,
|
412
|
+
batch=self,
|
384
413
|
device=device,
|
414
|
+
Penalizers={
|
415
|
+
penaltylib.BatchedFrequencyPenalizer,
|
416
|
+
penaltylib.BatchedMinNewTokensPenalizer,
|
417
|
+
penaltylib.BatchedPresencePenalizer,
|
418
|
+
penaltylib.BatchedRepetitionPenalizer,
|
419
|
+
},
|
385
420
|
)
|
386
421
|
|
387
422
|
# Handle logit bias but only allocate when needed
|
388
423
|
self.logit_bias = None
|
389
|
-
for i in range(bs):
|
390
|
-
if reqs[i].sampling_params.dtype == "int":
|
391
|
-
if self.logit_bias is None:
|
392
|
-
self.logit_bias = torch.zeros(
|
393
|
-
(bs, vocab_size), dtype=torch.float32, device=device
|
394
|
-
)
|
395
|
-
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
396
424
|
|
397
|
-
def prepare_for_extend(self, vocab_size: int
|
398
|
-
device = "cuda"
|
425
|
+
def prepare_for_extend(self, vocab_size: int):
|
399
426
|
bs = self.batch_size()
|
400
427
|
reqs = self.reqs
|
401
|
-
input_ids = [r.
|
402
|
-
|
403
|
-
|
404
|
-
# Handle prefix
|
405
|
-
extend_lens = []
|
406
|
-
prefix_lens = []
|
428
|
+
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
429
|
+
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
407
430
|
seq_lens = []
|
408
431
|
|
432
|
+
# Allocate memory
|
409
433
|
req_pool_indices_cpu = self.alloc_req_slots(bs)
|
434
|
+
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
410
435
|
|
436
|
+
pt = 0
|
411
437
|
for i, req in enumerate(reqs):
|
412
438
|
req.req_pool_idx = req_pool_indices_cpu[i]
|
413
|
-
|
439
|
+
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
440
|
+
ext_len = seq_len - pre_len
|
441
|
+
seq_lens.append(seq_len)
|
414
442
|
|
415
|
-
if
|
416
|
-
prefix_lens.append(0)
|
417
|
-
else:
|
418
|
-
prefix_lens.append(len(prefix_indices[i]))
|
443
|
+
if pre_len > 0:
|
419
444
|
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
420
|
-
:
|
421
|
-
] = prefix_indices
|
422
|
-
|
423
|
-
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
424
|
-
|
425
|
-
# Allocate memory
|
426
|
-
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
427
|
-
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
428
|
-
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
445
|
+
:pre_len
|
446
|
+
] = req.prefix_indices
|
429
447
|
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
] = out_cache_loc[pt : pt + extend_lens[i]]
|
435
|
-
pt += extend_lens[i]
|
448
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
449
|
+
out_cache_loc[pt : pt + ext_len]
|
450
|
+
)
|
451
|
+
pt += ext_len
|
436
452
|
|
437
453
|
# Set fields
|
438
454
|
with torch.device("cuda"):
|
439
455
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
440
456
|
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
|
441
457
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
442
|
-
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.
|
443
|
-
|
444
|
-
self.pixel_values = [r.pixel_values for r in reqs]
|
445
|
-
self.image_sizes = [r.image_size for r in reqs]
|
446
|
-
self.image_offsets = [
|
447
|
-
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
448
|
-
]
|
449
|
-
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
458
|
+
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
|
459
|
+
|
450
460
|
self.extend_num_tokens = extend_num_tokens
|
451
461
|
self.out_cache_loc = out_cache_loc
|
452
462
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
453
463
|
|
454
|
-
self.batch_sampling_params(vocab_size
|
464
|
+
self.batch_sampling_params(vocab_size)
|
455
465
|
|
456
466
|
def check_decode_mem(self):
|
457
467
|
bs = self.batch_size()
|
@@ -522,7 +532,7 @@ class ScheduleBatch:
|
|
522
532
|
residual_size = max(0, residual_size)
|
523
533
|
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
524
534
|
|
525
|
-
req.prefix_indices =
|
535
|
+
req.prefix_indices = []
|
526
536
|
req.last_node = None
|
527
537
|
req.extend_input_len = 0
|
528
538
|
|
@@ -596,15 +606,7 @@ class ScheduleBatch:
|
|
596
606
|
req.vid += 1
|
597
607
|
|
598
608
|
# insert the old request into tree_cache
|
599
|
-
self.tree_cache.
|
600
|
-
rid=req.rid,
|
601
|
-
token_ids=cur_all_ids,
|
602
|
-
last_uncached_pos=len(req.prefix_indices),
|
603
|
-
req_pool_idx=req.req_pool_idx,
|
604
|
-
)
|
605
|
-
|
606
|
-
# unlock the last node
|
607
|
-
self.tree_cache.dec_lock_ref(req.last_node)
|
609
|
+
self.tree_cache.cache_finished_req(req, cur_all_ids)
|
608
610
|
|
609
611
|
# re-applying image padding
|
610
612
|
if req.pixel_values is not None:
|
@@ -621,19 +623,21 @@ class ScheduleBatch:
|
|
621
623
|
jump_forward_reqs.append(req)
|
622
624
|
filter_indices.remove(i)
|
623
625
|
|
624
|
-
|
625
|
-
self.filter_batch(filter_indices)
|
626
|
+
self.filter_batch(filter_indices)
|
626
627
|
|
627
628
|
return jump_forward_reqs
|
628
629
|
|
629
630
|
def prepare_for_decode(self, input_ids=None):
|
630
631
|
if input_ids is None:
|
631
632
|
input_ids = [
|
632
|
-
r.output_ids[-1] if r.output_ids else r.
|
633
|
+
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
|
634
|
+
for r in self.reqs
|
633
635
|
]
|
636
|
+
else:
|
637
|
+
self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
638
|
+
|
634
639
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
635
640
|
self.seq_lens.add_(1)
|
636
|
-
self.prefix_lens = None
|
637
641
|
|
638
642
|
# Alloc mem
|
639
643
|
bs = self.batch_size()
|
@@ -644,23 +648,31 @@ class ScheduleBatch:
|
|
644
648
|
] = self.out_cache_loc
|
645
649
|
|
646
650
|
def filter_batch(self, unfinished_indices: List[int]):
|
651
|
+
if unfinished_indices is None or len(unfinished_indices) == 0:
|
652
|
+
# Filter out all requests
|
653
|
+
self.reqs = []
|
654
|
+
return
|
655
|
+
|
656
|
+
if len(unfinished_indices) == len(self.reqs):
|
657
|
+
# No need to filter
|
658
|
+
return
|
659
|
+
|
647
660
|
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
648
661
|
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
649
662
|
self.seq_lens = self.seq_lens[new_indices]
|
650
663
|
self.input_ids = None
|
651
664
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
652
|
-
self.prefix_lens = None
|
653
665
|
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
654
666
|
self.out_cache_loc = None
|
655
667
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
656
668
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
657
669
|
|
670
|
+
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
671
|
+
|
658
672
|
for item in [
|
659
673
|
"temperatures",
|
660
674
|
"top_ps",
|
661
675
|
"top_ks",
|
662
|
-
"frequency_penalties",
|
663
|
-
"presence_penalties",
|
664
676
|
"logit_bias",
|
665
677
|
]:
|
666
678
|
self_val = getattr(self, item, None)
|
@@ -668,13 +680,17 @@ class ScheduleBatch:
|
|
668
680
|
setattr(self, item, self_val[new_indices])
|
669
681
|
|
670
682
|
def merge(self, other: "ScheduleBatch"):
|
683
|
+
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
684
|
+
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
685
|
+
# needs to be called with pre-merged Batch.reqs.
|
686
|
+
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
687
|
+
|
671
688
|
self.reqs.extend(other.reqs)
|
672
689
|
|
673
690
|
self.req_pool_indices = torch.concat(
|
674
691
|
[self.req_pool_indices, other.req_pool_indices]
|
675
692
|
)
|
676
693
|
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
677
|
-
self.prefix_lens = None
|
678
694
|
self.position_ids_offsets = torch.concat(
|
679
695
|
[self.position_ids_offsets, other.position_ids_offsets]
|
680
696
|
)
|
@@ -686,8 +702,6 @@ class ScheduleBatch:
|
|
686
702
|
"temperatures",
|
687
703
|
"top_ps",
|
688
704
|
"top_ks",
|
689
|
-
"frequency_penalties",
|
690
|
-
"presence_penalties",
|
691
705
|
]:
|
692
706
|
self_val = getattr(self, item, None)
|
693
707
|
other_val = getattr(other, item, None)
|
@@ -711,6 +725,7 @@ class ScheduleBatch:
|
|
711
725
|
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
712
726
|
|
713
727
|
def sample(self, logits: torch.Tensor):
|
728
|
+
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
|
714
729
|
# Post process logits
|
715
730
|
logits = logits.contiguous()
|
716
731
|
logits.div_(self.temperatures)
|
@@ -728,7 +743,8 @@ class ScheduleBatch:
|
|
728
743
|
] = 1
|
729
744
|
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
730
745
|
|
731
|
-
|
746
|
+
logits = self.penalizer_orchestrator.apply(logits)
|
747
|
+
|
732
748
|
probs = torch.softmax(logits, dim=-1)
|
733
749
|
|
734
750
|
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
@@ -761,6 +777,8 @@ class ScheduleBatch:
|
|
761
777
|
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
762
778
|
)
|
763
779
|
|
780
|
+
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
|
781
|
+
|
764
782
|
return batch_next_token_ids
|
765
783
|
|
766
784
|
|
@@ -780,7 +798,7 @@ def top_k_top_p_sampling_from_probs_torch(
|
|
780
798
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
781
799
|
except RuntimeError:
|
782
800
|
batch_next_token_ids = torch.zeros(
|
783
|
-
(probs_sort.shape[0],), dtype=torch.
|
801
|
+
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
784
802
|
)
|
785
803
|
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
786
804
|
return batch_next_token_ids, success
|