sglang 0.3.4.post2__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_latency.py +3 -3
- sglang/bench_server_latency.py +2 -3
- sglang/bench_serving.py +92 -0
- sglang/global_config.py +9 -3
- sglang/lang/chat_template.py +50 -25
- sglang/lang/interpreter.py +9 -1
- sglang/lang/ir.py +11 -2
- sglang/launch_server.py +1 -1
- sglang/srt/configs/model_config.py +51 -13
- sglang/srt/constrained/__init__.py +18 -0
- sglang/srt/constrained/bnf_cache.py +61 -0
- sglang/srt/constrained/grammar.py +190 -0
- sglang/srt/hf_transformers_utils.py +6 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
- sglang/srt/layers/fused_moe/fused_moe.py +4 -3
- sglang/srt/layers/fused_moe/layer.py +28 -0
- sglang/srt/layers/quantization/base_config.py +16 -1
- sglang/srt/layers/vocab_parallel_embedding.py +486 -0
- sglang/srt/managers/data_parallel_controller.py +7 -6
- sglang/srt/managers/detokenizer_manager.py +9 -11
- sglang/srt/managers/image_processor.py +4 -3
- sglang/srt/managers/io_struct.py +70 -78
- sglang/srt/managers/schedule_batch.py +33 -49
- sglang/srt/managers/schedule_policy.py +24 -13
- sglang/srt/managers/scheduler.py +137 -80
- sglang/srt/managers/tokenizer_manager.py +224 -336
- sglang/srt/managers/tp_worker.py +5 -5
- sglang/srt/mem_cache/flush_cache.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/model_runner.py +8 -17
- sglang/srt/models/baichuan.py +4 -4
- sglang/srt/models/chatglm.py +4 -4
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +5 -5
- sglang/srt/models/deepseek.py +4 -4
- sglang/srt/models/deepseek_v2.py +4 -4
- sglang/srt/models/exaone.py +4 -4
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt2.py +287 -0
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +4 -4
- sglang/srt/models/internlm2.py +4 -4
- sglang/srt/models/llama.py +15 -7
- sglang/srt/models/llama_embedding.py +2 -10
- sglang/srt/models/llama_reward.py +5 -0
- sglang/srt/models/minicpm.py +4 -4
- sglang/srt/models/minicpm3.py +4 -4
- sglang/srt/models/mixtral.py +7 -5
- sglang/srt/models/mixtral_quant.py +4 -4
- sglang/srt/models/mllama.py +5 -5
- sglang/srt/models/olmo.py +4 -4
- sglang/srt/models/olmoe.py +4 -4
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +4 -4
- sglang/srt/models/qwen2_moe.py +4 -4
- sglang/srt/models/qwen2_vl.py +4 -8
- sglang/srt/models/stablelm.py +4 -4
- sglang/srt/models/torch_native_llama.py +4 -4
- sglang/srt/models/xverse.py +4 -4
- sglang/srt/models/xverse_moe.py +4 -4
- sglang/srt/openai_api/adapter.py +52 -66
- sglang/srt/sampling/sampling_batch_info.py +7 -13
- sglang/srt/server.py +31 -35
- sglang/srt/server_args.py +34 -5
- sglang/srt/utils.py +40 -56
- sglang/test/runners.py +2 -1
- sglang/test/test_utils.py +73 -25
- sglang/utils.py +62 -1
- sglang/version.py +1 -1
- sglang-0.3.5.dist-info/METADATA +344 -0
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/RECORD +77 -73
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
- sglang-0.3.4.post2.dist-info/METADATA +0 -899
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -16,10 +16,12 @@ limitations under the License.
|
|
16
16
|
"""TokenizerManager is a process that tokenizes the text."""
|
17
17
|
|
18
18
|
import asyncio
|
19
|
+
import copy
|
19
20
|
import dataclasses
|
20
|
-
import json
|
21
21
|
import logging
|
22
22
|
import os
|
23
|
+
import signal
|
24
|
+
import sys
|
23
25
|
from typing import Dict, List, Optional, Tuple, Union
|
24
26
|
|
25
27
|
import fastapi
|
@@ -28,12 +30,8 @@ import zmq
|
|
28
30
|
import zmq.asyncio
|
29
31
|
from fastapi import BackgroundTasks
|
30
32
|
|
31
|
-
from sglang.srt.
|
32
|
-
|
33
|
-
get_context_length,
|
34
|
-
get_processor,
|
35
|
-
get_tokenizer,
|
36
|
-
)
|
33
|
+
from sglang.srt.configs.model_config import ModelConfig
|
34
|
+
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
37
35
|
from sglang.srt.managers.image_processor import (
|
38
36
|
get_dummy_image_processor,
|
39
37
|
get_image_processor,
|
@@ -49,16 +47,14 @@ from sglang.srt.managers.io_struct import (
|
|
49
47
|
GetMemPoolSizeReq,
|
50
48
|
GetMemPoolSizeReqOutput,
|
51
49
|
ProfileReq,
|
52
|
-
RewardReqInput,
|
53
50
|
TokenizedEmbeddingReqInput,
|
54
51
|
TokenizedGenerateReqInput,
|
55
|
-
TokenizedRewardReqInput,
|
56
52
|
UpdateWeightReqInput,
|
57
53
|
UpdateWeightReqOutput,
|
58
54
|
)
|
59
55
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
60
56
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
61
|
-
from sglang.srt.utils import
|
57
|
+
from sglang.srt.utils import get_zmq_socket, kill_child_process
|
62
58
|
|
63
59
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
64
60
|
|
@@ -82,30 +78,32 @@ class TokenizerManager:
|
|
82
78
|
server_args: ServerArgs,
|
83
79
|
port_args: PortArgs,
|
84
80
|
):
|
81
|
+
# Parse args
|
85
82
|
self.server_args = server_args
|
86
83
|
|
87
84
|
# Init inter-process communication
|
88
85
|
context = zmq.asyncio.Context(2)
|
89
|
-
self.recv_from_detokenizer =
|
90
|
-
|
91
|
-
|
92
|
-
self.send_to_scheduler =
|
93
|
-
|
86
|
+
self.recv_from_detokenizer = get_zmq_socket(
|
87
|
+
context, zmq.PULL, port_args.tokenizer_ipc_name
|
88
|
+
)
|
89
|
+
self.send_to_scheduler = get_zmq_socket(
|
90
|
+
context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
91
|
+
)
|
94
92
|
|
95
93
|
# Read model args
|
96
94
|
self.model_path = server_args.model_path
|
97
95
|
self.served_model_name = server_args.served_model_name
|
98
|
-
self.
|
99
|
-
|
96
|
+
self.model_config = ModelConfig(
|
97
|
+
server_args.model_path,
|
100
98
|
trust_remote_code=server_args.trust_remote_code,
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
self.hf_config.architectures, self.server_args.is_embedding
|
105
|
-
)
|
106
|
-
self.context_len = server_args.context_length or get_context_length(
|
107
|
-
self.hf_config
|
99
|
+
context_length=server_args.context_length,
|
100
|
+
model_override_args=server_args.json_model_override_args,
|
101
|
+
is_embedding=server_args.is_embedding,
|
108
102
|
)
|
103
|
+
|
104
|
+
self.is_generation = self.model_config.is_generation
|
105
|
+
self.context_len = self.model_config.context_len
|
106
|
+
|
109
107
|
# Create image processor placeholder
|
110
108
|
self.image_processor = get_dummy_image_processor()
|
111
109
|
|
@@ -113,7 +111,7 @@ class TokenizerManager:
|
|
113
111
|
if server_args.skip_tokenizer_init:
|
114
112
|
self.tokenizer = self.processor = None
|
115
113
|
else:
|
116
|
-
if
|
114
|
+
if self.model_config.is_multimodal:
|
117
115
|
self.processor = get_processor(
|
118
116
|
server_args.tokenizer_path,
|
119
117
|
tokenizer_mode=server_args.tokenizer_mode,
|
@@ -124,7 +122,7 @@ class TokenizerManager:
|
|
124
122
|
|
125
123
|
# We want to parallelize the image pre-processing so we create an executor for it
|
126
124
|
self.image_processor = get_image_processor(
|
127
|
-
self.hf_config, server_args, self.processor
|
125
|
+
self.model_config.hf_config, server_args, self.processor
|
128
126
|
)
|
129
127
|
else:
|
130
128
|
self.tokenizer = get_tokenizer(
|
@@ -141,9 +139,12 @@ class TokenizerManager:
|
|
141
139
|
self.model_update_lock = asyncio.Lock()
|
142
140
|
self.model_update_result = None
|
143
141
|
|
142
|
+
# Others
|
143
|
+
self.gracefully_exit = False
|
144
|
+
|
144
145
|
async def generate_request(
|
145
146
|
self,
|
146
|
-
obj: Union[GenerateReqInput, EmbeddingReqInput
|
147
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
147
148
|
request: Optional[fastapi.Request] = None,
|
148
149
|
):
|
149
150
|
if self.to_create_loop:
|
@@ -154,133 +155,58 @@ class TokenizerManager:
|
|
154
155
|
|
155
156
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
156
157
|
raise ValueError(
|
157
|
-
"This model does not appear to be an embedding model by default.
|
158
|
+
"This model does not appear to be an embedding model by default. "
|
159
|
+
"Please add `--is-embedding` when launching the server or try another model."
|
158
160
|
)
|
159
161
|
|
160
|
-
obj.
|
162
|
+
obj.normalize_batch_and_arguments()
|
161
163
|
is_single = obj.is_single
|
162
164
|
if is_single:
|
163
|
-
|
165
|
+
tokenized_obj = await self._tokenize_one_request(obj)
|
166
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
167
|
+
async for response in self._wait_one_response(obj, request):
|
164
168
|
yield response
|
165
169
|
else:
|
166
170
|
async for response in self._handle_batch_request(obj, request):
|
167
171
|
yield response
|
168
172
|
|
169
|
-
async def
|
173
|
+
async def _tokenize_one_request(
|
170
174
|
self,
|
171
|
-
obj: Union[GenerateReqInput, EmbeddingReqInput
|
172
|
-
index: Optional[int] = None,
|
173
|
-
input_id_index: Optional[int] = None,
|
174
|
-
is_cache_for_prefill: Optional[bool] = False,
|
175
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
175
176
|
):
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
conv, tokenize=False
|
184
|
-
)
|
185
|
-
input_ids = self.tokenizer.encode(input_text)
|
186
|
-
elif obj.input_ids is None:
|
187
|
-
input_text = obj.text
|
188
|
-
input_ids = self.tokenizer.encode(input_text)
|
189
|
-
else:
|
190
|
-
input_text = obj.text if obj.text is not None else None
|
191
|
-
input_ids = obj.input_ids
|
192
|
-
|
193
|
-
sampling_params = self._get_sampling_params(obj.sampling_params)
|
194
|
-
if self.is_generation:
|
195
|
-
image_inputs = await self.image_processor.process_images_async(
|
196
|
-
obj.image_data, input_text or input_ids, obj
|
197
|
-
)
|
198
|
-
if image_inputs and "input_ids" in image_inputs:
|
199
|
-
input_ids = image_inputs["input_ids"]
|
200
|
-
return_logprob = obj.return_logprob
|
201
|
-
logprob_start_len = obj.logprob_start_len
|
202
|
-
top_logprobs_num = obj.top_logprobs_num
|
203
|
-
else:
|
204
|
-
rid = obj.rid[index]
|
205
|
-
if hasattr(obj, "conv"):
|
206
|
-
# reward model
|
207
|
-
conv = obj.conv[index]
|
208
|
-
input_text = self.tokenizer.apply_chat_template(
|
209
|
-
conv, tokenize=False
|
210
|
-
)
|
211
|
-
input_ids = self.tokenizer.encode(input_text)
|
212
|
-
elif obj.input_ids is None:
|
213
|
-
input_text = obj.text[input_id_index]
|
214
|
-
input_ids = self.tokenizer.encode(input_text)
|
215
|
-
else:
|
216
|
-
input_text = (
|
217
|
-
obj.text[input_id_index] if obj.text is not None else None
|
218
|
-
)
|
219
|
-
input_ids = obj.input_ids[input_id_index]
|
220
|
-
|
221
|
-
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
222
|
-
if self.is_generation:
|
223
|
-
image_inputs = await self.image_processor.process_images_async(
|
224
|
-
obj.image_data[index], input_text or input_ids, obj
|
225
|
-
)
|
226
|
-
if image_inputs and "input_ids" in image_inputs:
|
227
|
-
input_ids = image_inputs["input_ids"]
|
228
|
-
return_logprob = obj.return_logprob[index]
|
229
|
-
logprob_start_len = obj.logprob_start_len[index]
|
230
|
-
top_logprobs_num = obj.top_logprobs_num[index]
|
231
|
-
|
232
|
-
self._validate_input_length(input_ids)
|
233
|
-
|
234
|
-
else: # A prefill request to cache the common prompt for parallel sampling
|
235
|
-
assert self.is_generation
|
236
|
-
if obj.text is not None:
|
237
|
-
if isinstance(obj.text, list):
|
238
|
-
input_text = obj.text[input_id_index]
|
239
|
-
rid = obj.rid[index]
|
240
|
-
else:
|
241
|
-
input_text = obj.text
|
242
|
-
rid = obj.rid[0]
|
243
|
-
if self.tokenizer is not None:
|
244
|
-
input_ids = self.tokenizer.encode(input_text)
|
245
|
-
else:
|
246
|
-
assert obj.input_ids is not None
|
247
|
-
input_ids = obj.input_ids
|
248
|
-
if isinstance(obj.input_ids, list) and isinstance(
|
249
|
-
obj.input_ids[0], list
|
250
|
-
):
|
251
|
-
# when obj["input_ids"] is List[List[int]]
|
252
|
-
input_ids = obj.input_ids[input_id_index]
|
253
|
-
rid = obj.rid[index]
|
254
|
-
else:
|
255
|
-
input_ids = obj.input_ids
|
256
|
-
rid = obj.rid[0]
|
257
|
-
else:
|
258
|
-
input_text = None
|
259
|
-
if isinstance(obj.input_ids, list) and isinstance(
|
260
|
-
obj.input_ids[0], list
|
261
|
-
):
|
262
|
-
# when obj["input_ids"] is List[List[int]]
|
263
|
-
input_ids = obj.input_ids[input_id_index]
|
264
|
-
rid = obj.rid[index]
|
265
|
-
else:
|
266
|
-
input_ids = obj.input_ids
|
267
|
-
rid = obj.rid[0]
|
177
|
+
"""Tokenize one request."""
|
178
|
+
# Tokenize
|
179
|
+
input_text = obj.text
|
180
|
+
if obj.input_ids is None:
|
181
|
+
input_ids = self.tokenizer.encode(input_text)
|
182
|
+
else:
|
183
|
+
input_ids = obj.input_ids
|
268
184
|
|
269
|
-
|
270
|
-
sampling_params.max_new_tokens = 0
|
185
|
+
if self.is_generation:
|
271
186
|
image_inputs = await self.image_processor.process_images_async(
|
272
|
-
obj.image_data
|
187
|
+
obj.image_data, input_text or input_ids, obj
|
273
188
|
)
|
274
189
|
if image_inputs and "input_ids" in image_inputs:
|
275
190
|
input_ids = image_inputs["input_ids"]
|
276
|
-
return_logprob = obj.return_logprob
|
277
|
-
logprob_start_len = obj.logprob_start_len
|
278
|
-
top_logprobs_num = obj.top_logprobs_num
|
191
|
+
return_logprob = obj.return_logprob
|
192
|
+
logprob_start_len = obj.logprob_start_len
|
193
|
+
top_logprobs_num = obj.top_logprobs_num
|
279
194
|
|
280
|
-
|
281
|
-
|
195
|
+
if len(input_ids) >= self.context_len:
|
196
|
+
raise ValueError(
|
197
|
+
f"The input ({len(input_ids)} tokens) is longer than the "
|
198
|
+
f"model's context length ({self.context_len} tokens)."
|
199
|
+
)
|
200
|
+
|
201
|
+
# Parse sampling parameters
|
202
|
+
sampling_params = SamplingParams(**obj.sampling_params)
|
203
|
+
sampling_params.normalize(self.tokenizer)
|
204
|
+
sampling_params.verify()
|
205
|
+
|
206
|
+
# Build return object
|
207
|
+
if isinstance(obj, GenerateReqInput):
|
282
208
|
tokenized_obj = TokenizedGenerateReqInput(
|
283
|
-
rid,
|
209
|
+
obj.rid,
|
284
210
|
input_text,
|
285
211
|
input_ids,
|
286
212
|
image_inputs,
|
@@ -289,230 +215,125 @@ class TokenizerManager:
|
|
289
215
|
logprob_start_len,
|
290
216
|
top_logprobs_num,
|
291
217
|
obj.stream,
|
292
|
-
|
293
|
-
obj.lora_path[input_id_index]
|
294
|
-
if isinstance(obj.lora_path, list)
|
295
|
-
else obj.lora_path
|
296
|
-
),
|
218
|
+
obj.lora_path
|
297
219
|
)
|
298
220
|
elif isinstance(obj, EmbeddingReqInput):
|
299
221
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
300
|
-
rid,
|
301
|
-
input_text,
|
302
|
-
input_ids,
|
303
|
-
sampling_params,
|
304
|
-
)
|
305
|
-
else:
|
306
|
-
assert isinstance(obj, RewardReqInput)
|
307
|
-
tokenized_obj = TokenizedRewardReqInput(
|
308
|
-
rid,
|
222
|
+
obj.rid,
|
309
223
|
input_text,
|
310
224
|
input_ids,
|
311
225
|
sampling_params,
|
312
226
|
)
|
313
227
|
|
314
|
-
|
315
|
-
return rid, input_ids
|
228
|
+
return tokenized_obj
|
316
229
|
|
317
|
-
async def
|
230
|
+
async def _wait_one_response(
|
318
231
|
self,
|
319
|
-
obj: Union[GenerateReqInput, EmbeddingReqInput
|
232
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
320
233
|
request: Optional[fastapi.Request] = None,
|
321
|
-
index: Optional[int] = None,
|
322
|
-
input_id_index: Optional[int] = None,
|
323
|
-
is_cache_for_prefill: Optional[bool] = False,
|
324
234
|
):
|
325
|
-
|
326
|
-
obj,
|
327
|
-
index,
|
328
|
-
input_id_index=input_id_index,
|
329
|
-
is_cache_for_prefill=is_cache_for_prefill,
|
330
|
-
)
|
331
|
-
|
332
|
-
# Recv results
|
235
|
+
"""Wait for the response of one request."""
|
333
236
|
event = asyncio.Event()
|
334
237
|
state = ReqState([], False, event)
|
335
|
-
self.rid_to_state[rid] = state
|
336
|
-
|
337
|
-
if not is_cache_for_prefill:
|
338
|
-
async for response in self._wait_for_response(state, obj, rid, request):
|
339
|
-
yield response
|
340
|
-
else:
|
341
|
-
assert self.is_generation
|
342
|
-
await self._wait_for_cache_prefill_response(state, obj, rid, request)
|
343
|
-
yield input_ids
|
344
|
-
|
345
|
-
async def _handle_batch_request(
|
346
|
-
self,
|
347
|
-
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
348
|
-
request: Optional[fastapi.Request] = None,
|
349
|
-
):
|
350
|
-
batch_size = obj.batch_size
|
351
|
-
if self.is_generation:
|
352
|
-
parallel_sample_num = obj.parallel_sample_num
|
353
|
-
|
354
|
-
if parallel_sample_num != 1:
|
355
|
-
# Send prefill requests to cache the common prefix
|
356
|
-
parallel_sample_num += 1
|
357
|
-
input_id_result = [] if obj.input_ids is None else None
|
358
|
-
for i in range(batch_size):
|
359
|
-
async for input_id in self._handle_single_request(
|
360
|
-
obj,
|
361
|
-
request,
|
362
|
-
index=i,
|
363
|
-
input_id_index=i,
|
364
|
-
is_cache_for_prefill=True,
|
365
|
-
):
|
366
|
-
if input_id_result is not None:
|
367
|
-
input_id_result.append(input_id)
|
368
|
-
if input_id_result is not None:
|
369
|
-
obj.input_ids = input_id_result
|
370
|
-
else:
|
371
|
-
parallel_sample_num = 1
|
372
|
-
|
373
|
-
# First send out all requests
|
374
|
-
generators = []
|
375
|
-
for i in range(batch_size):
|
376
|
-
for j in range(parallel_sample_num):
|
377
|
-
if j == 0 and parallel_sample_num != 1:
|
378
|
-
continue
|
379
|
-
index = i * parallel_sample_num + j
|
380
|
-
if parallel_sample_num != 1:
|
381
|
-
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
382
|
-
index += batch_size - 1 - i
|
383
|
-
|
384
|
-
rid, _ = await self._send_single_request(
|
385
|
-
obj, index, input_id_index=i, is_cache_for_prefill=False
|
386
|
-
)
|
387
|
-
|
388
|
-
event = asyncio.Event()
|
389
|
-
state = ReqState([], False, event)
|
390
|
-
self.rid_to_state[rid] = state
|
391
|
-
|
392
|
-
generators.append(
|
393
|
-
self._wait_for_response(
|
394
|
-
state,
|
395
|
-
obj,
|
396
|
-
rid,
|
397
|
-
request,
|
398
|
-
index=index,
|
399
|
-
response_index=len(generators),
|
400
|
-
)
|
401
|
-
)
|
402
|
-
|
403
|
-
# Then process the responses based on streaming option
|
404
|
-
is_stream = hasattr(obj, "stream") and obj.stream
|
405
|
-
|
406
|
-
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
407
|
-
output_list = [None] * len(tasks)
|
238
|
+
self.rid_to_state[obj.rid] = state
|
408
239
|
|
409
|
-
# Fetch results
|
410
|
-
while tasks:
|
411
|
-
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
412
|
-
|
413
|
-
for task in done:
|
414
|
-
cur_index = tasks.index(task)
|
415
|
-
|
416
|
-
try:
|
417
|
-
result = task.result()
|
418
|
-
|
419
|
-
if is_stream:
|
420
|
-
yield result
|
421
|
-
else:
|
422
|
-
output_list[result["index"]] = result
|
423
|
-
|
424
|
-
tasks[cur_index] = asyncio.create_task(
|
425
|
-
generators[cur_index].__anext__()
|
426
|
-
)
|
427
|
-
except StopAsyncIteration:
|
428
|
-
del generators[cur_index]
|
429
|
-
del tasks[cur_index]
|
430
|
-
|
431
|
-
if not is_stream:
|
432
|
-
yield output_list
|
433
|
-
|
434
|
-
def _validate_input_length(self, input_ids: List[int]):
|
435
|
-
if len(input_ids) >= self.context_len:
|
436
|
-
raise ValueError(
|
437
|
-
f"The input ({len(input_ids)} tokens) is longer than the "
|
438
|
-
f"model's context length ({self.context_len} tokens)."
|
439
|
-
)
|
440
|
-
|
441
|
-
def _get_sampling_params(self, sampling_params_data: dict):
|
442
|
-
sampling_params = SamplingParams(**sampling_params_data)
|
443
|
-
if sampling_params.max_new_tokens != 0:
|
444
|
-
sampling_params.normalize(self.tokenizer)
|
445
|
-
sampling_params.verify()
|
446
|
-
return sampling_params
|
447
|
-
|
448
|
-
async def _wait_for_response(
|
449
|
-
self,
|
450
|
-
state: ReqState,
|
451
|
-
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
452
|
-
rid: str,
|
453
|
-
request: Optional[fastapi.Request] = None,
|
454
|
-
index: Optional[int] = None,
|
455
|
-
response_index: int = 0,
|
456
|
-
):
|
457
240
|
while True:
|
458
241
|
try:
|
459
242
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
460
243
|
except asyncio.TimeoutError:
|
461
244
|
if request is not None and await request.is_disconnected():
|
462
|
-
|
463
|
-
|
464
|
-
raise ValueError(f"Abort request {rid}")
|
245
|
+
self.abort_request(obj.rid)
|
246
|
+
raise ValueError(f"Abort request {obj.rid}")
|
465
247
|
continue
|
466
248
|
|
467
|
-
if
|
249
|
+
if isinstance(obj, GenerateReqInput):
|
468
250
|
out = self.convert_logprob_style(
|
469
251
|
state.out_list[-1],
|
470
|
-
obj.return_logprob
|
471
|
-
|
472
|
-
obj.top_logprobs_num
|
473
|
-
if index is None
|
474
|
-
else obj.top_logprobs_num[index]
|
475
|
-
),
|
252
|
+
obj.return_logprob,
|
253
|
+
obj.top_logprobs_num,
|
476
254
|
obj.return_text_in_logprobs,
|
477
255
|
)
|
478
|
-
else: # isinstance(obj, (EmbeddingReqInput,
|
256
|
+
else: # isinstance(obj, (EmbeddingReqInput,))
|
479
257
|
out = state.out_list[-1]
|
480
258
|
|
481
|
-
out["index"] = response_index
|
482
|
-
|
483
|
-
# Log requests
|
484
|
-
if self.server_args.log_requests and state.finished:
|
485
|
-
logger.info(f"in={obj}, out={out}")
|
486
|
-
|
487
259
|
state.out_list = []
|
488
260
|
if state.finished:
|
489
|
-
|
261
|
+
if self.server_args.log_requests:
|
262
|
+
# Log requests
|
263
|
+
logger.info(f"in={obj}, out={out}")
|
264
|
+
del self.rid_to_state[obj.rid]
|
490
265
|
yield out
|
491
266
|
break
|
492
267
|
|
493
268
|
state.event.clear()
|
494
269
|
yield out
|
495
270
|
|
496
|
-
async def
|
271
|
+
async def _handle_batch_request(
|
497
272
|
self,
|
498
|
-
|
499
|
-
obj: GenerateReqInput,
|
500
|
-
rid: str,
|
273
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
501
274
|
request: Optional[fastapi.Request] = None,
|
502
275
|
):
|
503
|
-
|
504
|
-
try:
|
505
|
-
await asyncio.wait_for(state.event.wait(), timeout=4)
|
506
|
-
break
|
507
|
-
except asyncio.TimeoutError:
|
508
|
-
if request is not None and await request.is_disconnected():
|
509
|
-
for rid in obj.rid:
|
510
|
-
self.abort_request(rid)
|
511
|
-
raise ValueError(f"Abort request {rid}")
|
512
|
-
continue
|
276
|
+
batch_size = obj.batch_size
|
513
277
|
|
514
|
-
|
515
|
-
|
278
|
+
generators = []
|
279
|
+
rids = []
|
280
|
+
if getattr(obj, "parallel_sample_num", 1) == 1:
|
281
|
+
# Send all requests
|
282
|
+
for i in range(batch_size):
|
283
|
+
tmp_obj = obj[i]
|
284
|
+
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
285
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
286
|
+
generators.append(self._wait_one_response(tmp_obj, request))
|
287
|
+
rids.append(tmp_obj.rid)
|
288
|
+
else:
|
289
|
+
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
290
|
+
|
291
|
+
# Tokenize all requests
|
292
|
+
objs = [obj[i] for i in range(batch_size)]
|
293
|
+
tokenized_objs = await asyncio.gather(*(self._tokenize_one_request(obj) for obj in objs))
|
294
|
+
|
295
|
+
# Cache the common prefix for parallel sampling
|
296
|
+
for i in range(batch_size):
|
297
|
+
tmp_obj = copy.copy(objs[i])
|
298
|
+
tokenized_obj = copy.copy(tokenized_objs[i])
|
299
|
+
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
300
|
+
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
|
301
|
+
tokenized_obj.sampling_params.max_new_tokens = 0
|
302
|
+
tokenized_obj.stream = False
|
303
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
304
|
+
await self._wait_one_response(tmp_obj, request).__anext__()
|
305
|
+
|
306
|
+
# Expand requests, assign new rids for them, and send them
|
307
|
+
for i in range(batch_size):
|
308
|
+
for _ in range(obj.parallel_sample_num):
|
309
|
+
tmp_obj = copy.copy(objs[i])
|
310
|
+
tokenized_obj = copy.copy(tokenized_objs[i])
|
311
|
+
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
312
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
313
|
+
generators.append(self._wait_one_response(tmp_obj, request))
|
314
|
+
rids.append(tmp_obj.rid)
|
315
|
+
|
316
|
+
# Wait for all requests
|
317
|
+
is_stream = hasattr(obj, "stream") and obj.stream
|
318
|
+
if not is_stream:
|
319
|
+
outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
|
320
|
+
yield outputs
|
321
|
+
else:
|
322
|
+
rid_to_index = {rid: i for i, rid in enumerate(rids)}
|
323
|
+
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
|
324
|
+
while task_map:
|
325
|
+
done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
|
326
|
+
|
327
|
+
for task in done:
|
328
|
+
gen = task_map.pop(task)
|
329
|
+
try:
|
330
|
+
result = task.result()
|
331
|
+
result["index"] = rid_to_index[result["meta_info"]["id"]]
|
332
|
+
yield result
|
333
|
+
new_task = asyncio.create_task(gen.__anext__())
|
334
|
+
task_map[new_task] = gen
|
335
|
+
except StopAsyncIteration:
|
336
|
+
pass
|
516
337
|
|
517
338
|
def flush_cache(self):
|
518
339
|
req = FlushCacheReq()
|
@@ -538,9 +359,19 @@ class TokenizerManager:
|
|
538
359
|
self.create_handle_loop()
|
539
360
|
|
540
361
|
req = GetMemPoolSizeReq()
|
362
|
+
|
541
363
|
self.send_to_scheduler.send_pyobj(req)
|
542
364
|
self.mem_pool_size = asyncio.Future()
|
543
|
-
|
365
|
+
|
366
|
+
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
|
367
|
+
if self.server_args.dp_size == 1:
|
368
|
+
res = await self.mem_pool_size
|
369
|
+
return res.size
|
370
|
+
else: # self.server_args.dp_size > 1
|
371
|
+
self.mem_pool_size_tmp = []
|
372
|
+
res = await self.mem_pool_size
|
373
|
+
ret = [r.size for r in res]
|
374
|
+
return ret
|
544
375
|
|
545
376
|
async def update_weights(
|
546
377
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
@@ -553,25 +384,41 @@ class TokenizerManager:
|
|
553
384
|
obj.load_format = self.server_args.load_format
|
554
385
|
|
555
386
|
if not self.model_update_lock.locked():
|
387
|
+
|
556
388
|
async with self.model_update_lock:
|
557
389
|
# wait for the previous generation requests to finish
|
558
390
|
while len(self.rid_to_state) > 0:
|
559
391
|
await asyncio.sleep(0.001)
|
560
392
|
self.send_to_scheduler.send_pyobj(obj)
|
561
393
|
self.model_update_result = asyncio.Future()
|
562
|
-
|
563
|
-
if
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
394
|
+
|
395
|
+
if self.server_args.dp_size == 1:
|
396
|
+
result = await self.model_update_result
|
397
|
+
if result.success:
|
398
|
+
self.server_args.model_path = obj.model_path
|
399
|
+
self.server_args.load_format = obj.load_format
|
400
|
+
self.model_path = obj.model_path
|
401
|
+
return result.success, result.message
|
402
|
+
else: # self.server_args.dp_size > 1
|
403
|
+
self.model_update_tmp = []
|
404
|
+
result = await self.model_update_result
|
405
|
+
|
406
|
+
all_success = all([r.success for r in result])
|
407
|
+
if all_success is True:
|
408
|
+
self.server_args.model_path = obj.model_path
|
409
|
+
self.server_args.load_format = obj.load_format
|
410
|
+
self.model_path = obj.model_path
|
411
|
+
all_message = [r.message for r in result]
|
412
|
+
all_message = " | ".join(all_message)
|
413
|
+
return all_success, all_message
|
414
|
+
|
568
415
|
else:
|
569
416
|
return False, "Another update is in progress. Please try again later."
|
570
417
|
|
571
418
|
def create_abort_task(self, obj: GenerateReqInput):
|
572
419
|
# Abort the request if the client is disconnected.
|
573
420
|
async def abort_request():
|
574
|
-
await asyncio.sleep(
|
421
|
+
await asyncio.sleep(1)
|
575
422
|
if obj.is_single:
|
576
423
|
self.abort_request(obj.rid)
|
577
424
|
else:
|
@@ -590,6 +437,28 @@ class TokenizerManager:
|
|
590
437
|
loop = asyncio.get_event_loop()
|
591
438
|
loop.create_task(self.handle_loop())
|
592
439
|
|
440
|
+
signal_handler = SignalHandler(self)
|
441
|
+
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
442
|
+
loop.create_task(self.sigterm_watchdog())
|
443
|
+
|
444
|
+
async def sigterm_watchdog(self):
|
445
|
+
while not self.gracefully_exit:
|
446
|
+
await asyncio.sleep(60)
|
447
|
+
|
448
|
+
# drain requests
|
449
|
+
while True:
|
450
|
+
remain_num_req = len(self.rid_to_state)
|
451
|
+
logger.info(
|
452
|
+
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
453
|
+
)
|
454
|
+
if remain_num_req > 0:
|
455
|
+
await asyncio.sleep(5)
|
456
|
+
else:
|
457
|
+
break
|
458
|
+
|
459
|
+
kill_child_process(include_self=True)
|
460
|
+
sys.exit(-1)
|
461
|
+
|
593
462
|
async def handle_loop(self):
|
594
463
|
"""The event loop that handles requests"""
|
595
464
|
|
@@ -599,10 +468,22 @@ class TokenizerManager:
|
|
599
468
|
] = await self.recv_from_detokenizer.recv_pyobj()
|
600
469
|
|
601
470
|
if isinstance(recv_obj, UpdateWeightReqOutput):
|
602
|
-
self.
|
471
|
+
if self.server_args.dp_size == 1:
|
472
|
+
self.model_update_result.set_result(recv_obj)
|
473
|
+
else: # self.server_args.dp_size > 1
|
474
|
+
self.model_update_tmp.append(recv_obj)
|
475
|
+
# set future if the all results are recevied
|
476
|
+
if len(self.model_update_tmp) == self.server_args.dp_size:
|
477
|
+
self.model_update_result.set_result(self.model_update_tmp)
|
603
478
|
continue
|
604
479
|
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
605
|
-
self.
|
480
|
+
if self.server_args.dp_size == 1:
|
481
|
+
self.mem_pool_size.set_result(recv_obj)
|
482
|
+
else: # self.sever_args.dp_size > 1
|
483
|
+
self.mem_pool_size_tmp.append(recv_obj)
|
484
|
+
# set future if the all results are received
|
485
|
+
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
486
|
+
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
606
487
|
continue
|
607
488
|
|
608
489
|
assert isinstance(
|
@@ -621,14 +502,10 @@ class TokenizerManager:
|
|
621
502
|
"meta_info": recv_obj.meta_info[i],
|
622
503
|
}
|
623
504
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
624
|
-
read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
|
625
505
|
out_dict = {
|
626
|
-
"token_ids": recv_obj.
|
627
|
-
read_start : recv_obj.read_offsets[i]
|
628
|
-
],
|
506
|
+
"token_ids": recv_obj.output_ids[i],
|
629
507
|
"meta_info": recv_obj.meta_info[i],
|
630
508
|
}
|
631
|
-
|
632
509
|
else:
|
633
510
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
634
511
|
out_dict = {
|
@@ -680,7 +557,7 @@ class TokenizerManager:
|
|
680
557
|
token_texts = self.tokenizer.batch_decode(token_ids)
|
681
558
|
return [
|
682
559
|
(logprob, token_id, token_text)
|
683
|
-
for (logprob, token_id), token_text
|
560
|
+
for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
|
684
561
|
]
|
685
562
|
|
686
563
|
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
|
@@ -692,3 +569,14 @@ class TokenizerManager:
|
|
692
569
|
token_top_logprobs, decode_to_text
|
693
570
|
)
|
694
571
|
return top_logprobs
|
572
|
+
|
573
|
+
|
574
|
+
class SignalHandler:
|
575
|
+
def __init__(self, tokenizer_manager):
|
576
|
+
self.tokenizer_manager = tokenizer_manager
|
577
|
+
|
578
|
+
def signal_handler(self, signum=None, frame=None):
|
579
|
+
logger.warning(
|
580
|
+
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
581
|
+
)
|
582
|
+
self.tokenizer_manager.gracefully_exit = True
|