sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -39,7 +39,6 @@ class SchedulerOutputProcessorMixin:
|
|
39
39
|
self: Scheduler,
|
40
40
|
batch: ScheduleBatch,
|
41
41
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
42
|
-
launch_done: Optional[threading.Event] = None,
|
43
42
|
):
|
44
43
|
skip_stream_req = None
|
45
44
|
|
@@ -49,29 +48,29 @@ class SchedulerOutputProcessorMixin:
|
|
49
48
|
next_token_ids,
|
50
49
|
extend_input_len_per_req,
|
51
50
|
extend_logprob_start_len_per_req,
|
51
|
+
copy_done,
|
52
52
|
) = (
|
53
53
|
result.logits_output,
|
54
54
|
result.next_token_ids,
|
55
55
|
result.extend_input_len_per_req,
|
56
56
|
result.extend_logprob_start_len_per_req,
|
57
|
+
result.copy_done,
|
57
58
|
)
|
58
59
|
|
59
|
-
if
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
logits_output.input_token_logprobs.tolist()
|
74
|
-
)
|
60
|
+
if copy_done is not None:
|
61
|
+
copy_done.synchronize()
|
62
|
+
|
63
|
+
# Move next_token_ids and logprobs to cpu
|
64
|
+
next_token_ids = next_token_ids.tolist()
|
65
|
+
if batch.return_logprob:
|
66
|
+
if logits_output.next_token_logprobs is not None:
|
67
|
+
logits_output.next_token_logprobs = (
|
68
|
+
logits_output.next_token_logprobs.tolist()
|
69
|
+
)
|
70
|
+
if logits_output.input_token_logprobs is not None:
|
71
|
+
logits_output.input_token_logprobs = tuple(
|
72
|
+
logits_output.input_token_logprobs.tolist()
|
73
|
+
)
|
75
74
|
|
76
75
|
hidden_state_offset = 0
|
77
76
|
|
@@ -105,7 +104,10 @@ class SchedulerOutputProcessorMixin:
|
|
105
104
|
assert extend_input_len_per_req is not None
|
106
105
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
107
106
|
extend_input_len = extend_input_len_per_req[i]
|
108
|
-
|
107
|
+
|
108
|
+
num_input_logprobs = self._calculate_num_input_logprobs(
|
109
|
+
req, extend_input_len, extend_logprob_start_len
|
110
|
+
)
|
109
111
|
|
110
112
|
if req.return_logprob:
|
111
113
|
self.add_logprob_return_values(
|
@@ -160,8 +162,8 @@ class SchedulerOutputProcessorMixin:
|
|
160
162
|
extend_input_len = extend_input_len_per_req[i]
|
161
163
|
if extend_logprob_start_len < extend_input_len:
|
162
164
|
# Update input logprobs.
|
163
|
-
num_input_logprobs = (
|
164
|
-
extend_input_len
|
165
|
+
num_input_logprobs = self._calculate_num_input_logprobs(
|
166
|
+
req, extend_input_len, extend_logprob_start_len
|
165
167
|
)
|
166
168
|
if req.return_logprob:
|
167
169
|
self.add_input_logprob_return_values(
|
@@ -174,8 +176,6 @@ class SchedulerOutputProcessorMixin:
|
|
174
176
|
)
|
175
177
|
logprob_pt += num_input_logprobs
|
176
178
|
|
177
|
-
self.set_next_batch_sampling_info_done(batch)
|
178
|
-
|
179
179
|
else: # embedding or reward model
|
180
180
|
embeddings = result.embeddings.tolist()
|
181
181
|
|
@@ -204,22 +204,19 @@ class SchedulerOutputProcessorMixin:
|
|
204
204
|
self: Scheduler,
|
205
205
|
batch: ScheduleBatch,
|
206
206
|
result: GenerationBatchResult,
|
207
|
-
launch_done: Optional[threading.Event] = None,
|
208
207
|
):
|
209
|
-
logits_output, next_token_ids, can_run_cuda_graph = (
|
208
|
+
logits_output, next_token_ids, can_run_cuda_graph, copy_done = (
|
210
209
|
result.logits_output,
|
211
210
|
result.next_token_ids,
|
212
211
|
result.can_run_cuda_graph,
|
212
|
+
result.copy_done,
|
213
213
|
)
|
214
214
|
self.num_generated_tokens += len(batch.reqs)
|
215
215
|
|
216
|
-
if
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
next_token_logprobs = logits_output.next_token_logprobs
|
221
|
-
elif batch.spec_algorithm.is_none():
|
222
|
-
# spec decoding handles output logprobs inside verify process.
|
216
|
+
if copy_done is not None:
|
217
|
+
copy_done.synchronize()
|
218
|
+
|
219
|
+
if batch.spec_algorithm.is_none():
|
223
220
|
next_token_ids = next_token_ids.tolist()
|
224
221
|
if batch.return_logprob:
|
225
222
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
@@ -299,7 +296,6 @@ class SchedulerOutputProcessorMixin:
|
|
299
296
|
self.abort_request(AbortReq(rid=req.rid))
|
300
297
|
req.grammar.finished = req.finished()
|
301
298
|
|
302
|
-
self.set_next_batch_sampling_info_done(batch)
|
303
299
|
self.stream_output(batch.reqs, batch.return_logprob)
|
304
300
|
self.token_to_kv_pool_allocator.free_group_end()
|
305
301
|
|
@@ -310,6 +306,153 @@ class SchedulerOutputProcessorMixin:
|
|
310
306
|
):
|
311
307
|
self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
|
312
308
|
|
309
|
+
def _process_input_token_logprobs(
|
310
|
+
self, req: Req, input_token_logprobs: List
|
311
|
+
) -> None:
|
312
|
+
"""Process input token logprobs values and indices."""
|
313
|
+
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
314
|
+
|
315
|
+
# Process logprob values - handle multi-item scoring vs regular requests
|
316
|
+
if is_multi_item_scoring:
|
317
|
+
# Multi-item scoring: use all logprobs as-is
|
318
|
+
req.input_token_logprobs_val = input_token_logprobs
|
319
|
+
else:
|
320
|
+
# Regular request: add None at start, remove last (sampling token)
|
321
|
+
req.input_token_logprobs_val = [None] + input_token_logprobs[:-1]
|
322
|
+
|
323
|
+
# Process logprob indices based on scoring type
|
324
|
+
if is_multi_item_scoring:
|
325
|
+
# Multi-item scoring: only include delimiter token positions
|
326
|
+
relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
|
327
|
+
input_token_logprobs_idx = [
|
328
|
+
token_id
|
329
|
+
for token_id in relevant_tokens
|
330
|
+
if token_id == self.server_args.multi_item_scoring_delimiter
|
331
|
+
]
|
332
|
+
else:
|
333
|
+
# Regular request: include all tokens from logprob_start_len onwards
|
334
|
+
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
335
|
+
|
336
|
+
# Clip padded hash values from image tokens to prevent detokenization errors
|
337
|
+
req.input_token_logprobs_idx = [
|
338
|
+
x if x < self.model_config.vocab_size - 1 else 0
|
339
|
+
for x in input_token_logprobs_idx
|
340
|
+
]
|
341
|
+
|
342
|
+
def _process_input_top_logprobs(self, req: Req) -> None:
|
343
|
+
"""Process input top logprobs."""
|
344
|
+
if req.top_logprobs_num <= 0:
|
345
|
+
return
|
346
|
+
|
347
|
+
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
348
|
+
|
349
|
+
# Initialize arrays - multi-item scoring starts empty, others start with None
|
350
|
+
req.input_top_logprobs_val = [] if is_multi_item_scoring else [None]
|
351
|
+
req.input_top_logprobs_idx = [] if is_multi_item_scoring else [None]
|
352
|
+
|
353
|
+
# Extend arrays with temp values
|
354
|
+
for val, idx in zip(
|
355
|
+
req.temp_input_top_logprobs_val,
|
356
|
+
req.temp_input_top_logprobs_idx,
|
357
|
+
strict=True,
|
358
|
+
):
|
359
|
+
req.input_top_logprobs_val.extend(val)
|
360
|
+
req.input_top_logprobs_idx.extend(idx)
|
361
|
+
|
362
|
+
# Remove last token (sampling token) for non multi-item scoring requests
|
363
|
+
if not is_multi_item_scoring:
|
364
|
+
req.input_top_logprobs_val.pop()
|
365
|
+
req.input_top_logprobs_idx.pop()
|
366
|
+
|
367
|
+
# Clean up temp storage
|
368
|
+
req.temp_input_top_logprobs_idx = None
|
369
|
+
req.temp_input_top_logprobs_val = None
|
370
|
+
|
371
|
+
def _process_input_token_ids_logprobs(self, req: Req) -> None:
|
372
|
+
"""Process input token IDs logprobs."""
|
373
|
+
if req.token_ids_logprob is None:
|
374
|
+
return
|
375
|
+
|
376
|
+
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
377
|
+
|
378
|
+
# Initialize arrays - multi-item scoring starts empty, others start with None
|
379
|
+
req.input_token_ids_logprobs_val = [] if is_multi_item_scoring else [None]
|
380
|
+
req.input_token_ids_logprobs_idx = [] if is_multi_item_scoring else [None]
|
381
|
+
|
382
|
+
# Process temp values - convert tensors to lists and extend arrays
|
383
|
+
for val, idx in zip(
|
384
|
+
req.temp_input_token_ids_logprobs_val,
|
385
|
+
req.temp_input_token_ids_logprobs_idx,
|
386
|
+
strict=True,
|
387
|
+
):
|
388
|
+
val_list = val.tolist() if isinstance(val, torch.Tensor) else val
|
389
|
+
req.input_token_ids_logprobs_val.extend(
|
390
|
+
val_list if isinstance(val_list, list) else [val_list]
|
391
|
+
)
|
392
|
+
req.input_token_ids_logprobs_idx.extend(idx)
|
393
|
+
|
394
|
+
# Remove last token (sampling token) for non multi-item scoring requests
|
395
|
+
if not is_multi_item_scoring:
|
396
|
+
req.input_token_ids_logprobs_val.pop()
|
397
|
+
req.input_token_ids_logprobs_idx.pop()
|
398
|
+
|
399
|
+
# Clean up temp storage
|
400
|
+
req.temp_input_token_ids_logprobs_idx = None
|
401
|
+
req.temp_input_token_ids_logprobs_val = None
|
402
|
+
|
403
|
+
def _calculate_relevant_tokens_len(self, req: Req) -> int:
|
404
|
+
"""Calculate the expected length of logprob arrays based on whether multi-item scoring is enabled.
|
405
|
+
|
406
|
+
For multi-item scoring, only delimiter positions have logprobs.
|
407
|
+
For regular requests, all positions from logprob_start_len onwards have logprobs.
|
408
|
+
"""
|
409
|
+
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
410
|
+
|
411
|
+
if is_multi_item_scoring:
|
412
|
+
# Multi-item scoring: count delimiter tokens from logprob_start_len onwards
|
413
|
+
relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
|
414
|
+
return sum(
|
415
|
+
1
|
416
|
+
for token_id in relevant_tokens
|
417
|
+
if token_id == self.server_args.multi_item_scoring_delimiter
|
418
|
+
)
|
419
|
+
else:
|
420
|
+
# Regular request: all tokens from logprob_start_len onwards
|
421
|
+
return len(req.origin_input_ids) - req.logprob_start_len
|
422
|
+
|
423
|
+
def _calculate_num_input_logprobs(
|
424
|
+
self, req: Req, extend_input_len: int, extend_logprob_start_len: int
|
425
|
+
) -> int:
|
426
|
+
"""Calculate the number of input logprobs based on whether multi-item scoring is enabled.
|
427
|
+
|
428
|
+
For multi-item scoring, only delimiter positions have logprobs.
|
429
|
+
For regular requests, all positions in the range have logprobs.
|
430
|
+
"""
|
431
|
+
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
432
|
+
|
433
|
+
if is_multi_item_scoring:
|
434
|
+
# Multi-item scoring: count delimiter tokens in the relevant portion
|
435
|
+
relevant_tokens = req.origin_input_ids[
|
436
|
+
extend_logprob_start_len:extend_input_len
|
437
|
+
]
|
438
|
+
return sum(
|
439
|
+
1
|
440
|
+
for token_id in relevant_tokens
|
441
|
+
if token_id == self.server_args.multi_item_scoring_delimiter
|
442
|
+
)
|
443
|
+
else:
|
444
|
+
# Regular request: all tokens in the range
|
445
|
+
return extend_input_len - extend_logprob_start_len
|
446
|
+
|
447
|
+
def _is_multi_item_scoring(self, req: Req) -> bool:
|
448
|
+
"""Check if request uses multi-item scoring.
|
449
|
+
|
450
|
+
Multi-item scoring applies to prefill-only requests when a delimiter
|
451
|
+
token is configured. In this mode, only positions containing the
|
452
|
+
delimiter token receive logprobs.
|
453
|
+
"""
|
454
|
+
return req.is_prefill_only and self.server_args.multi_item_scoring_delimiter
|
455
|
+
|
313
456
|
def add_input_logprob_return_values(
|
314
457
|
self: Scheduler,
|
315
458
|
i: int,
|
@@ -378,63 +521,14 @@ class SchedulerOutputProcessorMixin:
|
|
378
521
|
assert req.input_top_logprobs_val is None
|
379
522
|
assert req.input_top_logprobs_idx is None
|
380
523
|
|
381
|
-
#
|
382
|
-
|
383
|
-
req
|
384
|
-
req.input_token_logprobs_val.extend(input_token_logprobs)
|
385
|
-
# The last input logprob is for sampling, so just pop it out.
|
386
|
-
req.input_token_logprobs_val.pop()
|
524
|
+
# Process all input logprob types using helper functions
|
525
|
+
self._process_input_token_logprobs(req, input_token_logprobs)
|
526
|
+
self._process_input_top_logprobs(req)
|
387
527
|
|
388
|
-
|
389
|
-
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
390
|
-
# Clip the padded hash values from image tokens.
|
391
|
-
# Otherwise, it will lead to detokenization errors.
|
392
|
-
input_token_logprobs_idx = [
|
393
|
-
x if x < self.model_config.vocab_size - 1 else 0
|
394
|
-
for x in input_token_logprobs_idx
|
395
|
-
]
|
396
|
-
req.input_token_logprobs_idx = input_token_logprobs_idx
|
397
|
-
|
398
|
-
if req.top_logprobs_num > 0:
|
399
|
-
req.input_top_logprobs_val = [None]
|
400
|
-
req.input_top_logprobs_idx = [None]
|
401
|
-
assert len(req.temp_input_token_ids_logprobs_val) == len(
|
402
|
-
req.temp_input_token_ids_logprobs_idx
|
403
|
-
)
|
404
|
-
for val, idx in zip(
|
405
|
-
req.temp_input_top_logprobs_val,
|
406
|
-
req.temp_input_top_logprobs_idx,
|
407
|
-
strict=True,
|
408
|
-
):
|
409
|
-
req.input_top_logprobs_val.extend(val)
|
410
|
-
req.input_top_logprobs_idx.extend(idx)
|
411
|
-
|
412
|
-
# Last token is a sample token.
|
413
|
-
req.input_top_logprobs_val.pop()
|
414
|
-
req.input_top_logprobs_idx.pop()
|
415
|
-
req.temp_input_top_logprobs_idx = None
|
416
|
-
req.temp_input_top_logprobs_val = None
|
417
|
-
|
418
|
-
if req.token_ids_logprob is not None:
|
419
|
-
req.input_token_ids_logprobs_val = [None]
|
420
|
-
req.input_token_ids_logprobs_idx = [None]
|
421
|
-
|
422
|
-
for val, idx in zip(
|
423
|
-
req.temp_input_token_ids_logprobs_val,
|
424
|
-
req.temp_input_token_ids_logprobs_idx,
|
425
|
-
strict=True,
|
426
|
-
):
|
427
|
-
req.input_token_ids_logprobs_val.extend(val)
|
428
|
-
req.input_token_ids_logprobs_idx.extend(idx)
|
429
|
-
|
430
|
-
# Last token is a sample token.
|
431
|
-
req.input_token_ids_logprobs_val.pop()
|
432
|
-
req.input_token_ids_logprobs_idx.pop()
|
433
|
-
req.temp_input_token_ids_logprobs_idx = None
|
434
|
-
req.temp_input_token_ids_logprobs_val = None
|
528
|
+
self._process_input_token_ids_logprobs(req)
|
435
529
|
|
436
530
|
if req.return_logprob:
|
437
|
-
relevant_tokens_len =
|
531
|
+
relevant_tokens_len = self._calculate_relevant_tokens_len(req)
|
438
532
|
assert len(req.input_token_logprobs_val) == relevant_tokens_len
|
439
533
|
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
|
440
534
|
if req.top_logprobs_num > 0:
|