sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +4 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +4 -1
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +1 -1
- sglang/lang/ir.py +15 -5
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +64 -9
- sglang/srt/layers/fused_moe.py +186 -89
- sglang/srt/layers/logits_processor.py +53 -25
- sglang/srt/layers/radix_attention.py +34 -7
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +142 -67
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +8 -3
- sglang/srt/managers/controller/model_runner.py +154 -54
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +140 -135
- sglang/srt/managers/detokenizer_manager.py +15 -19
- sglang/srt/managers/io_struct.py +10 -4
- sglang/srt/managers/tokenizer_manager.py +14 -13
- sglang/srt/model_config.py +83 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +11 -4
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +33 -23
- sglang/srt/openai_protocol.py +1 -1
- sglang/srt/server.py +60 -19
- sglang/srt/server_args.py +79 -44
- sglang/srt/utils.py +146 -37
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,596 +0,0 @@
|
|
1
|
-
from dataclasses import dataclass
|
2
|
-
from enum import IntEnum, auto
|
3
|
-
from typing import List
|
4
|
-
|
5
|
-
import numpy as np
|
6
|
-
import torch
|
7
|
-
|
8
|
-
from sglang.srt.managers.router.radix_cache import RadixCache
|
9
|
-
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
10
|
-
|
11
|
-
|
12
|
-
class ForwardMode(IntEnum):
|
13
|
-
PREFILL = auto()
|
14
|
-
EXTEND = auto()
|
15
|
-
DECODE = auto()
|
16
|
-
|
17
|
-
|
18
|
-
class FinishReason(IntEnum):
|
19
|
-
EOS_TOKEN = auto()
|
20
|
-
LENGTH = auto()
|
21
|
-
STOP_STR = auto()
|
22
|
-
ABORT = auto()
|
23
|
-
|
24
|
-
@staticmethod
|
25
|
-
def to_str(reason):
|
26
|
-
if reason == FinishReason.EOS_TOKEN:
|
27
|
-
return None
|
28
|
-
elif reason == FinishReason.LENGTH:
|
29
|
-
return "length"
|
30
|
-
elif reason == FinishReason.STOP_STR:
|
31
|
-
return "stop"
|
32
|
-
elif reason == FinishReason.ABORT:
|
33
|
-
return "abort"
|
34
|
-
else:
|
35
|
-
return None
|
36
|
-
|
37
|
-
|
38
|
-
class Req:
|
39
|
-
def __init__(self, rid, origin_input_text, origin_input_ids):
|
40
|
-
self.rid = rid
|
41
|
-
self.origin_input_text = origin_input_text
|
42
|
-
self.origin_input_ids = origin_input_ids
|
43
|
-
self.origin_input_ids_unpadded = origin_input_ids # before image padding
|
44
|
-
self.prev_output_str = ""
|
45
|
-
self.prev_output_ids = []
|
46
|
-
self.output_ids = []
|
47
|
-
self.input_ids = None # input_ids = origin_input_ids + prev_output_ids
|
48
|
-
|
49
|
-
# The number of decoded tokens for token usage report. Note that
|
50
|
-
# this does not include the jump forward tokens.
|
51
|
-
self.completion_tokens_wo_jump_forward = 0
|
52
|
-
|
53
|
-
# For vision input
|
54
|
-
self.pixel_values = None
|
55
|
-
self.image_size = None
|
56
|
-
self.image_offset = 0
|
57
|
-
self.pad_value = None
|
58
|
-
|
59
|
-
# Sampling parameters
|
60
|
-
self.sampling_params = None
|
61
|
-
self.stream = False
|
62
|
-
|
63
|
-
# Check finish
|
64
|
-
self.tokenizer = None
|
65
|
-
self.finished = False
|
66
|
-
self.finish_reason = None
|
67
|
-
self.hit_stop_str = None
|
68
|
-
|
69
|
-
# Prefix info
|
70
|
-
self.extend_input_len = 0
|
71
|
-
self.prefix_indices = []
|
72
|
-
self.last_node = None
|
73
|
-
|
74
|
-
# Logprobs
|
75
|
-
self.return_logprob = False
|
76
|
-
self.logprob_start_len = 0
|
77
|
-
self.top_logprobs_num = 0
|
78
|
-
self.normalized_prompt_logprob = None
|
79
|
-
self.prefill_token_logprobs = None
|
80
|
-
self.prefill_top_logprobs = None
|
81
|
-
self.decode_token_logprobs = []
|
82
|
-
self.decode_top_logprobs = []
|
83
|
-
# The tokens is prefilled but need to be considered as decode tokens
|
84
|
-
# and should be updated for the decode logprobs
|
85
|
-
self.last_update_decode_tokens = 0
|
86
|
-
|
87
|
-
# Constrained decoding
|
88
|
-
self.regex_fsm = None
|
89
|
-
self.regex_fsm_state = 0
|
90
|
-
self.jump_forward_map = None
|
91
|
-
|
92
|
-
def partial_decode(self, ids):
|
93
|
-
first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
|
94
|
-
first_token = (
|
95
|
-
first_token.decode() if isinstance(first_token, bytes) else first_token
|
96
|
-
)
|
97
|
-
return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids)
|
98
|
-
|
99
|
-
def max_new_tokens(self):
|
100
|
-
return self.sampling_params.max_new_tokens
|
101
|
-
|
102
|
-
def check_finished(self):
|
103
|
-
if self.finished:
|
104
|
-
return
|
105
|
-
|
106
|
-
if (
|
107
|
-
len(self.prev_output_ids) + len(self.output_ids)
|
108
|
-
>= self.sampling_params.max_new_tokens
|
109
|
-
):
|
110
|
-
self.finished = True
|
111
|
-
self.finish_reason = FinishReason.LENGTH
|
112
|
-
return
|
113
|
-
|
114
|
-
if (
|
115
|
-
self.output_ids[-1] == self.tokenizer.eos_token_id
|
116
|
-
and self.sampling_params.ignore_eos == False
|
117
|
-
):
|
118
|
-
self.finished = True
|
119
|
-
self.finish_reason = FinishReason.EOS_TOKEN
|
120
|
-
return
|
121
|
-
|
122
|
-
if len(self.sampling_params.stop_strs) > 0:
|
123
|
-
tail_str = self.tokenizer.decode(
|
124
|
-
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
125
|
-
)
|
126
|
-
|
127
|
-
for stop_str in self.sampling_params.stop_strs:
|
128
|
-
# FIXME: (minor) try incremental match in prev_output_str
|
129
|
-
if stop_str in tail_str or stop_str in self.prev_output_str:
|
130
|
-
self.finished = True
|
131
|
-
self.finish_reason = FinishReason.STOP_STR
|
132
|
-
self.hit_stop_str = stop_str
|
133
|
-
return
|
134
|
-
|
135
|
-
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
136
|
-
# FIXME: This logic does not really solve the problem of determining whether
|
137
|
-
# there should be a leading space.
|
138
|
-
cur_output_str = self.partial_decode(self.output_ids)
|
139
|
-
|
140
|
-
# TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
|
141
|
-
if self.origin_input_text is None:
|
142
|
-
# Recovering text can only use unpadded ids
|
143
|
-
self.origin_input_text = self.tokenizer.decode(
|
144
|
-
self.origin_input_ids_unpadded
|
145
|
-
)
|
146
|
-
|
147
|
-
all_text = (
|
148
|
-
self.origin_input_text
|
149
|
-
+ self.prev_output_str
|
150
|
-
+ cur_output_str
|
151
|
-
+ jump_forward_str
|
152
|
-
)
|
153
|
-
all_ids = self.tokenizer.encode(all_text)
|
154
|
-
prompt_tokens = len(self.origin_input_ids_unpadded)
|
155
|
-
self.origin_input_ids = all_ids[:prompt_tokens]
|
156
|
-
self.origin_input_ids_unpadded = self.origin_input_ids
|
157
|
-
# NOTE: the output ids may not strictly correspond to the output text
|
158
|
-
old_prev_output_ids = self.prev_output_ids
|
159
|
-
self.prev_output_ids = all_ids[prompt_tokens:]
|
160
|
-
self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str
|
161
|
-
self.output_ids = []
|
162
|
-
|
163
|
-
self.regex_fsm_state = next_state
|
164
|
-
|
165
|
-
if self.return_logprob:
|
166
|
-
# For fast-forward part's logprobs
|
167
|
-
k = 0
|
168
|
-
for i, old_id in enumerate(old_prev_output_ids):
|
169
|
-
if old_id == self.prev_output_ids[i]:
|
170
|
-
k = k + 1
|
171
|
-
else:
|
172
|
-
break
|
173
|
-
self.decode_token_logprobs = self.decode_token_logprobs[:k]
|
174
|
-
self.decode_top_logprobs = self.decode_top_logprobs[:k]
|
175
|
-
self.logprob_start_len = prompt_tokens + k
|
176
|
-
self.last_update_decode_tokens = len(self.prev_output_ids) - k
|
177
|
-
|
178
|
-
# print("=" * 100)
|
179
|
-
# print(f"Catch jump forward:\n{jump_forward_str}")
|
180
|
-
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
181
|
-
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
|
182
|
-
|
183
|
-
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
|
184
|
-
# print("*" * 100)
|
185
|
-
|
186
|
-
def __repr__(self):
|
187
|
-
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
188
|
-
|
189
|
-
|
190
|
-
@dataclass
|
191
|
-
class Batch:
|
192
|
-
reqs: List[Req]
|
193
|
-
req_to_token_pool: ReqToTokenPool
|
194
|
-
token_to_kv_pool: TokenToKVPool
|
195
|
-
tree_cache: RadixCache
|
196
|
-
|
197
|
-
# batched arguments to model runner
|
198
|
-
input_ids: torch.Tensor = None
|
199
|
-
req_pool_indices: torch.Tensor = None
|
200
|
-
seq_lens: torch.Tensor = None
|
201
|
-
prefix_lens: torch.Tensor = None
|
202
|
-
position_ids_offsets: torch.Tensor = None
|
203
|
-
out_cache_loc: torch.Tensor = None
|
204
|
-
out_cache_cont_start: torch.Tensor = None
|
205
|
-
out_cache_cont_end: torch.Tensor = None
|
206
|
-
|
207
|
-
# for processing logprobs
|
208
|
-
return_logprob: bool = False
|
209
|
-
top_logprobs_nums: List[int] = None
|
210
|
-
|
211
|
-
# for multimodal
|
212
|
-
pixel_values: List[torch.Tensor] = None
|
213
|
-
image_sizes: List[List[int]] = None
|
214
|
-
image_offsets: List[int] = None
|
215
|
-
|
216
|
-
# other arguments for control
|
217
|
-
output_ids: torch.Tensor = None
|
218
|
-
extend_num_tokens: int = None
|
219
|
-
|
220
|
-
# batched sampling params
|
221
|
-
temperatures: torch.Tensor = None
|
222
|
-
top_ps: torch.Tensor = None
|
223
|
-
top_ks: torch.Tensor = None
|
224
|
-
frequency_penalties: torch.Tensor = None
|
225
|
-
presence_penalties: torch.Tensor = None
|
226
|
-
logit_bias: torch.Tensor = None
|
227
|
-
|
228
|
-
@classmethod
|
229
|
-
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
230
|
-
return_logprob = any(req.return_logprob for req in reqs)
|
231
|
-
|
232
|
-
return cls(
|
233
|
-
reqs=reqs,
|
234
|
-
req_to_token_pool=req_to_token_pool,
|
235
|
-
token_to_kv_pool=token_to_kv_pool,
|
236
|
-
tree_cache=tree_cache,
|
237
|
-
return_logprob=return_logprob,
|
238
|
-
)
|
239
|
-
|
240
|
-
def is_empty(self):
|
241
|
-
return len(self.reqs) == 0
|
242
|
-
|
243
|
-
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
244
|
-
device = "cuda"
|
245
|
-
bs = len(self.reqs)
|
246
|
-
reqs = self.reqs
|
247
|
-
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
248
|
-
prefix_indices = [r.prefix_indices for r in reqs]
|
249
|
-
|
250
|
-
# Handle prefix
|
251
|
-
flatten_input_ids = []
|
252
|
-
extend_lens = []
|
253
|
-
prefix_lens = []
|
254
|
-
seq_lens = []
|
255
|
-
|
256
|
-
req_pool_indices = self.req_to_token_pool.alloc(bs)
|
257
|
-
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
258
|
-
for i in range(bs):
|
259
|
-
flatten_input_ids.extend(input_ids[i])
|
260
|
-
extend_lens.append(len(input_ids[i]))
|
261
|
-
|
262
|
-
if len(prefix_indices[i]) == 0:
|
263
|
-
prefix_lens.append(0)
|
264
|
-
else:
|
265
|
-
prefix_lens.append(len(prefix_indices[i]))
|
266
|
-
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
267
|
-
: len(prefix_indices[i])
|
268
|
-
] = prefix_indices[i]
|
269
|
-
|
270
|
-
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
271
|
-
|
272
|
-
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
273
|
-
|
274
|
-
# Alloc mem
|
275
|
-
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
276
|
-
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
277
|
-
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
278
|
-
if out_cache_loc is None:
|
279
|
-
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
|
280
|
-
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
281
|
-
|
282
|
-
if out_cache_loc is None:
|
283
|
-
print("Prefill out of memory. This should never happen.")
|
284
|
-
self.tree_cache.pretty_print()
|
285
|
-
exit()
|
286
|
-
|
287
|
-
pt = 0
|
288
|
-
for i in range(bs):
|
289
|
-
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
290
|
-
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
291
|
-
] = out_cache_loc[pt : pt + extend_lens[i]]
|
292
|
-
pt += extend_lens[i]
|
293
|
-
|
294
|
-
# Handle logit bias but only allocate when needed
|
295
|
-
logit_bias = None
|
296
|
-
for i in range(bs):
|
297
|
-
if reqs[i].sampling_params.dtype == "int":
|
298
|
-
if logit_bias is None:
|
299
|
-
logit_bias = torch.zeros(
|
300
|
-
(bs, vocab_size), dtype=torch.float32, device=device
|
301
|
-
)
|
302
|
-
logit_bias[i] = int_token_logit_bias
|
303
|
-
|
304
|
-
# Set fields
|
305
|
-
self.input_ids = torch.tensor(
|
306
|
-
flatten_input_ids, dtype=torch.int32, device=device
|
307
|
-
)
|
308
|
-
self.pixel_values = [r.pixel_values for r in reqs]
|
309
|
-
self.image_sizes = [r.image_size for r in reqs]
|
310
|
-
self.image_offsets = [
|
311
|
-
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
312
|
-
]
|
313
|
-
self.req_pool_indices = req_pool_indices
|
314
|
-
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
|
315
|
-
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
316
|
-
self.position_ids_offsets = position_ids_offsets
|
317
|
-
self.extend_num_tokens = extend_num_tokens
|
318
|
-
self.out_cache_loc = out_cache_loc
|
319
|
-
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
320
|
-
|
321
|
-
self.temperatures = torch.tensor(
|
322
|
-
[r.sampling_params.temperature for r in reqs],
|
323
|
-
dtype=torch.float,
|
324
|
-
device=device,
|
325
|
-
).view(-1, 1)
|
326
|
-
self.top_ps = torch.tensor(
|
327
|
-
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
328
|
-
).view(-1, 1)
|
329
|
-
self.top_ks = torch.tensor(
|
330
|
-
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
331
|
-
).view(-1, 1)
|
332
|
-
self.frequency_penalties = torch.tensor(
|
333
|
-
[r.sampling_params.frequency_penalty for r in reqs],
|
334
|
-
dtype=torch.float,
|
335
|
-
device=device,
|
336
|
-
)
|
337
|
-
self.presence_penalties = torch.tensor(
|
338
|
-
[r.sampling_params.presence_penalty for r in reqs],
|
339
|
-
dtype=torch.float,
|
340
|
-
device=device,
|
341
|
-
)
|
342
|
-
self.logit_bias = logit_bias
|
343
|
-
|
344
|
-
def check_decode_mem(self):
|
345
|
-
bs = len(self.reqs)
|
346
|
-
if self.token_to_kv_pool.available_size() >= bs:
|
347
|
-
return True
|
348
|
-
|
349
|
-
self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
|
350
|
-
|
351
|
-
if self.token_to_kv_pool.available_size() >= bs:
|
352
|
-
return True
|
353
|
-
|
354
|
-
return False
|
355
|
-
|
356
|
-
def retract_decode(self):
|
357
|
-
sorted_indices = [i for i in range(len(self.reqs))]
|
358
|
-
# TODO(lsyin): improve the priority of retraction
|
359
|
-
sorted_indices.sort(
|
360
|
-
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
|
361
|
-
reverse=True,
|
362
|
-
)
|
363
|
-
|
364
|
-
retracted_reqs = []
|
365
|
-
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
366
|
-
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
367
|
-
while self.token_to_kv_pool.available_size() < len(self.reqs):
|
368
|
-
idx = sorted_indices.pop()
|
369
|
-
req = self.reqs[idx]
|
370
|
-
retracted_reqs.append(req)
|
371
|
-
|
372
|
-
# TODO: apply more fine-grained retraction
|
373
|
-
last_uncached_pos = len(req.prefix_indices)
|
374
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
375
|
-
req_pool_indices_cpu[idx]
|
376
|
-
][last_uncached_pos : seq_lens_cpu[idx]]
|
377
|
-
self.token_to_kv_pool.dec_refs(token_indices)
|
378
|
-
|
379
|
-
# release the last node
|
380
|
-
self.tree_cache.dec_lock_ref(req.last_node)
|
381
|
-
|
382
|
-
cur_output_str = req.partial_decode(req.output_ids)
|
383
|
-
req.prev_output_str = req.prev_output_str + cur_output_str
|
384
|
-
req.prev_output_ids.extend(req.output_ids)
|
385
|
-
|
386
|
-
req.prefix_indices = None
|
387
|
-
req.last_node = None
|
388
|
-
req.extend_input_len = 0
|
389
|
-
req.output_ids = []
|
390
|
-
|
391
|
-
# For incremental logprobs
|
392
|
-
req.last_update_decode_tokens = 0
|
393
|
-
req.logprob_start_len = 10**9
|
394
|
-
|
395
|
-
self.filter_batch(sorted_indices)
|
396
|
-
|
397
|
-
return retracted_reqs
|
398
|
-
|
399
|
-
def check_for_jump_forward(self, model_runner):
|
400
|
-
jump_forward_reqs = []
|
401
|
-
filter_indices = [i for i in range(len(self.reqs))]
|
402
|
-
|
403
|
-
req_pool_indices_cpu = None
|
404
|
-
|
405
|
-
for i, req in enumerate(self.reqs):
|
406
|
-
if req.jump_forward_map is not None:
|
407
|
-
res = req.jump_forward_map.jump_forward(req.regex_fsm_state)
|
408
|
-
if res is not None:
|
409
|
-
jump_forward_str, next_state = res
|
410
|
-
if len(jump_forward_str) <= 1:
|
411
|
-
continue
|
412
|
-
|
413
|
-
if req_pool_indices_cpu is None:
|
414
|
-
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
415
|
-
|
416
|
-
# insert the old request into tree_cache
|
417
|
-
self.tree_cache.cache_req(
|
418
|
-
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
419
|
-
last_uncached_pos=len(req.prefix_indices),
|
420
|
-
req_pool_idx=req_pool_indices_cpu[i],
|
421
|
-
)
|
422
|
-
|
423
|
-
# unlock the last node
|
424
|
-
self.tree_cache.dec_lock_ref(req.last_node)
|
425
|
-
|
426
|
-
# jump-forward
|
427
|
-
req.jump_forward_and_retokenize(jump_forward_str, next_state)
|
428
|
-
|
429
|
-
# re-applying image padding
|
430
|
-
if req.pixel_values is not None:
|
431
|
-
(
|
432
|
-
req.origin_input_ids,
|
433
|
-
req.image_offset,
|
434
|
-
) = model_runner.model.pad_input_ids(
|
435
|
-
req.origin_input_ids_unpadded,
|
436
|
-
req.pad_value,
|
437
|
-
req.pixel_values.shape,
|
438
|
-
req.image_size,
|
439
|
-
)
|
440
|
-
|
441
|
-
jump_forward_reqs.append(req)
|
442
|
-
filter_indices.remove(i)
|
443
|
-
|
444
|
-
if len(filter_indices) < len(self.reqs):
|
445
|
-
self.filter_batch(filter_indices)
|
446
|
-
|
447
|
-
return jump_forward_reqs
|
448
|
-
|
449
|
-
def prepare_for_decode(self, input_ids=None):
|
450
|
-
if input_ids is None:
|
451
|
-
input_ids = [
|
452
|
-
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
|
453
|
-
]
|
454
|
-
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
455
|
-
self.seq_lens.add_(1)
|
456
|
-
self.prefix_lens = None
|
457
|
-
|
458
|
-
# Alloc mem
|
459
|
-
bs = len(self.reqs)
|
460
|
-
alloc_res = self.token_to_kv_pool.alloc_contiguous(bs)
|
461
|
-
if alloc_res is None:
|
462
|
-
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
463
|
-
|
464
|
-
if self.out_cache_loc is None:
|
465
|
-
print("Decode out of memory. This should never happen.")
|
466
|
-
self.tree_cache.pretty_print()
|
467
|
-
exit()
|
468
|
-
|
469
|
-
self.out_cache_cont_start = None
|
470
|
-
self.out_cache_cont_end = None
|
471
|
-
else:
|
472
|
-
self.out_cache_loc = alloc_res[0]
|
473
|
-
self.out_cache_cont_start = alloc_res[1]
|
474
|
-
self.out_cache_cont_end = alloc_res[2]
|
475
|
-
|
476
|
-
self.req_to_token_pool.req_to_token[
|
477
|
-
self.req_pool_indices, self.seq_lens - 1
|
478
|
-
] = self.out_cache_loc
|
479
|
-
|
480
|
-
def filter_batch(self, unfinished_indices: List[int]):
|
481
|
-
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
482
|
-
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
483
|
-
self.seq_lens = self.seq_lens[new_indices]
|
484
|
-
self.input_ids = None
|
485
|
-
self.req_pool_indices = self.req_pool_indices[new_indices]
|
486
|
-
self.prefix_lens = None
|
487
|
-
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
488
|
-
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
489
|
-
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
490
|
-
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
491
|
-
|
492
|
-
for item in [
|
493
|
-
"temperatures",
|
494
|
-
"top_ps",
|
495
|
-
"top_ks",
|
496
|
-
"frequency_penalties",
|
497
|
-
"presence_penalties",
|
498
|
-
"logit_bias",
|
499
|
-
]:
|
500
|
-
self_val = getattr(self, item, None)
|
501
|
-
# logit_bias can be None
|
502
|
-
if self_val is not None:
|
503
|
-
setattr(self, item, self_val[new_indices])
|
504
|
-
|
505
|
-
def merge(self, other: "Batch"):
|
506
|
-
self.reqs.extend(other.reqs)
|
507
|
-
|
508
|
-
self.req_pool_indices = torch.concat(
|
509
|
-
[self.req_pool_indices, other.req_pool_indices]
|
510
|
-
)
|
511
|
-
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
512
|
-
self.prefix_lens = None
|
513
|
-
self.position_ids_offsets = torch.concat(
|
514
|
-
[self.position_ids_offsets, other.position_ids_offsets]
|
515
|
-
)
|
516
|
-
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
517
|
-
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
518
|
-
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
519
|
-
|
520
|
-
for item in [
|
521
|
-
"temperatures",
|
522
|
-
"top_ps",
|
523
|
-
"top_ks",
|
524
|
-
"frequency_penalties",
|
525
|
-
"presence_penalties",
|
526
|
-
]:
|
527
|
-
self_val = getattr(self, item, None)
|
528
|
-
other_val = getattr(other, item, None)
|
529
|
-
setattr(self, item, torch.concat([self_val, other_val]))
|
530
|
-
|
531
|
-
# logit_bias can be None
|
532
|
-
if self.logit_bias is not None or other.logit_bias is not None:
|
533
|
-
vocab_size = (
|
534
|
-
self.logit_bias.shape[1]
|
535
|
-
if self.logit_bias is not None
|
536
|
-
else other.logit_bias.shape[1]
|
537
|
-
)
|
538
|
-
if self.logit_bias is None:
|
539
|
-
self.logit_bias = torch.zeros(
|
540
|
-
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
541
|
-
)
|
542
|
-
if other.logit_bias is None:
|
543
|
-
other.logit_bias = torch.zeros(
|
544
|
-
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
545
|
-
)
|
546
|
-
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
547
|
-
|
548
|
-
def sample(self, logits: torch.Tensor):
|
549
|
-
# Post process logits
|
550
|
-
logits = logits.contiguous()
|
551
|
-
logits.div_(self.temperatures)
|
552
|
-
if self.logit_bias is not None:
|
553
|
-
logits.add_(self.logit_bias)
|
554
|
-
|
555
|
-
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
556
|
-
if has_regex:
|
557
|
-
allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
|
558
|
-
for i, req in enumerate(self.reqs):
|
559
|
-
if req.regex_fsm is not None:
|
560
|
-
allowed_mask.zero_()
|
561
|
-
allowed_mask[
|
562
|
-
req.regex_fsm.allowed_token_ids(req.regex_fsm_state)
|
563
|
-
] = 1
|
564
|
-
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
565
|
-
|
566
|
-
# TODO(lmzheng): apply penalty
|
567
|
-
probs = torch.softmax(logits, dim=-1)
|
568
|
-
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
|
569
|
-
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
570
|
-
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
571
|
-
-1
|
572
|
-
)
|
573
|
-
batch_next_token_probs = torch.gather(
|
574
|
-
probs_sort, dim=1, index=sampled_index
|
575
|
-
).view(-1)
|
576
|
-
|
577
|
-
if has_regex:
|
578
|
-
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
579
|
-
for i, req in enumerate(self.reqs):
|
580
|
-
if req.regex_fsm is not None:
|
581
|
-
req.regex_fsm_state = req.regex_fsm.next_state(
|
582
|
-
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
583
|
-
)
|
584
|
-
|
585
|
-
return batch_next_token_ids, batch_next_token_probs
|
586
|
-
|
587
|
-
|
588
|
-
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
|
589
|
-
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
590
|
-
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
591
|
-
probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
|
592
|
-
probs_sort[
|
593
|
-
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
|
594
|
-
] = 0.0
|
595
|
-
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
596
|
-
return probs_sort, probs_idx
|
@@ -1,82 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import logging
|
3
|
-
|
4
|
-
import uvloop
|
5
|
-
import zmq
|
6
|
-
import zmq.asyncio
|
7
|
-
|
8
|
-
from sglang.global_config import global_config
|
9
|
-
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
10
|
-
from sglang.srt.server_args import PortArgs, ServerArgs
|
11
|
-
from sglang.utils import get_exception_traceback
|
12
|
-
|
13
|
-
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
14
|
-
|
15
|
-
|
16
|
-
class RouterManager:
|
17
|
-
def __init__(self, model_client: ModelRpcClient, port_args: PortArgs):
|
18
|
-
# Init communication
|
19
|
-
context = zmq.asyncio.Context(2)
|
20
|
-
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
21
|
-
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
22
|
-
|
23
|
-
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
24
|
-
self.send_to_detokenizer.connect(
|
25
|
-
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
26
|
-
)
|
27
|
-
|
28
|
-
# Init status
|
29
|
-
self.model_client = model_client
|
30
|
-
self.recv_reqs = []
|
31
|
-
|
32
|
-
# Init some configs
|
33
|
-
self.request_dependency_time = global_config.request_dependency_time
|
34
|
-
|
35
|
-
async def loop_for_forward(self):
|
36
|
-
while True:
|
37
|
-
next_step_input = list(self.recv_reqs)
|
38
|
-
self.recv_reqs = []
|
39
|
-
out_pyobjs = await self.model_client.step(next_step_input)
|
40
|
-
|
41
|
-
for obj in out_pyobjs:
|
42
|
-
self.send_to_detokenizer.send_pyobj(obj)
|
43
|
-
|
44
|
-
# async sleep for receiving the subsequent request and avoiding cache miss
|
45
|
-
slept = False
|
46
|
-
if len(out_pyobjs) != 0:
|
47
|
-
has_finished = any([obj.finished for obj in out_pyobjs])
|
48
|
-
if has_finished:
|
49
|
-
if self.request_dependency_time > 0:
|
50
|
-
slept = True
|
51
|
-
await asyncio.sleep(self.request_dependency_time)
|
52
|
-
|
53
|
-
if not slept:
|
54
|
-
await asyncio.sleep(0.0006)
|
55
|
-
|
56
|
-
async def loop_for_recv_requests(self):
|
57
|
-
while True:
|
58
|
-
recv_req = await self.recv_from_tokenizer.recv_pyobj()
|
59
|
-
self.recv_reqs.append(recv_req)
|
60
|
-
|
61
|
-
|
62
|
-
def start_router_process(
|
63
|
-
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
|
64
|
-
):
|
65
|
-
logging.basicConfig(
|
66
|
-
level=getattr(logging, server_args.log_level.upper()),
|
67
|
-
format="%(message)s",
|
68
|
-
)
|
69
|
-
|
70
|
-
try:
|
71
|
-
model_client = ModelRpcClient(server_args, port_args, model_overide_args)
|
72
|
-
router = RouterManager(model_client, port_args)
|
73
|
-
except Exception:
|
74
|
-
pipe_writer.send(get_exception_traceback())
|
75
|
-
raise
|
76
|
-
|
77
|
-
pipe_writer.send("init ok")
|
78
|
-
|
79
|
-
loop = asyncio.new_event_loop()
|
80
|
-
asyncio.set_event_loop(loop)
|
81
|
-
loop.create_task(router.loop_for_recv_requests())
|
82
|
-
loop.run_until_complete(router.loop_for_forward())
|