sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +7 -7
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +158 -11
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +12 -2
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +28 -3
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +8 -2
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +3 -1
- sglang/srt/hf_transformers_utils.py +130 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +582 -0
- sglang/srt/layers/logits_processor.py +65 -32
- sglang/srt/layers/radix_attention.py +41 -7
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
- sglang/srt/managers/{router → controller}/model_runner.py +262 -158
- sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
- sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
- sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
- sglang/srt/managers/detokenizer_manager.py +42 -46
- sglang/srt/managers/io_struct.py +22 -12
- sglang/srt/managers/tokenizer_manager.py +151 -87
- sglang/srt/model_config.py +83 -5
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +12 -15
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +26 -15
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +86 -19
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +282 -103
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +150 -95
- sglang/srt/openai_protocol.py +11 -2
- sglang/srt/server.py +124 -48
- sglang/srt/server_args.py +128 -48
- sglang/srt/utils.py +234 -67
- sglang/test/test_programs.py +65 -3
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +23 -4
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,6 @@
|
|
1
|
+
"""Meta data for requests and batches"""
|
2
|
+
|
3
|
+
import warnings
|
1
4
|
from dataclasses import dataclass
|
2
5
|
from enum import IntEnum, auto
|
3
6
|
from typing import List
|
@@ -5,9 +8,13 @@ from typing import List
|
|
5
8
|
import numpy as np
|
6
9
|
import torch
|
7
10
|
|
8
|
-
from sglang.srt.
|
11
|
+
from sglang.srt.constrained import RegexGuide
|
12
|
+
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
13
|
+
from sglang.srt.managers.controller.radix_cache import RadixCache
|
9
14
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
10
15
|
|
16
|
+
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
17
|
+
|
11
18
|
|
12
19
|
class ForwardMode(IntEnum):
|
13
20
|
PREFILL = auto()
|
@@ -15,33 +22,62 @@ class ForwardMode(IntEnum):
|
|
15
22
|
DECODE = auto()
|
16
23
|
|
17
24
|
|
18
|
-
class
|
19
|
-
|
20
|
-
|
21
|
-
|
25
|
+
class BaseFinishReason:
|
26
|
+
def __init__(self, is_error: bool = False):
|
27
|
+
self.is_error = is_error
|
28
|
+
|
29
|
+
def __str__(self):
|
30
|
+
raise NotImplementedError("Subclasses must implement this method")
|
31
|
+
|
32
|
+
|
33
|
+
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
34
|
+
def __init__(self, matched: int | List[int]):
|
35
|
+
super().__init__()
|
36
|
+
self.matched = matched
|
37
|
+
|
38
|
+
def __str__(self) -> str:
|
39
|
+
return f"FINISH_MATCHED_TOKEN: {self.matched}"
|
40
|
+
|
41
|
+
|
42
|
+
class FINISH_LENGTH(BaseFinishReason):
|
43
|
+
def __init__(self, length: int):
|
44
|
+
super().__init__()
|
45
|
+
self.length = length
|
46
|
+
|
47
|
+
def __str__(self) -> str:
|
48
|
+
return f"FINISH_LENGTH: {self.length}"
|
49
|
+
|
50
|
+
|
51
|
+
class FINISH_MATCHED_STR(BaseFinishReason):
|
52
|
+
def __init__(self, matched: str):
|
53
|
+
super().__init__()
|
54
|
+
self.matched = matched
|
55
|
+
|
56
|
+
def __str__(self) -> str:
|
57
|
+
return f"FINISH_MATCHED_STR: {self.matched}"
|
22
58
|
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
return "stop"
|
31
|
-
else:
|
32
|
-
return None
|
59
|
+
|
60
|
+
class FINISH_ABORT(BaseFinishReason):
|
61
|
+
def __init__(self):
|
62
|
+
super().__init__(is_error=True)
|
63
|
+
|
64
|
+
def __str__(self) -> str:
|
65
|
+
return "FINISH_ABORT"
|
33
66
|
|
34
67
|
|
35
68
|
class Req:
|
36
|
-
def __init__(self, rid,
|
69
|
+
def __init__(self, rid, origin_input_text, origin_input_ids):
|
37
70
|
self.rid = rid
|
38
|
-
self.
|
39
|
-
self.
|
40
|
-
self.
|
71
|
+
self.origin_input_text = origin_input_text
|
72
|
+
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
73
|
+
self.origin_input_ids = origin_input_ids
|
74
|
+
self.output_ids = [] # Each decode stage's output ids
|
75
|
+
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
41
76
|
|
42
|
-
#
|
43
|
-
|
44
|
-
self.
|
77
|
+
# For incremental decode
|
78
|
+
self.decoded_text = ""
|
79
|
+
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
80
|
+
self.read_offset = None
|
45
81
|
|
46
82
|
# The number of decoded tokens for token usage report. Note that
|
47
83
|
# this does not include the jump forward tokens.
|
@@ -57,12 +93,12 @@ class Req:
|
|
57
93
|
self.sampling_params = None
|
58
94
|
self.stream = False
|
59
95
|
|
60
|
-
# Check finish
|
61
96
|
self.tokenizer = None
|
62
|
-
self.finished = False
|
63
|
-
self.finish_reason = None
|
64
|
-
self.hit_stop_str = None
|
65
97
|
|
98
|
+
# Check finish
|
99
|
+
self.finished_reason = None
|
100
|
+
|
101
|
+
# Prefix info
|
66
102
|
self.extend_input_len = 0
|
67
103
|
self.prefix_indices = []
|
68
104
|
self.last_node = None
|
@@ -73,80 +109,81 @@ class Req:
|
|
73
109
|
self.top_logprobs_num = 0
|
74
110
|
self.normalized_prompt_logprob = None
|
75
111
|
self.prefill_token_logprobs = None
|
76
|
-
self.decode_token_logprobs = None
|
77
112
|
self.prefill_top_logprobs = None
|
78
|
-
self.
|
113
|
+
self.decode_token_logprobs = []
|
114
|
+
self.decode_top_logprobs = []
|
115
|
+
# The tokens is prefilled but need to be considered as decode tokens
|
116
|
+
# and should be updated for the decode logprobs
|
117
|
+
self.last_update_decode_tokens = 0
|
79
118
|
|
80
119
|
# Constrained decoding
|
81
|
-
self.regex_fsm = None
|
82
|
-
self.regex_fsm_state = 0
|
83
|
-
self.jump_forward_map = None
|
84
|
-
|
120
|
+
self.regex_fsm: RegexGuide = None
|
121
|
+
self.regex_fsm_state: int = 0
|
122
|
+
self.jump_forward_map: JumpForwardMap = None
|
123
|
+
|
124
|
+
# whether request reached finished condition
|
125
|
+
def finished(self) -> bool:
|
126
|
+
return self.finished_reason is not None
|
127
|
+
|
128
|
+
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
129
|
+
def init_detokenize_incrementally(self):
|
130
|
+
first_iter = self.surr_offset is None or self.read_offset is None
|
131
|
+
|
132
|
+
if first_iter:
|
133
|
+
self.read_offset = len(self.origin_input_ids_unpadded)
|
134
|
+
self.surr_offset = max(
|
135
|
+
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
|
136
|
+
)
|
85
137
|
|
86
|
-
|
87
|
-
|
138
|
+
all_ids = self.origin_input_ids_unpadded + self.output_ids
|
139
|
+
surr_ids = all_ids[self.surr_offset : self.read_offset]
|
140
|
+
read_ids = all_ids[self.surr_offset :]
|
88
141
|
|
89
|
-
|
90
|
-
old_output_str = self.tokenizer.decode(self.output_ids)
|
91
|
-
# FIXME: This logic does not really solve the problem of determining whether
|
92
|
-
# there should be a leading space.
|
93
|
-
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
|
94
|
-
first_token = (
|
95
|
-
first_token.decode() if isinstance(first_token, bytes) else first_token
|
96
|
-
)
|
97
|
-
if first_token.startswith("▁"):
|
98
|
-
old_output_str = " " + old_output_str
|
99
|
-
if self.input_text is None:
|
100
|
-
# TODO(lmzheng): This can be wrong. Check with Liangsheng.
|
101
|
-
self.input_text = self.tokenizer.decode(self.input_ids)
|
102
|
-
new_input_string = (
|
103
|
-
self.input_text
|
104
|
-
+ self.output_and_jump_forward_str
|
105
|
-
+ old_output_str
|
106
|
-
+ jump_forward_str
|
107
|
-
)
|
108
|
-
new_input_ids = self.tokenizer.encode(new_input_string)
|
109
|
-
if self.pixel_values is not None:
|
110
|
-
# NOTE: This is a hack because the old input_ids contains the image padding
|
111
|
-
jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str))
|
112
|
-
else:
|
113
|
-
jump_forward_tokens_len = (
|
114
|
-
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
|
115
|
-
)
|
142
|
+
return surr_ids, read_ids, len(all_ids)
|
116
143
|
|
117
|
-
|
118
|
-
|
119
|
-
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
120
|
-
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
|
144
|
+
def detokenize_incrementally(self, inplace: bool = True):
|
145
|
+
surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
|
121
146
|
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
self.sampling_params.
|
147
|
+
surr_text = self.tokenizer.decode(
|
148
|
+
surr_ids,
|
149
|
+
skip_special_tokens=self.sampling_params.skip_special_tokens,
|
150
|
+
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
|
126
151
|
)
|
127
|
-
|
128
|
-
|
129
|
-
self.
|
152
|
+
new_text = self.tokenizer.decode(
|
153
|
+
read_ids,
|
154
|
+
skip_special_tokens=self.sampling_params.skip_special_tokens,
|
155
|
+
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
|
130
156
|
)
|
131
157
|
|
132
|
-
|
133
|
-
|
158
|
+
if len(new_text) > len(surr_text) and not new_text.endswith("�"):
|
159
|
+
new_text = new_text[len(surr_text) :]
|
160
|
+
if inplace:
|
161
|
+
self.decoded_text += new_text
|
162
|
+
self.surr_offset = self.read_offset
|
163
|
+
self.read_offset = num_all_tokens
|
164
|
+
|
165
|
+
return True, new_text
|
166
|
+
|
167
|
+
return False, ""
|
168
|
+
|
169
|
+
def max_new_tokens(self):
|
170
|
+
return self.sampling_params.max_new_tokens
|
134
171
|
|
135
172
|
def check_finished(self):
|
136
|
-
if self.finished:
|
173
|
+
if self.finished():
|
137
174
|
return
|
138
175
|
|
139
176
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
140
|
-
self.
|
141
|
-
self.finish_reason = FinishReason.LENGTH
|
177
|
+
self.finished_reason = FINISH_LENGTH(len(self.output_ids))
|
142
178
|
return
|
143
179
|
|
144
180
|
if (
|
145
181
|
self.output_ids[-1] == self.tokenizer.eos_token_id
|
146
|
-
and self.sampling_params.ignore_eos
|
182
|
+
and not self.sampling_params.ignore_eos
|
147
183
|
):
|
148
|
-
self.
|
149
|
-
|
184
|
+
self.finished_reason = FINISH_MATCHED_TOKEN(
|
185
|
+
matched=self.tokenizer.eos_token_id
|
186
|
+
)
|
150
187
|
return
|
151
188
|
|
152
189
|
if len(self.sampling_params.stop_strs) > 0:
|
@@ -155,14 +192,62 @@ class Req:
|
|
155
192
|
)
|
156
193
|
|
157
194
|
for stop_str in self.sampling_params.stop_strs:
|
158
|
-
if stop_str in tail_str:
|
159
|
-
self.
|
160
|
-
self.finish_reason = FinishReason.STOP_STR
|
161
|
-
self.hit_stop_str = stop_str
|
195
|
+
if stop_str in tail_str or stop_str in self.decoded_text:
|
196
|
+
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
162
197
|
return
|
163
198
|
|
199
|
+
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
200
|
+
if self.origin_input_text is None:
|
201
|
+
# Recovering text can only use unpadded ids
|
202
|
+
self.origin_input_text = self.tokenizer.decode(
|
203
|
+
self.origin_input_ids_unpadded
|
204
|
+
)
|
205
|
+
|
206
|
+
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
207
|
+
all_ids = self.tokenizer.encode(all_text)
|
208
|
+
prompt_tokens = len(self.origin_input_ids_unpadded)
|
209
|
+
|
210
|
+
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
211
|
+
# TODO(lsyin): fix token fusion
|
212
|
+
warnings.warn(
|
213
|
+
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
214
|
+
)
|
215
|
+
return False
|
216
|
+
|
217
|
+
old_output_ids = self.output_ids
|
218
|
+
self.output_ids = all_ids[prompt_tokens:]
|
219
|
+
self.decoded_text = self.decoded_text + jump_forward_str
|
220
|
+
self.surr_offset = prompt_tokens
|
221
|
+
self.read_offset = len(all_ids)
|
222
|
+
|
223
|
+
# NOTE: A trick to reduce the surrouding tokens decoding overhead
|
224
|
+
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
|
225
|
+
surr_text_ = self.tokenizer.decode(
|
226
|
+
all_ids[self.read_offset - i : self.read_offset]
|
227
|
+
)
|
228
|
+
if not surr_text_.endswith("�"):
|
229
|
+
self.surr_offset = self.read_offset - i
|
230
|
+
break
|
231
|
+
|
232
|
+
self.regex_fsm_state = next_state
|
233
|
+
|
234
|
+
if self.return_logprob:
|
235
|
+
# For fast-forward part's logprobs
|
236
|
+
k = 0
|
237
|
+
for i, old_id in enumerate(old_output_ids):
|
238
|
+
if old_id == self.output_ids[i]:
|
239
|
+
k = k + 1
|
240
|
+
else:
|
241
|
+
break
|
242
|
+
self.decode_token_logprobs = self.decode_token_logprobs[:k]
|
243
|
+
self.decode_top_logprobs = self.decode_top_logprobs[:k]
|
244
|
+
self.logprob_start_len = prompt_tokens + k
|
245
|
+
self.last_update_decode_tokens = len(self.output_ids) - k
|
246
|
+
|
247
|
+
return True
|
248
|
+
|
164
249
|
def __repr__(self):
|
165
|
-
return f"rid(n={self.rid}, " f"input_ids={self.
|
250
|
+
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
166
251
|
|
167
252
|
|
168
253
|
@dataclass
|
@@ -218,6 +303,10 @@ class Batch:
|
|
218
303
|
def is_empty(self):
|
219
304
|
return len(self.reqs) == 0
|
220
305
|
|
306
|
+
# whether batch has at least 1 streaming request
|
307
|
+
def has_stream(self) -> bool:
|
308
|
+
return any(r.stream for r in self.reqs)
|
309
|
+
|
221
310
|
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
222
311
|
device = "cuda"
|
223
312
|
bs = len(self.reqs)
|
@@ -333,8 +422,12 @@ class Batch:
|
|
333
422
|
|
334
423
|
def retract_decode(self):
|
335
424
|
sorted_indices = [i for i in range(len(self.reqs))]
|
425
|
+
# TODO(lsyin): improve the priority of retraction
|
336
426
|
sorted_indices.sort(
|
337
|
-
key=lambda i: (
|
427
|
+
key=lambda i: (
|
428
|
+
len(self.reqs[i].output_ids),
|
429
|
+
-len(self.reqs[i].origin_input_ids),
|
430
|
+
),
|
338
431
|
reverse=True,
|
339
432
|
)
|
340
433
|
|
@@ -353,18 +446,22 @@ class Batch:
|
|
353
446
|
][last_uncached_pos : seq_lens_cpu[idx]]
|
354
447
|
self.token_to_kv_pool.dec_refs(token_indices)
|
355
448
|
|
449
|
+
# release the last node
|
356
450
|
self.tree_cache.dec_lock_ref(req.last_node)
|
451
|
+
|
357
452
|
req.prefix_indices = None
|
358
453
|
req.last_node = None
|
359
454
|
req.extend_input_len = 0
|
360
|
-
|
361
|
-
|
455
|
+
|
456
|
+
# For incremental logprobs
|
457
|
+
req.last_update_decode_tokens = 0
|
458
|
+
req.logprob_start_len = 10**9
|
362
459
|
|
363
460
|
self.filter_batch(sorted_indices)
|
364
461
|
|
365
462
|
return retracted_reqs
|
366
463
|
|
367
|
-
def check_for_jump_forward(self):
|
464
|
+
def check_for_jump_forward(self, model_runner):
|
368
465
|
jump_forward_reqs = []
|
369
466
|
filter_indices = [i for i in range(len(self.reqs))]
|
370
467
|
|
@@ -372,18 +469,54 @@ class Batch:
|
|
372
469
|
|
373
470
|
for i, req in enumerate(self.reqs):
|
374
471
|
if req.jump_forward_map is not None:
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
472
|
+
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
|
473
|
+
req.regex_fsm_state
|
474
|
+
)
|
475
|
+
if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
|
476
|
+
suffix_bytes = []
|
477
|
+
continuation_range = range(0x80, 0xC0)
|
478
|
+
cur_state = req.regex_fsm_state
|
479
|
+
while (
|
480
|
+
len(jump_forward_bytes)
|
481
|
+
and jump_forward_bytes[0][0] in continuation_range
|
482
|
+
):
|
483
|
+
# continuation bytes
|
484
|
+
byte_edge = jump_forward_bytes.pop(0)
|
485
|
+
suffix_bytes.append(byte_edge[0])
|
486
|
+
cur_state = byte_edge[1]
|
487
|
+
|
488
|
+
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
489
|
+
suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
|
490
|
+
|
491
|
+
# Current ids, for cache and revert
|
492
|
+
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
493
|
+
cur_output_ids = req.output_ids
|
494
|
+
|
495
|
+
req.output_ids.extend(suffix_ids)
|
496
|
+
decode_res, new_text = req.detokenize_incrementally(inplace=False)
|
497
|
+
if not decode_res:
|
498
|
+
req.output_ids = cur_output_ids
|
379
499
|
continue
|
380
500
|
|
381
|
-
|
382
|
-
|
501
|
+
(
|
502
|
+
jump_forward_str,
|
503
|
+
next_state,
|
504
|
+
) = req.jump_forward_map.jump_forward_symbol(cur_state)
|
505
|
+
|
506
|
+
# Make the incrementally decoded text part of jump_forward_str
|
507
|
+
# so that the UTF-8 will not corrupt
|
508
|
+
jump_forward_str = new_text + jump_forward_str
|
509
|
+
if not req.jump_forward_and_retokenize(
|
510
|
+
jump_forward_str, next_state
|
511
|
+
):
|
512
|
+
req.output_ids = cur_output_ids
|
513
|
+
continue
|
383
514
|
|
384
515
|
# insert the old request into tree_cache
|
516
|
+
if req_pool_indices_cpu is None:
|
517
|
+
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
385
518
|
self.tree_cache.cache_req(
|
386
|
-
token_ids=
|
519
|
+
token_ids=cur_all_ids,
|
387
520
|
last_uncached_pos=len(req.prefix_indices),
|
388
521
|
req_pool_idx=req_pool_indices_cpu[i],
|
389
522
|
)
|
@@ -391,8 +524,17 @@ class Batch:
|
|
391
524
|
# unlock the last node
|
392
525
|
self.tree_cache.dec_lock_ref(req.last_node)
|
393
526
|
|
394
|
-
#
|
395
|
-
req.
|
527
|
+
# re-applying image padding
|
528
|
+
if req.pixel_values is not None:
|
529
|
+
(
|
530
|
+
req.origin_input_ids,
|
531
|
+
req.image_offset,
|
532
|
+
) = model_runner.model.pad_input_ids(
|
533
|
+
req.origin_input_ids_unpadded,
|
534
|
+
req.pad_value,
|
535
|
+
req.pixel_values.shape,
|
536
|
+
req.image_size,
|
537
|
+
)
|
396
538
|
|
397
539
|
jump_forward_reqs.append(req)
|
398
540
|
filter_indices.remove(i)
|
@@ -515,7 +657,7 @@ class Batch:
|
|
515
657
|
if req.regex_fsm is not None:
|
516
658
|
allowed_mask.zero_()
|
517
659
|
allowed_mask[
|
518
|
-
req.regex_fsm.
|
660
|
+
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
519
661
|
] = 1
|
520
662
|
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
521
663
|
|
@@ -534,7 +676,7 @@ class Batch:
|
|
534
676
|
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
535
677
|
for i, req in enumerate(self.reqs):
|
536
678
|
if req.regex_fsm is not None:
|
537
|
-
req.regex_fsm_state = req.regex_fsm.
|
679
|
+
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
538
680
|
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
539
681
|
)
|
540
682
|
|
@@ -0,0 +1,191 @@
|
|
1
|
+
"""
|
2
|
+
A controller that manages multiple data parallel workers.
|
3
|
+
Each data parallel worker can manage multiple tensor parallel workers.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import asyncio
|
7
|
+
import logging
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
9
|
+
from enum import Enum, auto
|
10
|
+
from typing import Dict
|
11
|
+
|
12
|
+
import zmq
|
13
|
+
import zmq.asyncio
|
14
|
+
|
15
|
+
from sglang.global_config import global_config
|
16
|
+
from sglang.srt.managers.controller.dp_worker import (
|
17
|
+
DataParallelWorkerThread,
|
18
|
+
start_data_parallel_worker,
|
19
|
+
)
|
20
|
+
from sglang.srt.managers.io_struct import (
|
21
|
+
AbortReq,
|
22
|
+
FlushCacheReq,
|
23
|
+
TokenizedGenerateReqInput,
|
24
|
+
)
|
25
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
26
|
+
from sglang.utils import get_exception_traceback
|
27
|
+
|
28
|
+
logger = logging.getLogger("srt.controller")
|
29
|
+
|
30
|
+
|
31
|
+
class LoadBalanceMethod(Enum):
|
32
|
+
ROUND_ROBIN = auto()
|
33
|
+
SHORTEST_QUEUE = auto()
|
34
|
+
|
35
|
+
@classmethod
|
36
|
+
def from_str(cls, method: str):
|
37
|
+
method = method.upper()
|
38
|
+
try:
|
39
|
+
return cls[method]
|
40
|
+
except KeyError as exc:
|
41
|
+
raise ValueError(f"Invalid load balance method: {method}") from exc
|
42
|
+
|
43
|
+
|
44
|
+
class Controller:
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
load_balance_method: str,
|
48
|
+
server_args: ServerArgs,
|
49
|
+
port_args: PortArgs,
|
50
|
+
model_overide_args,
|
51
|
+
):
|
52
|
+
self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
|
53
|
+
self.server_args = server_args
|
54
|
+
self.port_args = port_args
|
55
|
+
|
56
|
+
if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
|
57
|
+
self.round_robin_counter = 0
|
58
|
+
|
59
|
+
self.dispatch_lookup = {
|
60
|
+
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
61
|
+
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
62
|
+
}
|
63
|
+
self.dispatching = self.dispatch_lookup[self.load_balance_method]
|
64
|
+
|
65
|
+
# Init communication
|
66
|
+
context = zmq.asyncio.Context()
|
67
|
+
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
68
|
+
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
69
|
+
|
70
|
+
# Init status
|
71
|
+
self.recv_reqs = []
|
72
|
+
|
73
|
+
# Start data parallel workers
|
74
|
+
self.workers: Dict[int, DataParallelWorkerThread] = {}
|
75
|
+
tp_size = server_args.tp_size
|
76
|
+
|
77
|
+
def start_dp_worker(i):
|
78
|
+
try:
|
79
|
+
gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
|
80
|
+
worker_thread = start_data_parallel_worker(
|
81
|
+
server_args, port_args, model_overide_args, gpu_ids, i
|
82
|
+
)
|
83
|
+
self.workers[i] = worker_thread
|
84
|
+
except Exception:
|
85
|
+
logger.error(
|
86
|
+
f"Failed to start local worker {i}\n{get_exception_traceback()}"
|
87
|
+
)
|
88
|
+
|
89
|
+
for i in range(server_args.dp_size):
|
90
|
+
start_dp_worker(i)
|
91
|
+
|
92
|
+
# Parallel launch is slower, probably due to the disk bandwidth limitations.
|
93
|
+
# with ThreadPoolExecutor(server_args.dp_size) as executor:
|
94
|
+
# executor.map(start_dp_worker, range(server_args.dp_size))
|
95
|
+
|
96
|
+
def have_any_live_worker(self):
|
97
|
+
return any(worker_thread.liveness for worker_thread in self.workers.values())
|
98
|
+
|
99
|
+
def put_req_to_worker(self, worker_id, req):
|
100
|
+
self.workers[worker_id].request_queue.put(req)
|
101
|
+
|
102
|
+
async def round_robin_scheduler(self, input_requests):
|
103
|
+
available_workers = list(self.workers.keys())
|
104
|
+
for r in input_requests:
|
105
|
+
self.put_req_to_worker(available_workers[self.round_robin_counter], r)
|
106
|
+
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
107
|
+
available_workers
|
108
|
+
)
|
109
|
+
return
|
110
|
+
|
111
|
+
async def shortest_queue_scheduler(self, input_requests):
|
112
|
+
for r in input_requests:
|
113
|
+
worker = min(
|
114
|
+
self.workers, key=lambda w: self.workers[w].request_queue.qsize()
|
115
|
+
)
|
116
|
+
self.put_req_to_worker(worker, r)
|
117
|
+
return
|
118
|
+
|
119
|
+
async def remove_dead_workers(self):
|
120
|
+
for i in list(self.workers.keys()):
|
121
|
+
worker_thread = self.workers[i]
|
122
|
+
if not worker_thread.liveness:
|
123
|
+
worker_thread.join()
|
124
|
+
# move unsuccessful requests back to the queue
|
125
|
+
while not worker_thread.request_queue.empty():
|
126
|
+
self.recv_reqs.append(worker_thread.request_queue.get())
|
127
|
+
del self.workers[i]
|
128
|
+
logger.info(f"Stale worker {i} removed")
|
129
|
+
|
130
|
+
async def loop_for_forward(self):
|
131
|
+
while True:
|
132
|
+
await self.remove_dead_workers()
|
133
|
+
|
134
|
+
if self.have_any_live_worker():
|
135
|
+
next_step_input = list(self.recv_reqs)
|
136
|
+
self.recv_reqs = []
|
137
|
+
if next_step_input:
|
138
|
+
await self.dispatching(next_step_input)
|
139
|
+
# else:
|
140
|
+
# logger.error("There is no live worker.")
|
141
|
+
|
142
|
+
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
143
|
+
|
144
|
+
async def loop_for_recv_requests(self):
|
145
|
+
while True:
|
146
|
+
recv_req = await self.recv_from_tokenizer.recv_pyobj()
|
147
|
+
if isinstance(recv_req, FlushCacheReq):
|
148
|
+
# TODO(lsyin): apply more specific flushCacheReq
|
149
|
+
for worker_thread in self.workers.values():
|
150
|
+
worker_thread.request_queue.put(recv_req)
|
151
|
+
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
152
|
+
self.recv_reqs.append(recv_req)
|
153
|
+
elif isinstance(recv_req, AbortReq):
|
154
|
+
in_queue = False
|
155
|
+
for i, req in enumerate(self.recv_reqs):
|
156
|
+
if req.rid == recv_req.rid:
|
157
|
+
self.recv_reqs[i] = recv_req
|
158
|
+
in_queue = True
|
159
|
+
break
|
160
|
+
if not in_queue:
|
161
|
+
# Send abort req to all TP groups
|
162
|
+
for worker in list(self.workers.keys()):
|
163
|
+
self.put_req_to_worker(worker, recv_req)
|
164
|
+
else:
|
165
|
+
logger.error(f"Invalid object: {recv_req}")
|
166
|
+
|
167
|
+
|
168
|
+
def start_controller_process(
|
169
|
+
server_args: ServerArgs,
|
170
|
+
port_args: PortArgs,
|
171
|
+
pipe_writer,
|
172
|
+
model_overide_args=None,
|
173
|
+
):
|
174
|
+
logging.basicConfig(
|
175
|
+
level=getattr(logging, server_args.log_level.upper()),
|
176
|
+
format="%(message)s",
|
177
|
+
)
|
178
|
+
|
179
|
+
try:
|
180
|
+
controller = Controller(
|
181
|
+
server_args.load_balance_method, server_args, port_args, model_overide_args
|
182
|
+
)
|
183
|
+
except Exception:
|
184
|
+
pipe_writer.send(get_exception_traceback())
|
185
|
+
raise
|
186
|
+
|
187
|
+
pipe_writer.send("init ok")
|
188
|
+
loop = asyncio.get_event_loop()
|
189
|
+
asyncio.set_event_loop(loop)
|
190
|
+
loop.create_task(controller.loop_for_recv_requests())
|
191
|
+
loop.run_until_complete(controller.loop_for_forward())
|