sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +48 -33
  4. sglang/bench_server_latency.py +0 -6
  5. sglang/bench_serving.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +14 -1
  7. sglang/lang/interpreter.py +16 -6
  8. sglang/lang/ir.py +20 -4
  9. sglang/srt/configs/model_config.py +11 -9
  10. sglang/srt/constrained/fsm_cache.py +9 -1
  11. sglang/srt/constrained/jump_forward.py +15 -2
  12. sglang/srt/hf_transformers_utils.py +1 -0
  13. sglang/srt/layers/activation.py +4 -4
  14. sglang/srt/layers/attention/__init__.py +49 -0
  15. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  16. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  17. sglang/srt/layers/attention/triton_backend.py +161 -0
  18. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  19. sglang/srt/layers/fused_moe/patch.py +117 -0
  20. sglang/srt/layers/layernorm.py +4 -4
  21. sglang/srt/layers/logits_processor.py +19 -15
  22. sglang/srt/layers/pooler.py +3 -3
  23. sglang/srt/layers/quantization/__init__.py +0 -2
  24. sglang/srt/layers/radix_attention.py +6 -4
  25. sglang/srt/layers/sampler.py +6 -4
  26. sglang/srt/layers/torchao_utils.py +18 -0
  27. sglang/srt/lora/lora.py +20 -21
  28. sglang/srt/lora/lora_manager.py +97 -25
  29. sglang/srt/managers/detokenizer_manager.py +31 -18
  30. sglang/srt/managers/image_processor.py +187 -0
  31. sglang/srt/managers/io_struct.py +99 -75
  32. sglang/srt/managers/schedule_batch.py +187 -68
  33. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  34. sglang/srt/managers/scheduler.py +1021 -0
  35. sglang/srt/managers/tokenizer_manager.py +120 -247
  36. sglang/srt/managers/tp_worker.py +28 -925
  37. sglang/srt/mem_cache/memory_pool.py +34 -52
  38. sglang/srt/mem_cache/radix_cache.py +5 -5
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -25
  40. sglang/srt/model_executor/forward_batch_info.py +94 -97
  41. sglang/srt/model_executor/model_runner.py +76 -78
  42. sglang/srt/models/baichuan.py +10 -10
  43. sglang/srt/models/chatglm.py +12 -12
  44. sglang/srt/models/commandr.py +10 -10
  45. sglang/srt/models/dbrx.py +12 -12
  46. sglang/srt/models/deepseek.py +10 -10
  47. sglang/srt/models/deepseek_v2.py +14 -15
  48. sglang/srt/models/exaone.py +10 -10
  49. sglang/srt/models/gemma.py +10 -10
  50. sglang/srt/models/gemma2.py +11 -11
  51. sglang/srt/models/gpt_bigcode.py +10 -10
  52. sglang/srt/models/grok.py +10 -10
  53. sglang/srt/models/internlm2.py +10 -10
  54. sglang/srt/models/llama.py +22 -10
  55. sglang/srt/models/llama_classification.py +5 -5
  56. sglang/srt/models/llama_embedding.py +4 -4
  57. sglang/srt/models/llama_reward.py +142 -0
  58. sglang/srt/models/llava.py +39 -33
  59. sglang/srt/models/llavavid.py +31 -28
  60. sglang/srt/models/minicpm.py +10 -10
  61. sglang/srt/models/minicpm3.py +14 -15
  62. sglang/srt/models/mixtral.py +10 -10
  63. sglang/srt/models/mixtral_quant.py +10 -10
  64. sglang/srt/models/olmoe.py +10 -10
  65. sglang/srt/models/qwen.py +10 -10
  66. sglang/srt/models/qwen2.py +11 -11
  67. sglang/srt/models/qwen2_moe.py +10 -10
  68. sglang/srt/models/stablelm.py +10 -10
  69. sglang/srt/models/torch_native_llama.py +506 -0
  70. sglang/srt/models/xverse.py +10 -10
  71. sglang/srt/models/xverse_moe.py +10 -10
  72. sglang/srt/openai_api/adapter.py +7 -0
  73. sglang/srt/sampling/sampling_batch_info.py +36 -27
  74. sglang/srt/sampling/sampling_params.py +3 -1
  75. sglang/srt/server.py +170 -119
  76. sglang/srt/server_args.py +54 -27
  77. sglang/srt/utils.py +101 -128
  78. sglang/test/runners.py +76 -33
  79. sglang/test/test_programs.py +38 -5
  80. sglang/test/test_utils.py +53 -9
  81. sglang/version.py +1 -1
  82. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
  83. sglang-0.3.3.dist-info/RECORD +139 -0
  84. sglang/srt/layers/attention_backend.py +0 -482
  85. sglang/srt/managers/controller_multi.py +0 -207
  86. sglang/srt/managers/controller_single.py +0 -164
  87. sglang-0.3.1.post3.dist-info/RECORD +0 -134
  88. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  89. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  90. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  92. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -17,21 +17,56 @@ limitations under the License.
