sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +13 -1
- sglang/bench_latency.py +10 -5
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +5 -2
- sglang/lang/ir.py +22 -4
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -2
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +64 -27
- sglang/srt/layers/radix_attention.py +41 -18
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +59 -179
- sglang/srt/managers/tokenizer_manager.py +193 -84
- sglang/srt/managers/tp_worker.py +131 -50
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +97 -28
- sglang/srt/model_executor/forward_batch_info.py +188 -82
- sglang/srt/model_executor/model_runner.py +269 -87
- sglang/srt/models/chatglm.py +6 -14
- sglang/srt/models/commandr.py +6 -2
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +7 -3
- sglang/srt/models/deepseek_v2.py +12 -7
- sglang/srt/models/gemma.py +6 -2
- sglang/srt/models/gemma2.py +22 -8
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +66 -398
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +7 -3
- sglang/srt/models/mixtral.py +61 -255
- sglang/srt/models/mixtral_quant.py +6 -5
- sglang/srt/models/qwen.py +7 -4
- sglang/srt/models/qwen2.py +15 -5
- sglang/srt/models/qwen2_moe.py +7 -16
- sglang/srt/models/stablelm.py +6 -2
- sglang/srt/openai_api/adapter.py +149 -58
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
- sglang/srt/server.py +107 -71
- sglang/srt/server_args.py +49 -15
- sglang/srt/utils.py +27 -18
- sglang/test/runners.py +38 -38
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_programs.py +32 -5
- sglang/test/test_utils.py +37 -50
- sglang/version.py +1 -1
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.12.dist-info/RECORD +0 -112
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
sglang/srt/managers/io_struct.py
CHANGED
@@ -22,10 +22,8 @@ import uuid
|
|
22
22
|
from dataclasses import dataclass
|
23
23
|
from typing import Dict, List, Optional, Union
|
24
24
|
|
25
|
-
import torch
|
26
|
-
|
27
25
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
28
|
-
from sglang.srt.sampling_params import SamplingParams
|
26
|
+
from sglang.srt.sampling.sampling_params import SamplingParams
|
29
27
|
|
30
28
|
|
31
29
|
@dataclass
|
@@ -43,9 +41,9 @@ class GenerateReqInput:
|
|
43
41
|
rid: Optional[Union[List[str], str]] = None
|
44
42
|
# Whether to return logprobs.
|
45
43
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
46
|
-
#
|
44
|
+
# If return logprobs, the start location in the prompt for returning logprobs.
|
47
45
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
48
|
-
#
|
46
|
+
# If return logprobs, the number of top logprobs to return at each position.
|
49
47
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
50
48
|
# Whether to detokenize tokens in text in the returned logprobs.
|
51
49
|
return_text_in_logprobs: bool = False
|
@@ -77,7 +75,7 @@ class GenerateReqInput:
|
|
77
75
|
if self.return_logprob is None:
|
78
76
|
self.return_logprob = False
|
79
77
|
if self.logprob_start_len is None:
|
80
|
-
self.logprob_start_len =
|
78
|
+
self.logprob_start_len = -1
|
81
79
|
if self.top_logprobs_num is None:
|
82
80
|
self.top_logprobs_num = 0
|
83
81
|
else:
|
@@ -143,7 +141,7 @@ class GenerateReqInput:
|
|
143
141
|
self.return_logprob = [self.return_logprob] * num
|
144
142
|
|
145
143
|
if self.logprob_start_len is None:
|
146
|
-
self.logprob_start_len = [
|
144
|
+
self.logprob_start_len = [-1] * num
|
147
145
|
elif not isinstance(self.logprob_start_len, list):
|
148
146
|
self.logprob_start_len = [self.logprob_start_len] * num
|
149
147
|
|
@@ -155,16 +153,27 @@ class GenerateReqInput:
|
|
155
153
|
|
156
154
|
@dataclass
|
157
155
|
class TokenizedGenerateReqInput:
|
156
|
+
# The request id
|
158
157
|
rid: str
|
158
|
+
# The input text
|
159
159
|
input_text: str
|
160
|
+
# The input token ids
|
160
161
|
input_ids: List[int]
|
162
|
+
# The pixel values for input images
|
161
163
|
pixel_values: List[float]
|
164
|
+
# The hash of input images
|
162
165
|
image_hash: int
|
166
|
+
# The image size
|
163
167
|
image_size: List[int]
|
168
|
+
# The sampling parameters
|
164
169
|
sampling_params: SamplingParams
|
170
|
+
# Whether to return the logprobs
|
165
171
|
return_logprob: bool
|
172
|
+
# If return logprobs, the start location in the prompt for returning logprobs.
|
166
173
|
logprob_start_len: int
|
174
|
+
# If return logprobs, the number of top logprobs to return at each position.
|
167
175
|
top_logprobs_num: int
|
176
|
+
# Whether to stream output
|
168
177
|
stream: bool
|
169
178
|
|
170
179
|
|
@@ -215,15 +224,21 @@ class EmbeddingReqInput:
|
|
215
224
|
|
216
225
|
@dataclass
|
217
226
|
class TokenizedEmbeddingReqInput:
|
227
|
+
# The request id
|
218
228
|
rid: str
|
229
|
+
# The input text
|
219
230
|
input_text: str
|
231
|
+
# The input token ids
|
220
232
|
input_ids: List[int]
|
233
|
+
# Dummy sampling params for compatibility
|
221
234
|
sampling_params: SamplingParams
|
222
235
|
|
223
236
|
|
224
237
|
@dataclass
|
225
238
|
class BatchTokenIDOut:
|
239
|
+
# The request id
|
226
240
|
rids: List[str]
|
241
|
+
# The version id to sync decode status with in detokenizer_manager
|
227
242
|
vids: List[int]
|
228
243
|
decoded_texts: List[str]
|
229
244
|
decode_ids: List[int]
|
@@ -236,17 +251,25 @@ class BatchTokenIDOut:
|
|
236
251
|
|
237
252
|
@dataclass
|
238
253
|
class BatchStrOut:
|
254
|
+
# The request id
|
239
255
|
rids: List[str]
|
256
|
+
# The output decoded strings
|
240
257
|
output_strs: List[str]
|
258
|
+
# The meta info
|
241
259
|
meta_info: List[Dict]
|
260
|
+
# The finish reason
|
242
261
|
finished_reason: List[BaseFinishReason]
|
243
262
|
|
244
263
|
|
245
264
|
@dataclass
|
246
265
|
class BatchEmbeddingOut:
|
266
|
+
# The request id
|
247
267
|
rids: List[str]
|
268
|
+
# The output embedding
|
248
269
|
embeddings: List[List[float]]
|
270
|
+
# The meta info
|
249
271
|
meta_info: List[Dict]
|
272
|
+
# The finish reason
|
250
273
|
finished_reason: List[BaseFinishReason]
|
251
274
|
|
252
275
|
|
@@ -256,10 +279,20 @@ class FlushCacheReq:
|
|
256
279
|
|
257
280
|
|
258
281
|
@dataclass
|
259
|
-
class
|
260
|
-
|
282
|
+
class UpdateWeightReqInput:
|
283
|
+
# The model path with the new weights
|
284
|
+
model_path: str
|
285
|
+
# The format to load the weights
|
286
|
+
load_format: Optional[str] = None
|
261
287
|
|
262
288
|
|
263
289
|
@dataclass
|
264
|
-
class
|
265
|
-
|
290
|
+
class UpdateWeightReqOutput:
|
291
|
+
success: bool
|
292
|
+
message: str
|
293
|
+
|
294
|
+
|
295
|
+
@dataclass
|
296
|
+
class AbortReq:
|
297
|
+
# The request id
|
298
|
+
rid: str
|
@@ -111,11 +111,14 @@ class PrefillAdder:
|
|
111
111
|
rem_total_tokens: int,
|
112
112
|
rem_input_tokens: int,
|
113
113
|
rem_chunk_tokens: Optional[int],
|
114
|
+
mixed_with_decode_tokens: int = 0,
|
114
115
|
):
|
115
116
|
self.tree_cache = tree_cache
|
116
|
-
self.rem_total_tokens = rem_total_tokens
|
117
|
-
self.rem_input_tokens = rem_input_tokens
|
117
|
+
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
118
|
+
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
118
119
|
self.rem_chunk_tokens = rem_chunk_tokens
|
120
|
+
if self.rem_chunk_tokens is not None:
|
121
|
+
self.rem_chunk_tokens -= mixed_with_decode_tokens
|
119
122
|
|
120
123
|
self.can_run_list = []
|
121
124
|
self.new_inflight_req = None
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2023-2024 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -16,20 +18,22 @@ limitations under the License.
|
|
16
18
|
"""Meta data for requests and batches"""
|
17
19
|
|
18
20
|
import logging
|
19
|
-
import warnings
|
20
21
|
from dataclasses import dataclass
|
21
|
-
from typing import List, Optional, Union
|
22
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
22
23
|
|
23
24
|
import torch
|
24
|
-
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
25
25
|
|
26
|
-
import sglang.srt.sampling.penaltylib as penaltylib
|
27
26
|
from sglang.global_config import global_config
|
28
27
|
from sglang.srt.constrained import RegexGuide
|
29
28
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
30
29
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
31
30
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
32
31
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
32
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from sglang.srt.layers.sampler import SampleOutput
|
36
|
+
|
33
37
|
|
34
38
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
35
39
|
|
@@ -37,7 +41,7 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
|
37
41
|
global_server_args_dict = {
|
38
42
|
"disable_flashinfer": False,
|
39
43
|
"disable_flashinfer_sampling": False,
|
40
|
-
"
|
44
|
+
"triton_attention_reduce_in_fp32": False,
|
41
45
|
"enable_mla": False,
|
42
46
|
}
|
43
47
|
|
@@ -235,10 +239,12 @@ class Req:
|
|
235
239
|
return
|
236
240
|
|
237
241
|
last_token_id = self.output_ids[-1]
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
+
|
243
|
+
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
244
|
+
|
245
|
+
if self.tokenizer is not None:
|
246
|
+
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
247
|
+
|
242
248
|
if matched_eos and not self.sampling_params.ignore_eos:
|
243
249
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
244
250
|
return
|
@@ -266,7 +272,7 @@ class Req:
|
|
266
272
|
|
267
273
|
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
268
274
|
# TODO(lsyin): fix token fusion
|
269
|
-
|
275
|
+
logger.warning(
|
270
276
|
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
271
277
|
)
|
272
278
|
return False
|
@@ -325,17 +331,13 @@ class ScheduleBatch:
|
|
325
331
|
out_cache_loc: torch.Tensor = None
|
326
332
|
extend_num_tokens: int = None
|
327
333
|
|
334
|
+
# For mixed chunekd prefill
|
335
|
+
prefix_lens_cpu: List[int] = None
|
336
|
+
|
328
337
|
# For processing logprobs
|
329
338
|
return_logprob: bool = False
|
330
339
|
top_logprobs_nums: List[int] = None
|
331
340
|
|
332
|
-
# Batched sampling params
|
333
|
-
temperatures: torch.Tensor = None
|
334
|
-
top_ps: torch.Tensor = None
|
335
|
-
top_ks: torch.Tensor = None
|
336
|
-
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
337
|
-
logit_bias: torch.Tensor = None
|
338
|
-
|
339
341
|
@classmethod
|
340
342
|
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
341
343
|
return_logprob = any(req.return_logprob for req in reqs)
|
@@ -383,51 +385,7 @@ class ScheduleBatch:
|
|
383
385
|
|
384
386
|
return out_cache_loc
|
385
387
|
|
386
|
-
def
|
387
|
-
device = "cuda"
|
388
|
-
bs, reqs = self.batch_size(), self.reqs
|
389
|
-
self.temperatures = torch.tensor(
|
390
|
-
[r.sampling_params.temperature for r in reqs],
|
391
|
-
dtype=torch.float,
|
392
|
-
device=device,
|
393
|
-
).view(-1, 1)
|
394
|
-
self.top_ps = torch.tensor(
|
395
|
-
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
396
|
-
)
|
397
|
-
self.top_ks = torch.tensor(
|
398
|
-
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
399
|
-
)
|
400
|
-
|
401
|
-
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
402
|
-
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
403
|
-
# should not add hefty computation overhead other than simple checks.
|
404
|
-
#
|
405
|
-
# While we choose not to even create the class instances if they are not required, this
|
406
|
-
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
407
|
-
# handle {filter_batch()} and {merge()} cases as well.
|
408
|
-
self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
409
|
-
vocab_size=vocab_size,
|
410
|
-
batch=self,
|
411
|
-
device=device,
|
412
|
-
Penalizers={
|
413
|
-
penaltylib.BatchedFrequencyPenalizer,
|
414
|
-
penaltylib.BatchedMinNewTokensPenalizer,
|
415
|
-
penaltylib.BatchedPresencePenalizer,
|
416
|
-
penaltylib.BatchedRepetitionPenalizer,
|
417
|
-
},
|
418
|
-
)
|
419
|
-
|
420
|
-
# Handle logit bias but only allocate when needed
|
421
|
-
self.logit_bias = None
|
422
|
-
for i in range(bs):
|
423
|
-
if reqs[i].sampling_params.dtype == "int":
|
424
|
-
if self.logit_bias is None:
|
425
|
-
self.logit_bias = torch.zeros(
|
426
|
-
(bs, vocab_size), dtype=torch.float32, device=device
|
427
|
-
)
|
428
|
-
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
429
|
-
|
430
|
-
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
388
|
+
def prepare_for_extend(self, vocab_size: int):
|
431
389
|
bs = self.batch_size()
|
432
390
|
reqs = self.reqs
|
433
391
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
@@ -465,8 +423,32 @@ class ScheduleBatch:
|
|
465
423
|
self.extend_num_tokens = extend_num_tokens
|
466
424
|
self.out_cache_loc = out_cache_loc
|
467
425
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
426
|
+
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
|
427
|
+
|
428
|
+
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
429
|
+
|
430
|
+
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
431
|
+
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
432
|
+
prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
|
433
|
+
prefix_lens_cpu.extend(
|
434
|
+
[
|
435
|
+
len(r.origin_input_ids) + len(r.output_ids) - 1
|
436
|
+
for r in running_batch.reqs
|
437
|
+
]
|
438
|
+
)
|
439
|
+
|
440
|
+
for req in running_batch.reqs:
|
441
|
+
req.fill_ids = req.origin_input_ids + req.output_ids
|
442
|
+
req.extend_input_len = 1
|
468
443
|
|
469
|
-
self.
|
444
|
+
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
445
|
+
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
446
|
+
extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
|
447
|
+
self.merge(running_batch)
|
448
|
+
self.input_ids = input_ids
|
449
|
+
self.out_cache_loc = out_cache_loc
|
450
|
+
self.extend_num_tokens = extend_num_tokens
|
451
|
+
self.prefix_lens_cpu = prefix_lens_cpu
|
470
452
|
|
471
453
|
def check_decode_mem(self):
|
472
454
|
bs = self.batch_size()
|
@@ -639,7 +621,7 @@ class ScheduleBatch:
|
|
639
621
|
for r in self.reqs
|
640
622
|
]
|
641
623
|
else:
|
642
|
-
self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
624
|
+
self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
643
625
|
|
644
626
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
645
627
|
self.seq_lens.add_(1)
|
@@ -652,6 +634,8 @@ class ScheduleBatch:
|
|
652
634
|
self.req_pool_indices, self.seq_lens - 1
|
653
635
|
] = self.out_cache_loc
|
654
636
|
|
637
|
+
self.sampling_info.update_regex_vocab_mask(self)
|
638
|
+
|
655
639
|
def filter_batch(self, unfinished_indices: List[int]):
|
656
640
|
if unfinished_indices is None or len(unfinished_indices) == 0:
|
657
641
|
# Filter out all requests
|
@@ -672,23 +656,13 @@ class ScheduleBatch:
|
|
672
656
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
673
657
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
674
658
|
|
675
|
-
self.
|
676
|
-
|
677
|
-
for item in [
|
678
|
-
"temperatures",
|
679
|
-
"top_ps",
|
680
|
-
"top_ks",
|
681
|
-
"logit_bias",
|
682
|
-
]:
|
683
|
-
self_val = getattr(self, item, None)
|
684
|
-
if self_val is not None: # logit_bias can be None
|
685
|
-
setattr(self, item, self_val[new_indices])
|
659
|
+
self.sampling_info.filter(unfinished_indices, new_indices)
|
686
660
|
|
687
661
|
def merge(self, other: "ScheduleBatch"):
|
688
662
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
689
663
|
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
690
664
|
# needs to be called with pre-merged Batch.reqs.
|
691
|
-
self.
|
665
|
+
self.sampling_info.merge(other.sampling_info)
|
692
666
|
|
693
667
|
self.reqs.extend(other.reqs)
|
694
668
|
|
@@ -703,111 +677,17 @@ class ScheduleBatch:
|
|
703
677
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
704
678
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
705
679
|
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
self_val = getattr(self, item, None)
|
712
|
-
other_val = getattr(other, item, None)
|
713
|
-
setattr(self, item, torch.concat([self_val, other_val]))
|
714
|
-
|
715
|
-
# logit_bias can be None
|
716
|
-
if self.logit_bias is not None or other.logit_bias is not None:
|
717
|
-
vocab_size = (
|
718
|
-
self.logit_bias.shape[1]
|
719
|
-
if self.logit_bias is not None
|
720
|
-
else other.logit_bias.shape[1]
|
721
|
-
)
|
722
|
-
if self.logit_bias is None:
|
723
|
-
self.logit_bias = torch.zeros(
|
724
|
-
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
725
|
-
)
|
726
|
-
if other.logit_bias is None:
|
727
|
-
other.logit_bias = torch.zeros(
|
728
|
-
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
729
|
-
)
|
730
|
-
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
731
|
-
|
732
|
-
def sample(self, logits: torch.Tensor):
|
733
|
-
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
|
734
|
-
# Post process logits
|
735
|
-
logits = logits.contiguous()
|
736
|
-
logits.div_(self.temperatures)
|
737
|
-
if self.logit_bias is not None:
|
738
|
-
logits.add_(self.logit_bias)
|
739
|
-
|
740
|
-
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
741
|
-
if has_regex:
|
742
|
-
allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
|
743
|
-
for i, req in enumerate(self.reqs):
|
744
|
-
if req.regex_fsm is not None:
|
745
|
-
allowed_mask.zero_()
|
746
|
-
allowed_mask[
|
747
|
-
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
748
|
-
] = 1
|
749
|
-
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
750
|
-
|
751
|
-
logits = self.penalizer_orchestrator.apply(logits)
|
752
|
-
|
753
|
-
probs = torch.softmax(logits, dim=-1)
|
754
|
-
|
755
|
-
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
756
|
-
max_top_k_round, batch_size = 32, probs.shape[0]
|
757
|
-
uniform_samples = torch.rand(
|
758
|
-
(max_top_k_round, batch_size), device=probs.device
|
759
|
-
)
|
760
|
-
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
761
|
-
probs, uniform_samples, self.top_ks, self.top_ps
|
762
|
-
)
|
763
|
-
else:
|
764
|
-
# Here we provide a slower fallback implementation.
|
765
|
-
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
|
766
|
-
probs, self.top_ks, self.top_ps
|
767
|
-
)
|
768
|
-
|
769
|
-
if not torch.all(success):
|
770
|
-
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
|
680
|
+
def check_sample_results(self, sample_output: SampleOutput):
|
681
|
+
if not torch.all(sample_output.success):
|
682
|
+
probs = sample_output.probs
|
683
|
+
batch_next_token_ids = sample_output.batch_next_token_ids
|
684
|
+
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
771
685
|
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
772
686
|
argmax_ids = torch.argmax(probs, dim=-1)
|
773
687
|
batch_next_token_ids = torch.where(
|
774
|
-
success, batch_next_token_ids, argmax_ids
|
688
|
+
sample_output.success, batch_next_token_ids, argmax_ids
|
775
689
|
)
|
690
|
+
sample_output.probs = probs
|
691
|
+
sample_output.batch_next_token_ids = batch_next_token_ids
|
776
692
|
|
777
|
-
|
778
|
-
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
779
|
-
for i, req in enumerate(self.reqs):
|
780
|
-
if req.regex_fsm is not None:
|
781
|
-
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
782
|
-
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
783
|
-
)
|
784
|
-
|
785
|
-
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
|
786
|
-
|
787
|
-
return batch_next_token_ids
|
788
|
-
|
789
|
-
|
790
|
-
def top_k_top_p_sampling_from_probs_torch(
|
791
|
-
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
792
|
-
):
|
793
|
-
"""A top-k and top-k sampling implementation with native pytorch operations."""
|
794
|
-
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
795
|
-
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
796
|
-
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
797
|
-
probs_sort[
|
798
|
-
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
799
|
-
>= top_ks.view(-1, 1)
|
800
|
-
] = 0.0
|
801
|
-
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
802
|
-
try:
|
803
|
-
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
804
|
-
except RuntimeError:
|
805
|
-
batch_next_token_ids = torch.zeros(
|
806
|
-
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
807
|
-
)
|
808
|
-
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
809
|
-
return batch_next_token_ids, success
|
810
|
-
|
811
|
-
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
812
|
-
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
|
813
|
-
return batch_next_token_ids, success
|
693
|
+
return sample_output.batch_next_token_ids
|