sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +3 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +8 -1
  8. sglang/lang/interpreter.py +114 -67
  9. sglang/lang/ir.py +17 -2
  10. sglang/srt/constrained/fsm_cache.py +3 -0
  11. sglang/srt/flush_cache.py +1 -1
  12. sglang/srt/hf_transformers_utils.py +75 -1
  13. sglang/srt/layers/extend_attention.py +17 -0
  14. sglang/srt/layers/fused_moe.py +485 -0
  15. sglang/srt/layers/logits_processor.py +12 -7
  16. sglang/srt/layers/radix_attention.py +10 -3
  17. sglang/srt/layers/token_attention.py +16 -1
  18. sglang/srt/managers/controller/dp_worker.py +110 -0
  19. sglang/srt/managers/controller/infer_batch.py +619 -0
  20. sglang/srt/managers/controller/manager_multi.py +191 -0
  21. sglang/srt/managers/controller/manager_single.py +97 -0
  22. sglang/srt/managers/controller/model_runner.py +462 -0
  23. sglang/srt/managers/controller/radix_cache.py +267 -0
  24. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  25. sglang/srt/managers/controller/tp_worker.py +791 -0
  26. sglang/srt/managers/detokenizer_manager.py +45 -45
  27. sglang/srt/managers/io_struct.py +15 -11
  28. sglang/srt/managers/router/infer_batch.py +103 -59
  29. sglang/srt/managers/router/manager.py +1 -1
  30. sglang/srt/managers/router/model_rpc.py +175 -122
  31. sglang/srt/managers/router/model_runner.py +91 -104
  32. sglang/srt/managers/router/radix_cache.py +7 -1
  33. sglang/srt/managers/router/scheduler.py +6 -6
  34. sglang/srt/managers/tokenizer_manager.py +152 -89
  35. sglang/srt/model_config.py +4 -5
  36. sglang/srt/models/commandr.py +10 -13
  37. sglang/srt/models/dbrx.py +9 -15
  38. sglang/srt/models/gemma.py +8 -15
  39. sglang/srt/models/grok.py +671 -0
  40. sglang/srt/models/llama2.py +19 -15
  41. sglang/srt/models/llava.py +84 -20
  42. sglang/srt/models/llavavid.py +11 -20
  43. sglang/srt/models/mixtral.py +248 -118
  44. sglang/srt/models/mixtral_quant.py +373 -0
  45. sglang/srt/models/qwen.py +9 -13
  46. sglang/srt/models/qwen2.py +11 -13
  47. sglang/srt/models/stablelm.py +9 -15
  48. sglang/srt/models/yivl.py +17 -22
  49. sglang/srt/openai_api_adapter.py +140 -95
  50. sglang/srt/openai_protocol.py +10 -1
  51. sglang/srt/server.py +77 -42
  52. sglang/srt/server_args.py +51 -6
  53. sglang/srt/utils.py +124 -66
  54. sglang/test/test_programs.py +44 -0
  55. sglang/test/test_utils.py +32 -1
  56. sglang/utils.py +22 -4
  57. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
  58. sglang-0.1.17.dist-info/RECORD +81 -0
  59. sglang/srt/backend_config.py +0 -13
  60. sglang/srt/models/dbrx_config.py +0 -281
  61. sglang/srt/weight_utils.py +0 -417
  62. sglang-0.1.16.dist-info/RECORD +0 -72
  63. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  64. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  65. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -1,66 +1,32 @@
1
1
  import importlib
2
2
  import importlib.resources
3
- import inspect
4
3
  import logging
5
4
  import pkgutil
6
5
  from dataclasses import dataclass
7
6
  from functools import lru_cache
8
- from typing import List
7
+ from typing import List, Optional, Type
9
8
 
10
9
  import numpy as np
11
10
  import torch
11
+ import torch.nn as nn
12
+ from vllm.config import DeviceConfig, LoadConfig
13
+ from vllm.config import ModelConfig as VllmModelConfig
12
14
  from vllm.distributed import initialize_model_parallel
