sglang 0.3.4.post1__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_latency.py +3 -3
- sglang/bench_server_latency.py +2 -3
- sglang/bench_serving.py +92 -0
- sglang/global_config.py +9 -3
- sglang/lang/chat_template.py +50 -25
- sglang/lang/interpreter.py +9 -1
- sglang/lang/ir.py +11 -2
- sglang/launch_server.py +1 -1
- sglang/srt/configs/model_config.py +76 -15
- sglang/srt/constrained/__init__.py +18 -0
- sglang/srt/constrained/bnf_cache.py +61 -0
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/constrained/grammar.py +190 -0
- sglang/srt/hf_transformers_utils.py +20 -5
- sglang/srt/layers/attention/flashinfer_backend.py +5 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
- sglang/srt/layers/fused_moe/fused_moe.py +4 -3
- sglang/srt/layers/fused_moe/layer.py +28 -0
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/quantization/base_config.py +16 -1
- sglang/srt/layers/rotary_embedding.py +15 -48
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/layers/vocab_parallel_embedding.py +486 -0
- sglang/srt/managers/data_parallel_controller.py +8 -7
- sglang/srt/managers/detokenizer_manager.py +11 -9
- sglang/srt/managers/image_processor.py +4 -3
- sglang/srt/managers/io_struct.py +80 -78
- sglang/srt/managers/schedule_batch.py +46 -52
- sglang/srt/managers/schedule_policy.py +24 -13
- sglang/srt/managers/scheduler.py +145 -82
- sglang/srt/managers/tokenizer_manager.py +236 -334
- sglang/srt/managers/tp_worker.py +5 -5
- sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
- sglang/srt/mem_cache/flush_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +10 -3
- sglang/srt/model_executor/cuda_graph_runner.py +34 -23
- sglang/srt/model_executor/forward_batch_info.py +6 -9
- sglang/srt/model_executor/model_runner.py +10 -19
- sglang/srt/models/baichuan.py +4 -4
- sglang/srt/models/chatglm.py +4 -4
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +5 -5
- sglang/srt/models/deepseek.py +4 -4
- sglang/srt/models/deepseek_v2.py +4 -4
- sglang/srt/models/exaone.py +4 -4
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt2.py +287 -0
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +4 -4
- sglang/srt/models/internlm2.py +4 -4
- sglang/srt/models/llama.py +15 -7
- sglang/srt/models/llama_embedding.py +2 -10
- sglang/srt/models/llama_reward.py +5 -0
- sglang/srt/models/minicpm.py +4 -4
- sglang/srt/models/minicpm3.py +4 -4
- sglang/srt/models/mixtral.py +7 -5
- sglang/srt/models/mixtral_quant.py +4 -4
- sglang/srt/models/mllama.py +5 -5
- sglang/srt/models/olmo.py +4 -4
- sglang/srt/models/olmoe.py +4 -4
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +4 -4
- sglang/srt/models/qwen2_moe.py +4 -4
- sglang/srt/models/qwen2_vl.py +4 -8
- sglang/srt/models/stablelm.py +4 -4
- sglang/srt/models/torch_native_llama.py +4 -4
- sglang/srt/models/xverse.py +4 -4
- sglang/srt/models/xverse_moe.py +4 -4
- sglang/srt/openai_api/adapter.py +52 -66
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -13
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +41 -33
- sglang/srt/server_args.py +34 -5
- sglang/srt/utils.py +40 -56
- sglang/test/run_eval.py +2 -0
- sglang/test/runners.py +2 -1
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +151 -6
- sglang/utils.py +62 -1
- sglang/version.py +1 -1
- sglang-0.3.5.dist-info/METADATA +344 -0
- sglang-0.3.5.dist-info/RECORD +152 -0
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
- sglang-0.3.4.post1.dist-info/METADATA +0 -900
- sglang-0.3.4.post1.dist-info/RECORD +0 -148
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -16,10 +16,12 @@ limitations under the License.
|
|
16
16
|
"""TokenizerManager is a process that tokenizes the text."""
|
17
17
|
|
18
18
|
import asyncio
|
19
|
+
import copy
|
19
20
|
import dataclasses
|
20
|
-
import json
|
21
21
|
import logging
|
22
22
|
import os
|
23
|
+
import signal
|
24
|
+
import sys
|
23
25
|
from typing import Dict, List, Optional, Tuple, Union
|
24
26
|
|
25
27
|
import fastapi
|
@@ -28,12 +30,8 @@ import zmq
|
|
28
30
|
import zmq.asyncio
|
29
31
|
from fastapi import BackgroundTasks
|
30
32
|
|
31
|
-
from sglang.srt.
|
32
|
-
|
33
|
-
get_context_length,
|
34
|
-
get_processor,
|
35
|
-
get_tokenizer,
|
36
|
-
)
|
33
|
+
from sglang.srt.configs.model_config import ModelConfig
|
34
|
+
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
37
35
|
from sglang.srt.managers.image_processor import (
|
38
36
|
get_dummy_image_processor,
|
39
37
|
get_image_processor,
|
@@ -46,17 +44,17 @@ from sglang.srt.managers.io_struct import (
|
|
46
44
|
EmbeddingReqInput,
|
47
45
|
FlushCacheReq,
|
48
46
|
GenerateReqInput,
|
47
|
+
GetMemPoolSizeReq,
|
48
|
+
GetMemPoolSizeReqOutput,
|
49
49
|
ProfileReq,
|
50
|
-
RewardReqInput,
|
51
50
|
TokenizedEmbeddingReqInput,
|
52
51
|
TokenizedGenerateReqInput,
|
53
|
-
TokenizedRewardReqInput,
|
54
52
|
UpdateWeightReqInput,
|
55
53
|
UpdateWeightReqOutput,
|
56
54
|
)
|
57
55
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
58
56
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
59
|
-
from sglang.srt.utils import
|
57
|
+
from sglang.srt.utils import get_zmq_socket, kill_child_process
|
60
58
|
|
61
59
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
62
60
|
|
@@ -80,30 +78,32 @@ class TokenizerManager:
|
|
80
78
|
server_args: ServerArgs,
|
81
79
|
port_args: PortArgs,
|
82
80
|
):
|
81
|
+
# Parse args
|
83
82
|
self.server_args = server_args
|
84
83
|
|
85
84
|
# Init inter-process communication
|
86
85
|
context = zmq.asyncio.Context(2)
|
87
|
-
self.recv_from_detokenizer =
|
88
|
-
|
89
|
-
|
90
|
-
self.send_to_scheduler =
|
91
|
-
|
86
|
+
self.recv_from_detokenizer = get_zmq_socket(
|
87
|
+
context, zmq.PULL, port_args.tokenizer_ipc_name
|
88
|
+
)
|
89
|
+
self.send_to_scheduler = get_zmq_socket(
|
90
|
+
context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
91
|
+
)
|
92
92
|
|
93
93
|
# Read model args
|
94
94
|
self.model_path = server_args.model_path
|
95
95
|
self.served_model_name = server_args.served_model_name
|
96
|
-
self.
|
97
|
-
|
96
|
+
self.model_config = ModelConfig(
|
97
|
+
server_args.model_path,
|
98
98
|
trust_remote_code=server_args.trust_remote_code,
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
self.hf_config.architectures, self.server_args.is_embedding
|
103
|
-
)
|
104
|
-
self.context_len = server_args.context_length or get_context_length(
|
105
|
-
self.hf_config
|
99
|
+
context_length=server_args.context_length,
|
100
|
+
model_override_args=server_args.json_model_override_args,
|
101
|
+
is_embedding=server_args.is_embedding,
|
106
102
|
)
|
103
|
+
|
104
|
+
self.is_generation = self.model_config.is_generation
|
105
|
+
self.context_len = self.model_config.context_len
|
106
|
+
|
107
107
|
# Create image processor placeholder
|
108
108
|
self.image_processor = get_dummy_image_processor()
|
109
109
|
|
@@ -111,7 +111,7 @@ class TokenizerManager:
|
|
111
111
|
if server_args.skip_tokenizer_init:
|
112
112
|
self.tokenizer = self.processor = None
|
113
113
|
else:
|
114
|
-
if
|
114
|
+
if self.model_config.is_multimodal:
|
115
115
|
self.processor = get_processor(
|
116
116
|
server_args.tokenizer_path,
|
117
117
|
tokenizer_mode=server_args.tokenizer_mode,
|
@@ -122,7 +122,7 @@ class TokenizerManager:
|
|
122
122
|
|
123
123
|
# We want to parallelize the image pre-processing so we create an executor for it
|
124
124
|
self.image_processor = get_image_processor(
|
125
|
-
self.hf_config, server_args, self.processor
|
125
|
+
self.model_config.hf_config, server_args, self.processor
|
126
126
|
)
|
127
127
|
else:
|
128
128
|
self.tokenizer = get_tokenizer(
|
@@ -139,9 +139,12 @@ class TokenizerManager:
|
|
139
139
|
self.model_update_lock = asyncio.Lock()
|
140
140
|
self.model_update_result = None
|
141
141
|
|
142
|
+
# Others
|
143
|
+
self.gracefully_exit = False
|
144
|
+
|
142
145
|
async def generate_request(
|
143
146
|
self,
|
144
|
-
obj: Union[GenerateReqInput, EmbeddingReqInput
|
147
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
145
148
|
request: Optional[fastapi.Request] = None,
|
146
149
|
):
|
147
150
|
if self.to_create_loop:
|
@@ -152,133 +155,58 @@ class TokenizerManager:
|
|
152
155
|
|
153
156
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
154
157
|
raise ValueError(
|
155
|
-
"This model does not appear to be an embedding model by default.
|
158
|
+
"This model does not appear to be an embedding model by default. "
|
159
|
+
"Please add `--is-embedding` when launching the server or try another model."
|
156
160
|
)
|
157
161
|
|
158
|
-
obj.
|
162
|
+
obj.normalize_batch_and_arguments()
|
159
163
|
is_single = obj.is_single
|
160
164
|
if is_single:
|
161
|
-
|
165
|
+
tokenized_obj = await self._tokenize_one_request(obj)
|
166
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
167
|
+
async for response in self._wait_one_response(obj, request):
|
162
168
|
yield response
|
163
169
|
else:
|
164
170
|
async for response in self._handle_batch_request(obj, request):
|
165
171
|
yield response
|
166
172
|
|
167
|
-
async def
|
173
|
+
async def _tokenize_one_request(
|
168
174
|
self,
|
169
|
-
obj: Union[GenerateReqInput, EmbeddingReqInput
|
170
|
-
index: Optional[int] = None,
|
171
|
-
input_id_index: Optional[int] = None,
|
172
|
-
is_cache_for_prefill: Optional[bool] = False,
|
175
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
173
176
|
):
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
conv, tokenize=False
|
182
|
-
)
|
183
|
-
input_ids = self.tokenizer.encode(input_text)
|
184
|
-
elif obj.input_ids is None:
|
185
|
-
input_text = obj.text
|
186
|
-
input_ids = self.tokenizer.encode(input_text)
|
187
|
-
else:
|
188
|
-
input_text = obj.text if obj.text is not None else None
|
189
|
-
input_ids = obj.input_ids
|
190
|
-
|
191
|
-
sampling_params = self._get_sampling_params(obj.sampling_params)
|
192
|
-
if self.is_generation:
|
193
|
-
image_inputs = await self.image_processor.process_images_async(
|
194
|
-
obj.image_data, input_text or input_ids, obj
|
195
|
-
)
|
196
|
-
if image_inputs and "input_ids" in image_inputs:
|
197
|
-
input_ids = image_inputs["input_ids"]
|
198
|
-
return_logprob = obj.return_logprob
|
199
|
-
logprob_start_len = obj.logprob_start_len
|
200
|
-
top_logprobs_num = obj.top_logprobs_num
|
201
|
-
else:
|
202
|
-
rid = obj.rid[index]
|
203
|
-
if hasattr(obj, "conv"):
|
204
|
-
# reward model
|
205
|
-
conv = obj.conv[index]
|
206
|
-
input_text = self.tokenizer.apply_chat_template(
|
207
|
-
conv, tokenize=False
|
208
|
-
)
|
209
|
-
input_ids = self.tokenizer.encode(input_text)
|
210
|
-
elif obj.input_ids is None:
|
211
|
-
input_text = obj.text[input_id_index]
|
212
|
-
input_ids = self.tokenizer.encode(input_text)
|
213
|
-
else:
|
214
|
-
input_text = (
|
215
|
-
obj.text[input_id_index] if obj.text is not None else None
|
216
|
-
)
|
217
|
-
input_ids = obj.input_ids[input_id_index]
|
218
|
-
|
219
|
-
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
220
|
-
if self.is_generation:
|
221
|
-
image_inputs = await self.image_processor.process_images_async(
|
222
|
-
obj.image_data[index], input_text or input_ids, obj
|
223
|
-
)
|
224
|
-
if image_inputs and "input_ids" in image_inputs:
|
225
|
-
input_ids = image_inputs["input_ids"]
|
226
|
-
return_logprob = obj.return_logprob[index]
|
227
|
-
logprob_start_len = obj.logprob_start_len[index]
|
228
|
-
top_logprobs_num = obj.top_logprobs_num[index]
|
229
|
-
|
230
|
-
self._validate_input_length(input_ids)
|
231
|
-
|
232
|
-
else: # A prefill request to cache the common prompt for parallel sampling
|
233
|
-
assert self.is_generation
|
234
|
-
if obj.text is not None:
|
235
|
-
if isinstance(obj.text, list):
|
236
|
-
input_text = obj.text[input_id_index]
|
237
|
-
rid = obj.rid[index]
|
238
|
-
else:
|
239
|
-
input_text = obj.text
|
240
|
-
rid = obj.rid[0]
|
241
|
-
if self.tokenizer is not None:
|
242
|
-
input_ids = self.tokenizer.encode(input_text)
|
243
|
-
else:
|
244
|
-
assert obj.input_ids is not None
|
245
|
-
input_ids = obj.input_ids
|
246
|
-
if isinstance(obj.input_ids, list) and isinstance(
|
247
|
-
obj.input_ids[0], list
|
248
|
-
):
|
249
|
-
# when obj["input_ids"] is List[List[int]]
|
250
|
-
input_ids = obj.input_ids[input_id_index]
|
251
|
-
rid = obj.rid[index]
|
252
|
-
else:
|
253
|
-
input_ids = obj.input_ids
|
254
|
-
rid = obj.rid[0]
|
255
|
-
else:
|
256
|
-
input_text = None
|
257
|
-
if isinstance(obj.input_ids, list) and isinstance(
|
258
|
-
obj.input_ids[0], list
|
259
|
-
):
|
260
|
-
# when obj["input_ids"] is List[List[int]]
|
261
|
-
input_ids = obj.input_ids[input_id_index]
|
262
|
-
rid = obj.rid[index]
|
263
|
-
else:
|
264
|
-
input_ids = obj.input_ids
|
265
|
-
rid = obj.rid[0]
|
177
|
+
"""Tokenize one request."""
|
178
|
+
# Tokenize
|
179
|
+
input_text = obj.text
|
180
|
+
if obj.input_ids is None:
|
181
|
+
input_ids = self.tokenizer.encode(input_text)
|
182
|
+
else:
|
183
|
+
input_ids = obj.input_ids
|
266
184
|
|
267
|
-
|
268
|
-
sampling_params.max_new_tokens = 0
|
185
|
+
if self.is_generation:
|
269
186
|
image_inputs = await self.image_processor.process_images_async(
|
270
|
-
obj.image_data
|
187
|
+
obj.image_data, input_text or input_ids, obj
|
271
188
|
)
|
272
189
|
if image_inputs and "input_ids" in image_inputs:
|
273
190
|
input_ids = image_inputs["input_ids"]
|
274
|
-
return_logprob = obj.return_logprob
|
275
|
-
logprob_start_len = obj.logprob_start_len
|
276
|
-
top_logprobs_num = obj.top_logprobs_num
|
191
|
+
return_logprob = obj.return_logprob
|
192
|
+
logprob_start_len = obj.logprob_start_len
|
193
|
+
top_logprobs_num = obj.top_logprobs_num
|
277
194
|
|
278
|
-
|
279
|
-
|
195
|
+
if len(input_ids) >= self.context_len:
|
196
|
+
raise ValueError(
|
197
|
+
f"The input ({len(input_ids)} tokens) is longer than the "
|
198
|
+
f"model's context length ({self.context_len} tokens)."
|
199
|
+
)
|
200
|
+
|
201
|
+
# Parse sampling parameters
|
202
|
+
sampling_params = SamplingParams(**obj.sampling_params)
|
203
|
+
sampling_params.normalize(self.tokenizer)
|
204
|
+
sampling_params.verify()
|
205
|
+
|
206
|
+
# Build return object
|
207
|
+
if isinstance(obj, GenerateReqInput):
|
280
208
|
tokenized_obj = TokenizedGenerateReqInput(
|
281
|
-
rid,
|
209
|
+
obj.rid,
|
282
210
|
input_text,
|
283
211
|
input_ids,
|
284
212
|
image_inputs,
|
@@ -287,230 +215,125 @@ class TokenizerManager:
|
|
287
215
|
logprob_start_len,
|
288
216
|
top_logprobs_num,
|
289
217
|
obj.stream,
|
290
|
-
|
291
|
-
obj.lora_path[input_id_index]
|
292
|
-
if isinstance(obj.lora_path, list)
|
293
|
-
else obj.lora_path
|
294
|
-
),
|
218
|
+
obj.lora_path
|
295
219
|
)
|
296
220
|
elif isinstance(obj, EmbeddingReqInput):
|
297
221
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
298
|
-
rid,
|
299
|
-
input_text,
|
300
|
-
input_ids,
|
301
|
-
sampling_params,
|
302
|
-
)
|
303
|
-
else:
|
304
|
-
assert isinstance(obj, RewardReqInput)
|
305
|
-
tokenized_obj = TokenizedRewardReqInput(
|
306
|
-
rid,
|
222
|
+
obj.rid,
|
307
223
|
input_text,
|
308
224
|
input_ids,
|
309
225
|
sampling_params,
|
310
226
|
)
|
311
227
|
|
312
|
-
|
313
|
-
return rid, input_ids
|
228
|
+
return tokenized_obj
|
314
229
|
|
315
|
-
async def
|
230
|
+
async def _wait_one_response(
|
316
231
|
self,
|
317
|
-
obj: Union[GenerateReqInput, EmbeddingReqInput
|
232
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
318
233
|
request: Optional[fastapi.Request] = None,
|
319
|
-
index: Optional[int] = None,
|
320
|
-
input_id_index: Optional[int] = None,
|
321
|
-
is_cache_for_prefill: Optional[bool] = False,
|
322
234
|
):
|
323
|
-
|
324
|
-
obj,
|
325
|
-
index,
|
326
|
-
input_id_index=input_id_index,
|
327
|
-
is_cache_for_prefill=is_cache_for_prefill,
|
328
|
-
)
|
329
|
-
|
330
|
-
# Recv results
|
235
|
+
"""Wait for the response of one request."""
|
331
236
|
event = asyncio.Event()
|
332
237
|
state = ReqState([], False, event)
|
333
|
-
self.rid_to_state[rid] = state
|
334
|
-
|
335
|
-
if not is_cache_for_prefill:
|
336
|
-
async for response in self._wait_for_response(state, obj, rid, request):
|
337
|
-
yield response
|
338
|
-
else:
|
339
|
-
assert self.is_generation
|
340
|
-
await self._wait_for_cache_prefill_response(state, obj, rid, request)
|
341
|
-
yield input_ids
|
342
|
-
|
343
|
-
async def _handle_batch_request(
|
344
|
-
self,
|
345
|
-
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
346
|
-
request: Optional[fastapi.Request] = None,
|
347
|
-
):
|
348
|
-
batch_size = obj.batch_size
|
349
|
-
if self.is_generation:
|
350
|
-
parallel_sample_num = obj.parallel_sample_num
|
351
|
-
|
352
|
-
if parallel_sample_num != 1:
|
353
|
-
# Send prefill requests to cache the common prefix
|
354
|
-
parallel_sample_num += 1
|
355
|
-
input_id_result = [] if obj.input_ids is None else None
|
356
|
-
for i in range(batch_size):
|
357
|
-
async for input_id in self._handle_single_request(
|
358
|
-
obj,
|
359
|
-
request,
|
360
|
-
index=i,
|
361
|
-
input_id_index=i,
|
362
|
-
is_cache_for_prefill=True,
|
363
|
-
):
|
364
|
-
if input_id_result is not None:
|
365
|
-
input_id_result.append(input_id)
|
366
|
-
if input_id_result is not None:
|
367
|
-
obj.input_ids = input_id_result
|
368
|
-
else:
|
369
|
-
parallel_sample_num = 1
|
370
|
-
|
371
|
-
# First send out all requests
|
372
|
-
generators = []
|
373
|
-
for i in range(batch_size):
|
374
|
-
for j in range(parallel_sample_num):
|
375
|
-
if j == 0 and parallel_sample_num != 1:
|
376
|
-
continue
|
377
|
-
index = i * parallel_sample_num + j
|
378
|
-
if parallel_sample_num != 1:
|
379
|
-
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
380
|
-
index += batch_size - 1 - i
|
381
|
-
|
382
|
-
rid, _ = await self._send_single_request(
|
383
|
-
obj, index, input_id_index=i, is_cache_for_prefill=False
|
384
|
-
)
|
385
|
-
|
386
|
-
event = asyncio.Event()
|
387
|
-
state = ReqState([], False, event)
|
388
|
-
self.rid_to_state[rid] = state
|
389
|
-
|
390
|
-
generators.append(
|
391
|
-
self._wait_for_response(
|
392
|
-
state,
|
393
|
-
obj,
|
394
|
-
rid,
|
395
|
-
request,
|
396
|
-
index=index,
|
397
|
-
response_index=len(generators),
|
398
|
-
)
|
399
|
-
)
|
400
|
-
|
401
|
-
# Then process the responses based on streaming option
|
402
|
-
is_stream = hasattr(obj, "stream") and obj.stream
|
238
|
+
self.rid_to_state[obj.rid] = state
|
403
239
|
|
404
|
-
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
405
|
-
output_list = [None] * len(tasks)
|
406
|
-
|
407
|
-
# Fetch results
|
408
|
-
while tasks:
|
409
|
-
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
410
|
-
|
411
|
-
for task in done:
|
412
|
-
cur_index = tasks.index(task)
|
413
|
-
|
414
|
-
try:
|
415
|
-
result = task.result()
|
416
|
-
|
417
|
-
if is_stream:
|
418
|
-
yield result
|
419
|
-
else:
|
420
|
-
output_list[result["index"]] = result
|
421
|
-
|
422
|
-
tasks[cur_index] = asyncio.create_task(
|
423
|
-
generators[cur_index].__anext__()
|
424
|
-
)
|
425
|
-
except StopAsyncIteration:
|
426
|
-
del generators[cur_index]
|
427
|
-
del tasks[cur_index]
|
428
|
-
|
429
|
-
if not is_stream:
|
430
|
-
yield output_list
|
431
|
-
|
432
|
-
def _validate_input_length(self, input_ids: List[int]):
|
433
|
-
if len(input_ids) >= self.context_len:
|
434
|
-
raise ValueError(
|
435
|
-
f"The input ({len(input_ids)} tokens) is longer than the "
|
436
|
-
f"model's context length ({self.context_len} tokens)."
|
437
|
-
)
|
438
|
-
|
439
|
-
def _get_sampling_params(self, sampling_params_data: dict):
|
440
|
-
sampling_params = SamplingParams(**sampling_params_data)
|
441
|
-
if sampling_params.max_new_tokens != 0:
|
442
|
-
sampling_params.normalize(self.tokenizer)
|
443
|
-
sampling_params.verify()
|
444
|
-
return sampling_params
|
445
|
-
|
446
|
-
async def _wait_for_response(
|
447
|
-
self,
|
448
|
-
state: ReqState,
|
449
|
-
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
450
|
-
rid: str,
|
451
|
-
request: Optional[fastapi.Request] = None,
|
452
|
-
index: Optional[int] = None,
|
453
|
-
response_index: int = 0,
|
454
|
-
):
|
455
240
|
while True:
|
456
241
|
try:
|
457
242
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
458
243
|
except asyncio.TimeoutError:
|
459
244
|
if request is not None and await request.is_disconnected():
|
460
|
-
|
461
|
-
|
462
|
-
raise ValueError(f"Abort request {rid}")
|
245
|
+
self.abort_request(obj.rid)
|
246
|
+
raise ValueError(f"Abort request {obj.rid}")
|
463
247
|
continue
|
464
248
|
|
465
|
-
if
|
249
|
+
if isinstance(obj, GenerateReqInput):
|
466
250
|
out = self.convert_logprob_style(
|
467
251
|
state.out_list[-1],
|
468
|
-
obj.return_logprob
|
469
|
-
|
470
|
-
obj.top_logprobs_num
|
471
|
-
if index is None
|
472
|
-
else obj.top_logprobs_num[index]
|
473
|
-
),
|
252
|
+
obj.return_logprob,
|
253
|
+
obj.top_logprobs_num,
|
474
254
|
obj.return_text_in_logprobs,
|
475
255
|
)
|
476
|
-
else: # isinstance(obj, (EmbeddingReqInput,
|
256
|
+
else: # isinstance(obj, (EmbeddingReqInput,))
|
477
257
|
out = state.out_list[-1]
|
478
258
|
|
479
|
-
out["index"] = response_index
|
480
|
-
|
481
|
-
# Log requests
|
482
|
-
if self.server_args.log_requests and state.finished:
|
483
|
-
logger.info(f"in={obj}, out={out}")
|
484
|
-
|
485
259
|
state.out_list = []
|
486
260
|
if state.finished:
|
487
|
-
|
261
|
+
if self.server_args.log_requests:
|
262
|
+
# Log requests
|
263
|
+
logger.info(f"in={obj}, out={out}")
|
264
|
+
del self.rid_to_state[obj.rid]
|
488
265
|
yield out
|
489
266
|
break
|
490
267
|
|
491
268
|
state.event.clear()
|
492
269
|
yield out
|
493
270
|
|
494
|
-
async def
|
271
|
+
async def _handle_batch_request(
|
495
272
|
self,
|
496
|
-
|
497
|
-
obj: GenerateReqInput,
|
498
|
-
rid: str,
|
273
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
499
274
|
request: Optional[fastapi.Request] = None,
|
500
275
|
):
|
501
|
-
|
502
|
-
try:
|
503
|
-
await asyncio.wait_for(state.event.wait(), timeout=4)
|
504
|
-
break
|
505
|
-
except asyncio.TimeoutError:
|
506
|
-
if request is not None and await request.is_disconnected():
|
507
|
-
for rid in obj.rid:
|
508
|
-
self.abort_request(rid)
|
509
|
-
raise ValueError(f"Abort request {rid}")
|
510
|
-
continue
|
276
|
+
batch_size = obj.batch_size
|
511
277
|
|
512
|
-
|
513
|
-
|
278
|
+
generators = []
|
279
|
+
rids = []
|
280
|
+
if getattr(obj, "parallel_sample_num", 1) == 1:
|
281
|
+
# Send all requests
|
282
|
+
for i in range(batch_size):
|
283
|
+
tmp_obj = obj[i]
|
284
|
+
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
285
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
286
|
+
generators.append(self._wait_one_response(tmp_obj, request))
|
287
|
+
rids.append(tmp_obj.rid)
|
288
|
+
else:
|
289
|
+
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
290
|
+
|
291
|
+
# Tokenize all requests
|
292
|
+
objs = [obj[i] for i in range(batch_size)]
|
293
|
+
tokenized_objs = await asyncio.gather(*(self._tokenize_one_request(obj) for obj in objs))
|
294
|
+
|
295
|
+
# Cache the common prefix for parallel sampling
|
296
|
+
for i in range(batch_size):
|
297
|
+
tmp_obj = copy.copy(objs[i])
|
298
|
+
tokenized_obj = copy.copy(tokenized_objs[i])
|
299
|
+
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
300
|
+
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
|
301
|
+
tokenized_obj.sampling_params.max_new_tokens = 0
|
302
|
+
tokenized_obj.stream = False
|
303
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
304
|
+
await self._wait_one_response(tmp_obj, request).__anext__()
|
305
|
+
|
306
|
+
# Expand requests, assign new rids for them, and send them
|
307
|
+
for i in range(batch_size):
|
308
|
+
for _ in range(obj.parallel_sample_num):
|
309
|
+
tmp_obj = copy.copy(objs[i])
|
310
|
+
tokenized_obj = copy.copy(tokenized_objs[i])
|
311
|
+
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
312
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
313
|
+
generators.append(self._wait_one_response(tmp_obj, request))
|
314
|
+
rids.append(tmp_obj.rid)
|
315
|
+
|
316
|
+
# Wait for all requests
|
317
|
+
is_stream = hasattr(obj, "stream") and obj.stream
|
318
|
+
if not is_stream:
|
319
|
+
outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
|
320
|
+
yield outputs
|
321
|
+
else:
|
322
|
+
rid_to_index = {rid: i for i, rid in enumerate(rids)}
|
323
|
+
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
|
324
|
+
while task_map:
|
325
|
+
done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
|
326
|
+
|
327
|
+
for task in done:
|
328
|
+
gen = task_map.pop(task)
|
329
|
+
try:
|
330
|
+
result = task.result()
|
331
|
+
result["index"] = rid_to_index[result["meta_info"]["id"]]
|
332
|
+
yield result
|
333
|
+
new_task = asyncio.create_task(gen.__anext__())
|
334
|
+
task_map[new_task] = gen
|
335
|
+
except StopAsyncIteration:
|
336
|
+
pass
|
514
337
|
|
515
338
|
def flush_cache(self):
|
516
339
|
req = FlushCacheReq()
|
@@ -531,6 +354,25 @@ class TokenizerManager:
|
|
531
354
|
req = ProfileReq.STOP_PROFILE
|
532
355
|
self.send_to_scheduler.send_pyobj(req)
|
533
356
|
|
357
|
+
async def get_memory_pool_size(self):
|
358
|
+
if self.to_create_loop:
|
359
|
+
self.create_handle_loop()
|
360
|
+
|
361
|
+
req = GetMemPoolSizeReq()
|
362
|
+
|
363
|
+
self.send_to_scheduler.send_pyobj(req)
|
364
|
+
self.mem_pool_size = asyncio.Future()
|
365
|
+
|
366
|
+
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
|
367
|
+
if self.server_args.dp_size == 1:
|
368
|
+
res = await self.mem_pool_size
|
369
|
+
return res.size
|
370
|
+
else: # self.server_args.dp_size > 1
|
371
|
+
self.mem_pool_size_tmp = []
|
372
|
+
res = await self.mem_pool_size
|
373
|
+
ret = [r.size for r in res]
|
374
|
+
return ret
|
375
|
+
|
534
376
|
async def update_weights(
|
535
377
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
536
378
|
):
|
@@ -542,25 +384,41 @@ class TokenizerManager:
|
|
542
384
|
obj.load_format = self.server_args.load_format
|
543
385
|
|
544
386
|
if not self.model_update_lock.locked():
|
387
|
+
|
545
388
|
async with self.model_update_lock:
|
546
389
|
# wait for the previous generation requests to finish
|
547
390
|
while len(self.rid_to_state) > 0:
|
548
391
|
await asyncio.sleep(0.001)
|
549
392
|
self.send_to_scheduler.send_pyobj(obj)
|
550
393
|
self.model_update_result = asyncio.Future()
|
551
|
-
|
552
|
-
if
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
394
|
+
|
395
|
+
if self.server_args.dp_size == 1:
|
396
|
+
result = await self.model_update_result
|
397
|
+
if result.success:
|
398
|
+
self.server_args.model_path = obj.model_path
|
399
|
+
self.server_args.load_format = obj.load_format
|
400
|
+
self.model_path = obj.model_path
|
401
|
+
return result.success, result.message
|
402
|
+
else: # self.server_args.dp_size > 1
|
403
|
+
self.model_update_tmp = []
|
404
|
+
result = await self.model_update_result
|
405
|
+
|
406
|
+
all_success = all([r.success for r in result])
|
407
|
+
if all_success is True:
|
408
|
+
self.server_args.model_path = obj.model_path
|
409
|
+
self.server_args.load_format = obj.load_format
|
410
|
+
self.model_path = obj.model_path
|
411
|
+
all_message = [r.message for r in result]
|
412
|
+
all_message = " | ".join(all_message)
|
413
|
+
return all_success, all_message
|
414
|
+
|
557
415
|
else:
|
558
416
|
return False, "Another update is in progress. Please try again later."
|
559
417
|
|
560
418
|
def create_abort_task(self, obj: GenerateReqInput):
|
561
419
|
# Abort the request if the client is disconnected.
|
562
420
|
async def abort_request():
|
563
|
-
await asyncio.sleep(
|
421
|
+
await asyncio.sleep(1)
|
564
422
|
if obj.is_single:
|
565
423
|
self.abort_request(obj.rid)
|
566
424
|
else:
|
@@ -579,6 +437,28 @@ class TokenizerManager:
|
|
579
437
|
loop = asyncio.get_event_loop()
|
580
438
|
loop.create_task(self.handle_loop())
|
581
439
|
|
440
|
+
signal_handler = SignalHandler(self)
|
441
|
+
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
442
|
+
loop.create_task(self.sigterm_watchdog())
|
443
|
+
|
444
|
+
async def sigterm_watchdog(self):
|
445
|
+
while not self.gracefully_exit:
|
446
|
+
await asyncio.sleep(60)
|
447
|
+
|
448
|
+
# drain requests
|
449
|
+
while True:
|
450
|
+
remain_num_req = len(self.rid_to_state)
|
451
|
+
logger.info(
|
452
|
+
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
453
|
+
)
|
454
|
+
if remain_num_req > 0:
|
455
|
+
await asyncio.sleep(5)
|
456
|
+
else:
|
457
|
+
break
|
458
|
+
|
459
|
+
kill_child_process(include_self=True)
|
460
|
+
sys.exit(-1)
|
461
|
+
|
582
462
|
async def handle_loop(self):
|
583
463
|
"""The event loop that handles requests"""
|
584
464
|
|
@@ -588,7 +468,22 @@ class TokenizerManager:
|
|
588
468
|
] = await self.recv_from_detokenizer.recv_pyobj()
|
589
469
|
|
590
470
|
if isinstance(recv_obj, UpdateWeightReqOutput):
|
591
|
-
self.
|
471
|
+
if self.server_args.dp_size == 1:
|
472
|
+
self.model_update_result.set_result(recv_obj)
|
473
|
+
else: # self.server_args.dp_size > 1
|
474
|
+
self.model_update_tmp.append(recv_obj)
|
475
|
+
# set future if the all results are recevied
|
476
|
+
if len(self.model_update_tmp) == self.server_args.dp_size:
|
477
|
+
self.model_update_result.set_result(self.model_update_tmp)
|
478
|
+
continue
|
479
|
+
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
480
|
+
if self.server_args.dp_size == 1:
|
481
|
+
self.mem_pool_size.set_result(recv_obj)
|
482
|
+
else: # self.sever_args.dp_size > 1
|
483
|
+
self.mem_pool_size_tmp.append(recv_obj)
|
484
|
+
# set future if the all results are received
|
485
|
+
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
486
|
+
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
592
487
|
continue
|
593
488
|
|
594
489
|
assert isinstance(
|
@@ -607,14 +502,10 @@ class TokenizerManager:
|
|
607
502
|
"meta_info": recv_obj.meta_info[i],
|
608
503
|
}
|
609
504
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
610
|
-
read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
|
611
505
|
out_dict = {
|
612
|
-
"token_ids": recv_obj.
|
613
|
-
read_start : recv_obj.read_offsets[i]
|
614
|
-
],
|
506
|
+
"token_ids": recv_obj.output_ids[i],
|
615
507
|
"meta_info": recv_obj.meta_info[i],
|
616
508
|
}
|
617
|
-
|
618
509
|
else:
|
619
510
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
620
511
|
out_dict = {
|
@@ -666,7 +557,7 @@ class TokenizerManager:
|
|
666
557
|
token_texts = self.tokenizer.batch_decode(token_ids)
|
667
558
|
return [
|
668
559
|
(logprob, token_id, token_text)
|
669
|
-
for (logprob, token_id), token_text
|
560
|
+
for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
|
670
561
|
]
|
671
562
|
|
672
563
|
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
|
@@ -678,3 +569,14 @@ class TokenizerManager:
|
|
678
569
|
token_top_logprobs, decode_to_text
|
679
570
|
)
|
680
571
|
return top_logprobs
|
572
|
+
|
573
|
+
|
574
|
+
class SignalHandler:
|
575
|
+
def __init__(self, tokenizer_manager):
|
576
|
+
self.tokenizer_manager = tokenizer_manager
|
577
|
+
|
578
|
+
def signal_handler(self, signum=None, frame=None):
|
579
|
+
logger.warning(
|
580
|
+
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
581
|
+
)
|
582
|
+
self.tokenizer_manager.gracefully_exit = True
|