17
17
  # and "Punica: Multi-Tenant LoRA Serving"
18
18
 
19
19
 
20
+ import logging
20
21
  import re
21
- from dataclasses import dataclass
22
22
 
23
23
  import torch
24
24
 
25
25
  from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
26
26
  from sglang.srt.lora.lora_config import LoRAConfig
27
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
28
- from sglang.srt.utils import is_hip, replace_submodule
27
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
+ from sglang.srt.utils import is_flashinfer_available, replace_submodule
29
29
 
30
- # ROCm: flashinfer available later
31
- if not is_hip():
30
+ logger = logging.getLogger(__name__)
31
+
32
+ if is_flashinfer_available():
32
33
  from flashinfer import SegmentGEMMWrapper
33
34
 
34
35
 
36
+ def get_module_name(name):
37
+ # Fallback solution of mapping from config module name to module name in model class.
38
+ # Please check if it aligns with your base model.
39
+ # Please implement the function in the model class if it is not.
40
+ # You can reference this function in llama.py.
41
+ params_mapping = {
42
+ "q_proj": "qkv_proj",
43
+ "k_proj": "qkv_proj",
44
+ "v_proj": "qkv_proj",
45
+ "gate_proj": "gate_up_proj",
46
+ "up_proj": "gate_up_proj",
47
+ }
48
+ return params_mapping.get(name, name)
49
+
50
+
51
+ def get_hidden_dim(module_name, config):
52
+ # Fallback solution of get_hidden_dim for different modules
53
+ # Please check if it aligns with your base model.
54
+ # Please implement the function in the model class if it is not.
55
+ # You can reference this function in llama.py.
56
+ if module_name in ["q_proj", "o_proj", "qkv_proj"]:
57
+ return config.hidden_size, config.hidden_size
58
+ elif module_name in ["kv_proj"]:
59
+ return config.hidden_size, config.hidden_size // (
60
+ config.num_attention_heads // config.num_key_value_heads
61
+ )
62
+ elif module_name == "gate_up_proj":
63
+ return config.hidden_size, config.intermediate_size
64
+ elif module_name == "down_proj":
65
+ return config.intermediate_size, config.hidden_size
66
+ else:
67
+ raise NotImplementedError()
68
+
69
+
35
70
  def get_stacked_name(name):
36
71
  # origin name -> (name for A, name for B)
