sglang 0.4.0.post1__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +18 -6
- sglang/bench_one_batch.py +13 -0
- sglang/bench_serving.py +8 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/constrained/xgrammar_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +2 -0
- sglang/srt/layers/attention/triton_backend.py +16 -25
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/ep_moe/layer.py +4 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
- sglang/srt/layers/fused_moe_triton/layer.py +1 -1
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/quantization/__init__.py +2 -47
- sglang/srt/layers/quantization/fp8.py +58 -10
- sglang/srt/layers/radix_attention.py +8 -1
- sglang/srt/layers/sampler.py +27 -5
- sglang/srt/layers/torchao_utils.py +35 -0
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +38 -24
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +169 -134
- sglang/srt/managers/tokenizer_manager.py +99 -58
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +24 -10
- sglang/srt/model_executor/model_runner.py +22 -14
- sglang/srt/model_parallel.py +66 -5
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +72 -8
- sglang/srt/models/llama.py +22 -0
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/server.py +1 -1
- sglang/srt/server_args.py +19 -9
- sglang/srt/utils.py +7 -10
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +11 -6
- {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +54 -52
- {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,11 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
2
2
|
|
3
3
|
import logging
|
4
|
+
import os
|
4
5
|
from typing import Any, Callable, Dict, List, Optional
|
5
6
|
|
6
7
|
import torch
|
8
|
+
import torch.nn.functional as F
|
7
9
|
from torch.nn import Module
|
8
10
|
from torch.nn.parameter import Parameter
|
9
11
|
from vllm import _custom_ops as ops
|
@@ -24,11 +26,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
24
26
|
)
|
25
27
|
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
26
28
|
|
27
|
-
from sglang.srt.layers.fused_moe_triton import
|
28
|
-
FusedMoE,
|
29
|
-
FusedMoEMethodBase,
|
30
|
-
FusedMoeWeightScaleSupported,
|
31
|
-
)
|
29
|
+
from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
|
32
30
|
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
33
31
|
from sglang.srt.layers.quantization.base_config import (
|
34
32
|
QuantizationConfig,
|
@@ -100,6 +98,8 @@ class Fp8Config(QuantizationConfig):
|
|
100
98
|
) -> Optional["QuantizeMethodBase"]:
|
101
99
|
from vllm.attention.layer import Attention # Avoid circular import
|
102
100
|
|
101
|
+
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
102
|
+
|
103
103
|
if isinstance(layer, LinearBase):
|
104
104
|
if is_layer_skipped(prefix, self.ignored_layers):
|
105
105
|
return UnquantizedLinearMethod()
|
@@ -306,7 +306,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
306
306
|
)
|
307
307
|
|
308
308
|
|
309
|
-
class Fp8MoEMethod
|
309
|
+
class Fp8MoEMethod:
|
310
310
|
"""MoE method for FP8.
|
311
311
|
Supports loading FP8 checkpoints with static weight scale and
|
312
312
|
dynamic/static activation scale.
|
@@ -319,7 +319,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
319
319
|
quant_config: The quantization config.
|
320
320
|
"""
|
321
321
|
|
322
|
-
def
|
322
|
+
def __new__(cls, *args, **kwargs):
|
323
|
+
from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase
|
324
|
+
|
325
|
+
if not hasattr(cls, "_initialized"):
|
326
|
+
original_init = cls.__init__
|
327
|
+
new_cls = type(
|
328
|
+
cls.__name__,
|
329
|
+
(FusedMoEMethodBase,),
|
330
|
+
{
|
331
|
+
"__init__": original_init,
|
332
|
+
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
333
|
+
},
|
334
|
+
)
|
335
|
+
obj = super(new_cls, new_cls).__new__(new_cls)
|
336
|
+
obj.__init__(*args, **kwargs)
|
337
|
+
return obj
|
338
|
+
return super().__new__(cls)
|
339
|
+
|
340
|
+
def __init__(self, quant_config):
|
323
341
|
self.quant_config = quant_config
|
324
342
|
|
325
343
|
def create_weights(
|
@@ -331,6 +349,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
331
349
|
params_dtype: torch.dtype,
|
332
350
|
**extra_weight_attrs,
|
333
351
|
):
|
352
|
+
from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported
|
334
353
|
|
335
354
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
336
355
|
params_dtype = torch.float8_e4m3fn
|
@@ -404,7 +423,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
404
423
|
|
405
424
|
def process_weights_after_loading(self, layer: Module) -> None:
|
406
425
|
|
407
|
-
# If checkpoint is fp16, quantize in place.
|
426
|
+
# If checkpoint is fp16 or bfloat16, quantize in place.
|
408
427
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
409
428
|
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
410
429
|
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
@@ -428,6 +447,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
428
447
|
)
|
429
448
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
430
449
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
450
|
+
|
451
|
+
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
452
|
+
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
453
|
+
layer.w13_weight = torch.nn.Parameter(
|
454
|
+
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
455
|
+
requires_grad=False,
|
456
|
+
)
|
457
|
+
torch.cuda.empty_cache()
|
458
|
+
layer.w2_weight = torch.nn.Parameter(
|
459
|
+
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
460
|
+
requires_grad=False,
|
461
|
+
)
|
462
|
+
torch.cuda.empty_cache()
|
431
463
|
return
|
432
464
|
|
433
465
|
# If checkpoint is fp8, we need to handle that the
|
@@ -456,6 +488,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
456
488
|
layer.w2_input_scale = torch.nn.Parameter(
|
457
489
|
layer.w2_input_scale.max(), requires_grad=False
|
458
490
|
)
|
491
|
+
|
459
492
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
460
493
|
if is_hip():
|
461
494
|
# Normalize the weights and scales
|
@@ -507,6 +540,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
507
540
|
layer.w13_weight_scale = torch.nn.Parameter(
|
508
541
|
max_w13_scales, requires_grad=False
|
509
542
|
)
|
543
|
+
|
544
|
+
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
545
|
+
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
546
|
+
layer.w13_weight = torch.nn.Parameter(
|
547
|
+
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
548
|
+
requires_grad=False,
|
549
|
+
)
|
550
|
+
torch.cuda.empty_cache()
|
551
|
+
layer.w2_weight = torch.nn.Parameter(
|
552
|
+
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
553
|
+
requires_grad=False,
|
554
|
+
)
|
555
|
+
torch.cuda.empty_cache()
|
510
556
|
return
|
511
557
|
|
512
558
|
def apply(
|
@@ -521,9 +567,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
521
567
|
num_expert_group: Optional[int] = None,
|
522
568
|
custom_routing_function: Optional[Callable] = None,
|
523
569
|
) -> torch.Tensor:
|
570
|
+
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
571
|
+
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
524
572
|
|
525
|
-
|
526
|
-
|
573
|
+
# Expert selection
|
527
574
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
528
575
|
hidden_states=x,
|
529
576
|
router_logits=router_logits,
|
@@ -535,6 +582,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
535
582
|
custom_routing_function=custom_routing_function,
|
536
583
|
)
|
537
584
|
|
585
|
+
# Expert fusion with FP8 quantization
|
538
586
|
return fused_experts(
|
539
587
|
x,
|
540
588
|
layer.w13_weight,
|
@@ -48,7 +48,14 @@ class RadixAttention(nn.Module):
|
|
48
48
|
self.sliding_window_size = sliding_window_size or -1
|
49
49
|
self.is_cross_attention = is_cross_attention
|
50
50
|
|
51
|
-
def forward(
|
51
|
+
def forward(
|
52
|
+
self,
|
53
|
+
q,
|
54
|
+
k,
|
55
|
+
v,
|
56
|
+
forward_batch: ForwardBatch,
|
57
|
+
save_kv_cache: bool = True,
|
58
|
+
):
|
52
59
|
if k is not None:
|
53
60
|
# For cross-layer sharing, kv can be None
|
54
61
|
assert v is not None
|
sglang/srt/layers/sampler.py
CHANGED
@@ -51,7 +51,6 @@ class Sampler(nn.Module):
|
|
51
51
|
# Post process logits
|
52
52
|
logits.div_(sampling_info.temperatures)
|
53
53
|
probs = torch.softmax(logits, dim=-1)
|
54
|
-
logits = None
|
55
54
|
del logits
|
56
55
|
|
57
56
|
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
@@ -84,6 +83,7 @@ class Sampler(nn.Module):
|
|
84
83
|
sampling_info.top_ks,
|
85
84
|
sampling_info.top_ps,
|
86
85
|
sampling_info.min_ps,
|
86
|
+
sampling_info.need_min_p_sampling,
|
87
87
|
)
|
88
88
|
else:
|
89
89
|
raise ValueError(
|
@@ -98,20 +98,42 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
98
98
|
top_ks: torch.Tensor,
|
99
99
|
top_ps: torch.Tensor,
|
100
100
|
min_ps: torch.Tensor,
|
101
|
+
need_min_p_sampling: bool,
|
101
102
|
):
|
102
103
|
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
|
103
104
|
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
104
105
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
105
|
-
min_p_thresholds = probs_sort[:, 0] * min_ps
|
106
|
-
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
107
106
|
probs_sort[
|
108
107
|
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
109
108
|
>= top_ks.view(-1, 1)
|
110
109
|
] = 0.0
|
111
|
-
probs_sort[probs_sort
|
112
|
-
|
110
|
+
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
111
|
+
|
112
|
+
if need_min_p_sampling:
|
113
|
+
min_p_thresholds = probs_sort[:, 0] * min_ps
|
114
|
+
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
115
|
+
|
113
116
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
114
117
|
# int32 range is enough to represent the token ids
|
115
118
|
probs_idx = probs_idx.to(torch.int32)
|
116
119
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
117
120
|
return batch_next_token_ids
|
121
|
+
|
122
|
+
|
123
|
+
def top_p_normalize_probs(
|
124
|
+
probs: torch.Tensor,
|
125
|
+
top_ps: torch.Tensor,
|
126
|
+
):
|
127
|
+
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
128
|
+
return top_p_renorm_prob(probs, top_ps)
|
129
|
+
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
130
|
+
# See also top_k_top_p_min_p_sampling_from_probs_torch
|
131
|
+
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
132
|
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
133
|
+
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
134
|
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
135
|
+
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
|
136
|
+
else:
|
137
|
+
raise ValueError(
|
138
|
+
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
139
|
+
)
|
@@ -47,6 +47,41 @@ def apply_torchao_config_to_model(
|
|
47
47
|
256,
|
48
48
|
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
|
49
49
|
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
|
50
|
+
elif "gemlite" in torchao_config:
|
51
|
+
# gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
|
52
|
+
# gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
|
53
|
+
import os
|
54
|
+
import pwd
|
55
|
+
|
56
|
+
import gemlite
|
57
|
+
from gemlite.core import GemLiteLinearTriton, set_autotune
|
58
|
+
|
59
|
+
try:
|
60
|
+
from torchao.quantization import gemlite_uintx_weight_only
|
61
|
+
except:
|
62
|
+
print(
|
63
|
+
f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization"
|
64
|
+
)
|
65
|
+
return model
|
66
|
+
|
67
|
+
_quant_args = torchao_config.split("-")
|
68
|
+
bit_width = int(_quant_args[-2])
|
69
|
+
group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
|
70
|
+
try:
|
71
|
+
packing_bitwidth = int(_quant_args[-3])
|
72
|
+
except:
|
73
|
+
# if only 2 inputs found, use default value
|
74
|
+
packing_bitwidth = 32
|
75
|
+
|
76
|
+
quantize_(
|
77
|
+
model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth)
|
78
|
+
)
|
79
|
+
|
80
|
+
# try to load gemlite kernel config
|
81
|
+
GemLiteLinearTriton.load_config(
|
82
|
+
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
|
83
|
+
)
|
84
|
+
|
50
85
|
elif "fp8wo" in torchao_config:
|
51
86
|
# this requires newer hardware
|
52
87
|
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
@@ -17,9 +17,10 @@ import dataclasses
|
|
17
17
|
import logging
|
18
18
|
import signal
|
19
19
|
from collections import OrderedDict
|
20
|
-
from typing import List, Union
|
20
|
+
from typing import Dict, List, Union
|
21
21
|
|
22
22
|
import psutil
|
23
|
+
import setproctitle
|
23
24
|
import zmq
|
24
25
|
|
25
26
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
@@ -28,7 +29,6 @@ from sglang.srt.managers.io_struct import (
|
|
28
29
|
BatchStrOut,
|
29
30
|
BatchTokenIDOut,
|
30
31
|
)
|
31
|
-
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
32
32
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
33
33
|
from sglang.srt.utils import configure_logger, get_zmq_socket
|
34
34
|
from sglang.utils import find_printable_text, get_exception_traceback
|
@@ -75,17 +75,25 @@ class DetokenizerManager:
|
|
75
75
|
|
76
76
|
self.decode_status = LimitedCapacityDict()
|
77
77
|
|
78
|
-
def
|
79
|
-
|
78
|
+
def trim_matched_stop(
|
79
|
+
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
80
|
+
):
|
81
|
+
if no_stop_trim or not finished_reason:
|
82
|
+
return output
|
83
|
+
|
84
|
+
matched = finished_reason.get("matched", None)
|
85
|
+
if not matched:
|
80
86
|
return output
|
81
87
|
|
82
|
-
#
|
83
|
-
|
84
|
-
|
88
|
+
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
89
|
+
|
90
|
+
# Trim stop str.
|
91
|
+
if isinstance(matched, str) and isinstance(output, str):
|
92
|
+
pos = output.find(matched)
|
85
93
|
return output[:pos] if pos != -1 else output
|
86
|
-
|
87
|
-
|
88
|
-
):
|
94
|
+
|
95
|
+
# Trim stop token.
|
96
|
+
if isinstance(matched, int) and isinstance(output, list):
|
89
97
|
assert len(output) > 0
|
90
98
|
return output[:-1]
|
91
99
|
return output
|
@@ -124,9 +132,9 @@ class DetokenizerManager:
|
|
124
132
|
s.decode_ids = recv_obj.decode_ids[i]
|
125
133
|
|
126
134
|
read_ids.append(
|
127
|
-
self.
|
135
|
+
self.trim_matched_stop(
|
128
136
|
s.decode_ids[s.surr_offset :],
|
129
|
-
recv_obj.
|
137
|
+
recv_obj.finished_reasons[i],
|
130
138
|
recv_obj.no_stop_trim[i],
|
131
139
|
)
|
132
140
|
)
|
@@ -149,7 +157,7 @@ class DetokenizerManager:
|
|
149
157
|
for i in range(bs):
|
150
158
|
s = self.decode_status[recv_obj.rids[i]]
|
151
159
|
new_text = read_texts[i][len(surr_texts[i]) :]
|
152
|
-
if recv_obj.
|
160
|
+
if recv_obj.finished_reasons[i] is None:
|
153
161
|
# Streaming chunk: update the decode status
|
154
162
|
if len(new_text) > 0 and not new_text.endswith("�"):
|
155
163
|
s.decoded_text = s.decoded_text + new_text
|
@@ -160,9 +168,9 @@ class DetokenizerManager:
|
|
160
168
|
new_text = find_printable_text(new_text)
|
161
169
|
|
162
170
|
output_strs.append(
|
163
|
-
self.
|
171
|
+
self.trim_matched_stop(
|
164
172
|
s.decoded_text + new_text,
|
165
|
-
recv_obj.
|
173
|
+
recv_obj.finished_reasons[i],
|
166
174
|
recv_obj.no_stop_trim[i],
|
167
175
|
)
|
168
176
|
)
|
@@ -170,9 +178,20 @@ class DetokenizerManager:
|
|
170
178
|
self.send_to_tokenizer.send_pyobj(
|
171
179
|
BatchStrOut(
|
172
180
|
rids=recv_obj.rids,
|
181
|
+
finished_reasons=recv_obj.finished_reasons,
|
173
182
|
output_strs=output_strs,
|
174
|
-
|
175
|
-
|
183
|
+
prompt_tokens=recv_obj.prompt_tokens,
|
184
|
+
completion_tokens=recv_obj.completion_tokens,
|
185
|
+
cached_tokens=recv_obj.cached_tokens,
|
186
|
+
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
187
|
+
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
188
|
+
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
189
|
+
output_token_logprobs_idx=recv_obj.output_token_logprobs_idx,
|
190
|
+
input_top_logprobs_val=recv_obj.input_top_logprobs_val,
|
191
|
+
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
192
|
+
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
193
|
+
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
194
|
+
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
|
176
195
|
)
|
177
196
|
)
|
178
197
|
|
@@ -194,6 +213,7 @@ def run_detokenizer_process(
|
|
194
213
|
server_args: ServerArgs,
|
195
214
|
port_args: PortArgs,
|
196
215
|
):
|
216
|
+
setproctitle.setproctitle("sglang::detokenizer")
|
197
217
|
configure_logger(server_args)
|
198
218
|
parent_process = psutil.Process().parent()
|
199
219
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -308,6 +308,9 @@ class TokenizedEmbeddingReqInput:
|
|
308
308
|
class BatchTokenIDOut:
|
309
309
|
# The request id
|
310
310
|
rids: List[str]
|
311
|
+
# The finish reason
|
312
|
+
finished_reasons: List[BaseFinishReason]
|
313
|
+
# For incremental decoding
|
311
314
|
# The version id to sync decode status with in detokenizer_manager
|
312
315
|
vids: List[int]
|
313
316
|
decoded_texts: List[str]
|
@@ -315,35 +318,61 @@ class BatchTokenIDOut:
|
|
315
318
|
read_offsets: List[int]
|
316
319
|
# Only used when `--skip-tokenizer-init`
|
317
320
|
output_ids: Optional[List[int]]
|
321
|
+
# Detokenization configs
|
318
322
|
skip_special_tokens: List[bool]
|
319
323
|
spaces_between_special_tokens: List[bool]
|
320
|
-
meta_info: List[Dict]
|
321
|
-
finished_reason: List[BaseFinishReason]
|
322
324
|
no_stop_trim: List[bool]
|
325
|
+
# Token counts
|
326
|
+
prompt_tokens: List[int]
|
327
|
+
completion_tokens: List[int]
|
328
|
+
cached_tokens: List[int]
|
329
|
+
# Logprobs
|
330
|
+
input_token_logprobs_val: List[float]
|
331
|
+
input_token_logprobs_idx: List[int]
|
332
|
+
output_token_logprobs_val: List[float]
|
333
|
+
output_token_logprobs_idx: List[int]
|
334
|
+
input_top_logprobs_val: List[List]
|
335
|
+
input_top_logprobs_idx: List[List]
|
336
|
+
output_top_logprobs_val: List[List]
|
337
|
+
output_top_logprobs_idx: List[List]
|
338
|
+
normalized_prompt_logprob: List[float]
|
323
339
|
|
324
340
|
|
325
341
|
@dataclass
|
326
342
|
class BatchStrOut:
|
327
343
|
# The request id
|
328
344
|
rids: List[str]
|
345
|
+
# The finish reason
|
346
|
+
finished_reasons: List[dict]
|
329
347
|
# The output decoded strings
|
330
348
|
output_strs: List[str]
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
349
|
+
|
350
|
+
# Token counts
|
351
|
+
prompt_tokens: List[int]
|
352
|
+
completion_tokens: List[int]
|
353
|
+
cached_tokens: List[int]
|
354
|
+
# Logprobs
|
355
|
+
input_token_logprobs_val: List[float]
|
356
|
+
input_token_logprobs_idx: List[int]
|
357
|
+
output_token_logprobs_val: List[float]
|
358
|
+
output_token_logprobs_idx: List[int]
|
359
|
+
input_top_logprobs_val: List[List]
|
360
|
+
input_top_logprobs_idx: List[List]
|
361
|
+
output_top_logprobs_val: List[List]
|
362
|
+
output_top_logprobs_idx: List[List]
|
363
|
+
normalized_prompt_logprob: List[float]
|
335
364
|
|
336
365
|
|
337
366
|
@dataclass
|
338
367
|
class BatchEmbeddingOut:
|
339
368
|
# The request id
|
340
369
|
rids: List[str]
|
370
|
+
# The finish reason
|
371
|
+
finished_reasons: List[BaseFinishReason]
|
341
372
|
# The output embedding
|
342
373
|
embeddings: List[List[float]]
|
343
|
-
#
|
344
|
-
|
345
|
-
# The finish reason
|
346
|
-
finished_reason: List[BaseFinishReason]
|
374
|
+
# Token counts
|
375
|
+
prompt_tokens: List[int]
|
347
376
|
|
348
377
|
|
349
378
|
@dataclass
|
@@ -129,6 +129,7 @@ class ImageInputs:
|
|
129
129
|
image_hashes: Optional[list] = None
|
130
130
|
image_sizes: Optional[list] = None
|
131
131
|
image_offsets: Optional[list] = None
|
132
|
+
image_pad_len: Optional[list] = None
|
132
133
|
pad_values: Optional[list] = None
|
133
134
|
modalities: Optional[list] = None
|
134
135
|
num_image_tokens: Optional[int] = None
|
@@ -181,6 +182,7 @@ class ImageInputs:
|
|
181
182
|
optional_args = [
|
182
183
|
"image_sizes",
|
183
184
|
"image_offsets",
|
185
|
+
"image_pad_len",
|
184
186
|
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
185
187
|
"aspect_ratio_ids",
|
186
188
|
"aspect_ratio_mask",
|
@@ -200,6 +202,9 @@ class Req:
|
|
200
202
|
origin_input_text: str,
|
201
203
|
origin_input_ids: Tuple[int],
|
202
204
|
sampling_params: SamplingParams,
|
205
|
+
return_logprob: bool = False,
|
206
|
+
top_logprobs_num: int = 0,
|
207
|
+
stream: bool = False,
|
203
208
|
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
204
209
|
lora_path: Optional[str] = None,
|
205
210
|
input_embeds: Optional[List[List[float]]] = None,
|
@@ -217,10 +222,11 @@ class Req:
|
|
217
222
|
self.output_ids = [] # Each decode stage's output ids
|
218
223
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
219
224
|
self.session_id = session_id
|
225
|
+
self.input_embeds = input_embeds
|
220
226
|
|
227
|
+
# Sampling info
|
221
228
|
self.sampling_params = sampling_params
|
222
229
|
self.lora_path = lora_path
|
223
|
-
self.input_embeds = input_embeds
|
224
230
|
|
225
231
|
# Memory pool info
|
226
232
|
self.req_pool_idx = None
|
@@ -228,8 +234,8 @@ class Req:
|
|
228
234
|
# Check finish
|
229
235
|
self.tokenizer = None
|
230
236
|
self.finished_reason = None
|
231
|
-
self.stream = False
|
232
237
|
self.to_abort = False
|
238
|
+
self.stream = stream
|
233
239
|
|
234
240
|
# For incremental decoding
|
235
241
|
# ----- | --------- read_ids -------|
|
@@ -241,37 +247,46 @@ class Req:
|
|
241
247
|
# 2: read_offset
|
242
248
|
# 3: last token
|
243
249
|
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
244
|
-
self.decoded_text = ""
|
245
250
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
246
251
|
self.read_offset = None
|
247
|
-
|
248
|
-
# The number of decoded tokens for token usage report. Note that
|
249
|
-
# this does not include the jump forward tokens.
|
250
|
-
self.completion_tokens_wo_jump_forward = 0
|
252
|
+
self.decoded_text = ""
|
251
253
|
|
252
254
|
# For multimodal inputs
|
253
255
|
self.image_inputs: Optional[ImageInputs] = None
|
254
256
|
|
255
257
|
# Prefix info
|
256
258
|
self.prefix_indices = []
|
259
|
+
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
|
257
260
|
self.extend_input_len = 0
|
258
261
|
self.last_node = None
|
262
|
+
|
263
|
+
# Chunked prefill
|
259
264
|
self.is_being_chunked = 0
|
260
265
|
|
261
266
|
# For retraction
|
262
267
|
self.is_retracted = False
|
263
268
|
|
264
269
|
# Logprobs (arguments)
|
265
|
-
self.return_logprob =
|
270
|
+
self.return_logprob = return_logprob
|
266
271
|
self.logprob_start_len = 0
|
267
|
-
self.top_logprobs_num =
|
272
|
+
self.top_logprobs_num = top_logprobs_num
|
268
273
|
|
269
274
|
# Logprobs (return value)
|
270
275
|
self.normalized_prompt_logprob = None
|
271
|
-
self.
|
272
|
-
self.
|
273
|
-
self.
|
274
|
-
self.
|
276
|
+
self.input_token_logprobs_val = None
|
277
|
+
self.input_token_logprobs_idx = None
|
278
|
+
self.input_top_logprobs_val = None
|
279
|
+
self.input_top_logprobs_idx = None
|
280
|
+
|
281
|
+
if return_logprob:
|
282
|
+
self.output_token_logprobs_val = []
|
283
|
+
self.output_token_logprobs_idx = []
|
284
|
+
self.output_top_logprobs_val = []
|
285
|
+
self.output_top_logprobs_idx = []
|
286
|
+
else:
|
287
|
+
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
288
|
+
self.output_top_logprobs_val
|
289
|
+
) = self.output_top_logprobs_idx = None
|
275
290
|
|
276
291
|
# Logprobs (internal values)
|
277
292
|
# The tokens is prefilled but need to be considered as decode tokens
|
@@ -295,13 +310,14 @@ class Req:
|
|
295
310
|
else:
|
296
311
|
self.image_inputs.merge(image_inputs)
|
297
312
|
|
298
|
-
# whether request reached finished condition
|
299
313
|
def finished(self) -> bool:
|
314
|
+
# Whether request reached finished condition
|
300
315
|
return self.finished_reason is not None
|
301
316
|
|
302
317
|
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
303
318
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
304
319
|
if tree_cache is not None:
|
320
|
+
# tree cache is None if the prefix is not computed with tree cache.
|
305
321
|
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
306
322
|
rid=self.rid, key=self.adjust_max_prefix_ids()
|
307
323
|
)
|
@@ -454,8 +470,10 @@ class Req:
|
|
454
470
|
k = k + 1
|
455
471
|
else:
|
456
472
|
break
|
457
|
-
self.
|
458
|
-
self.
|
473
|
+
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
|
474
|
+
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
475
|
+
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
476
|
+
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
459
477
|
self.logprob_start_len = prompt_tokens + k
|
460
478
|
self.last_update_decode_tokens = len(self.output_ids) - k
|
461
479
|
|
@@ -470,7 +488,7 @@ bid = 0
|
|
470
488
|
|
471
489
|
@dataclasses.dataclass
|
472
490
|
class ScheduleBatch:
|
473
|
-
"""Store all
|
491
|
+
"""Store all information of a batch on the scheduler."""
|
474
492
|
|
475
493
|
# Request, memory pool, and cache
|
476
494
|
reqs: List[Req]
|
@@ -1068,9 +1086,9 @@ class ScheduleBatch:
|
|
1068
1086
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
1069
1087
|
self.reqs.extend(other.reqs)
|
1070
1088
|
|
1071
|
-
self.return_logprob
|
1072
|
-
self.has_stream
|
1073
|
-
self.has_grammar
|
1089
|
+
self.return_logprob |= other.return_logprob
|
1090
|
+
self.has_stream |= other.has_stream
|
1091
|
+
self.has_grammar |= other.has_grammar
|
1074
1092
|
|
1075
1093
|
def get_model_worker_batch(self):
|
1076
1094
|
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
@@ -1097,7 +1115,6 @@ class ScheduleBatch:
|
|
1097
1115
|
seq_lens=self.seq_lens,
|
1098
1116
|
out_cache_loc=self.out_cache_loc,
|
1099
1117
|
seq_lens_sum=self.seq_lens_sum,
|
1100
|
-
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
1101
1118
|
return_logprob=self.return_logprob,
|
1102
1119
|
top_logprobs_nums=self.top_logprobs_nums,
|
1103
1120
|
global_num_tokens=self.global_num_tokens,
|
@@ -1152,9 +1169,6 @@ class ModelWorkerBatch:
|
|
1152
1169
|
# The sum of all sequence lengths
|
1153
1170
|
seq_lens_sum: int
|
1154
1171
|
|
1155
|
-
# The memory pool operation records
|
1156
|
-
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
|
1157
|
-
|
1158
1172
|
# For logprob
|
1159
1173
|
return_logprob: bool
|
1160
1174
|
top_logprobs_nums: Optional[List[int]]
|