sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +36 -2
- sglang/srt/conversation.py +56 -3
- sglang/srt/disaggregation/ascend/__init__.py +6 -0
- sglang/srt/disaggregation/ascend/conn.py +44 -0
- sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
- sglang/srt/disaggregation/mooncake/conn.py +50 -18
- sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
- sglang/srt/disaggregation/utils.py +25 -3
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +1 -0
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +11 -0
- sglang/srt/entrypoints/openai/serving_chat.py +7 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/kimik2_detector.py +220 -0
- sglang/srt/hf_transformers_utils.py +18 -0
- sglang/srt/jinja_template_utils.py +8 -0
- sglang/srt/layers/communicator.py +20 -5
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/linear.py +12 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
- sglang/srt/layers/moe/ep_moe/layer.py +141 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/topk.py +8 -2
- sglang/srt/layers/parameter.py +19 -3
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -2
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +738 -14
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +35 -3
- sglang/srt/managers/mm_utils.py +59 -96
- sglang/srt/managers/schedule_batch.py +17 -6
- sglang/srt/managers/scheduler.py +38 -6
- sglang/srt/managers/tokenizer_manager.py +16 -0
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +176 -101
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -1
- sglang/srt/model_loader/loader.py +23 -12
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +78 -19
- sglang/srt/models/deepseek_vl2.py +1 -1
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +6 -3
- sglang/srt/models/internvl.py +8 -2
- sglang/srt/models/kimi_vl.py +8 -2
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llava.py +3 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpmo.py +1 -2
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral_quant.py +4 -0
- sglang/srt/models/mllama4.py +372 -82
- sglang/srt/models/phi4mm.py +8 -2
- sglang/srt/models/phimoe.py +553 -0
- sglang/srt/models/qwen2.py +2 -0
- sglang/srt/models/qwen2_5_vl.py +10 -7
- sglang/srt/models/qwen2_vl.py +12 -1
- sglang/srt/models/vila.py +8 -2
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/base_processor.py +197 -137
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
- sglang/srt/multimodal/processors/gemma3.py +4 -2
- sglang/srt/multimodal/processors/gemma3n.py +1 -1
- sglang/srt/multimodal/processors/internvl.py +1 -1
- sglang/srt/multimodal/processors/janus_pro.py +1 -1
- sglang/srt/multimodal/processors/kimi_vl.py +1 -1
- sglang/srt/multimodal/processors/minicpm.py +4 -3
- sglang/srt/multimodal/processors/mllama4.py +63 -61
- sglang/srt/multimodal/processors/phi4mm.py +1 -1
- sglang/srt/multimodal/processors/pixtral.py +1 -1
- sglang/srt/multimodal/processors/qwen_vl.py +203 -80
- sglang/srt/multimodal/processors/vila.py +1 -1
- sglang/srt/server_args.py +26 -4
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +191 -48
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
-
import concurrent.futures
|
17
16
|
import logging
|
18
17
|
import math
|
19
18
|
import threading
|
@@ -169,12 +168,23 @@ class HiCacheController:
|
|
169
168
|
page_size: int,
|
170
169
|
load_cache_event: threading.Event = None,
|
171
170
|
write_policy: str = "write_through_selective",
|
171
|
+
io_backend: str = "",
|
172
172
|
):
|
173
173
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
174
174
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
175
175
|
self.mem_pool_host = mem_pool_host
|
176
176
|
self.write_policy = write_policy
|
177
177
|
self.page_size = page_size
|
178
|
+
# using kernel for small page KV cache transfer and DMA for large pages
|
179
|
+
if not io_backend:
|
180
|
+
IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
|
181
|
+
self.io_backend = (
|
182
|
+
"direct"
|
183
|
+
if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
|
184
|
+
else "kernel"
|
185
|
+
)
|
186
|
+
else:
|
187
|
+
self.io_backend = io_backend
|
178
188
|
|
179
189
|
self.load_cache_event = load_cache_event
|
180
190
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
@@ -203,12 +213,7 @@ class HiCacheController:
|
|
203
213
|
self.load_stream = torch.cuda.Stream()
|
204
214
|
|
205
215
|
self.write_thread = threading.Thread(
|
206
|
-
target=
|
207
|
-
self.write_thread_func_buffer
|
208
|
-
if self.page_size == 1
|
209
|
-
else self.write_thread_func_direct
|
210
|
-
),
|
211
|
-
daemon=True,
|
216
|
+
target=self.write_thread_func_direct, daemon=True
|
212
217
|
)
|
213
218
|
self.load_thread = threading.Thread(
|
214
219
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
@@ -229,12 +234,7 @@ class HiCacheController:
|
|
229
234
|
self.ack_load_queue.queue.clear()
|
230
235
|
|
231
236
|
self.write_thread = threading.Thread(
|
232
|
-
target=
|
233
|
-
self.write_thread_func_buffer
|
234
|
-
if self.page_size == 1
|
235
|
-
else self.write_thread_func_direct
|
236
|
-
),
|
237
|
-
daemon=True,
|
237
|
+
target=self.write_thread_func_direct, daemon=True
|
238
238
|
)
|
239
239
|
self.load_thread = threading.Thread(
|
240
240
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
@@ -281,6 +281,15 @@ class HiCacheController:
|
|
281
281
|
)
|
282
282
|
return device_indices
|
283
283
|
|
284
|
+
def move_indices(self, host_indices, device_indices):
|
285
|
+
# move indices to GPU if using kernels, to host if using direct indexing
|
286
|
+
if self.io_backend == "kernel":
|
287
|
+
return host_indices.to(self.mem_pool_device.device), device_indices
|
288
|
+
elif self.io_backend == "direct":
|
289
|
+
return host_indices, device_indices.cpu()
|
290
|
+
else:
|
291
|
+
raise ValueError(f"Unsupported io backend")
|
292
|
+
|
284
293
|
def write_thread_func_direct(self):
|
285
294
|
"""
|
286
295
|
Directly write through KV caches to host memory without buffering.
|
@@ -289,10 +298,14 @@ class HiCacheController:
|
|
289
298
|
while not self.stop_event.is_set():
|
290
299
|
try:
|
291
300
|
operation = self.write_queue.get(block=True, timeout=1)
|
292
|
-
self.
|
293
|
-
operation.host_indices,
|
294
|
-
|
295
|
-
|
301
|
+
host_indices, device_indices = self.move_indices(
|
302
|
+
operation.host_indices, operation.device_indices
|
303
|
+
)
|
304
|
+
self.mem_pool_device.backup_to_host_all_layer(
|
305
|
+
self.mem_pool_host,
|
306
|
+
host_indices,
|
307
|
+
device_indices,
|
308
|
+
self.io_backend,
|
296
309
|
)
|
297
310
|
self.write_stream.synchronize()
|
298
311
|
self.mem_pool_host.complete_io(operation.host_indices)
|
@@ -304,27 +317,6 @@ class HiCacheController:
|
|
304
317
|
except Exception as e:
|
305
318
|
logger.error(e)
|
306
319
|
|
307
|
-
def load_thread_func_direct(self):
|
308
|
-
"""
|
309
|
-
Directly load KV caches from host memory to device memory without buffering.
|
310
|
-
"""
|
311
|
-
torch.cuda.set_stream(self.load_stream)
|
312
|
-
while not self.stop_event.is_set():
|
313
|
-
try:
|
314
|
-
operation = self.load_queue.get(block=True, timeout=1)
|
315
|
-
operation.data = self.mem_pool_host.get_flat_data(
|
316
|
-
operation.host_indices
|
317
|
-
)
|
318
|
-
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
319
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
320
|
-
for node_id in operation.node_ids:
|
321
|
-
if node_id != 0:
|
322
|
-
self.ack_load_queue.put(node_id)
|
323
|
-
except Empty:
|
324
|
-
continue
|
325
|
-
except Exception as e:
|
326
|
-
logger.error(e)
|
327
|
-
|
328
320
|
def load_thread_func_layer_by_layer(self):
|
329
321
|
"""
|
330
322
|
Load KV caches from host memory to device memory layer by layer.
|
@@ -349,22 +341,18 @@ class HiCacheController:
|
|
349
341
|
|
350
342
|
# start layer-wise KV cache transfer from CPU to GPU
|
351
343
|
self.layer_done_counter.reset()
|
344
|
+
host_indices, device_indices = self.move_indices(
|
345
|
+
batch_operation.host_indices, batch_operation.device_indices
|
346
|
+
)
|
352
347
|
for i in range(self.mem_pool_host.layer_num):
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
self.mem_pool_host.load_page_per_layer(
|
362
|
-
batch_operation.host_indices,
|
363
|
-
batch_operation.device_indices,
|
364
|
-
self.mem_pool_device,
|
365
|
-
i,
|
366
|
-
)
|
367
|
-
self.load_stream.synchronize()
|
348
|
+
self.mem_pool_device.load_from_host_per_layer(
|
349
|
+
self.mem_pool_host,
|
350
|
+
host_indices,
|
351
|
+
device_indices,
|
352
|
+
i,
|
353
|
+
self.io_backend,
|
354
|
+
)
|
355
|
+
self.load_stream.synchronize()
|
368
356
|
self.layer_done_counter.increment()
|
369
357
|
|
370
358
|
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
@@ -372,148 +360,6 @@ class HiCacheController:
|
|
372
360
|
if node_id != 0:
|
373
361
|
self.ack_load_queue.put(node_id)
|
374
362
|
|
375
|
-
def write_aux_func(self, no_wait=False):
|
376
|
-
"""
|
377
|
-
Auxiliary function to prepare the buffer for write operations.
|
378
|
-
"""
|
379
|
-
torch.cuda.set_stream(self.write_stream)
|
380
|
-
|
381
|
-
def _to_op(op_):
|
382
|
-
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
383
|
-
op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
|
384
|
-
self.mem_pool_host.device
|
385
|
-
)
|
386
|
-
self.write_buffer.put(op_)
|
387
|
-
return op_
|
388
|
-
|
389
|
-
buffer = None
|
390
|
-
while not self.stop_event.is_set():
|
391
|
-
try:
|
392
|
-
operation = self.write_queue.get(block=True, timeout=1)
|
393
|
-
factor = (
|
394
|
-
len(operation.device_indices) // self.write_buffer.max_buffer_size
|
395
|
-
)
|
396
|
-
|
397
|
-
if factor >= 1:
|
398
|
-
if buffer is not None:
|
399
|
-
_to_op(buffer)
|
400
|
-
buffer = None
|
401
|
-
|
402
|
-
if factor < 2:
|
403
|
-
_to_op(operation)
|
404
|
-
else:
|
405
|
-
split_ops = operation.split(factor)
|
406
|
-
for op_ in split_ops:
|
407
|
-
_to_op(op_)
|
408
|
-
continue
|
409
|
-
|
410
|
-
if buffer is None:
|
411
|
-
buffer = operation
|
412
|
-
else:
|
413
|
-
buffer.merge(operation)
|
414
|
-
if (
|
415
|
-
no_wait
|
416
|
-
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
417
|
-
or self.write_queue.empty()
|
418
|
-
or self.write_buffer.empty()
|
419
|
-
):
|
420
|
-
_to_op(buffer)
|
421
|
-
buffer = None
|
422
|
-
except Empty:
|
423
|
-
continue
|
424
|
-
except Exception as e:
|
425
|
-
logger.error(e)
|
426
|
-
|
427
|
-
def load_aux_func(self):
|
428
|
-
"""
|
429
|
-
Auxiliary function to prepare the buffer for load operations.
|
430
|
-
"""
|
431
|
-
|
432
|
-
def _pin_op(op_, put=True):
|
433
|
-
op_.data = (
|
434
|
-
self.mem_pool_host.get_flat_data(op_.host_indices)
|
435
|
-
.contiguous()
|
436
|
-
.pin_memory()
|
437
|
-
)
|
438
|
-
if put:
|
439
|
-
self.load_buffer.put(op_)
|
440
|
-
return op_
|
441
|
-
|
442
|
-
buffer = None
|
443
|
-
while not self.stop_event.is_set():
|
444
|
-
try:
|
445
|
-
operation = self.load_queue.get(block=True, timeout=1)
|
446
|
-
factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
|
447
|
-
|
448
|
-
if factor >= 1:
|
449
|
-
if buffer is not None:
|
450
|
-
_pin_op(buffer)
|
451
|
-
buffer = None
|
452
|
-
|
453
|
-
if factor < 2:
|
454
|
-
_pin_op(operation)
|
455
|
-
else:
|
456
|
-
split_ops = operation.split(factor)
|
457
|
-
split_args = [(op_, True) for op_ in split_ops[:-1]]
|
458
|
-
split_args.append((split_ops[-1], False))
|
459
|
-
# Spawn threads to pin each op concurrently
|
460
|
-
with concurrent.futures.ThreadPoolExecutor() as executor:
|
461
|
-
pinned_ops = list(
|
462
|
-
executor.map(
|
463
|
-
lambda x: _pin_op(x[0], put=x[1]), split_args
|
464
|
-
)
|
465
|
-
)
|
466
|
-
# preserve the order of last op to ensure correct ack
|
467
|
-
self.load_buffer.put(pinned_ops[-1])
|
468
|
-
continue
|
469
|
-
|
470
|
-
if buffer is None:
|
471
|
-
buffer = operation
|
472
|
-
else:
|
473
|
-
buffer.merge(operation)
|
474
|
-
if (
|
475
|
-
len(buffer.host_indices) >= self.load_buffer.max_buffer_size
|
476
|
-
or self.load_queue.empty()
|
477
|
-
or self.load_buffer.empty()
|
478
|
-
):
|
479
|
-
_pin_op(buffer)
|
480
|
-
buffer = None
|
481
|
-
except Empty:
|
482
|
-
continue
|
483
|
-
except Exception as e:
|
484
|
-
logger.error(e)
|
485
|
-
|
486
|
-
# todo (zhiqiang): double buffering to be deprecated
|
487
|
-
def write_thread_func_buffer(self):
|
488
|
-
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
489
|
-
aux_thread.start()
|
490
|
-
|
491
|
-
while not self.stop_event.is_set():
|
492
|
-
operation = self.write_buffer.get()
|
493
|
-
if operation is None:
|
494
|
-
continue
|
495
|
-
self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
|
496
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
497
|
-
for node_id in operation.node_ids:
|
498
|
-
if node_id != 0:
|
499
|
-
self.ack_write_queue.put(node_id)
|
500
|
-
aux_thread.join()
|
501
|
-
|
502
|
-
def load_thread_func_buffer(self):
|
503
|
-
torch.cuda.set_stream(self.load_stream)
|
504
|
-
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
505
|
-
aux_thread.start()
|
506
|
-
while not self.stop_event.is_set():
|
507
|
-
operation = self.load_buffer.get()
|
508
|
-
if operation is None:
|
509
|
-
continue
|
510
|
-
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
511
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
512
|
-
for node_id in operation.node_ids:
|
513
|
-
if node_id != 0:
|
514
|
-
self.ack_load_queue.put(node_id)
|
515
|
-
aux_thread.join()
|
516
|
-
|
517
363
|
def evict_device(
|
518
364
|
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
519
365
|
) -> int:
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -65,6 +65,8 @@ class GenerateReqInput:
|
|
65
65
|
] = None
|
66
66
|
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
67
67
|
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
|
68
|
+
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
69
|
+
video_data: Optional[Union[List[List[str]], List[str], str]] = None
|
68
70
|
# The sampling_params. See descriptions below.
|
69
71
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
70
72
|
# The request id.
|
@@ -110,7 +112,11 @@ class GenerateReqInput:
|
|
110
112
|
data_parallel_rank: Optional[int] = None
|
111
113
|
|
112
114
|
def contains_mm_input(self) -> bool:
|
113
|
-
return
|
115
|
+
return (
|
116
|
+
has_valid_data(self.image_data)
|
117
|
+
or has_valid_data(self.video_data)
|
118
|
+
or has_valid_data(self.audio_data)
|
119
|
+
)
|
114
120
|
|
115
121
|
def normalize_batch_and_arguments(self):
|
116
122
|
"""
|
@@ -200,6 +206,8 @@ class GenerateReqInput:
|
|
200
206
|
self.text = [self.text]
|
201
207
|
if self.input_ids is not None:
|
202
208
|
self.input_ids = [self.input_ids]
|
209
|
+
if self.input_embeds is not None:
|
210
|
+
self.input_embeds = [self.input_embeds]
|
203
211
|
|
204
212
|
def _normalize_single_inputs(self):
|
205
213
|
"""Normalize inputs for a single example."""
|
@@ -230,6 +238,7 @@ class GenerateReqInput:
|
|
230
238
|
self._normalize_rid(num)
|
231
239
|
self._normalize_lora_paths(num)
|
232
240
|
self._normalize_image_data(num)
|
241
|
+
self._normalize_video_data(num)
|
233
242
|
self._normalize_audio_data(num)
|
234
243
|
self._normalize_sampling_params(num)
|
235
244
|
self._normalize_logprob_params(num)
|
@@ -298,6 +307,15 @@ class GenerateReqInput:
|
|
298
307
|
self.image_data = wrapped_images * self.parallel_sample_num
|
299
308
|
self.modalities = ["image"] * num
|
300
309
|
|
310
|
+
def _normalize_video_data(self, num):
|
311
|
+
"""Normalize video data for batch processing."""
|
312
|
+
if self.video_data is None:
|
313
|
+
self.video_data = [None] * num
|
314
|
+
elif not isinstance(self.video_data, list):
|
315
|
+
self.video_data = [self.video_data] * num
|
316
|
+
elif isinstance(self.video_data, list):
|
317
|
+
self.video_data = self.video_data * self.parallel_sample_num
|
318
|
+
|
301
319
|
def _normalize_audio_data(self, num):
|
302
320
|
"""Normalize audio data for batch processing."""
|
303
321
|
if self.audio_data is None:
|
@@ -324,7 +342,9 @@ class GenerateReqInput:
|
|
324
342
|
new_rids = [f"{self.rid}_{i}" for i in range(num)]
|
325
343
|
self.rid = new_rids
|
326
344
|
elif isinstance(self.rid, list):
|
327
|
-
|
345
|
+
# Note: the length of rid shall be the same as the batch_size,
|
346
|
+
# as the rid would be expanded for parallel sampling in tokenizer_manager
|
347
|
+
if len(self.rid) != self.batch_size:
|
328
348
|
raise ValueError(
|
329
349
|
"The specified rids length mismatch with the batch_size for batch processing."
|
330
350
|
)
|
@@ -400,7 +420,11 @@ class GenerateReqInput:
|
|
400
420
|
return GenerateReqInput(
|
401
421
|
text=self.text[i] if self.text is not None else None,
|
402
422
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
423
|
+
input_embeds=(
|
424
|
+
self.input_embeds[i] if self.input_embeds is not None else None
|
425
|
+
),
|
403
426
|
image_data=self.image_data[i],
|
427
|
+
video_data=self.video_data[i],
|
404
428
|
audio_data=self.audio_data[i],
|
405
429
|
sampling_params=self.sampling_params[i],
|
406
430
|
rid=self.rid[i],
|
@@ -500,6 +524,8 @@ class EmbeddingReqInput:
|
|
500
524
|
image_data: Optional[
|
501
525
|
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
502
526
|
] = None
|
527
|
+
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
528
|
+
video_data: Optional[Union[List[str], str]] = None
|
503
529
|
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
504
530
|
audio_data: Optional[Union[List[str], str]] = None
|
505
531
|
# The token ids for text; one can either specify text or input_ids.
|
@@ -571,7 +597,11 @@ class EmbeddingReqInput:
|
|
571
597
|
return self.rid
|
572
598
|
|
573
599
|
def contains_mm_input(self) -> bool:
|
574
|
-
return
|
600
|
+
return (
|
601
|
+
has_valid_data(self.image_data)
|
602
|
+
or has_valid_data(self.video_data)
|
603
|
+
or has_valid_data(self.audio_data)
|
604
|
+
)
|
575
605
|
|
576
606
|
def __getitem__(self, i):
|
577
607
|
if self.is_cross_encoder_request:
|
@@ -898,6 +928,7 @@ class ProfileReqInput:
|
|
898
928
|
# If set, it profile as many as this number of steps.
|
899
929
|
# If it is set, profiling is automatically stopped after this step, and
|
900
930
|
# the caller doesn't need to run stop_profile.
|
931
|
+
start_step: Optional[int] = None
|
901
932
|
num_steps: Optional[int] = None
|
902
933
|
activities: Optional[List[str]] = None
|
903
934
|
profile_by_stage: bool = False
|
@@ -925,6 +956,7 @@ class ExpertDistributionReqOutput:
|
|
925
956
|
class ProfileReq:
|
926
957
|
type: ProfileReqType
|
927
958
|
output_dir: Optional[str] = None
|
959
|
+
start_step: Optional[int] = None
|
928
960
|
num_steps: Optional[int] = None
|
929
961
|
activities: Optional[List[str]] = None
|
930
962
|
profile_by_stage: bool = False
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -4,7 +4,7 @@ Multi-modality utils
|
|
4
4
|
|
5
5
|
import hashlib
|
6
6
|
from abc import abstractmethod
|
7
|
-
from typing import Callable, List, Optional, Tuple
|
7
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import torch
|
@@ -76,6 +76,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
76
76
|
This function will replace the data-tokens in between with pad_values accordingly
|
77
77
|
"""
|
78
78
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
79
|
+
print(f"{mm_inputs.mm_items=}")
|
79
80
|
data_token_pairs = self.data_token_id_pairs
|
80
81
|
mm_inputs.data_offsets = []
|
81
82
|
if data_token_pairs is None:
|
@@ -159,10 +160,10 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
|
159
160
|
return ret_input_ids
|
160
161
|
|
161
162
|
|
162
|
-
embedding_cache = None
|
163
|
+
embedding_cache: Optional[MultiModalCache] = None
|
163
164
|
|
164
165
|
|
165
|
-
def init_embedding_cache(max_size: int):
|
166
|
+
def init_embedding_cache(max_size: int = 0):
|
166
167
|
global embedding_cache
|
167
168
|
embedding_cache = MultiModalCache(max_size)
|
168
169
|
|
@@ -248,11 +249,14 @@ def _get_chunked_prefill_embedding(
|
|
248
249
|
) -> Optional[torch.Tensor]:
|
249
250
|
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
|
250
251
|
embedding_list = []
|
251
|
-
for
|
252
|
+
# FIXME(Xinyuan): temporary workaround for eagle3, which may have len(items_size) > len(prefix_length)
|
253
|
+
max_iterations = min(len(items_size) - 1, len(prefix_length))
|
254
|
+
for i in range(max_iterations):
|
252
255
|
if items_size[i] == items_size[i + 1]:
|
253
256
|
continue
|
254
257
|
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
|
255
258
|
items_offset = items_offset_list[i]
|
259
|
+
assert items_offset is not None, items_offset
|
256
260
|
embedding_items_hash = get_embedding_hash(embedding_items_per_req)
|
257
261
|
# if all items has been prefixed, we do not need to calculate embedding
|
258
262
|
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
|
@@ -269,7 +273,7 @@ def _get_chunked_prefill_embedding(
|
|
269
273
|
embedding_per_req_chunk, _, end_index = get_embedding_chunk(
|
270
274
|
embedding=embedding_per_req,
|
271
275
|
extend_prefix_len=prefix_length[i],
|
272
|
-
extend_seq_len=extend_length[i],
|
276
|
+
extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
|
273
277
|
items_offset=items_offset,
|
274
278
|
)
|
275
279
|
# remove this item from cache if chunk reaches to the end
|
@@ -378,11 +382,9 @@ def embed_mm_inputs(
|
|
378
382
|
extend_seq_lens: List[int],
|
379
383
|
input_ids: torch.Tensor,
|
380
384
|
input_embedding: nn.Embedding,
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
audio_data_embedding_func: Callable[
|
385
|
-
[List[MultimodalDataItem]], torch.Tensor
|
385
|
+
multimodal_model: nn.Module = None,
|
386
|
+
data_embedding_func_mapping: Dict[
|
387
|
+
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
386
388
|
] = None,
|
387
389
|
placeholder_tokens: dict[Modality, List[int]] = None,
|
388
390
|
) -> Optional[torch.Tensor]:
|
@@ -395,8 +397,6 @@ def embed_mm_inputs(
|
|
395
397
|
extend_seq_lens: Sequence lengths for each request
|
396
398
|
input_ids: Input token IDs tensor
|
397
399
|
input_embedding: Embedding layer for text tokens
|
398
|
-
image_data_embedding_func: Function to embed image data
|
399
|
-
audio_data_embedding_func: Function to embed audio data
|
400
400
|
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
|
401
401
|
|
402
402
|
Returns:
|
@@ -413,88 +413,53 @@ def embed_mm_inputs(
|
|
413
413
|
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
|
414
414
|
|
415
415
|
embeddings, masks = [], []
|
416
|
-
|
417
416
|
# 2. Get multimodal embedding separately
|
418
|
-
#
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
device=input_ids.device,
|
417
|
+
# Try get mm embedding if any
|
418
|
+
for modality in Modality.all():
|
419
|
+
items = [
|
420
|
+
item for item in item_flatten_list if item.is_modality(modality=modality)
|
421
|
+
]
|
422
|
+
embedder = (
|
423
|
+
None
|
424
|
+
if data_embedding_func_mapping is None
|
425
|
+
else data_embedding_func_mapping.get(modality, None)
|
428
426
|
)
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
[
|
438
|
-
item.image_offsets
|
439
|
-
for item in mm_inputs.mm_items
|
440
|
-
if item.is_image()
|
441
|
-
]
|
442
|
-
)
|
427
|
+
if embedder is None:
|
428
|
+
# "image", "video", etc
|
429
|
+
modality_id = modality.name.lower()
|
430
|
+
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
|
431
|
+
if len(items) != 0 and embedder is not None:
|
432
|
+
placeholder_tensor = torch.tensor(
|
433
|
+
[item.pad_value for item in items],
|
434
|
+
device=input_ids.device,
|
443
435
|
)
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
embeddings += [embedding]
|
457
|
-
masks += [mask]
|
458
|
-
|
459
|
-
# Try get audio embedding if any
|
460
|
-
if (
|
461
|
-
any(True for item in item_flatten_list if item.is_audio())
|
462
|
-
and audio_data_embedding_func
|
463
|
-
):
|
464
|
-
items = [item for item in item_flatten_list if item.is_audio()]
|
465
|
-
placeholder_tensor = torch.tensor(
|
466
|
-
[item.pad_value for item in items],
|
467
|
-
device=input_ids.device,
|
468
|
-
)
|
469
|
-
items_offsets = []
|
470
|
-
# calculate per request items length offset
|
471
|
-
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
472
|
-
for i, mm_inputs in enumerate(mm_inputs_list):
|
473
|
-
audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
|
474
|
-
items_size[i + 1] = len(audio_items)
|
475
|
-
items_offsets.append(
|
476
|
-
flatten_nested_list(
|
477
|
-
[
|
478
|
-
item.audio_offsets
|
479
|
-
for item in mm_inputs.mm_items
|
480
|
-
if item.is_audio()
|
481
|
-
]
|
436
|
+
# calculate per request items length offset
|
437
|
+
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
438
|
+
items_offsets = []
|
439
|
+
for i, mm_inputs in enumerate(mm_inputs_list):
|
440
|
+
mm_items = [
|
441
|
+
item
|
442
|
+
for item in mm_inputs.mm_items
|
443
|
+
if item.is_modality(modality=modality)
|
444
|
+
]
|
445
|
+
items_size[i + 1] = len(mm_items)
|
446
|
+
items_offsets.append(
|
447
|
+
flatten_nested_list([item.offsets for item in mm_inputs.mm_items])
|
482
448
|
)
|
449
|
+
items_size = torch.cumsum(items_size, dim=0).tolist()
|
450
|
+
|
451
|
+
embedding, mask = get_embedding_and_mask(
|
452
|
+
data_embedding_func=embedder,
|
453
|
+
embedding_items=items,
|
454
|
+
placeholder_tensor=placeholder_tensor,
|
455
|
+
input_ids=input_ids,
|
456
|
+
items_size=items_size,
|
457
|
+
prefix_length=extend_prefix_lens,
|
458
|
+
extend_length=extend_seq_lens,
|
459
|
+
items_offset_list=items_offsets,
|
483
460
|
)
|
484
|
-
|
485
|
-
|
486
|
-
embedding, mask = get_embedding_and_mask(
|
487
|
-
data_embedding_func=audio_data_embedding_func,
|
488
|
-
embedding_items=items,
|
489
|
-
placeholder_tensor=placeholder_tensor,
|
490
|
-
input_ids=input_ids,
|
491
|
-
items_size=items_size,
|
492
|
-
prefix_length=extend_prefix_lens,
|
493
|
-
extend_length=extend_seq_lens,
|
494
|
-
items_offset_list=items_offsets,
|
495
|
-
)
|
496
|
-
embeddings += [embedding]
|
497
|
-
masks += [mask]
|
461
|
+
embeddings += [embedding]
|
462
|
+
masks += [mask]
|
498
463
|
|
499
464
|
# 3. Get input embeddings
|
500
465
|
vocab_size = input_embedding.num_embeddings
|
@@ -521,11 +486,9 @@ def general_mm_embed_routine(
|
|
521
486
|
input_ids: torch.Tensor,
|
522
487
|
forward_batch: ForwardBatch,
|
523
488
|
language_model: nn.Module,
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
audio_data_embedding_func: Optional[
|
528
|
-
Callable[[List[MultimodalDataItem]], torch.Tensor]
|
489
|
+
multimodal_model: Optional[nn.Module] = None,
|
490
|
+
data_embedding_funcs: Dict[
|
491
|
+
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
529
492
|
] = None,
|
530
493
|
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
531
494
|
**kwargs,
|
@@ -570,8 +533,8 @@ def general_mm_embed_routine(
|
|
570
533
|
extend_seq_lens=extend_seq_lens,
|
571
534
|
input_ids=input_ids,
|
572
535
|
input_embedding=embed_tokens,
|
573
|
-
|
574
|
-
|
536
|
+
multimodal_model=multimodal_model,
|
537
|
+
data_embedding_func_mapping=data_embedding_funcs,
|
575
538
|
placeholder_tokens=placeholder_tokens,
|
576
539
|
)
|
577
540
|
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
|