37
72
  params_mapping = {
@@ -104,12 +139,20 @@ class LoRAManager:
104
139
  self.origin_target_modules = set(self.origin_target_modules) | set(
105
140
  self.configs[name].target_modules
106
141
  )
107
- self.target_modules = set(
108
- [
142
+ if hasattr(self.base_model, "get_module_name"):
143
+ self.target_modules = {
109
144
  self.base_model.get_module_name(module)
110
145
  for module in self.origin_target_modules
111
- ]
112
- )
146
+ }
147
+ else:
148
+ logger.warning(
149
+ f"WARNING: get_module_name() is not defined, "
150
+ f"which is used to map config module name to model implementation module name."
151
+ f"Use the default one, but please check if it is correct for your model."
152
+ )
153
+ self.target_modules = {
154
+ get_module_name(module) for module in self.origin_target_modules
155
+ }
113
156
  self.target_weights = set(
114
157
  [get_stacked_name(module) for module in self.origin_target_modules]
115
158
  )
@@ -147,7 +190,15 @@ class LoRAManager:
147
190
  num_layer = self.base_hf_config.num_hidden_layers
148
191
  for module_A, module_B in self.target_weights:
149
192
  # init A tensor, column_major=True
150
- hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
193
+ if hasattr(self.base_model, "get_hidden_dim"):
194
+ hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
195
+ else:
196
+ logger.warning(
197
+ f"WARNING: get_hidden_dim() is not defined, "
198
+ f"which is used to get the hidden dim for different lora modules"
199
+ f"Use the default one, but please check if it is correct for your model."
200
+ )
201
+ hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
151
202
  c = self.loras[-1].get_stacked_multiply(module_A)
152
203
  if module_A not in self.A_buffer:
153
204
  self.A_buffer[module_A] = [
@@ -163,7 +214,15 @@ class LoRAManager:
163
214
  for i in range(num_layer)
164
215
  ]
165
216
  # init B tensor, column_major=True
166
- _, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
217
+ if hasattr(self.base_model, "get_hidden_dim"):
218
+ _, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
219
+ else:
220
+ logger.warning(
221
+ f"WARNING: get_hidden_dim() is not defined, "
222
+ f"which is used to get the hidden dim for different lora modules"
223
+ f"Use the default one, but please check if it is correct for your model."
224
+ )
225
+ _, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
167
226
  c = self.loras[-1].get_stacked_multiply(module_B)
168
227
  if module_B not in self.B_buffer:
169
228
  self.B_buffer[module_B] = [
@@ -208,33 +267,46 @@ class LoRAManager:
208
267
  if lora_weight_name:
209
268
  self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
210
269
 
211
- def prepare_lora_batch(self, batch, extend_seq_lens=None):
270
+ def prepare_lora_batch(self, forward_batch: ForwardBatch):
212
271
  # load active loras into lora memory pool
213
- cur_uids = set([req.lora_path for req in batch.reqs])
272
+ cur_uids = set(forward_batch.lora_paths)
214
273
  assert len(cur_uids) <= self.max_loras_per_batch
215
274
  i = 0
275
+ j = len(self.active_uids)
216
276
  evictable_uids = list(self.active_uids)
217
277
  for uid in cur_uids:
218
278
  if uid not in self.active_uids:
219
- while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
220
- i += 1
221
- if i < len(evictable_uids):
279
+ if j < self.max_loras_per_batch:
280
+ index = j
281
+ j += 1
282
+ else:
283
+ while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
284
+ i += 1
285
+ assert i < len(evictable_uids)
222
286
  self.active_uids.remove(evictable_uids[i])
223
287
  self.buffer_id.pop(evictable_uids[i])
224
- self.load_lora(uid, i)
288
+ index = i
289
+ i += 1
290
+ self.load_lora(uid, index)
225
291
  self.active_uids.add(uid)
226
- self.buffer_id[uid] = i
227
- i += 1
292
+ self.buffer_id[uid] = index
228
293
 
229
294
  if cur_uids == set([None]):
230
295
  return
231
296
 
232
297
  # setup lora in forward modules
233
- bs = len(batch.reqs)
234
- seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs)
298
+ bs = forward_batch.batch_size
299
+ seg_lens = (
300
+ forward_batch.extend_seq_lens
301
+ if forward_batch.forward_mode.is_extend()
302
+ else torch.ones(bs, device="cuda")
303
+ )
304
+ # FIXME: reuse the data rather than recompute
305
+ seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
306
+ seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
235
307
  weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
236
- for i, req in enumerate(batch.reqs):
237
- weight_indices[i] = self.buffer_id[req.lora_path]
308
+ for i, lora_path in enumerate(forward_batch.lora_paths):
309
+ weight_indices[i] = self.buffer_id[lora_path]
238
310
 
239
311
  for module_name, module in self.lora_modules:
240
312
  layer_id = get_layer_id(module_name)
@@ -245,7 +317,7 @@ class LoRAManager:
245
317
  self.A_buffer[weight_name][layer_id],
246
318
  self.B_buffer[weight_name][layer_id],
247
319
  bs,
248
- seg_lens,
320
+ seg_indptr,
249
321
  weight_indices,
250
322
  )
