sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +3 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +8 -1
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +17 -2
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +1 -1
- sglang/srt/hf_transformers_utils.py +75 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +15 -11
- sglang/srt/managers/router/infer_batch.py +103 -59
- sglang/srt/managers/router/manager.py +1 -1
- sglang/srt/managers/router/model_rpc.py +175 -122
- sglang/srt/managers/router/model_runner.py +91 -104
- sglang/srt/managers/router/radix_cache.py +7 -1
- sglang/srt/managers/router/scheduler.py +6 -6
- sglang/srt/managers/tokenizer_manager.py +152 -89
- sglang/srt/model_config.py +4 -5
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +8 -15
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +19 -15
- sglang/srt/models/llava.py +84 -20
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +248 -118
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +77 -42
- sglang/srt/server_args.py +51 -6
- sglang/srt/utils.py +124 -66
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +22 -4
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -1,66 +1,32 @@
|
|
1
1
|
import importlib
|
2
2
|
import importlib.resources
|
3
|
-
import inspect
|
4
3
|
import logging
|
5
4
|
import pkgutil
|
6
5
|
from dataclasses import dataclass
|
7
6
|
from functools import lru_cache
|
8
|
-
from typing import List
|
7
|
+
from typing import List, Optional, Type
|
9
8
|
|
10
9
|
import numpy as np
|
11
10
|
import torch
|
11
|
+
import torch.nn as nn
|
12
|
+
from vllm.config import DeviceConfig, LoadConfig
|
13
|
+
from vllm.config import ModelConfig as VllmModelConfig
|
12
14
|
from vllm.distributed import initialize_model_parallel
|
13
|
-
from vllm.model_executor.
|
14
|
-
from vllm.model_executor.
|
15
|
-
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
16
|
-
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
15
|
+
from vllm.model_executor.model_loader import get_model
|
16
|
+
from vllm.model_executor.models import ModelRegistry
|
17
17
|
|
18
18
|
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
19
19
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
20
|
-
from sglang.srt.
|
20
|
+
from sglang.srt.server_args import ServerArgs
|
21
|
+
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
|
21
22
|
|
22
23
|
|
23
|
-
QUANTIZATION_CONFIG_MAPPING = {
|
24
|
-
"awq": AWQConfig,
|
25
|
-
"gptq": GPTQConfig,
|
26
|
-
"marlin": MarlinConfig,
|
27
|
-
}
|
28
|
-
|
29
24
|
logger = logging.getLogger("model_runner")
|
30
25
|
|
31
26
|
# for server args in model endpoints
|
32
27
|
global_server_args_dict = {}
|
33
28
|
|
34
29
|
|
35
|
-
@lru_cache()
|
36
|
-
def import_model_classes():
|
37
|
-
model_arch_name_to_cls = {}
|
38
|
-
package_name = "sglang.srt.models"
|
39
|
-
package = importlib.import_module(package_name)
|
40
|
-
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
41
|
-
if not ispkg:
|
42
|
-
module = importlib.import_module(name)
|
43
|
-
if hasattr(module, "EntryClass"):
|
44
|
-
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
|
45
|
-
return model_arch_name_to_cls
|
46
|
-
|
47
|
-
|
48
|
-
def get_model_cls_by_arch_name(model_arch_names):
|
49
|
-
model_arch_name_to_cls = import_model_classes()
|
50
|
-
|
51
|
-
model_class = None
|
52
|
-
for arch in model_arch_names:
|
53
|
-
if arch in model_arch_name_to_cls:
|
54
|
-
model_class = model_arch_name_to_cls[arch]
|
55
|
-
break
|
56
|
-
else:
|
57
|
-
raise ValueError(
|
58
|
-
f"Unsupported architectures: {arch}. "
|
59
|
-
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
60
|
-
)
|
61
|
-
return model_class
|
62
|
-
|
63
|
-
|
64
30
|
@dataclass
|
65
31
|
class InputMetadata:
|
66
32
|
model_runner: "ModelRunner"
|
@@ -253,113 +219,102 @@ class ModelRunner:
|
|
253
219
|
tp_rank,
|
254
220
|
tp_size,
|
255
221
|
nccl_port,
|
256
|
-
|
257
|
-
trust_remote_code=True,
|
258
|
-
server_args_dict: dict = {},
|
222
|
+
server_args: ServerArgs,
|
259
223
|
):
|
260
224
|
self.model_config = model_config
|
261
225
|
self.mem_fraction_static = mem_fraction_static
|
262
226
|
self.tp_rank = tp_rank
|
263
227
|
self.tp_size = tp_size
|
264
228
|
self.nccl_port = nccl_port
|
265
|
-
self.
|
266
|
-
self.trust_remote_code = trust_remote_code
|
229
|
+
self.server_args = server_args
|
267
230
|
|
268
231
|
global global_server_args_dict
|
269
|
-
global_server_args_dict =
|
232
|
+
global_server_args_dict = {
|
233
|
+
"enable_flashinfer": server_args.enable_flashinfer,
|
234
|
+
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
235
|
+
}
|
270
236
|
|
271
237
|
# Init torch distributed
|
238
|
+
logger.info(f"[rank={self.tp_rank}] Set cuda device.")
|
272
239
|
torch.cuda.set_device(self.tp_rank)
|
240
|
+
logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
|
273
241
|
torch.distributed.init_process_group(
|
274
242
|
backend="nccl",
|
275
243
|
world_size=self.tp_size,
|
276
244
|
rank=self.tp_rank,
|
277
245
|
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
278
246
|
)
|
279
|
-
|
280
247
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
248
|
+
logger.info(f"[rank={self.tp_rank}] Init torch end.")
|
249
|
+
|
250
|
+
total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
|
251
|
+
|
252
|
+
if self.tp_size > 1:
|
253
|
+
total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
|
254
|
+
if total_local_gpu_memory < total_gpu_memory * 0.9:
|
255
|
+
raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
|
281
256
|
|
282
|
-
total_gpu_memory = get_available_gpu_memory(
|
283
|
-
self.tp_rank, distributed=self.tp_size > 1
|
284
|
-
) * (1 << 30)
|
285
257
|
self.load_model()
|
286
258
|
self.init_memory_pool(total_gpu_memory)
|
287
259
|
|
288
260
|
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
289
261
|
|
290
262
|
def load_model(self):
|
291
|
-
"
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
with set_default_torch_dtype(torch.float16):
|
322
|
-
with torch.device("cuda"):
|
323
|
-
model = model_class(
|
324
|
-
config=self.model_config.hf_config, quant_config=quant_config
|
325
|
-
)
|
326
|
-
model.load_weights(
|
327
|
-
self.model_config.path,
|
328
|
-
cache_dir=None,
|
329
|
-
load_format=self.load_format,
|
330
|
-
revision=None,
|
331
|
-
)
|
332
|
-
self.model = model.eval()
|
333
|
-
|
334
|
-
logger.info(f"Rank {self.tp_rank}: load weight end.")
|
263
|
+
logger.info(f"[rank={self.tp_rank}] Load weight begin.")
|
264
|
+
|
265
|
+
device_config = DeviceConfig()
|
266
|
+
load_config = LoadConfig(load_format=self.server_args.load_format)
|
267
|
+
vllm_model_config = VllmModelConfig(
|
268
|
+
model=self.server_args.model_path,
|
269
|
+
quantization=self.server_args.quantization,
|
270
|
+
tokenizer=None,
|
271
|
+
tokenizer_mode=None,
|
272
|
+
trust_remote_code=self.server_args.trust_remote_code,
|
273
|
+
dtype=torch.float16,
|
274
|
+
seed=42,
|
275
|
+
skip_tokenizer_init=True,
|
276
|
+
)
|
277
|
+
if self.model_config.model_overide_args is not None:
|
278
|
+
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
279
|
+
|
280
|
+
self.model = get_model(
|
281
|
+
model_config=vllm_model_config,
|
282
|
+
device_config=device_config,
|
283
|
+
load_config=load_config,
|
284
|
+
lora_config=None,
|
285
|
+
vision_language_config=None,
|
286
|
+
parallel_config=None,
|
287
|
+
scheduler_config=None,
|
288
|
+
)
|
289
|
+
logger.info(f"[rank={self.tp_rank}] Load weight end. "
|
290
|
+
f"Type={type(self.model).__name__}. "
|
291
|
+
f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
|
335
292
|
|
336
293
|
def profile_max_num_token(self, total_gpu_memory):
|
337
|
-
available_gpu_memory = get_available_gpu_memory(
|
338
|
-
self.tp_rank, distributed=self.tp_size > 1
|
339
|
-
) * (1 << 30)
|
294
|
+
available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
|
340
295
|
head_dim = self.model_config.head_dim
|
341
296
|
head_num = self.model_config.num_key_value_heads // self.tp_size
|
342
297
|
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
343
298
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
344
299
|
1 - self.mem_fraction_static
|
345
300
|
)
|
346
|
-
max_num_token = int(rest_memory // cell_size)
|
301
|
+
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
347
302
|
return max_num_token
|
348
303
|
|
349
304
|
def init_memory_pool(self, total_gpu_memory):
|
350
|
-
self.
|
305
|
+
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
351
306
|
|
352
|
-
if self.
|
307
|
+
if self.max_total_num_tokens <= 0:
|
353
308
|
raise RuntimeError(
|
354
309
|
"Not enought memory. " "Please try to increase --mem-fraction-static."
|
355
310
|
)
|
356
311
|
|
357
312
|
self.req_to_token_pool = ReqToTokenPool(
|
358
|
-
int(self.
|
313
|
+
int(self.max_total_num_tokens / self.model_config.context_len * 256),
|
359
314
|
self.model_config.context_len + 8,
|
360
315
|
)
|
361
316
|
self.token_to_kv_pool = TokenToKVPool(
|
362
|
-
self.
|
317
|
+
self.max_total_num_tokens,
|
363
318
|
dtype=torch.float16,
|
364
319
|
head_num=self.model_config.num_key_value_heads // self.tp_size,
|
365
320
|
head_dim=self.model_config.head_dim,
|
@@ -456,3 +411,35 @@ class ModelRunner:
|
|
456
411
|
return self.forward_prefill(batch)
|
457
412
|
else:
|
458
413
|
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
414
|
+
|
415
|
+
|
416
|
+
@lru_cache()
|
417
|
+
def import_model_classes():
|
418
|
+
model_arch_name_to_cls = {}
|
419
|
+
package_name = "sglang.srt.models"
|
420
|
+
package = importlib.import_module(package_name)
|
421
|
+
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
422
|
+
if not ispkg:
|
423
|
+
module = importlib.import_module(name)
|
424
|
+
if hasattr(module, "EntryClass"):
|
425
|
+
entry = module.EntryClass
|
426
|
+
if isinstance(entry, list): # To support multiple model classes in one module
|
427
|
+
for cls in entry:
|
428
|
+
model_arch_name_to_cls[cls.__name__] = cls
|
429
|
+
else:
|
430
|
+
model_arch_name_to_cls[entry.__name__] = entry
|
431
|
+
return model_arch_name_to_cls
|
432
|
+
|
433
|
+
|
434
|
+
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
435
|
+
model_arch_name_to_cls = import_model_classes()
|
436
|
+
if model_arch not in model_arch_name_to_cls:
|
437
|
+
raise ValueError(
|
438
|
+
f"Unsupported architectures: {model_arch}. "
|
439
|
+
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
440
|
+
)
|
441
|
+
return model_arch_name_to_cls[model_arch]
|
442
|
+
|
443
|
+
|
444
|
+
# Monkey patch model loader
|
445
|
+
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
|
@@ -58,7 +58,7 @@ class RadixCache:
|
|
58
58
|
|
59
59
|
def insert(self, key, value=None):
|
60
60
|
if self.disable:
|
61
|
-
return
|
61
|
+
return 0
|
62
62
|
|
63
63
|
if value is None:
|
64
64
|
value = [x for x in key]
|
@@ -76,6 +76,12 @@ class RadixCache:
|
|
76
76
|
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
77
77
|
new_prefix_len = self.insert(token_ids, indices.clone())
|
78
78
|
|
79
|
+
if self.disable:
|
80
|
+
if del_in_memory_pool:
|
81
|
+
self.token_to_kv_pool.dec_refs(indices)
|
82
|
+
else:
|
83
|
+
return torch.tensor([], dtype=torch.int64), self.root_node
|
84
|
+
|
79
85
|
# Radix Cache takes one ref in memory pool
|
80
86
|
self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
|
81
87
|
|
@@ -6,15 +6,15 @@ class Scheduler:
|
|
6
6
|
def __init__(
|
7
7
|
self,
|
8
8
|
schedule_heuristic,
|
9
|
-
|
10
|
-
|
11
|
-
|
9
|
+
max_running_seqs,
|
10
|
+
max_prefill_num_tokens,
|
11
|
+
max_total_num_tokens,
|
12
12
|
tree_cache,
|
13
13
|
):
|
14
14
|
self.schedule_heuristic = schedule_heuristic
|
15
|
-
self.
|
16
|
-
self.
|
17
|
-
self.
|
15
|
+
self.max_running_seqs = max_running_seqs
|
16
|
+
self.max_prefill_num_tokens = max_prefill_num_tokens
|
17
|
+
self.max_total_num_tokens = max_total_num_tokens
|
18
18
|
self.tree_cache = tree_cache
|
19
19
|
|
20
20
|
def get_priority_queue(self, forward_queue):
|
@@ -4,13 +4,14 @@ import dataclasses
|
|
4
4
|
import logging
|
5
5
|
import multiprocessing as mp
|
6
6
|
import os
|
7
|
-
from typing import List
|
7
|
+
from typing import List, Dict
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import transformers
|
11
11
|
import uvloop
|
12
12
|
import zmq
|
13
13
|
import zmq.asyncio
|
14
|
+
from fastapi import BackgroundTasks
|
14
15
|
|
15
16
|
from sglang.srt.hf_transformers_utils import (
|
16
17
|
get_config,
|
@@ -19,16 +20,18 @@ from sglang.srt.hf_transformers_utils import (
|
|
19
20
|
get_tokenizer,
|
20
21
|
)
|
21
22
|
from sglang.srt.managers.io_struct import (
|
23
|
+
AbortReq,
|
22
24
|
BatchStrOut,
|
23
|
-
DetokenizeReqInput,
|
24
25
|
FlushCacheReq,
|
25
26
|
GenerateReqInput,
|
26
27
|
TokenizedGenerateReqInput,
|
27
28
|
)
|
29
|
+
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
28
30
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
29
31
|
from sglang.srt.sampling_params import SamplingParams
|
30
32
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
31
|
-
from sglang.srt.utils import
|
33
|
+
from sglang.srt.utils import is_multimodal_model, load_image
|
34
|
+
from sglang.utils import get_exception_traceback
|
32
35
|
|
33
36
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
34
37
|
|
@@ -42,51 +45,6 @@ class ReqState:
|
|
42
45
|
event: asyncio.Event
|
43
46
|
|
44
47
|
|
45
|
-
global global_processor
|
46
|
-
|
47
|
-
|
48
|
-
def init_global_processor(server_args: ServerArgs):
|
49
|
-
global global_processor
|
50
|
-
transformers.logging.set_verbosity_error()
|
51
|
-
global_processor = get_processor(
|
52
|
-
server_args.tokenizer_path,
|
53
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
54
|
-
trust_remote_code=server_args.trust_remote_code,
|
55
|
-
)
|
56
|
-
|
57
|
-
|
58
|
-
def get_pixel_values(
|
59
|
-
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
60
|
-
):
|
61
|
-
try:
|
62
|
-
processor = processor or global_processor
|
63
|
-
image, image_size = load_image(image_data)
|
64
|
-
if image_size != None:
|
65
|
-
image_hash = hash(image_data)
|
66
|
-
pixel_values = processor.image_processor(image)["pixel_values"]
|
67
|
-
for _ in range(len(pixel_values)):
|
68
|
-
pixel_values[_] = pixel_values[_].astype(np.float16)
|
69
|
-
pixel_values = np.stack(pixel_values, axis=0)
|
70
|
-
return pixel_values, image_hash, image_size
|
71
|
-
else:
|
72
|
-
image_hash = hash(image_data)
|
73
|
-
if image_aspect_ratio == "pad":
|
74
|
-
image = expand2square(
|
75
|
-
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
|
76
|
-
)
|
77
|
-
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
78
|
-
elif image_aspect_ratio == "anyres":
|
79
|
-
pixel_values = process_anyres_image(
|
80
|
-
image, processor.image_processor, image_grid_pinpoints
|
81
|
-
)
|
82
|
-
else:
|
83
|
-
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
84
|
-
pixel_values = pixel_values.astype(np.float16)
|
85
|
-
return pixel_values, image_hash, image.size
|
86
|
-
except Exception:
|
87
|
-
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
88
|
-
|
89
|
-
|
90
48
|
class TokenizerManager:
|
91
49
|
def __init__(
|
92
50
|
self,
|
@@ -132,7 +90,7 @@ class TokenizerManager:
|
|
132
90
|
)
|
133
91
|
|
134
92
|
self.to_create_loop = True
|
135
|
-
self.rid_to_state
|
93
|
+
self.rid_to_state: Dict[str, ReqState] = {}
|
136
94
|
|
137
95
|
async def get_pixel_values(self, image_data):
|
138
96
|
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
@@ -153,10 +111,11 @@ class TokenizerManager:
|
|
153
111
|
image_data, aspect_ratio, grid_pinpoints, self.processor
|
154
112
|
)
|
155
113
|
|
156
|
-
async def generate_request(self, obj: GenerateReqInput):
|
114
|
+
async def generate_request(self, obj: GenerateReqInput, request=None):
|
157
115
|
if self.to_create_loop:
|
158
|
-
|
116
|
+
self.create_handle_loop()
|
159
117
|
|
118
|
+
obj.post_init()
|
160
119
|
is_single = obj.is_single
|
161
120
|
if is_single:
|
162
121
|
rid = obj.rid
|
@@ -169,7 +128,7 @@ class TokenizerManager:
|
|
169
128
|
if len(input_ids) >= self.context_len:
|
170
129
|
raise ValueError(
|
171
130
|
f"The input ({len(input_ids)} tokens) is longer than the "
|
172
|
-
f"model's context length ({self.context_len} tokens)"
|
131
|
+
f"model's context length ({self.context_len} tokens)."
|
173
132
|
)
|
174
133
|
|
175
134
|
sampling_params = SamplingParams(**obj.sampling_params)
|
@@ -207,23 +166,38 @@ class TokenizerManager:
|
|
207
166
|
self.rid_to_state[rid] = state
|
208
167
|
|
209
168
|
while True:
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
169
|
+
try:
|
170
|
+
await asyncio.wait_for(event.wait(), timeout=4)
|
171
|
+
except asyncio.TimeoutError:
|
172
|
+
if request is not None and await request.is_disconnected():
|
173
|
+
self.abort_request(rid)
|
174
|
+
raise ValueError(f"Abort request {rid}")
|
175
|
+
continue
|
176
|
+
|
177
|
+
out = self.convert_logprob_style(
|
178
|
+
state.out_list[-1],
|
179
|
+
obj.return_logprob,
|
180
|
+
obj.top_logprobs_num,
|
181
|
+
obj.return_text_in_logprobs,
|
182
|
+
)
|
215
183
|
|
216
184
|
if self.server_args.log_requests and state.finished:
|
217
185
|
logger.info(f"in={obj.text}, out={out}")
|
218
186
|
|
219
|
-
yield out
|
220
187
|
state.out_list = []
|
221
188
|
if state.finished:
|
222
189
|
del self.rid_to_state[rid]
|
190
|
+
|
191
|
+
yield out
|
192
|
+
|
223
193
|
break
|
194
|
+
|
224
195
|
event.clear()
|
196
|
+
|
197
|
+
yield out
|
225
198
|
else:
|
226
|
-
|
199
|
+
if obj.stream:
|
200
|
+
raise ValueError("Do not support stream for batch mode.")
|
227
201
|
|
228
202
|
if obj.input_ids is None:
|
229
203
|
bs = len(obj.text)
|
@@ -273,45 +247,84 @@ class TokenizerManager:
|
|
273
247
|
for i in range(bs):
|
274
248
|
rid = obj.rid[i]
|
275
249
|
state = self.rid_to_state[rid]
|
276
|
-
|
250
|
+
|
251
|
+
while True:
|
252
|
+
try:
|
253
|
+
await asyncio.wait_for(state.event.wait(), timeout=4)
|
254
|
+
break
|
255
|
+
except asyncio.TimeoutError:
|
256
|
+
if request is not None and await request.is_disconnected():
|
257
|
+
for rid in obj.rid:
|
258
|
+
self.abort_request(rid)
|
259
|
+
raise ValueError(f"Abort request {rid}")
|
260
|
+
continue
|
261
|
+
|
277
262
|
output_list.append(
|
278
|
-
self.convert_logprob_style(
|
279
|
-
|
280
|
-
|
281
|
-
|
263
|
+
self.convert_logprob_style(
|
264
|
+
state.out_list[-1],
|
265
|
+
obj.return_logprob[i],
|
266
|
+
obj.top_logprobs_num[i],
|
267
|
+
obj.return_text_in_logprobs,
|
268
|
+
)
|
269
|
+
)
|
282
270
|
assert state.finished
|
283
271
|
del self.rid_to_state[rid]
|
284
272
|
|
285
273
|
yield output_list
|
286
274
|
|
287
|
-
|
288
|
-
|
289
|
-
self.send_to_router.send_pyobj(
|
275
|
+
def flush_cache(self):
|
276
|
+
req = FlushCacheReq()
|
277
|
+
self.send_to_router.send_pyobj(req)
|
278
|
+
|
279
|
+
def abort_request(self, rid):
|
280
|
+
if rid not in self.rid_to_state:
|
281
|
+
return
|
282
|
+
del self.rid_to_state[rid]
|
283
|
+
req = AbortReq(rid)
|
284
|
+
self.send_to_router.send_pyobj(req)
|
285
|
+
|
286
|
+
def create_abort_task(self, obj):
|
287
|
+
# Abort the request if the client is disconnected.
|
288
|
+
async def abort_request():
|
289
|
+
await asyncio.sleep(3)
|
290
|
+
if obj.is_single:
|
291
|
+
self.abort_request(obj.rid)
|
292
|
+
else:
|
293
|
+
for rid in obj.rids:
|
294
|
+
self.abort_request(rid)
|
295
|
+
|
296
|
+
background_tasks = BackgroundTasks()
|
297
|
+
background_tasks.add_task(abort_request)
|
298
|
+
return background_tasks
|
290
299
|
|
291
|
-
|
300
|
+
def create_handle_loop(self):
|
292
301
|
self.to_create_loop = False
|
293
302
|
loop = asyncio.get_event_loop()
|
294
303
|
loop.create_task(self.handle_loop())
|
295
304
|
|
296
305
|
async def handle_loop(self):
|
297
306
|
while True:
|
298
|
-
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
307
|
+
recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj()
|
308
|
+
assert isinstance(recv_obj, BatchStrOut)
|
309
|
+
|
310
|
+
for i, rid in enumerate(recv_obj.rids):
|
311
|
+
state = self.rid_to_state.get(rid, None)
|
312
|
+
if state is None:
|
313
|
+
continue
|
314
|
+
|
315
|
+
recv_obj.meta_info[i]["id"] = rid
|
316
|
+
out_dict = {
|
317
|
+
"text": recv_obj.output_str[i],
|
318
|
+
"meta_info": recv_obj.meta_info[i],
|
319
|
+
}
|
320
|
+
state.out_list.append(out_dict)
|
321
|
+
state.finished = recv_obj.finished_reason[i] is not None
|
322
|
+
state.event.set()
|
323
|
+
|
324
|
+
|
325
|
+
def convert_logprob_style(
|
326
|
+
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
327
|
+
):
|
315
328
|
if return_logprob:
|
316
329
|
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
|
317
330
|
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
|
@@ -320,11 +333,15 @@ class TokenizerManager:
|
|
320
333
|
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
321
334
|
)
|
322
335
|
if top_logprobs_num > 0:
|
323
|
-
ret["meta_info"]["prefill_top_logprobs"] =
|
324
|
-
|
336
|
+
ret["meta_info"]["prefill_top_logprobs"] = (
|
337
|
+
self.detokenize_top_logprobs_tokens(
|
338
|
+
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
339
|
+
)
|
325
340
|
)
|
326
|
-
ret["meta_info"]["decode_top_logprobs"] =
|
327
|
-
|
341
|
+
ret["meta_info"]["decode_top_logprobs"] = (
|
342
|
+
self.detokenize_top_logprobs_tokens(
|
343
|
+
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
344
|
+
)
|
328
345
|
)
|
329
346
|
return ret
|
330
347
|
|
@@ -344,3 +361,49 @@ class TokenizerManager:
|
|
344
361
|
if t:
|
345
362
|
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
346
363
|
return top_logprobs
|
364
|
+
|
365
|
+
|
366
|
+
global global_processor
|
367
|
+
|
368
|
+
|
369
|
+
def init_global_processor(server_args: ServerArgs):
|
370
|
+
global global_processor
|
371
|
+
transformers.logging.set_verbosity_error()
|
372
|
+
global_processor = get_processor(
|
373
|
+
server_args.tokenizer_path,
|
374
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
375
|
+
trust_remote_code=server_args.trust_remote_code,
|
376
|
+
)
|
377
|
+
|
378
|
+
|
379
|
+
def get_pixel_values(
|
380
|
+
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
381
|
+
):
|
382
|
+
try:
|
383
|
+
processor = processor or global_processor
|
384
|
+
image, image_size = load_image(image_data)
|
385
|
+
if image_size != None:
|
386
|
+
image_hash = hash(image_data)
|
387
|
+
pixel_values = processor.image_processor(image)["pixel_values"]
|
388
|
+
for _ in range(len(pixel_values)):
|
389
|
+
pixel_values[_] = pixel_values[_].astype(np.float16)
|
390
|
+
pixel_values = np.stack(pixel_values, axis=0)
|
391
|
+
return pixel_values, image_hash, image_size
|
392
|
+
else:
|
393
|
+
image_hash = hash(image_data)
|
394
|
+
if image_aspect_ratio == "pad":
|
395
|
+
image = expand2square(
|
396
|
+
image,
|
397
|
+
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
398
|
+
)
|
399
|
+
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
400
|
+
elif image_aspect_ratio == "anyres":
|
401
|
+
pixel_values = process_anyres_image(
|
402
|
+
image, processor.image_processor, image_grid_pinpoints
|
403
|
+
)
|
404
|
+
else:
|
405
|
+
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
406
|
+
pixel_values = pixel_values.astype(np.float16)
|
407
|
+
return pixel_values, image_hash, image.size
|
408
|
+
except Exception:
|
409
|
+
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|