13
- from vllm.model_executor.layers.quantization.awq import AWQConfig
14
- from vllm.model_executor.layers.quantization.gptq import GPTQConfig
15
- from vllm.model_executor.layers.quantization.marlin import MarlinConfig
16
- from vllm.model_executor.model_loader.utils import set_default_torch_dtype
15
+ from vllm.model_executor.model_loader import get_model
16
+ from vllm.model_executor.models import ModelRegistry
17
17
 
18
18
  from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
19
19
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
20
- from sglang.srt.utils import is_multimodal_model, get_available_gpu_memory
20
+ from sglang.srt.server_args import ServerArgs
21
+ from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
21
22
 
22
23
 
23
- QUANTIZATION_CONFIG_MAPPING = {
24
- "awq": AWQConfig,
25
- "gptq": GPTQConfig,
26
- "marlin": MarlinConfig,
27
- }
28
-
29
24
  logger = logging.getLogger("model_runner")
30
25
 
31
26
  # for server args in model endpoints
32
27
  global_server_args_dict = {}
33
28
 
34
29
 
35
- @lru_cache()
36
- def import_model_classes():
37
- model_arch_name_to_cls = {}
38
- package_name = "sglang.srt.models"
39
- package = importlib.import_module(package_name)
40
- for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
41
- if not ispkg:
42
- module = importlib.import_module(name)
43
- if hasattr(module, "EntryClass"):
44
- model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
45
- return model_arch_name_to_cls
46
-
47
-
48
- def get_model_cls_by_arch_name(model_arch_names):
49
- model_arch_name_to_cls = import_model_classes()
50
-
51
- model_class = None
52
- for arch in model_arch_names:
53
- if arch in model_arch_name_to_cls:
54
- model_class = model_arch_name_to_cls[arch]
55
- break
56
- else:
57
- raise ValueError(
58
- f"Unsupported architectures: {arch}. "
59
- f"Supported list: {list(model_arch_name_to_cls.keys())}"
60
- )
61
- return model_class
62
-
63
-
64
30
  @dataclass
65
31
  class InputMetadata:
66
32
  model_runner: "ModelRunner"
@@ -253,113 +219,102 @@ class ModelRunner:
253
219
  tp_rank,
254
220
  tp_size,
255
221
  nccl_port,
256
- load_format="auto",
257
- trust_remote_code=True,
258
- server_args_dict: dict = {},
222
+ server_args: ServerArgs,
259
223
  ):
260
224
  self.model_config = model_config
261
225
  self.mem_fraction_static = mem_fraction_static
262
226
  self.tp_rank = tp_rank
263
227
  self.tp_size = tp_size
264
228
  self.nccl_port = nccl_port
265
- self.load_format = load_format
266
- self.trust_remote_code = trust_remote_code
229
+ self.server_args = server_args
267
230
 
268
231
  global global_server_args_dict
269
- global_server_args_dict = server_args_dict
232
+ global_server_args_dict = {
233
+ "enable_flashinfer": server_args.enable_flashinfer,
234
+ "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
235
+ }
270
236
 
271
237
  # Init torch distributed
238
+ logger.info(f"[rank={self.tp_rank}] Set cuda device.")
272
239
  torch.cuda.set_device(self.tp_rank)
240
+ logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
273
241
  torch.distributed.init_process_group(
274
242
  backend="nccl",
275
243
  world_size=self.tp_size,
276
244
  rank=self.tp_rank,
277
245
  init_method=f"tcp://127.0.0.1:{self.nccl_port}",
278
246
  )
279
-
280
247
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
248
+ logger.info(f"[rank={self.tp_rank}] Init torch end.")
249
+
250
+ total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
251
+
252
+ if self.tp_size > 1:
253
+ total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
254
+ if total_local_gpu_memory < total_gpu_memory * 0.9:
255
+ raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
281
256
 
282
- total_gpu_memory = get_available_gpu_memory(
283
- self.tp_rank, distributed=self.tp_size > 1
284
- ) * (1 << 30)
285
257
  self.load_model()
286
258
  self.init_memory_pool(total_gpu_memory)
287
259
 
288
260
  self.is_multimodal_model = is_multimodal_model(self.model_config)
289
261
 
290
262
  def load_model(self):