251
323
  else:
@@ -254,6 +326,6 @@ class LoRAManager:
254
326
  self.B_buffer["q_proj"][layer_id],
255
327
  self.B_buffer["kv_proj"][layer_id],
256
328
  bs,
257
- seg_lens,
329
+ seg_indptr,
258
330
  weight_indices,
259
331
  )
@@ -15,13 +15,12 @@ limitations under the License.
15
15
 
16
16
  """DetokenizerManager is a process that detokenizes the token ids."""
17
17
 
18
- import asyncio
19
18
  import dataclasses
19
+ import logging
20
+ from collections import OrderedDict
20
21
  from typing import List
21
22
 
22
- import uvloop
23
23
  import zmq
24
- import zmq.asyncio
25
24
 
26
25
  from sglang.srt.hf_transformers_utils import get_tokenizer
27
26
  from sglang.srt.managers.io_struct import (
@@ -32,9 +31,10 @@ from sglang.srt.managers.io_struct import (
32
31
  )
33
32
  from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
34
33
  from sglang.srt.server_args import PortArgs, ServerArgs
34
+ from sglang.srt.utils import configure_logger, kill_parent_process
35
35
  from sglang.utils import find_printable_text, get_exception_traceback
36
36
 
37
- asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
37
+ logger = logging.getLogger(__name__)
38
38
 
39
39
 
40
40
  @dataclasses.dataclass
@@ -57,12 +57,12 @@ class DetokenizerManager:
57
57
  port_args: PortArgs,
58
58
  ):
59
59
  # Init inter-process communication
60
- context = zmq.asyncio.Context(2)
61
- self.recv_from_router = context.socket(zmq.PULL)
62
- self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
60
+ context = zmq.Context(2)
61
+ self.recv_from_scheduler = context.socket(zmq.PULL)
62
+ self.recv_from_scheduler.bind(f"ipc://{port_args.detokenizer_ipc_name}")
63
63
 
64
64
  self.send_to_tokenizer = context.socket(zmq.PUSH)
65
- self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
65
+ self.send_to_tokenizer.connect(f"ipc://{port_args.tokenizer_ipc_name}")
66
66
 
67
67
  if server_args.skip_tokenizer_init:
68
68
  self.tokenizer = None
@@ -73,13 +73,13 @@ class DetokenizerManager:
73
73
  trust_remote_code=server_args.trust_remote_code,
74
74
  )
75
75
 
76
- self.decode_status = {}
76
+ self.decode_status = LimitedCapacityDict()
77
77
 
78
- async def handle_loop(self):
78
+ def event_loop(self):
79
79
  """The event loop that handles requests"""
80
80
 
81
81
  while True:
82
- recv_obj = await self.recv_from_router.recv_pyobj()
82
+ recv_obj = self.recv_from_scheduler.recv_pyobj()
83
83
 
84
84
  if isinstance(recv_obj, BatchEmbeddingOut):
85
85
  # If it is embedding model, no detokenization is needed.
@@ -170,16 +170,29 @@ class DetokenizerManager:
170
170
  )
171
171
 
172
172
 
