sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,503 +0,0 @@
|
|
1
|
-
from dataclasses import dataclass
|
2
|
-
from enum import Enum, auto
|
3
|
-
from typing import List
|
4
|
-
|
5
|
-
import numpy as np
|
6
|
-
import torch
|
7
|
-
from sglang.srt.managers.router.radix_cache import RadixCache
|
8
|
-
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
9
|
-
|
10
|
-
|
11
|
-
class ForwardMode(Enum):
|
12
|
-
PREFILL = auto()
|
13
|
-
EXTEND = auto()
|
14
|
-
DECODE = auto()
|
15
|
-
|
16
|
-
|
17
|
-
class FinishReason(Enum):
|
18
|
-
LENGTH = auto()
|
19
|
-
EOS_TOKEN = auto()
|
20
|
-
STOP_STR = auto()
|
21
|
-
|
22
|
-
|
23
|
-
class Req:
|
24
|
-
def __init__(self, rid, input_text, input_ids):
|
25
|
-
self.rid = rid
|
26
|
-
self.input_text = input_text
|
27
|
-
self.input_ids = input_ids
|
28
|
-
self.output_ids = []
|
29
|
-
|
30
|
-
# Since jump forward may retokenize the prompt with partial outputs,
|
31
|
-
# we maintain the original prompt length to report the correct usage.
|
32
|
-
self.prompt_tokens = len(input_ids)
|
33
|
-
# The number of decoded tokens for token usage report. Note that
|
34
|
-
# this does not include the jump forward tokens.
|
35
|
-
self.completion_tokens_wo_jump_forward = 0
|
36
|
-
|
37
|
-
# For vision input
|
38
|
-
self.pixel_values = None
|
39
|
-
self.image_size = None
|
40
|
-
self.image_offset = 0
|
41
|
-
self.pad_value = None
|
42
|
-
|
43
|
-
self.sampling_params = None
|
44
|
-
self.return_logprob = False
|
45
|
-
self.logprob_start_len = 0
|
46
|
-
self.stream = False
|
47
|
-
|
48
|
-
self.tokenizer = None
|
49
|
-
self.finished = False
|
50
|
-
self.finish_reason = None
|
51
|
-
self.hit_stop_str = None
|
52
|
-
|
53
|
-
self.extend_input_len = 0
|
54
|
-
self.prefix_indices = []
|
55
|
-
self.last_node = None
|
56
|
-
|
57
|
-
self.logprob = None
|
58
|
-
self.token_logprob = None
|
59
|
-
self.normalized_logprob = None
|
60
|
-
|
61
|
-
# For constrained decoding
|
62
|
-
self.regex_fsm = None
|
63
|
-
self.regex_fsm_state = 0
|
64
|
-
self.jump_forward_map = None
|
65
|
-
self.output_and_jump_forward_str = ""
|
66
|
-
|
67
|
-
def max_new_tokens(self):
|
68
|
-
return self.sampling_params.max_new_tokens
|
69
|
-
|
70
|
-
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
71
|
-
old_output_str = self.tokenizer.decode(self.output_ids)
|
72
|
-
# FIXME: This logic does not really solve the problem of determining whether
|
73
|
-
# there should be a leading space.
|
74
|
-
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
|
75
|
-
first_token = (
|
76
|
-
first_token.decode() if isinstance(first_token, bytes) else first_token
|
77
|
-
)
|
78
|
-
if first_token.startswith("▁"):
|
79
|
-
old_output_str = " " + old_output_str
|
80
|
-
new_input_string = (
|
81
|
-
self.input_text
|
82
|
-
+ self.output_and_jump_forward_str
|
83
|
-
+ old_output_str
|
84
|
-
+ jump_forward_str
|
85
|
-
)
|
86
|
-
new_input_ids = self.tokenizer.encode(new_input_string)
|
87
|
-
if self.pixel_values is not None:
|
88
|
-
# NOTE: This is a hack because the old input_ids contains the image padding
|
89
|
-
jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str))
|
90
|
-
else:
|
91
|
-
jump_forward_tokens_len = (
|
92
|
-
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
|
93
|
-
)
|
94
|
-
|
95
|
-
# print("=" * 100)
|
96
|
-
# print(f"Catch jump forward:\n{jump_forward_str}")
|
97
|
-
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
98
|
-
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
|
99
|
-
|
100
|
-
self.input_ids = new_input_ids
|
101
|
-
self.output_ids = []
|
102
|
-
self.sampling_params.max_new_tokens = max(
|
103
|
-
self.sampling_params.max_new_tokens - jump_forward_tokens_len, 0
|
104
|
-
)
|
105
|
-
self.regex_fsm_state = next_state
|
106
|
-
self.output_and_jump_forward_str = (
|
107
|
-
self.output_and_jump_forward_str + old_output_str + jump_forward_str
|
108
|
-
)
|
109
|
-
|
110
|
-
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
|
111
|
-
# print("*" * 100)
|
112
|
-
|
113
|
-
def check_finished(self):
|
114
|
-
if self.finished:
|
115
|
-
return
|
116
|
-
|
117
|
-
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
118
|
-
self.finished = True
|
119
|
-
self.finish_reason = FinishReason.LENGTH
|
120
|
-
return
|
121
|
-
|
122
|
-
if (
|
123
|
-
self.output_ids[-1] == self.tokenizer.eos_token_id
|
124
|
-
and self.sampling_params.ignore_eos == False
|
125
|
-
):
|
126
|
-
self.finished = True
|
127
|
-
self.finish_reason = FinishReason.EOS_TOKEN
|
128
|
-
return
|
129
|
-
|
130
|
-
if len(self.sampling_params.stop_strs) > 0:
|
131
|
-
tail_str = self.tokenizer.decode(
|
132
|
-
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
133
|
-
)
|
134
|
-
|
135
|
-
for stop_str in self.sampling_params.stop_strs:
|
136
|
-
if stop_str in tail_str:
|
137
|
-
self.finished = True
|
138
|
-
self.finish_reason = FinishReason.STOP_STR
|
139
|
-
self.hit_stop_str = stop_str
|
140
|
-
return
|
141
|
-
|
142
|
-
def __repr__(self):
|
143
|
-
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
|
144
|
-
|
145
|
-
|
146
|
-
@dataclass
|
147
|
-
class Batch:
|
148
|
-
reqs: List[Req]
|
149
|
-
req_to_token_pool: ReqToTokenPool
|
150
|
-
token_to_kv_pool: TokenToKVPool
|
151
|
-
tree_cache: RadixCache
|
152
|
-
|
153
|
-
# batched arguments to model runner
|
154
|
-
input_ids: torch.Tensor = None
|
155
|
-
req_pool_indices: torch.Tensor = None
|
156
|
-
seq_lens: torch.Tensor = None
|
157
|
-
prefix_lens: torch.Tensor = None
|
158
|
-
position_ids_offsets: torch.Tensor = None
|
159
|
-
out_cache_loc: torch.Tensor = None
|
160
|
-
out_cache_cont_start: torch.Tensor = None
|
161
|
-
out_cache_cont_end: torch.Tensor = None
|
162
|
-
return_logprob: bool = False
|
163
|
-
|
164
|
-
# for multimodal
|
165
|
-
pixel_values: List[torch.Tensor] = None
|
166
|
-
image_sizes: List[List[int]] = None
|
167
|
-
image_offsets: List[int] = None
|
168
|
-
|
169
|
-
# other arguments for control
|
170
|
-
output_ids: torch.Tensor = None
|
171
|
-
extend_num_tokens: int = None
|
172
|
-
|
173
|
-
# batched sampling params
|
174
|
-
temperatures: torch.Tensor = None
|
175
|
-
top_ps: torch.Tensor = None
|
176
|
-
top_ks: torch.Tensor = None
|
177
|
-
frequency_penalties: torch.Tensor = None
|
178
|
-
presence_penalties: torch.Tensor = None
|
179
|
-
logit_bias: torch.Tensor = None
|
180
|
-
|
181
|
-
@classmethod
|
182
|
-
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
183
|
-
return_logprob = any(req.return_logprob for req in reqs)
|
184
|
-
|
185
|
-
return cls(
|
186
|
-
reqs=reqs,
|
187
|
-
req_to_token_pool=req_to_token_pool,
|
188
|
-
token_to_kv_pool=token_to_kv_pool,
|
189
|
-
tree_cache=tree_cache,
|
190
|
-
return_logprob=return_logprob,
|
191
|
-
)
|
192
|
-
|
193
|
-
def is_empty(self):
|
194
|
-
return len(self.reqs) == 0
|
195
|
-
|
196
|
-
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
197
|
-
device = "cuda"
|
198
|
-
bs = len(self.reqs)
|
199
|
-
reqs = self.reqs
|
200
|
-
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
201
|
-
prefix_indices = [r.prefix_indices for r in reqs]
|
202
|
-
|
203
|
-
# Handle prefix
|
204
|
-
flatten_input_ids = []
|
205
|
-
extend_lens = []
|
206
|
-
prefix_lens = []
|
207
|
-
seq_lens = []
|
208
|
-
|
209
|
-
req_pool_indices = self.req_to_token_pool.alloc(bs)
|
210
|
-
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
211
|
-
for i in range(bs):
|
212
|
-
flatten_input_ids.extend(input_ids[i])
|
213
|
-
extend_lens.append(len(input_ids[i]))
|
214
|
-
|
215
|
-
if len(prefix_indices[i]) == 0:
|
216
|
-
prefix_lens.append(0)
|
217
|
-
else:
|
218
|
-
prefix_lens.append(len(prefix_indices[i]))
|
219
|
-
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
220
|
-
: len(prefix_indices[i])
|
221
|
-
] = prefix_indices[i]
|
222
|
-
|
223
|
-
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
224
|
-
|
225
|
-
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
226
|
-
|
227
|
-
# Alloc mem
|
228
|
-
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
229
|
-
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
230
|
-
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
231
|
-
if out_cache_loc is None:
|
232
|
-
if not self.tree_cache.disable:
|
233
|
-
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
234
|
-
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
235
|
-
|
236
|
-
if out_cache_loc is None:
|
237
|
-
print("Prefill out of memory. This should nerver happen.")
|
238
|
-
self.tree_cache.pretty_print()
|
239
|
-
exit()
|
240
|
-
|
241
|
-
pt = 0
|
242
|
-
for i in range(bs):
|
243
|
-
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
244
|
-
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
245
|
-
] = out_cache_loc[pt : pt + extend_lens[i]]
|
246
|
-
pt += extend_lens[i]
|
247
|
-
|
248
|
-
# Handle logit bias
|
249
|
-
logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device)
|
250
|
-
for i in range(bs):
|
251
|
-
if reqs[i].sampling_params.dtype == "int":
|
252
|
-
logit_bias[i] = int_token_logit_bias
|
253
|
-
|
254
|
-
# Set fields
|
255
|
-
self.input_ids = torch.tensor(
|
256
|
-
flatten_input_ids, dtype=torch.int32, device=device
|
257
|
-
)
|
258
|
-
self.pixel_values = [r.pixel_values for r in reqs]
|
259
|
-
self.image_sizes = [r.image_size for r in reqs]
|
260
|
-
self.image_offsets = [
|
261
|
-
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
262
|
-
]
|
263
|
-
self.req_pool_indices = req_pool_indices
|
264
|
-
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
|
265
|
-
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
266
|
-
self.position_ids_offsets = position_ids_offsets
|
267
|
-
self.extend_num_tokens = extend_num_tokens
|
268
|
-
self.out_cache_loc = out_cache_loc
|
269
|
-
|
270
|
-
self.temperatures = torch.tensor(
|
271
|
-
[r.sampling_params.temperature for r in reqs],
|
272
|
-
dtype=torch.float,
|
273
|
-
device=device,
|
274
|
-
).view(-1, 1)
|
275
|
-
self.top_ps = torch.tensor(
|
276
|
-
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
277
|
-
).view(-1, 1)
|
278
|
-
self.top_ks = torch.tensor(
|
279
|
-
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
280
|
-
).view(-1, 1)
|
281
|
-
self.frequency_penalties = torch.tensor(
|
282
|
-
[r.sampling_params.frequency_penalty for r in reqs],
|
283
|
-
dtype=torch.float,
|
284
|
-
device=device,
|
285
|
-
)
|
286
|
-
self.presence_penalties = torch.tensor(
|
287
|
-
[r.sampling_params.presence_penalty for r in reqs],
|
288
|
-
dtype=torch.float,
|
289
|
-
device=device,
|
290
|
-
)
|
291
|
-
self.logit_bias = logit_bias
|
292
|
-
|
293
|
-
def check_decode_mem(self):
|
294
|
-
bs = len(self.reqs)
|
295
|
-
if self.token_to_kv_pool.available_size() >= bs:
|
296
|
-
return True
|
297
|
-
|
298
|
-
if not self.tree_cache.disable:
|
299
|
-
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
300
|
-
if self.token_to_kv_pool.available_size() >= bs:
|
301
|
-
return True
|
302
|
-
|
303
|
-
return False
|
304
|
-
|
305
|
-
def retract_decode(self):
|
306
|
-
sorted_indices = [i for i in range(len(self.reqs))]
|
307
|
-
sorted_indices.sort(
|
308
|
-
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
|
309
|
-
reverse=True,
|
310
|
-
)
|
311
|
-
|
312
|
-
retracted_reqs = []
|
313
|
-
seq_lens_np = self.seq_lens.cpu().numpy()
|
314
|
-
req_pool_indices_np = self.req_pool_indices.cpu().numpy()
|
315
|
-
while self.token_to_kv_pool.available_size() < len(self.reqs):
|
316
|
-
idx = sorted_indices.pop()
|
317
|
-
req = self.reqs[idx]
|
318
|
-
retracted_reqs.append(req)
|
319
|
-
|
320
|
-
self.tree_cache.dec_ref_counter(req.last_node)
|
321
|
-
req.prefix_indices = None
|
322
|
-
req.last_node = None
|
323
|
-
req.extend_input_len = 0
|
324
|
-
req.output_ids = []
|
325
|
-
req.regex_fsm_state = 0
|
326
|
-
|
327
|
-
# TODO: apply more fine-grained retraction
|
328
|
-
|
329
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
330
|
-
req_pool_indices_np[idx]
|
331
|
-
][: seq_lens_np[idx]]
|
332
|
-
self.token_to_kv_pool.free(token_indices)
|
333
|
-
|
334
|
-
self.filter_batch(sorted_indices)
|
335
|
-
|
336
|
-
return retracted_reqs
|
337
|
-
|
338
|
-
def check_for_jump_forward(self):
|
339
|
-
jump_forward_reqs = []
|
340
|
-
filter_indices = [i for i in range(len(self.reqs))]
|
341
|
-
|
342
|
-
req_pool_indices_cpu = None
|
343
|
-
|
344
|
-
for i, req in enumerate(self.reqs):
|
345
|
-
if req.jump_forward_map is not None:
|
346
|
-
res = req.jump_forward_map.jump_forward(req.regex_fsm_state)
|
347
|
-
if res is not None:
|
348
|
-
jump_forward_str, next_state = res
|
349
|
-
if len(jump_forward_str) <= 1:
|
350
|
-
continue
|
351
|
-
|
352
|
-
# insert the old request into tree_cache
|
353
|
-
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
|
354
|
-
if req_pool_indices_cpu is None:
|
355
|
-
req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
|
356
|
-
req_pool_idx = req_pool_indices_cpu[i]
|
357
|
-
indices = self.req_to_token_pool.req_to_token[
|
358
|
-
req_pool_idx, : len(token_ids_in_memory)
|
359
|
-
]
|
360
|
-
prefix_len = self.tree_cache.insert(
|
361
|
-
token_ids_in_memory, indices.clone()
|
362
|
-
)
|
363
|
-
self.token_to_kv_pool.free(indices[:prefix_len])
|
364
|
-
self.req_to_token_pool.free(req_pool_idx)
|
365
|
-
self.tree_cache.dec_ref_counter(req.last_node)
|
366
|
-
|
367
|
-
# jump-forward
|
368
|
-
req.jump_forward_and_retokenize(jump_forward_str, next_state)
|
369
|
-
|
370
|
-
jump_forward_reqs.append(req)
|
371
|
-
filter_indices.remove(i)
|
372
|
-
|
373
|
-
if len(filter_indices) < len(self.reqs):
|
374
|
-
self.filter_batch(filter_indices)
|
375
|
-
|
376
|
-
return jump_forward_reqs
|
377
|
-
|
378
|
-
def prepare_for_decode(self, input_ids=None):
|
379
|
-
if input_ids is None:
|
380
|
-
input_ids = [
|
381
|
-
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
|
382
|
-
]
|
383
|
-
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
384
|
-
self.seq_lens.add_(1)
|
385
|
-
self.prefix_lens = None
|
386
|
-
|
387
|
-
# Alloc mem
|
388
|
-
bs = len(self.reqs)
|
389
|
-
alloc_res = self.token_to_kv_pool.alloc_contiguous(bs)
|
390
|
-
if alloc_res is None:
|
391
|
-
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
392
|
-
|
393
|
-
if self.out_cache_loc is None:
|
394
|
-
print("Decode out of memory. This should nerver happen.")
|
395
|
-
self.tree_cache.pretty_print()
|
396
|
-
exit()
|
397
|
-
|
398
|
-
self.out_cache_cont_start = None
|
399
|
-
self.out_cache_cont_end = None
|
400
|
-
else:
|
401
|
-
self.out_cache_loc = alloc_res[0]
|
402
|
-
self.out_cache_cont_start = alloc_res[1]
|
403
|
-
self.out_cache_cont_end = alloc_res[2]
|
404
|
-
|
405
|
-
self.req_to_token_pool.req_to_token[
|
406
|
-
self.req_pool_indices, self.seq_lens - 1
|
407
|
-
] = self.out_cache_loc
|
408
|
-
|
409
|
-
def filter_batch(self, unfinished_indices: List[int]):
|
410
|
-
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
411
|
-
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
412
|
-
self.seq_lens = self.seq_lens[new_indices]
|
413
|
-
self.input_ids = None
|
414
|
-
self.req_pool_indices = self.req_pool_indices[new_indices]
|
415
|
-
self.prefix_lens = None
|
416
|
-
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
417
|
-
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
418
|
-
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
419
|
-
|
420
|
-
for item in [
|
421
|
-
"temperatures",
|
422
|
-
"top_ps",
|
423
|
-
"top_ks",
|
424
|
-
"frequency_penalties",
|
425
|
-
"presence_penalties",
|
426
|
-
"logit_bias",
|
427
|
-
]:
|
428
|
-
setattr(self, item, getattr(self, item)[new_indices])
|
429
|
-
|
430
|
-
def merge(self, other):
|
431
|
-
self.reqs.extend(other.reqs)
|
432
|
-
|
433
|
-
self.req_pool_indices = torch.concat(
|
434
|
-
[self.req_pool_indices, other.req_pool_indices]
|
435
|
-
)
|
436
|
-
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
437
|
-
self.prefix_lens = None
|
438
|
-
self.position_ids_offsets = torch.concat(
|
439
|
-
[self.position_ids_offsets, other.position_ids_offsets]
|
440
|
-
)
|
441
|
-
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
442
|
-
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
443
|
-
|
444
|
-
for item in [
|
445
|
-
"temperatures",
|
446
|
-
"top_ps",
|
447
|
-
"top_ks",
|
448
|
-
"frequency_penalties",
|
449
|
-
"presence_penalties",
|
450
|
-
"logit_bias",
|
451
|
-
]:
|
452
|
-
setattr(
|
453
|
-
self, item, torch.concat([getattr(self, item), getattr(other, item)])
|
454
|
-
)
|
455
|
-
|
456
|
-
def sample(self, logits: torch.Tensor):
|
457
|
-
# Post process logits
|
458
|
-
logits = logits.contiguous()
|
459
|
-
logits.div_(self.temperatures)
|
460
|
-
logits.add_(self.logit_bias)
|
461
|
-
|
462
|
-
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
463
|
-
if has_regex:
|
464
|
-
allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
|
465
|
-
for i, req in enumerate(self.reqs):
|
466
|
-
if req.regex_fsm is not None:
|
467
|
-
allowed_mask.zero_()
|
468
|
-
allowed_mask[
|
469
|
-
req.regex_fsm.allowed_token_ids(req.regex_fsm_state)
|
470
|
-
] = 1
|
471
|
-
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
472
|
-
|
473
|
-
# TODO(lmzheng): apply penalty
|
474
|
-
probs = torch.softmax(logits, dim=-1)
|
475
|
-
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
|
476
|
-
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
477
|
-
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
478
|
-
-1
|
479
|
-
)
|
480
|
-
batch_next_token_probs = torch.gather(
|
481
|
-
probs_sort, dim=1, index=sampled_index
|
482
|
-
).view(-1)
|
483
|
-
|
484
|
-
if has_regex:
|
485
|
-
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
486
|
-
for i, req in enumerate(self.reqs):
|
487
|
-
if req.regex_fsm is not None:
|
488
|
-
req.regex_fsm_state = req.regex_fsm.next_state(
|
489
|
-
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
490
|
-
)
|
491
|
-
|
492
|
-
return batch_next_token_ids, batch_next_token_probs
|
493
|
-
|
494
|
-
|
495
|
-
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
|
496
|
-
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
497
|
-
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
498
|
-
probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
|
499
|
-
probs_sort[
|
500
|
-
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
|
501
|
-
] = 0.0
|
502
|
-
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
503
|
-
return probs_sort, probs_idx
|
@@ -1,79 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import logging
|
3
|
-
|
4
|
-
import uvloop
|
5
|
-
import zmq
|
6
|
-
import zmq.asyncio
|
7
|
-
from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
|
8
|
-
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
9
|
-
from sglang.srt.server_args import PortArgs, ServerArgs
|
10
|
-
from sglang.srt.utils import get_exception_traceback
|
11
|
-
|
12
|
-
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
13
|
-
|
14
|
-
|
15
|
-
class RouterManager:
|
16
|
-
def __init__(self, model_client: ModelRpcClient, port_args: PortArgs):
|
17
|
-
# Init communication
|
18
|
-
context = zmq.asyncio.Context(2)
|
19
|
-
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
20
|
-
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
21
|
-
|
22
|
-
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
23
|
-
self.send_to_detokenizer.connect(
|
24
|
-
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
25
|
-
)
|
26
|
-
|
27
|
-
# Init status
|
28
|
-
self.model_client = model_client
|
29
|
-
self.recv_reqs = []
|
30
|
-
|
31
|
-
# Init some configs
|
32
|
-
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
|
33
|
-
|
34
|
-
async def loop_for_forward(self):
|
35
|
-
while True:
|
36
|
-
next_step_input = list(self.recv_reqs)
|
37
|
-
self.recv_reqs = []
|
38
|
-
out_pyobjs = await self.model_client.step(next_step_input)
|
39
|
-
|
40
|
-
for obj in out_pyobjs:
|
41
|
-
self.send_to_detokenizer.send_pyobj(obj)
|
42
|
-
|
43
|
-
# async sleep for receiving the subsequent request and avoiding cache miss
|
44
|
-
if len(out_pyobjs) != 0:
|
45
|
-
has_finished = any([obj.finished for obj in out_pyobjs])
|
46
|
-
if has_finished:
|
47
|
-
await asyncio.sleep(self.extend_dependency_time)
|
48
|
-
|
49
|
-
await asyncio.sleep(0.0006)
|
50
|
-
|
51
|
-
async def loop_for_recv_requests(self):
|
52
|
-
while True:
|
53
|
-
recv_req = await self.recv_from_tokenizer.recv_pyobj()
|
54
|
-
self.recv_reqs.append(recv_req)
|
55
|
-
|
56
|
-
|
57
|
-
def start_router_process(
|
58
|
-
server_args: ServerArgs,
|
59
|
-
port_args: PortArgs,
|
60
|
-
pipe_writer,
|
61
|
-
):
|
62
|
-
logging.basicConfig(
|
63
|
-
level=getattr(logging, server_args.log_level.upper()),
|
64
|
-
format="%(message)s",
|
65
|
-
)
|
66
|
-
|
67
|
-
try:
|
68
|
-
model_client = ModelRpcClient(server_args, port_args)
|
69
|
-
router = RouterManager(model_client, port_args)
|
70
|
-
except Exception:
|
71
|
-
pipe_writer.send(get_exception_traceback())
|
72
|
-
raise
|
73
|
-
|
74
|
-
pipe_writer.send("init ok")
|
75
|
-
|
76
|
-
loop = asyncio.new_event_loop()
|
77
|
-
asyncio.set_event_loop(loop)
|
78
|
-
loop.create_task(router.loop_for_recv_requests())
|
79
|
-
loop.run_until_complete(router.loop_for_forward())
|