291
- """See also vllm/model_executor/model_loader.py::get_model"""
292
- # Select model class
293
- architectures = getattr(self.model_config.hf_config, "architectures", [])
294
- model_class = get_model_cls_by_arch_name(architectures)
295
- logger.info(f"Rank {self.tp_rank}: load weight begin.")
296
-
297
- # Load weights
298
- quant_config = None
299
-
300
- quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
301
- if quant_cfg is not None:
302
- quant_method = quant_cfg.get("quant_method", "").lower()
303
- # compat: autogptq >=0.8.0 use checkpoint_format: str
304
- # compat: autogptq <=0.7.1 is_marlin_format: bool
305
- is_format_marlin = quant_cfg.get(
306
- "checkpoint_format"
307
- ) == "marlin" or quant_cfg.get("is_marlin_format", False)
308
-
309
- # Use marlin if the GPTQ model is serialized in marlin format.
310
- if quant_method == "gptq" and is_format_marlin:
311
- quant_method = "marlin"
312
-
313
- quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)
314
-
315
- if quant_config_class is None:
316
- raise ValueError(f"Unsupported quantization method: {quant_method}")
317
-
318
- quant_config = quant_config_class.from_config(quant_cfg)
319
- logger.info(f"quant_config: {quant_config}")
320
-
321
- with set_default_torch_dtype(torch.float16):
322
- with torch.device("cuda"):
323
- model = model_class(
324
- config=self.model_config.hf_config, quant_config=quant_config
325
- )
326
- model.load_weights(
327
- self.model_config.path,
328
- cache_dir=None,
329
- load_format=self.load_format,
330
- revision=None,
331
- )
332
- self.model = model.eval()
333
-
334
- logger.info(f"Rank {self.tp_rank}: load weight end.")
263
+ logger.info(f"[rank={self.tp_rank}] Load weight begin.")
264
+
265
+ device_config = DeviceConfig()
266
+ load_config = LoadConfig(load_format=self.server_args.load_format)
267
+ vllm_model_config = VllmModelConfig(
268
+ model=self.server_args.model_path,
269
+ quantization=self.server_args.quantization,
270
+ tokenizer=None,
271
+ tokenizer_mode=None,
272
+ trust_remote_code=self.server_args.trust_remote_code,
273
+ dtype=torch.float16,
274
+ seed=42,
275
+ skip_tokenizer_init=True,
276
+ )
277
+ if self.model_config.model_overide_args is not None:
278
+ vllm_model_config.hf_config.update(self.model_config.model_overide_args)
279
+
280
+ self.model = get_model(
281
+ model_config=vllm_model_config,
282
+ device_config=device_config,
283
+ load_config=load_config,
284
+ lora_config=None,
285
+ vision_language_config=None,
286
+ parallel_config=None,
287
+ scheduler_config=None,
288
+ )
289
+ logger.info(f"[rank={self.tp_rank}] Load weight end. "
290
+ f"Type={type(self.model).__name__}. "
291
+ f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
335
292
 
336
293
  def profile_max_num_token(self, total_gpu_memory):
337
- available_gpu_memory = get_available_gpu_memory(
338
- self.tp_rank, distributed=self.tp_size > 1
339
- ) * (1 << 30)
294
+ available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
340
295
  head_dim = self.model_config.head_dim
341
296
  head_num = self.model_config.num_key_value_heads // self.tp_size
342
297
  cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
343
298
  rest_memory = available_gpu_memory - total_gpu_memory * (
344
299
  1 - self.mem_fraction_static
345
300
  )
346
- max_num_token = int(rest_memory // cell_size)
301
+ max_num_token = int(rest_memory * (1 << 30) // cell_size)
347
302
  return max_num_token
348
303
 
349
304
  def init_memory_pool(self, total_gpu_memory):
350
- self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
305
+ self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
351
306
 
352
- if self.max_total_num_token <= 0:
307
+ if self.max_total_num_tokens <= 0:
353
308
  raise RuntimeError(
354
309
  "Not enought memory. " "Please try to increase --mem-fraction-static."
355
310
  )
356
311
 
357
312
  self.req_to_token_pool = ReqToTokenPool(
358
- int(self.max_total_num_token / self.model_config.context_len * 256),
313
+ int(self.max_total_num_tokens / self.model_config.context_len * 256),
359
314
  self.model_config.context_len + 8,
360
315
  )
361
316
  self.token_to_kv_pool = TokenToKVPool(
362
- self.max_total_num_token,
317
+ self.max_total_num_tokens,
363
318
  dtype=torch.float16,
364
319
  head_num=self.model_config.num_key_value_heads // self.tp_size,
365
320
  head_dim=self.model_config.head_dim,
@@ -456,3 +411,35 @@ class ModelRunner:
456
411
  return self.forward_prefill(batch)
457
412
  else:
458
413
  raise ValueError(f"Invaid forward mode: {forward_mode}")
414
+
415
+
416
+ @lru_cache()
417
+ def import_model_classes():
418
+ model_arch_name_to_cls = {}
419
+ package_name = "sglang.srt.models"
420
+ package = importlib.import_module(package_name)
421
+ for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
422
+ if not ispkg:
423
+ module = importlib.import_module(name)
424
+ if hasattr(module, "EntryClass"):
425
+ entry = module.EntryClass
426
+ if isinstance(entry, list): # To support multiple model classes in one module
427
+ for cls in entry:
428
+ model_arch_name_to_cls[cls.__name__] = cls
429
+ else:
430
+ model_arch_name_to_cls[entry.__name__] = entry
431
+ return model_arch_name_to_cls
432
+
433
+
434
+ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
435
+ model_arch_name_to_cls = import_model_classes()
436
+ if model_arch not in model_arch_name_to_cls:
437
+ raise ValueError(
438
+ f"Unsupported architectures: {model_arch}. "
439
+ f"Supported list: {list(model_arch_name_to_cls.keys())}"
440
+ )
441
+ return model_arch_name_to_cls[model_arch]
442
+
443
+
444
+ # Monkey patch model loader
445
+ setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
@@ -58,7 +58,7 @@ class RadixCache:
58
58
 
59
59
  def insert(self, key, value=None):
60
60
  if self.disable:
61
- return len(key)
61
+ return 0
62
62
 
63
63
  if value is None:
64
64
  value = [x for x in key]
@@ -76,6 +76,12 @@ class RadixCache:
76
76
  indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
77
77
  new_prefix_len = self.insert(token_ids, indices.clone())
78
78
 
79
+ if self.disable:
80
+ if del_in_memory_pool:
81
+ self.token_to_kv_pool.dec_refs(indices)
82
+ else:
83
+ return torch.tensor([], dtype=torch.int64), self.root_node
84
+
79
85
  # Radix Cache takes one ref in memory pool
80
86
  self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
81
87
 
@@ -6,15 +6,15 @@ class Scheduler:
6
6
  def __init__(
7
7
  self,
8
8
  schedule_heuristic,
9
- max_running_seq,
10
- max_prefill_num_token,
11
- max_total_num_token,
9
+ max_running_seqs,
10
+ max_prefill_num_tokens,
11
+ max_total_num_tokens,
12
12
  tree_cache,
13
13
  ):
14
14
  self.schedule_heuristic = schedule_heuristic
15
- self.max_running_seq = max_running_seq
16
- self.max_prefill_num_token = max_prefill_num_token
17
- self.max_total_num_token = max_total_num_token
15
+ self.max_running_seqs = max_running_seqs
16
+ self.max_prefill_num_tokens = max_prefill_num_tokens
17
+ self.max_total_num_tokens = max_total_num_tokens
18
18
  self.tree_cache = tree_cache
19
19
 
20
20
  def get_priority_queue(self, forward_queue):
@@ -4,13 +4,14 @@ import dataclasses
4
4
  import logging
5
5
  import multiprocessing as mp
6
6
  import os
7
- from typing import List
7
+ from typing import List, Dict
8
8
 
9
9
  import numpy as np
10
10
  import transformers
11
11
  import uvloop
12
12
  import zmq
13
13
  import zmq.asyncio
14
+ from fastapi import BackgroundTasks
14
15
 
15
16
  from sglang.srt.hf_transformers_utils import (
16
17
  get_config,
@@ -19,16 +20,18 @@ from sglang.srt.hf_transformers_utils import (
19
20
  get_tokenizer,
20
21
  )
21
22
  from sglang.srt.managers.io_struct import (
23
+ AbortReq,
22
24
  BatchStrOut,
23
- DetokenizeReqInput,
24
25
  FlushCacheReq,
25
26
  GenerateReqInput,
26
27
  TokenizedGenerateReqInput,
27
28
  )
29
+ from sglang.srt.managers.io_struct import BatchTokenIDOut
28
30
  from sglang.srt.mm_utils import expand2square, process_anyres_image
29
31
  from sglang.srt.sampling_params import SamplingParams
30
32
  from sglang.srt.server_args import PortArgs, ServerArgs
31
- from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
33
+ from sglang.srt.utils import is_multimodal_model, load_image
34
+ from sglang.utils import get_exception_traceback
32
35
 
33
36
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
34
37
 
@@ -42,51 +45,6 @@ class ReqState:
42
45
  event: asyncio.Event
43
46
 
44
47
 
45
- global global_processor
46
-
47
-
48
- def init_global_processor(server_args: ServerArgs):
49
- global global_processor
50
- transformers.logging.set_verbosity_error()
51
- global_processor = get_processor(
52
- server_args.tokenizer_path,
53
- tokenizer_mode=server_args.tokenizer_mode,
54
- trust_remote_code=server_args.trust_remote_code,
55
- )
56
-
57
-
58
- def get_pixel_values(
59
- image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
60
- ):
61
- try:
62
- processor = processor or global_processor
63
- image, image_size = load_image(image_data)
64
- if image_size != None:
65
- image_hash = hash(image_data)
66
- pixel_values = processor.image_processor(image)["pixel_values"]
67
- for _ in range(len(pixel_values)):
68
- pixel_values[_] = pixel_values[_].astype(np.float16)
69
- pixel_values = np.stack(pixel_values, axis=0)
70
- return pixel_values, image_hash, image_size
71
- else:
72
- image_hash = hash(image_data)
73
- if image_aspect_ratio == "pad":
74
- image = expand2square(
75
- image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
76
- )
77
- pixel_values = processor.image_processor(image)["pixel_values"][0]
78
- elif image_aspect_ratio == "anyres":
79
- pixel_values = process_anyres_image(
80
- image, processor.image_processor, image_grid_pinpoints
81
- )
82
- else:
83
- pixel_values = processor.image_processor(image)["pixel_values"][0]
84
- pixel_values = pixel_values.astype(np.float16)
85
- return pixel_values, image_hash, image.size
86
- except Exception:
87
- print("Exception in TokenizerManager:\n" + get_exception_traceback())
88
-
89
-
90
48
  class TokenizerManager:
91
49
  def __init__(
92
50
  self,
@@ -132,7 +90,7 @@ class TokenizerManager:
132
90
  )
133
91
 
134
92
  self.to_create_loop = True
135
- self.rid_to_state = {} # Dict[str -> ReqState]
93
+ self.rid_to_state: Dict[str, ReqState] = {}
136
94
 
137
95
  async def get_pixel_values(self, image_data):
138
96
  aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
@@ -153,10 +111,11 @@ class TokenizerManager:
153
111
  image_data, aspect_ratio, grid_pinpoints, self.processor
154
112
  )
155
113
 
156
- async def generate_request(self, obj: GenerateReqInput):
114
+ async def generate_request(self, obj: GenerateReqInput, request=None):
157
115
  if self.to_create_loop:
158
- await self.create_handle_loop()
116
+ self.create_handle_loop()
159
117
 
118
+ obj.post_init()
160
119
  is_single = obj.is_single
161
120
  if is_single:
162
121
  rid = obj.rid
@@ -169,7 +128,7 @@ class TokenizerManager:
169
128
  if len(input_ids) >= self.context_len:
170
129
  raise ValueError(
171
130
  f"The input ({len(input_ids)} tokens) is longer than the "
172
- f"model's context length ({self.context_len} tokens)"
131
+ f"model's context length ({self.context_len} tokens)."
173
132
  )
174
133
 
175
134
  sampling_params = SamplingParams(**obj.sampling_params)
@@ -207,23 +166,38 @@ class TokenizerManager:
207
166
  self.rid_to_state[rid] = state
208
167
 
209
168
  while True:
210
- await event.wait()
211
- out = self.convert_logprob_style(state.out_list[-1],
212
- obj.return_logprob,
213
- obj.top_logprobs_num,
214
- obj.return_text_in_logprobs)
169
+ try:
170
+ await asyncio.wait_for(event.wait(), timeout=4)
171
+ except asyncio.TimeoutError:
172
+ if request is not None and await request.is_disconnected():
173
+ self.abort_request(rid)
174
+ raise ValueError(f"Abort request {rid}")
175
+ continue
176
+
177
+ out = self.convert_logprob_style(
178
+ state.out_list[-1],
179
+ obj.return_logprob,
180
+ obj.top_logprobs_num,
181
+ obj.return_text_in_logprobs,
182
+ )
215
183
 
216
184
  if self.server_args.log_requests and state.finished:
217
185
  logger.info(f"in={obj.text}, out={out}")
218
186
 
219
- yield out
220
187
  state.out_list = []
221
188
  if state.finished:
222
189
  del self.rid_to_state[rid]
190
+
191
+ yield out
192
+
223
193
  break
194
+
224
195
  event.clear()
196
+
197
+ yield out
225
198
  else:
226
- assert obj.stream is False
199
+ if obj.stream:
200
+ raise ValueError("Do not support stream for batch mode.")
227
201
 
228
202
  if obj.input_ids is None:
229
203
  bs = len(obj.text)
@@ -273,45 +247,84 @@ class TokenizerManager:
273
247
  for i in range(bs):
274
248
  rid = obj.rid[i]
275
249
  state = self.rid_to_state[rid]
276
- await state.event.wait()
250
+
251
+ while True:
252
+ try:
253
+ await asyncio.wait_for(state.event.wait(), timeout=4)
254
+ break
255
+ except asyncio.TimeoutError:
256
+ if request is not None and await request.is_disconnected():
257
+ for rid in obj.rid:
258
+ self.abort_request(rid)
259
+ raise ValueError(f"Abort request {rid}")
260
+ continue
261
+
277
262
  output_list.append(
278
- self.convert_logprob_style(state.out_list[-1],
279
- obj.return_logprob[i],
280
- obj.top_logprobs_num[i],
281
- obj.return_text_in_logprobs))
263
+ self.convert_logprob_style(
264
+ state.out_list[-1],
265
+ obj.return_logprob[i],
266
+ obj.top_logprobs_num[i],
267
+ obj.return_text_in_logprobs,
268
+ )
269
+ )
282
270
  assert state.finished
283
271
  del self.rid_to_state[rid]
284
272
 
285
273
  yield output_list
286
274
 
287
- async def flush_cache(self):
288
- flush_cache_req = FlushCacheReq()
289
- self.send_to_router.send_pyobj(flush_cache_req)
275
+ def flush_cache(self):
276
+ req = FlushCacheReq()
277
+ self.send_to_router.send_pyobj(req)
278
+
279
+ def abort_request(self, rid):
280
+ if rid not in self.rid_to_state:
281
+ return
282
+ del self.rid_to_state[rid]
283
+ req = AbortReq(rid)
284
+ self.send_to_router.send_pyobj(req)
285
+
286
+ def create_abort_task(self, obj):
287
+ # Abort the request if the client is disconnected.
288
+ async def abort_request():
289
+ await asyncio.sleep(3)
290
+ if obj.is_single:
291
+ self.abort_request(obj.rid)
292
+ else:
293
+ for rid in obj.rids:
294
+ self.abort_request(rid)
295
+
296
+ background_tasks = BackgroundTasks()
297
+ background_tasks.add_task(abort_request)
298
+ return background_tasks
290
299
 
291
- async def create_handle_loop(self):
300
+ def create_handle_loop(self):
292
301
  self.to_create_loop = False
293
302
  loop = asyncio.get_event_loop()
294
303
  loop.create_task(self.handle_loop())
295
304
 
296
305
  async def handle_loop(self):
297
306
  while True:
298
- recv_obj = await self.recv_from_detokenizer.recv_pyobj()
299
-
300
- if isinstance(recv_obj, BatchStrOut):
301
- for i, rid in enumerate(recv_obj.rids):
302
- recv_obj.meta_info[i]["id"] = rid
303
- out_dict = {
304
- "text": recv_obj.output_str[i],
305
- "meta_info": recv_obj.meta_info[i],
306
- }
307
- state = self.rid_to_state[rid]
308
- state.out_list.append(out_dict)
309
- state.finished = recv_obj.finished[i]
310
- state.event.set()
311
- else:
312
- raise ValueError(f"Invalid object: {recv_obj}")
313
-
314
- def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
307
+ recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj()
308
+ assert isinstance(recv_obj, BatchStrOut)
309
+
310
+ for i, rid in enumerate(recv_obj.rids):
311
+ state = self.rid_to_state.get(rid, None)
312
+ if state is None:
313
+ continue
314
+
315
+ recv_obj.meta_info[i]["id"] = rid
316
+ out_dict = {
317
+ "text": recv_obj.output_str[i],
318
+ "meta_info": recv_obj.meta_info[i],
319
+ }
320
+ state.out_list.append(out_dict)
321
+ state.finished = recv_obj.finished_reason[i] is not None
322
+ state.event.set()
323
+
324
+
325
+ def convert_logprob_style(
326
+ self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
327
+ ):
315
328
  if return_logprob:
316
329
  ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
317
330
  ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
@@ -320,11 +333,15 @@ class TokenizerManager:
320
333
  ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
321
334
  )
322
335
  if top_logprobs_num > 0:
323
- ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
324
- ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
336
+ ret["meta_info"]["prefill_top_logprobs"] = (
337
+ self.detokenize_top_logprobs_tokens(
338
+ ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
339
+ )
325
340
  )
326
- ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
327
- ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
341
+ ret["meta_info"]["decode_top_logprobs"] = (
342
+ self.detokenize_top_logprobs_tokens(
343
+ ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
344
+ )
328
345
  )
329
346
  return ret
330
347
 
@@ -344,3 +361,49 @@ class TokenizerManager:
344
361
  if t:
345
362
  top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
346
363
  return top_logprobs
364
+
365
+
366
+ global global_processor
367
+
368
+
369
+ def init_global_processor(server_args: ServerArgs):
370
+ global global_processor
371
+ transformers.logging.set_verbosity_error()
372
+ global_processor = get_processor(
373
+ server_args.tokenizer_path,
374
+ tokenizer_mode=server_args.tokenizer_mode,
375
+ trust_remote_code=server_args.trust_remote_code,
376
+ )
377
+
378
+
379
+ def get_pixel_values(
380
+ image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
381
+ ):
382
+ try:
383
+ processor = processor or global_processor
384
+ image, image_size = load_image(image_data)
385
+ if image_size != None:
386
+ image_hash = hash(image_data)
387
+ pixel_values = processor.image_processor(image)["pixel_values"]
388
+ for _ in range(len(pixel_values)):
389
+ pixel_values[_] = pixel_values[_].astype(np.float16)
390
+ pixel_values = np.stack(pixel_values, axis=0)
391
+ return pixel_values, image_hash, image_size
392
+ else:
393
+ image_hash = hash(image_data)
394
+ if image_aspect_ratio == "pad":
395
+ image = expand2square(
396
+ image,
397
+ tuple(int(x * 255) for x in processor.image_processor.image_mean),
398
+ )
399
+ pixel_values = processor.image_processor(image)["pixel_values"][0]
400
+ elif image_aspect_ratio == "anyres":
401
+ pixel_values = process_anyres_image(
402
+ image, processor.image_processor, image_grid_pinpoints
403
+ )
404
+ else:
405
+ pixel_values = processor.image_processor(image)["pixel_values"][0]
406
+ pixel_values = pixel_values.astype(np.float16)
407
+ return pixel_values, image_hash, image.size
408
+ except Exception:
409
+ print("Exception in TokenizerManager:\n" + get_exception_traceback())