173
- def start_detokenizer_process(
173
+ class LimitedCapacityDict(OrderedDict):
174
+ def __init__(self, capacity=1 << 15, *args, **kwargs):
175
+ super().__init__(*args, **kwargs)
176
+ self.capacity = capacity
177
+
178
+ def __setitem__(self, key, value):
179
+ if len(self) >= self.capacity:
180
+ # Remove the oldest element (first item in the dict)
181
+ self.popitem(last=False)
182
+ # Set the new item
183
+ super().__setitem__(key, value)
184
+
185
+
186
+ def run_detokenizer_process(
174
187
  server_args: ServerArgs,
175
188
  port_args: PortArgs,
176
- pipe_writer,
177
189
  ):
190
+ configure_logger(server_args)
191
+
178
192
  try:
179
193
  manager = DetokenizerManager(server_args, port_args)
194
+ manager.event_loop()
180
195
  except Exception:
181
- pipe_writer.send(get_exception_traceback())
182
- raise
183
- pipe_writer.send("init ok")
184
- loop = asyncio.get_event_loop()
185
- loop.run_until_complete(manager.handle_loop())
196
+ msg = get_exception_traceback()
197
+ logger.error(msg)
198
+ kill_parent_process()
@@ -0,0 +1,187 @@
1
+ # TODO: also move pad_input_ids into this module
2
+ import asyncio
3
+ import concurrent.futures
4
+ import logging
5
+ import multiprocessing as mp
6
+ import os
7
+ from abc import ABC, abstractmethod
8
+ from typing import List, Optional, Union
9
+
10
+ import numpy as np
11
+ import transformers
12
+
13
+ from sglang.srt.hf_transformers_utils import get_processor
14
+ from sglang.srt.mm_utils import expand2square, process_anyres_image
15
+ from sglang.srt.server_args import ServerArgs
16
+ from sglang.srt.utils import load_image
17
+ from sglang.utils import get_exception_traceback
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ global global_processor
22
+
23
+
24
+ def init_global_processor(server_args: ServerArgs):
25
+ """Init the global processor for multi modal models."""
26
+ global global_processor
27
+ transformers.logging.set_verbosity_error()
28
+ global_processor = get_processor(
29
+ server_args.tokenizer_path,
30
+ tokenizer_mode=server_args.tokenizer_mode,
31
+ trust_remote_code=server_args.trust_remote_code,
32
+ )
33
+
34
+
35
+ class BaseImageProcessor(ABC):
36
+ @abstractmethod
37
+ async def process_images_async(self, image_data, **kwargs):
38
+ pass
39
+
40
+
41
+ class DummyImageProcessor(BaseImageProcessor):
42
+ async def process_images_async(self, *args, **kwargs):
43
+ return None
44
+
45
+
46
+ class LlavaImageProcessor(BaseImageProcessor):
47
+ def __init__(self, hf_config, server_args, _image_processor):
48
+ self.hf_config = hf_config
49
+ self._image_processor = _image_processor
50
+ self.executor = concurrent.futures.ProcessPoolExecutor(
51
+ initializer=init_global_processor,
52
+ mp_context=mp.get_context("fork"),
53
+ initargs=(server_args,),
54
+ max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
55
+ )
56
+
57
+ @staticmethod
58
+ def _process_single_image_task(
59
+ image_data: Union[str, bytes],
60
+ image_aspect_ratio: Optional[str] = None,
61
+ image_grid_pinpoints: Optional[str] = None,
62
+ image_processor=None,
63
+ ):
64
+ image_processor = image_processor or global_processor.image_processor
65
+
66
+ try:
67
+ image, image_size = load_image(image_data)
68
+ if image_size is not None:
69
+ # It is a video with multiple images
70
+ image_hash = hash(image_data)
71
+ pixel_values = image_processor(image)["pixel_values"]
72
+ for _ in range(len(pixel_values)):
73
+ pixel_values[_] = pixel_values[_].astype(np.float16)
74
+ pixel_values = np.stack(pixel_values, axis=0)
75
+ return pixel_values, image_hash, image_size
76
+ else:
77
+ # It is an image
78
+ image_hash = hash(image_data)
79
+ if image_aspect_ratio == "pad":
80
+ image = expand2square(
81
+ image,
82
+ tuple(int(x * 255) for x in image_processor.image_mean),
83
+ )
84
+ pixel_values = image_processor(image.convert("RGB"))[
85
+ "pixel_values"
86
+ ][0]
87
+ elif image_aspect_ratio == "anyres" or (
88
+ image_aspect_ratio is not None
89
+ and "anyres_max" in image_aspect_ratio
90
+ ):
91
+ pixel_values = process_anyres_image(
92
+ image, image_processor, image_grid_pinpoints
93
+ )
94
+ else:
95
+ pixel_values = image_processor(image)["pixel_values"][0]
96
+
97
+ if isinstance(pixel_values, np.ndarray):
98
+ pixel_values = pixel_values.astype(np.float16)
99
+
100
+ return pixel_values, image_hash, image.size
101
+ except Exception:
102
+ logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
103
+
104
+ async def _process_single_image(
105
+ self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
106
+ ):
107
+ if self.executor is not None:
108
+ loop = asyncio.get_event_loop()
109
+ return await loop.run_in_executor(
110
+ self.executor,
111
+ LlavaImageProcessor._process_single_image_task,
112
+ image_data,
113
+ aspect_ratio,
114
+ grid_pinpoints,
115
+ )
116
+ else:
117
+ return self._process_single_image_task(
118
+ image_data, aspect_ratio, grid_pinpoints
119
+ )
120
+
121
+ async def process_images_async(
122
+ self, image_data: List[Union[str, bytes]], request_obj
123
+ ):
124
+ if not image_data:
125
+ return None
126
+
127
+ aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
128
+ grid_pinpoints = (
129
+ self.hf_config.image_grid_pinpoints
130
+ if hasattr(self.hf_config, "image_grid_pinpoints")
131
+ and "anyres" in aspect_ratio
132
+ else None
133
+ )
134
+
135
+ if isinstance(image_data, list) and len(image_data) > 0:
136
+ # Multiple images
137
+ if len(image_data) > 1:
138
+ aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
139
+ pixel_values, image_hashes, image_sizes = [], [], []
140
+ res = []
141
+ for img_data in image_data:
142
+ res.append(
143
+ self._process_single_image(
144
+ img_data, aspect_ratio, grid_pinpoints
145
+ )
146
+ )
147
+ res = await asyncio.gather(*res)
148
+ for pixel_v, image_h, image_s in res:
149
+ pixel_values.append(pixel_v)
150
+ image_hashes.append(image_h)
151
+ image_sizes.append(image_s)
152
+
153
+ if isinstance(pixel_values[0], np.ndarray):
154
+ pixel_values = np.stack(pixel_values, axis=0)
155
+ else:
156
+ # A single image
157
+ pixel_values, image_hash, image_size = await self._process_single_image(
158
+ image_data[0], aspect_ratio, grid_pinpoints
159
+ )
160
+ image_hashes = [image_hash]
161
+ image_sizes = [image_size]
162
+ elif isinstance(image_data, str):
163
+ # A single image
164
+ pixel_values, image_hash, image_size = await self._process_single_image(
165
+ image_data, aspect_ratio, grid_pinpoints
166
+ )
167
+ image_hashes = [image_hash]
168
+ image_sizes = [image_size]
169
+ else:
170
+ raise ValueError(f"Invalid image data: {image_data}")
171
+
172
+ return {
173
+ "pixel_values": pixel_values,
174
+ "image_hashes": image_hashes,
175
+ "image_sizes": image_sizes,
176
+ "modalities": request_obj.modalities,
177
+ }
178
+
179
+
180
+ def get_image_processor(
181
+ hf_config, server_args: ServerArgs, _image_processor
182
+ ) -> BaseImageProcessor:
183
+ return LlavaImageProcessor(hf_config, server_args, _image_processor)
184
+
185
+
186
+ def get_dummy_image_processor():
187
+ return DummyImageProcessor()