sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_offline_throughput.py +18 -6
- sglang/bench_one_batch.py +13 -0
- sglang/bench_serving.py +8 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/constrained/outlines_backend.py +5 -0
- sglang/srt/constrained/xgrammar_backend.py +9 -6
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +22 -5
- sglang/srt/layers/attention/torch_native_backend.py +22 -8
- sglang/srt/layers/attention/triton_backend.py +38 -33
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +665 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
- sglang/srt/layers/fused_moe_triton/layer.py +1 -1
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/quantization/__init__.py +2 -47
- sglang/srt/layers/quantization/fp8.py +607 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +11 -2
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/torchao_utils.py +58 -45
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +39 -24
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +236 -197
- sglang/srt/managers/tokenizer_manager.py +99 -58
- sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -11
- sglang/srt/model_executor/model_runner.py +24 -9
- sglang/srt/model_parallel.py +67 -10
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/deepseek_v2.py +87 -7
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +72 -13
- sglang/srt/models/llama.py +22 -5
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +12 -9
- sglang/srt/models/phi3_small.py +0 -5
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +0 -5
- sglang/srt/models/torch_native_llama.py +0 -5
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +4 -4
- sglang/srt/server_args.py +62 -13
- sglang/srt/utils.py +57 -10
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
sglang/srt/layers/sampler.py
CHANGED
@@ -51,7 +51,6 @@ class Sampler(nn.Module):
|
|
51
51
|
# Post process logits
|
52
52
|
logits.div_(sampling_info.temperatures)
|
53
53
|
probs = torch.softmax(logits, dim=-1)
|
54
|
-
logits = None
|
55
54
|
del logits
|
56
55
|
|
57
56
|
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
@@ -84,6 +83,7 @@ class Sampler(nn.Module):
|
|
84
83
|
sampling_info.top_ks,
|
85
84
|
sampling_info.top_ps,
|
86
85
|
sampling_info.min_ps,
|
86
|
+
sampling_info.need_min_p_sampling,
|
87
87
|
)
|
88
88
|
else:
|
89
89
|
raise ValueError(
|
@@ -98,18 +98,42 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
98
98
|
top_ks: torch.Tensor,
|
99
99
|
top_ps: torch.Tensor,
|
100
100
|
min_ps: torch.Tensor,
|
101
|
+
need_min_p_sampling: bool,
|
101
102
|
):
|
102
103
|
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
|
103
104
|
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
104
105
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
105
|
-
min_p_thresholds = probs_sort[:, 0] * min_ps
|
106
|
-
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
107
106
|
probs_sort[
|
108
107
|
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
109
108
|
>= top_ks.view(-1, 1)
|
110
109
|
] = 0.0
|
111
|
-
probs_sort[probs_sort
|
112
|
-
|
110
|
+
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
111
|
+
|
112
|
+
if need_min_p_sampling:
|
113
|
+
min_p_thresholds = probs_sort[:, 0] * min_ps
|
114
|
+
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
115
|
+
|
113
116
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
117
|
+
# int32 range is enough to represent the token ids
|
118
|
+
probs_idx = probs_idx.to(torch.int32)
|
114
119
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
115
120
|
return batch_next_token_ids
|
121
|
+
|
122
|
+
|
123
|
+
def top_p_normalize_probs(
|
124
|
+
probs: torch.Tensor,
|
125
|
+
top_ps: torch.Tensor,
|
126
|
+
):
|
127
|
+
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
128
|
+
return top_p_renorm_prob(probs, top_ps)
|
129
|
+
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
130
|
+
# See also top_k_top_p_min_p_sampling_from_probs_torch
|
131
|
+
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
132
|
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
133
|
+
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
134
|
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
135
|
+
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
|
136
|
+
else:
|
137
|
+
raise ValueError(
|
138
|
+
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
139
|
+
)
|
@@ -2,23 +2,24 @@
|
|
2
2
|
Common utilities for torchao.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from typing import Dict, Set
|
6
|
-
|
7
5
|
import torch
|
8
6
|
|
9
7
|
|
10
|
-
def
|
11
|
-
|
8
|
+
def apply_torchao_config_to_model(
|
9
|
+
model: torch.nn.Module, torchao_config: str, filter_fn=None
|
10
|
+
):
|
11
|
+
"""Quantize a modelwith torchao quantization specified by torchao_config
|
12
12
|
|
13
13
|
Args:
|
14
|
-
`
|
15
|
-
`torchao_config
|
16
|
-
quantize the
|
14
|
+
`model`: a model to be quantized based on torchao_config
|
15
|
+
`torchao_config` (str): type of quantization and their arguments we want to use to
|
16
|
+
quantize the model, e.g. int4wo-128 means int4 weight only quantization with group_size
|
17
17
|
128
|
18
18
|
"""
|
19
19
|
# Lazy import to suppress some warnings
|
20
20
|
from torchao.quantization import (
|
21
21
|
float8_dynamic_activation_float8_weight,
|
22
|
+
float8_weight_only,
|
22
23
|
int4_weight_only,
|
23
24
|
int8_dynamic_activation_int8_weight,
|
24
25
|
int8_weight_only,
|
@@ -26,12 +27,17 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
|
26
27
|
)
|
27
28
|
from torchao.quantization.observer import PerRow, PerTensor
|
28
29
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
30
|
+
if filter_fn is None:
|
31
|
+
|
32
|
+
def filter_fn(module, fqn):
|
33
|
+
return "proj" in fqn
|
34
|
+
|
35
|
+
if torchao_config == "" or torchao_config is None:
|
36
|
+
return model
|
37
|
+
elif "int8wo" in torchao_config:
|
38
|
+
quantize_(model, int8_weight_only(), filter_fn=filter_fn)
|
33
39
|
elif "int8dq" in torchao_config:
|
34
|
-
quantize_(
|
40
|
+
quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn)
|
35
41
|
elif "int4wo" in torchao_config:
|
36
42
|
group_size = int(torchao_config.split("-")[-1])
|
37
43
|
assert group_size in [
|
@@ -40,13 +46,46 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
|
40
46
|
128,
|
41
47
|
256,
|
42
48
|
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
|
43
|
-
quantize_(
|
44
|
-
elif "
|
45
|
-
|
49
|
+
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
|
50
|
+
elif "gemlite" in torchao_config:
|
51
|
+
# gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
|
52
|
+
# gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
|
53
|
+
import os
|
54
|
+
import pwd
|
55
|
+
|
56
|
+
import gemlite
|
57
|
+
from gemlite.core import GemLiteLinearTriton, set_autotune
|
58
|
+
|
59
|
+
try:
|
60
|
+
from torchao.quantization import gemlite_uintx_weight_only
|
61
|
+
except:
|
62
|
+
print(
|
63
|
+
f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization"
|
64
|
+
)
|
65
|
+
return model
|
66
|
+
|
67
|
+
_quant_args = torchao_config.split("-")
|
68
|
+
bit_width = int(_quant_args[-2])
|
69
|
+
group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
|
70
|
+
try:
|
71
|
+
packing_bitwidth = int(_quant_args[-3])
|
72
|
+
except:
|
73
|
+
# if only 2 inputs found, use default value
|
74
|
+
packing_bitwidth = 32
|
75
|
+
|
76
|
+
quantize_(
|
77
|
+
model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth)
|
78
|
+
)
|
46
79
|
|
80
|
+
# try to load gemlite kernel config
|
81
|
+
GemLiteLinearTriton.load_config(
|
82
|
+
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
|
83
|
+
)
|
84
|
+
|
85
|
+
elif "fp8wo" in torchao_config:
|
47
86
|
# this requires newer hardware
|
48
87
|
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
49
|
-
quantize_(
|
88
|
+
quantize_(model, float8_weight_only(), filter_fn=filter_fn)
|
50
89
|
elif "fp8dq" in torchao_config:
|
51
90
|
granularity = torchao_config.split("-")[-1]
|
52
91
|
GRANULARITY_MAP = {
|
@@ -57,39 +96,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
|
57
96
|
granularity in GRANULARITY_MAP
|
58
97
|
), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}"
|
59
98
|
quantize_(
|
60
|
-
|
99
|
+
model,
|
61
100
|
float8_dynamic_activation_float8_weight(
|
62
101
|
granularity=GRANULARITY_MAP[granularity]
|
63
102
|
),
|
103
|
+
filter_fn=filter_fn,
|
64
104
|
)
|
65
105
|
else:
|
66
106
|
raise ValueError(f"Unexpected config: {torchao_config}")
|
67
107
|
|
68
|
-
return
|
69
|
-
|
70
|
-
|
71
|
-
def apply_torchao_config_(
|
72
|
-
self: torch.nn.Module,
|
73
|
-
params_dict: Dict[str, torch.Tensor],
|
74
|
-
param_suffixes: Set[str],
|
75
|
-
) -> None:
|
76
|
-
"""A util function used for quantizing the weight parameters after they are loaded if
|
77
|
-
self.torchao_config is specified
|
78
|
-
|
79
|
-
Args:
|
80
|
-
`self`: the model we want to quantize
|
81
|
-
`params_dict`: dictionary mapping from param_name to the parameter Tensor
|
82
|
-
`param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
|
83
|
-
|
84
|
-
Returns:
|
85
|
-
None, the `params_dict` is modified inplace and the weights of `self` model are quantized
|
86
|
-
"""
|
87
|
-
if self.torchao_config:
|
88
|
-
for param_suffix in param_suffixes:
|
89
|
-
for name in params_dict:
|
90
|
-
param = params_dict[name]
|
91
|
-
if param_suffix in name and param.ndim == 2:
|
92
|
-
params_dict[name] = torchao_quantize_param_data(
|
93
|
-
param, self.torchao_config
|
94
|
-
)
|
95
|
-
self.load_state_dict(params_dict, assign=True)
|
108
|
+
return model
|
@@ -17,9 +17,10 @@ import dataclasses
|
|
17
17
|
import logging
|
18
18
|
import signal
|
19
19
|
from collections import OrderedDict
|
20
|
-
from typing import List, Union
|
20
|
+
from typing import Dict, List, Union
|
21
21
|
|
22
22
|
import psutil
|
23
|
+
import setproctitle
|
23
24
|
import zmq
|
24
25
|
|
25
26
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
@@ -28,7 +29,6 @@ from sglang.srt.managers.io_struct import (
|
|
28
29
|
BatchStrOut,
|
29
30
|
BatchTokenIDOut,
|
30
31
|
)
|
31
|
-
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
32
32
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
33
33
|
from sglang.srt.utils import configure_logger, get_zmq_socket
|
34
34
|
from sglang.utils import find_printable_text, get_exception_traceback
|
@@ -75,17 +75,25 @@ class DetokenizerManager:
|
|
75
75
|
|
76
76
|
self.decode_status = LimitedCapacityDict()
|
77
77
|
|
78
|
-
def
|
79
|
-
|
78
|
+
def trim_matched_stop(
|
79
|
+
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
80
|
+
):
|
81
|
+
if no_stop_trim or not finished_reason:
|
82
|
+
return output
|
83
|
+
|
84
|
+
matched = finished_reason.get("matched", None)
|
85
|
+
if not matched:
|
80
86
|
return output
|
81
87
|
|
82
|
-
#
|
83
|
-
|
84
|
-
|
88
|
+
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
89
|
+
|
90
|
+
# Trim stop str.
|
91
|
+
if isinstance(matched, str) and isinstance(output, str):
|
92
|
+
pos = output.find(matched)
|
85
93
|
return output[:pos] if pos != -1 else output
|
86
|
-
|
87
|
-
|
88
|
-
):
|
94
|
+
|
95
|
+
# Trim stop token.
|
96
|
+
if isinstance(matched, int) and isinstance(output, list):
|
89
97
|
assert len(output) > 0
|
90
98
|
return output[:-1]
|
91
99
|
return output
|
@@ -124,9 +132,9 @@ class DetokenizerManager:
|
|
124
132
|
s.decode_ids = recv_obj.decode_ids[i]
|
125
133
|
|
126
134
|
read_ids.append(
|
127
|
-
self.
|
135
|
+
self.trim_matched_stop(
|
128
136
|
s.decode_ids[s.surr_offset :],
|
129
|
-
recv_obj.
|
137
|
+
recv_obj.finished_reasons[i],
|
130
138
|
recv_obj.no_stop_trim[i],
|
131
139
|
)
|
132
140
|
)
|
@@ -149,7 +157,7 @@ class DetokenizerManager:
|
|
149
157
|
for i in range(bs):
|
150
158
|
s = self.decode_status[recv_obj.rids[i]]
|
151
159
|
new_text = read_texts[i][len(surr_texts[i]) :]
|
152
|
-
if recv_obj.
|
160
|
+
if recv_obj.finished_reasons[i] is None:
|
153
161
|
# Streaming chunk: update the decode status
|
154
162
|
if len(new_text) > 0 and not new_text.endswith("�"):
|
155
163
|
s.decoded_text = s.decoded_text + new_text
|
@@ -160,9 +168,9 @@ class DetokenizerManager:
|
|
160
168
|
new_text = find_printable_text(new_text)
|
161
169
|
|
162
170
|
output_strs.append(
|
163
|
-
self.
|
171
|
+
self.trim_matched_stop(
|
164
172
|
s.decoded_text + new_text,
|
165
|
-
recv_obj.
|
173
|
+
recv_obj.finished_reasons[i],
|
166
174
|
recv_obj.no_stop_trim[i],
|
167
175
|
)
|
168
176
|
)
|
@@ -170,9 +178,20 @@ class DetokenizerManager:
|
|
170
178
|
self.send_to_tokenizer.send_pyobj(
|
171
179
|
BatchStrOut(
|
172
180
|
rids=recv_obj.rids,
|
181
|
+
finished_reasons=recv_obj.finished_reasons,
|
173
182
|
output_strs=output_strs,
|
174
|
-
|
175
|
-
|
183
|
+
prompt_tokens=recv_obj.prompt_tokens,
|
184
|
+
completion_tokens=recv_obj.completion_tokens,
|
185
|
+
cached_tokens=recv_obj.cached_tokens,
|
186
|
+
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
187
|
+
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
188
|
+
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
189
|
+
output_token_logprobs_idx=recv_obj.output_token_logprobs_idx,
|
190
|
+
input_top_logprobs_val=recv_obj.input_top_logprobs_val,
|
191
|
+
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
192
|
+
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
193
|
+
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
194
|
+
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
|
176
195
|
)
|
177
196
|
)
|
178
197
|
|
@@ -194,6 +213,7 @@ def run_detokenizer_process(
|
|
194
213
|
server_args: ServerArgs,
|
195
214
|
port_args: PortArgs,
|
196
215
|
):
|
216
|
+
setproctitle.setproctitle("sglang::detokenizer")
|
197
217
|
configure_logger(server_args)
|
198
218
|
parent_process = psutil.Process().parent()
|
199
219
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -308,6 +308,9 @@ class TokenizedEmbeddingReqInput:
|
|
308
308
|
class BatchTokenIDOut:
|
309
309
|
# The request id
|
310
310
|
rids: List[str]
|
311
|
+
# The finish reason
|
312
|
+
finished_reasons: List[BaseFinishReason]
|
313
|
+
# For incremental decoding
|
311
314
|
# The version id to sync decode status with in detokenizer_manager
|
312
315
|
vids: List[int]
|
313
316
|
decoded_texts: List[str]
|
@@ -315,35 +318,61 @@ class BatchTokenIDOut:
|
|
315
318
|
read_offsets: List[int]
|
316
319
|
# Only used when `--skip-tokenizer-init`
|
317
320
|
output_ids: Optional[List[int]]
|
321
|
+
# Detokenization configs
|
318
322
|
skip_special_tokens: List[bool]
|
319
323
|
spaces_between_special_tokens: List[bool]
|
320
|
-
meta_info: List[Dict]
|
321
|
-
finished_reason: List[BaseFinishReason]
|
322
324
|
no_stop_trim: List[bool]
|
325
|
+
# Token counts
|
326
|
+
prompt_tokens: List[int]
|
327
|
+
completion_tokens: List[int]
|
328
|
+
cached_tokens: List[int]
|
329
|
+
# Logprobs
|
330
|
+
input_token_logprobs_val: List[float]
|
331
|
+
input_token_logprobs_idx: List[int]
|
332
|
+
output_token_logprobs_val: List[float]
|
333
|
+
output_token_logprobs_idx: List[int]
|
334
|
+
input_top_logprobs_val: List[List]
|
335
|
+
input_top_logprobs_idx: List[List]
|
336
|
+
output_top_logprobs_val: List[List]
|
337
|
+
output_top_logprobs_idx: List[List]
|
338
|
+
normalized_prompt_logprob: List[float]
|
323
339
|
|
324
340
|
|
325
341
|
@dataclass
|
326
342
|
class BatchStrOut:
|
327
343
|
# The request id
|
328
344
|
rids: List[str]
|
345
|
+
# The finish reason
|
346
|
+
finished_reasons: List[dict]
|
329
347
|
# The output decoded strings
|
330
348
|
output_strs: List[str]
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
349
|
+
|
350
|
+
# Token counts
|
351
|
+
prompt_tokens: List[int]
|
352
|
+
completion_tokens: List[int]
|
353
|
+
cached_tokens: List[int]
|
354
|
+
# Logprobs
|
355
|
+
input_token_logprobs_val: List[float]
|
356
|
+
input_token_logprobs_idx: List[int]
|
357
|
+
output_token_logprobs_val: List[float]
|
358
|
+
output_token_logprobs_idx: List[int]
|
359
|
+
input_top_logprobs_val: List[List]
|
360
|
+
input_top_logprobs_idx: List[List]
|
361
|
+
output_top_logprobs_val: List[List]
|
362
|
+
output_top_logprobs_idx: List[List]
|
363
|
+
normalized_prompt_logprob: List[float]
|
335
364
|
|
336
365
|
|
337
366
|
@dataclass
|
338
367
|
class BatchEmbeddingOut:
|
339
368
|
# The request id
|
340
369
|
rids: List[str]
|
370
|
+
# The finish reason
|
371
|
+
finished_reasons: List[BaseFinishReason]
|
341
372
|
# The output embedding
|
342
373
|
embeddings: List[List[float]]
|
343
|
-
#
|
344
|
-
|
345
|
-
# The finish reason
|
346
|
-
finished_reason: List[BaseFinishReason]
|
374
|
+
# Token counts
|
375
|
+
prompt_tokens: List[int]
|
347
376
|
|
348
377
|
|
349
378
|
@dataclass
|
@@ -58,6 +58,7 @@ global_server_args_dict = {
|
|
58
58
|
"torchao_config": ServerArgs.torchao_config,
|
59
59
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
60
60
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
61
|
+
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
61
62
|
}
|
62
63
|
|
63
64
|
|
@@ -128,6 +129,7 @@ class ImageInputs:
|
|
128
129
|
image_hashes: Optional[list] = None
|
129
130
|
image_sizes: Optional[list] = None
|
130
131
|
image_offsets: Optional[list] = None
|
132
|
+
image_pad_len: Optional[list] = None
|
131
133
|
pad_values: Optional[list] = None
|
132
134
|
modalities: Optional[list] = None
|
133
135
|
num_image_tokens: Optional[int] = None
|
@@ -180,6 +182,7 @@ class ImageInputs:
|
|
180
182
|
optional_args = [
|
181
183
|
"image_sizes",
|
182
184
|
"image_offsets",
|
185
|
+
"image_pad_len",
|
183
186
|
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
184
187
|
"aspect_ratio_ids",
|
185
188
|
"aspect_ratio_mask",
|
@@ -199,6 +202,9 @@ class Req:
|
|
199
202
|
origin_input_text: str,
|
200
203
|
origin_input_ids: Tuple[int],
|
201
204
|
sampling_params: SamplingParams,
|
205
|
+
return_logprob: bool = False,
|
206
|
+
top_logprobs_num: int = 0,
|
207
|
+
stream: bool = False,
|
202
208
|
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
203
209
|
lora_path: Optional[str] = None,
|
204
210
|
input_embeds: Optional[List[List[float]]] = None,
|
@@ -216,10 +222,11 @@ class Req:
|
|
216
222
|
self.output_ids = [] # Each decode stage's output ids
|
217
223
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
218
224
|
self.session_id = session_id
|
225
|
+
self.input_embeds = input_embeds
|
219
226
|
|
227
|
+
# Sampling info
|
220
228
|
self.sampling_params = sampling_params
|
221
229
|
self.lora_path = lora_path
|
222
|
-
self.input_embeds = input_embeds
|
223
230
|
|
224
231
|
# Memory pool info
|
225
232
|
self.req_pool_idx = None
|
@@ -227,8 +234,8 @@ class Req:
|
|
227
234
|
# Check finish
|
228
235
|
self.tokenizer = None
|
229
236
|
self.finished_reason = None
|
230
|
-
self.stream = False
|
231
237
|
self.to_abort = False
|
238
|
+
self.stream = stream
|
232
239
|
|
233
240
|
# For incremental decoding
|
234
241
|
# ----- | --------- read_ids -------|
|
@@ -240,37 +247,46 @@ class Req:
|
|
240
247
|
# 2: read_offset
|
241
248
|
# 3: last token
|
242
249
|
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
243
|
-
self.decoded_text = ""
|
244
250
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
245
251
|
self.read_offset = None
|
246
|
-
|
247
|
-
# The number of decoded tokens for token usage report. Note that
|
248
|
-
# this does not include the jump forward tokens.
|
249
|
-
self.completion_tokens_wo_jump_forward = 0
|
252
|
+
self.decoded_text = ""
|
250
253
|
|
251
254
|
# For multimodal inputs
|
252
255
|
self.image_inputs: Optional[ImageInputs] = None
|
253
256
|
|
254
257
|
# Prefix info
|
255
258
|
self.prefix_indices = []
|
259
|
+
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
|
256
260
|
self.extend_input_len = 0
|
257
261
|
self.last_node = None
|
262
|
+
|
263
|
+
# Chunked prefill
|
258
264
|
self.is_being_chunked = 0
|
259
265
|
|
260
266
|
# For retraction
|
261
267
|
self.is_retracted = False
|
262
268
|
|
263
269
|
# Logprobs (arguments)
|
264
|
-
self.return_logprob =
|
270
|
+
self.return_logprob = return_logprob
|
265
271
|
self.logprob_start_len = 0
|
266
|
-
self.top_logprobs_num =
|
272
|
+
self.top_logprobs_num = top_logprobs_num
|
267
273
|
|
268
274
|
# Logprobs (return value)
|
269
275
|
self.normalized_prompt_logprob = None
|
270
|
-
self.
|
271
|
-
self.
|
272
|
-
self.
|
273
|
-
self.
|
276
|
+
self.input_token_logprobs_val = None
|
277
|
+
self.input_token_logprobs_idx = None
|
278
|
+
self.input_top_logprobs_val = None
|
279
|
+
self.input_top_logprobs_idx = None
|
280
|
+
|
281
|
+
if return_logprob:
|
282
|
+
self.output_token_logprobs_val = []
|
283
|
+
self.output_token_logprobs_idx = []
|
284
|
+
self.output_top_logprobs_val = []
|
285
|
+
self.output_top_logprobs_idx = []
|
286
|
+
else:
|
287
|
+
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
288
|
+
self.output_top_logprobs_val
|
289
|
+
) = self.output_top_logprobs_idx = None
|
274
290
|
|
275
291
|
# Logprobs (internal values)
|
276
292
|
# The tokens is prefilled but need to be considered as decode tokens
|
@@ -294,13 +310,14 @@ class Req:
|
|
294
310
|
else:
|
295
311
|
self.image_inputs.merge(image_inputs)
|
296
312
|
|
297
|
-
# whether request reached finished condition
|
298
313
|
def finished(self) -> bool:
|
314
|
+
# Whether request reached finished condition
|
299
315
|
return self.finished_reason is not None
|
300
316
|
|
301
317
|
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
302
318
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
303
319
|
if tree_cache is not None:
|
320
|
+
# tree cache is None if the prefix is not computed with tree cache.
|
304
321
|
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
305
322
|
rid=self.rid, key=self.adjust_max_prefix_ids()
|
306
323
|
)
|
@@ -453,8 +470,10 @@ class Req:
|
|
453
470
|
k = k + 1
|
454
471
|
else:
|
455
472
|
break
|
456
|
-
self.
|
457
|
-
self.
|
473
|
+
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
|
474
|
+
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
475
|
+
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
476
|
+
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
458
477
|
self.logprob_start_len = prompt_tokens + k
|
459
478
|
self.last_update_decode_tokens = len(self.output_ids) - k
|
460
479
|
|
@@ -469,7 +488,7 @@ bid = 0
|
|
469
488
|
|
470
489
|
@dataclasses.dataclass
|
471
490
|
class ScheduleBatch:
|
472
|
-
"""Store all
|
491
|
+
"""Store all information of a batch on the scheduler."""
|
473
492
|
|
474
493
|
# Request, memory pool, and cache
|
475
494
|
reqs: List[Req]
|
@@ -1067,9 +1086,9 @@ class ScheduleBatch:
|
|
1067
1086
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
1068
1087
|
self.reqs.extend(other.reqs)
|
1069
1088
|
|
1070
|
-
self.return_logprob
|
1071
|
-
self.has_stream
|
1072
|
-
self.has_grammar
|
1089
|
+
self.return_logprob |= other.return_logprob
|
1090
|
+
self.has_stream |= other.has_stream
|
1091
|
+
self.has_grammar |= other.has_grammar
|
1073
1092
|
|
1074
1093
|
def get_model_worker_batch(self):
|
1075
1094
|
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
@@ -1096,7 +1115,6 @@ class ScheduleBatch:
|
|
1096
1115
|
seq_lens=self.seq_lens,
|
1097
1116
|
out_cache_loc=self.out_cache_loc,
|
1098
1117
|
seq_lens_sum=self.seq_lens_sum,
|
1099
|
-
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
1100
1118
|
return_logprob=self.return_logprob,
|
1101
1119
|
top_logprobs_nums=self.top_logprobs_nums,
|
1102
1120
|
global_num_tokens=self.global_num_tokens,
|
@@ -1151,9 +1169,6 @@ class ModelWorkerBatch:
|
|
1151
1169
|
# The sum of all sequence lengths
|
1152
1170
|
seq_lens_sum: int
|
1153
1171
|
|
1154
|
-
# The memory pool operation records
|
1155
|
-
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
|
1156
|
-
|
1157
1172
|
# For logprob
|
1158
1173
|
return_logprob: bool
|
1159
1174
|
top_logprobs_nums: Optional[List[int]]
|