sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +36 -2
  3. sglang/srt/conversation.py +56 -3
  4. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  5. sglang/srt/disaggregation/ascend/conn.py +44 -0
  6. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +50 -18
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  9. sglang/srt/disaggregation/utils.py +25 -3
  10. sglang/srt/entrypoints/engine.py +1 -1
  11. sglang/srt/entrypoints/http_server.py +1 -0
  12. sglang/srt/entrypoints/http_server_engine.py +1 -1
  13. sglang/srt/entrypoints/openai/protocol.py +11 -0
  14. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  15. sglang/srt/function_call/function_call_parser.py +2 -0
  16. sglang/srt/function_call/kimik2_detector.py +220 -0
  17. sglang/srt/hf_transformers_utils.py +18 -0
  18. sglang/srt/jinja_template_utils.py +8 -0
  19. sglang/srt/layers/communicator.py +20 -5
  20. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  21. sglang/srt/layers/layernorm.py +2 -2
  22. sglang/srt/layers/linear.py +12 -2
  23. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  24. sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
  25. sglang/srt/layers/moe/ep_moe/layer.py +141 -2
  26. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  29. sglang/srt/layers/moe/topk.py +8 -2
  30. sglang/srt/layers/parameter.py +19 -3
  31. sglang/srt/layers/quantization/__init__.py +2 -0
  32. sglang/srt/layers/quantization/fp8.py +28 -7
  33. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  35. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  36. sglang/srt/layers/quantization/w4afp8.py +264 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  38. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  39. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  40. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  41. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  42. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  43. sglang/srt/managers/cache_controller.py +41 -195
  44. sglang/srt/managers/io_struct.py +35 -3
  45. sglang/srt/managers/mm_utils.py +59 -96
  46. sglang/srt/managers/schedule_batch.py +17 -6
  47. sglang/srt/managers/scheduler.py +38 -6
  48. sglang/srt/managers/tokenizer_manager.py +16 -0
  49. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  50. sglang/srt/mem_cache/memory_pool.py +176 -101
  51. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  52. sglang/srt/mem_cache/radix_cache.py +8 -4
  53. sglang/srt/model_executor/forward_batch_info.py +13 -1
  54. sglang/srt/model_loader/loader.py +23 -12
  55. sglang/srt/models/deepseek_janus_pro.py +1 -1
  56. sglang/srt/models/deepseek_v2.py +78 -19
  57. sglang/srt/models/deepseek_vl2.py +1 -1
  58. sglang/srt/models/gemma3_mm.py +1 -1
  59. sglang/srt/models/gemma3n_mm.py +6 -3
  60. sglang/srt/models/internvl.py +8 -2
  61. sglang/srt/models/kimi_vl.py +8 -2
  62. sglang/srt/models/llama.py +2 -0
  63. sglang/srt/models/llava.py +3 -1
  64. sglang/srt/models/llavavid.py +1 -1
  65. sglang/srt/models/minicpmo.py +1 -2
  66. sglang/srt/models/minicpmv.py +1 -1
  67. sglang/srt/models/mixtral_quant.py +4 -0
  68. sglang/srt/models/mllama4.py +372 -82
  69. sglang/srt/models/phi4mm.py +8 -2
  70. sglang/srt/models/phimoe.py +553 -0
  71. sglang/srt/models/qwen2.py +2 -0
  72. sglang/srt/models/qwen2_5_vl.py +10 -7
  73. sglang/srt/models/qwen2_vl.py +12 -1
  74. sglang/srt/models/vila.py +8 -2
  75. sglang/srt/multimodal/mm_utils.py +2 -2
  76. sglang/srt/multimodal/processors/base_processor.py +197 -137
  77. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  78. sglang/srt/multimodal/processors/gemma3.py +4 -2
  79. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  80. sglang/srt/multimodal/processors/internvl.py +1 -1
  81. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  82. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  83. sglang/srt/multimodal/processors/minicpm.py +4 -3
  84. sglang/srt/multimodal/processors/mllama4.py +63 -61
  85. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  86. sglang/srt/multimodal/processors/pixtral.py +1 -1
  87. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  88. sglang/srt/multimodal/processors/vila.py +1 -1
  89. sglang/srt/server_args.py +26 -4
  90. sglang/srt/two_batch_overlap.py +3 -0
  91. sglang/srt/utils.py +191 -48
  92. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  93. sglang/utils.py +5 -5
  94. sglang/version.py +1 -1
  95. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
  96. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
  97. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
- import concurrent.futures
17
16
  import logging
18
17
  import math
19
18
  import threading
@@ -169,12 +168,23 @@ class HiCacheController:
169
168
  page_size: int,
170
169
  load_cache_event: threading.Event = None,
171
170
  write_policy: str = "write_through_selective",
171
+ io_backend: str = "",
172
172
  ):
173
173
  self.mem_pool_device_allocator = token_to_kv_pool_allocator
174
174
  self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
175
175
  self.mem_pool_host = mem_pool_host
176
176
  self.write_policy = write_policy
177
177
  self.page_size = page_size
178
+ # using kernel for small page KV cache transfer and DMA for large pages
179
+ if not io_backend:
180
+ IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
181
+ self.io_backend = (
182
+ "direct"
183
+ if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
184
+ else "kernel"
185
+ )
186
+ else:
187
+ self.io_backend = io_backend
178
188
 
179
189
  self.load_cache_event = load_cache_event
180
190
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -203,12 +213,7 @@ class HiCacheController:
203
213
  self.load_stream = torch.cuda.Stream()
204
214
 
205
215
  self.write_thread = threading.Thread(
206
- target=(
207
- self.write_thread_func_buffer
208
- if self.page_size == 1
209
- else self.write_thread_func_direct
210
- ),
211
- daemon=True,
216
+ target=self.write_thread_func_direct, daemon=True
212
217
  )
213
218
  self.load_thread = threading.Thread(
214
219
  target=self.load_thread_func_layer_by_layer, daemon=True
@@ -229,12 +234,7 @@ class HiCacheController:
229
234
  self.ack_load_queue.queue.clear()
230
235
 
231
236
  self.write_thread = threading.Thread(
232
- target=(
233
- self.write_thread_func_buffer
234
- if self.page_size == 1
235
- else self.write_thread_func_direct
236
- ),
237
- daemon=True,
237
+ target=self.write_thread_func_direct, daemon=True
238
238
  )
239
239
  self.load_thread = threading.Thread(
240
240
  target=self.load_thread_func_layer_by_layer, daemon=True
@@ -281,6 +281,15 @@ class HiCacheController:
281
281
  )
282
282
  return device_indices
283
283
 
284
+ def move_indices(self, host_indices, device_indices):
285
+ # move indices to GPU if using kernels, to host if using direct indexing
286
+ if self.io_backend == "kernel":
287
+ return host_indices.to(self.mem_pool_device.device), device_indices
288
+ elif self.io_backend == "direct":
289
+ return host_indices, device_indices.cpu()
290
+ else:
291
+ raise ValueError(f"Unsupported io backend")
292
+
284
293
  def write_thread_func_direct(self):
285
294
  """
286
295
  Directly write through KV caches to host memory without buffering.
@@ -289,10 +298,14 @@ class HiCacheController:
289
298
  while not self.stop_event.is_set():
290
299
  try:
291
300
  operation = self.write_queue.get(block=True, timeout=1)
292
- self.mem_pool_host.write_page_all_layers(
293
- operation.host_indices,
294
- operation.device_indices,
295
- self.mem_pool_device,
301
+ host_indices, device_indices = self.move_indices(
302
+ operation.host_indices, operation.device_indices
303
+ )
304
+ self.mem_pool_device.backup_to_host_all_layer(
305
+ self.mem_pool_host,
306
+ host_indices,
307
+ device_indices,
308
+ self.io_backend,
296
309
  )
297
310
  self.write_stream.synchronize()
298
311
  self.mem_pool_host.complete_io(operation.host_indices)
@@ -304,27 +317,6 @@ class HiCacheController:
304
317
  except Exception as e:
305
318
  logger.error(e)
306
319
 
307
- def load_thread_func_direct(self):
308
- """
309
- Directly load KV caches from host memory to device memory without buffering.
310
- """
311
- torch.cuda.set_stream(self.load_stream)
312
- while not self.stop_event.is_set():
313
- try:
314
- operation = self.load_queue.get(block=True, timeout=1)
315
- operation.data = self.mem_pool_host.get_flat_data(
316
- operation.host_indices
317
- )
318
- self.mem_pool_device.transfer(operation.device_indices, operation.data)
319
- self.mem_pool_host.complete_io(operation.host_indices)
320
- for node_id in operation.node_ids:
321
- if node_id != 0:
322
- self.ack_load_queue.put(node_id)
323
- except Empty:
324
- continue
325
- except Exception as e:
326
- logger.error(e)
327
-
328
320
  def load_thread_func_layer_by_layer(self):
329
321
  """
330
322
  Load KV caches from host memory to device memory layer by layer.
@@ -349,22 +341,18 @@ class HiCacheController:
349
341
 
350
342
  # start layer-wise KV cache transfer from CPU to GPU
351
343
  self.layer_done_counter.reset()
344
+ host_indices, device_indices = self.move_indices(
345
+ batch_operation.host_indices, batch_operation.device_indices
346
+ )
352
347
  for i in range(self.mem_pool_host.layer_num):
353
- if self.page_size == 1:
354
- flat_data = self.mem_pool_host.get_flat_data_by_layer(
355
- batch_operation.host_indices, i
356
- )
357
- self.mem_pool_device.transfer_per_layer(
358
- batch_operation.device_indices, flat_data, i
359
- )
360
- else:
361
- self.mem_pool_host.load_page_per_layer(
362
- batch_operation.host_indices,
363
- batch_operation.device_indices,
364
- self.mem_pool_device,
365
- i,
366
- )
367
- self.load_stream.synchronize()
348
+ self.mem_pool_device.load_from_host_per_layer(
349
+ self.mem_pool_host,
350
+ host_indices,
351
+ device_indices,
352
+ i,
353
+ self.io_backend,
354
+ )
355
+ self.load_stream.synchronize()
368
356
  self.layer_done_counter.increment()
369
357
 
370
358
  self.mem_pool_host.complete_io(batch_operation.host_indices)
@@ -372,148 +360,6 @@ class HiCacheController:
372
360
  if node_id != 0:
373
361
  self.ack_load_queue.put(node_id)
374
362
 
375
- def write_aux_func(self, no_wait=False):
376
- """
377
- Auxiliary function to prepare the buffer for write operations.
378
- """
379
- torch.cuda.set_stream(self.write_stream)
380
-
381
- def _to_op(op_):
382
- assert op_.device_indices.is_cuda, "Device indices should be on GPU"
383
- op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
384
- self.mem_pool_host.device
385
- )
386
- self.write_buffer.put(op_)
387
- return op_
388
-
389
- buffer = None
390
- while not self.stop_event.is_set():
391
- try:
392
- operation = self.write_queue.get(block=True, timeout=1)
393
- factor = (
394
- len(operation.device_indices) // self.write_buffer.max_buffer_size
395
- )
396
-
397
- if factor >= 1:
398
- if buffer is not None:
399
- _to_op(buffer)
400
- buffer = None
401
-
402
- if factor < 2:
403
- _to_op(operation)
404
- else:
405
- split_ops = operation.split(factor)
406
- for op_ in split_ops:
407
- _to_op(op_)
408
- continue
409
-
410
- if buffer is None:
411
- buffer = operation
412
- else:
413
- buffer.merge(operation)
414
- if (
415
- no_wait
416
- or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
417
- or self.write_queue.empty()
418
- or self.write_buffer.empty()
419
- ):
420
- _to_op(buffer)
421
- buffer = None
422
- except Empty:
423
- continue
424
- except Exception as e:
425
- logger.error(e)
426
-
427
- def load_aux_func(self):
428
- """
429
- Auxiliary function to prepare the buffer for load operations.
430
- """
431
-
432
- def _pin_op(op_, put=True):
433
- op_.data = (
434
- self.mem_pool_host.get_flat_data(op_.host_indices)
435
- .contiguous()
436
- .pin_memory()
437
- )
438
- if put:
439
- self.load_buffer.put(op_)
440
- return op_
441
-
442
- buffer = None
443
- while not self.stop_event.is_set():
444
- try:
445
- operation = self.load_queue.get(block=True, timeout=1)
446
- factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
447
-
448
- if factor >= 1:
449
- if buffer is not None:
450
- _pin_op(buffer)
451
- buffer = None
452
-
453
- if factor < 2:
454
- _pin_op(operation)
455
- else:
456
- split_ops = operation.split(factor)
457
- split_args = [(op_, True) for op_ in split_ops[:-1]]
458
- split_args.append((split_ops[-1], False))
459
- # Spawn threads to pin each op concurrently
460
- with concurrent.futures.ThreadPoolExecutor() as executor:
461
- pinned_ops = list(
462
- executor.map(
463
- lambda x: _pin_op(x[0], put=x[1]), split_args
464
- )
465
- )
466
- # preserve the order of last op to ensure correct ack
467
- self.load_buffer.put(pinned_ops[-1])
468
- continue
469
-
470
- if buffer is None:
471
- buffer = operation
472
- else:
473
- buffer.merge(operation)
474
- if (
475
- len(buffer.host_indices) >= self.load_buffer.max_buffer_size
476
- or self.load_queue.empty()
477
- or self.load_buffer.empty()
478
- ):
479
- _pin_op(buffer)
480
- buffer = None
481
- except Empty:
482
- continue
483
- except Exception as e:
484
- logger.error(e)
485
-
486
- # todo (zhiqiang): double buffering to be deprecated
487
- def write_thread_func_buffer(self):
488
- aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
489
- aux_thread.start()
490
-
491
- while not self.stop_event.is_set():
492
- operation = self.write_buffer.get()
493
- if operation is None:
494
- continue
495
- self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
496
- self.mem_pool_host.complete_io(operation.host_indices)
497
- for node_id in operation.node_ids:
498
- if node_id != 0:
499
- self.ack_write_queue.put(node_id)
500
- aux_thread.join()
501
-
502
- def load_thread_func_buffer(self):
503
- torch.cuda.set_stream(self.load_stream)
504
- aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
505
- aux_thread.start()
506
- while not self.stop_event.is_set():
507
- operation = self.load_buffer.get()
508
- if operation is None:
509
- continue
510
- self.mem_pool_device.transfer(operation.device_indices, operation.data)
511
- self.mem_pool_host.complete_io(operation.host_indices)
512
- for node_id in operation.node_ids:
513
- if node_id != 0:
514
- self.ack_load_queue.put(node_id)
515
- aux_thread.join()
516
-
517
363
  def evict_device(
518
364
  self, device_indices: torch.Tensor, host_indices: torch.Tensor
519
365
  ) -> int:
@@ -65,6 +65,8 @@ class GenerateReqInput:
65
65
  ] = None
66
66
  # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
67
67
  audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
68
+ # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
69
+ video_data: Optional[Union[List[List[str]], List[str], str]] = None
68
70
  # The sampling_params. See descriptions below.
69
71
  sampling_params: Optional[Union[List[Dict], Dict]] = None
70
72
  # The request id.
@@ -110,7 +112,11 @@ class GenerateReqInput:
110
112
  data_parallel_rank: Optional[int] = None
111
113
 
112
114
  def contains_mm_input(self) -> bool:
113
- return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
115
+ return (
116
+ has_valid_data(self.image_data)
117
+ or has_valid_data(self.video_data)
118
+ or has_valid_data(self.audio_data)
119
+ )
114
120
 
115
121
  def normalize_batch_and_arguments(self):
116
122
  """
@@ -200,6 +206,8 @@ class GenerateReqInput:
200
206
  self.text = [self.text]
201
207
  if self.input_ids is not None:
202
208
  self.input_ids = [self.input_ids]
209
+ if self.input_embeds is not None:
210
+ self.input_embeds = [self.input_embeds]
203
211
 
204
212
  def _normalize_single_inputs(self):
205
213
  """Normalize inputs for a single example."""
@@ -230,6 +238,7 @@ class GenerateReqInput:
230
238
  self._normalize_rid(num)
231
239
  self._normalize_lora_paths(num)
232
240
  self._normalize_image_data(num)
241
+ self._normalize_video_data(num)
233
242
  self._normalize_audio_data(num)
234
243
  self._normalize_sampling_params(num)
235
244
  self._normalize_logprob_params(num)
@@ -298,6 +307,15 @@ class GenerateReqInput:
298
307
  self.image_data = wrapped_images * self.parallel_sample_num
299
308
  self.modalities = ["image"] * num
300
309
 
310
+ def _normalize_video_data(self, num):
311
+ """Normalize video data for batch processing."""
312
+ if self.video_data is None:
313
+ self.video_data = [None] * num
314
+ elif not isinstance(self.video_data, list):
315
+ self.video_data = [self.video_data] * num
316
+ elif isinstance(self.video_data, list):
317
+ self.video_data = self.video_data * self.parallel_sample_num
318
+
301
319
  def _normalize_audio_data(self, num):
302
320
  """Normalize audio data for batch processing."""
303
321
  if self.audio_data is None:
@@ -324,7 +342,9 @@ class GenerateReqInput:
324
342
  new_rids = [f"{self.rid}_{i}" for i in range(num)]
325
343
  self.rid = new_rids
326
344
  elif isinstance(self.rid, list):
327
- if len(self.rid) != num:
345
+ # Note: the length of rid shall be the same as the batch_size,
346
+ # as the rid would be expanded for parallel sampling in tokenizer_manager
347
+ if len(self.rid) != self.batch_size:
328
348
  raise ValueError(
329
349
  "The specified rids length mismatch with the batch_size for batch processing."
330
350
  )
@@ -400,7 +420,11 @@ class GenerateReqInput:
400
420
  return GenerateReqInput(
401
421
  text=self.text[i] if self.text is not None else None,
402
422
  input_ids=self.input_ids[i] if self.input_ids is not None else None,
423
+ input_embeds=(
424
+ self.input_embeds[i] if self.input_embeds is not None else None
425
+ ),
403
426
  image_data=self.image_data[i],
427
+ video_data=self.video_data[i],
404
428
  audio_data=self.audio_data[i],
405
429
  sampling_params=self.sampling_params[i],
406
430
  rid=self.rid[i],
@@ -500,6 +524,8 @@ class EmbeddingReqInput:
500
524
  image_data: Optional[
501
525
  Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
502
526
  ] = None
527
+ # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
528
+ video_data: Optional[Union[List[str], str]] = None
503
529
  # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
504
530
  audio_data: Optional[Union[List[str], str]] = None
505
531
  # The token ids for text; one can either specify text or input_ids.
@@ -571,7 +597,11 @@ class EmbeddingReqInput:
571
597
  return self.rid
572
598
 
573
599
  def contains_mm_input(self) -> bool:
574
- return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
600
+ return (
601
+ has_valid_data(self.image_data)
602
+ or has_valid_data(self.video_data)
603
+ or has_valid_data(self.audio_data)
604
+ )
575
605
 
576
606
  def __getitem__(self, i):
577
607
  if self.is_cross_encoder_request:
@@ -898,6 +928,7 @@ class ProfileReqInput:
898
928
  # If set, it profile as many as this number of steps.
899
929
  # If it is set, profiling is automatically stopped after this step, and
900
930
  # the caller doesn't need to run stop_profile.
931
+ start_step: Optional[int] = None
901
932
  num_steps: Optional[int] = None
902
933
  activities: Optional[List[str]] = None
903
934
  profile_by_stage: bool = False
@@ -925,6 +956,7 @@ class ExpertDistributionReqOutput:
925
956
  class ProfileReq:
926
957
  type: ProfileReqType
927
958
  output_dir: Optional[str] = None
959
+ start_step: Optional[int] = None
928
960
  num_steps: Optional[int] = None
929
961
  activities: Optional[List[str]] = None
930
962
  profile_by_stage: bool = False
@@ -4,7 +4,7 @@ Multi-modality utils
4
4
 
5
5
  import hashlib
6
6
  from abc import abstractmethod
7
- from typing import Callable, List, Optional, Tuple
7
+ from typing import Callable, Dict, List, Optional, Tuple
8
8
 
9
9
  import numpy as np
10
10
  import torch
@@ -76,6 +76,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
76
76
  This function will replace the data-tokens in between with pad_values accordingly
77
77
  """
78
78
  pad_values = [item.pad_value for item in mm_inputs.mm_items]
79
+ print(f"{mm_inputs.mm_items=}")
79
80
  data_token_pairs = self.data_token_id_pairs
80
81
  mm_inputs.data_offsets = []
81
82
  if data_token_pairs is None:
@@ -159,10 +160,10 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
159
160
  return ret_input_ids
160
161
 
161
162
 
162
- embedding_cache = None
163
+ embedding_cache: Optional[MultiModalCache] = None
163
164
 
164
165
 
165
- def init_embedding_cache(max_size: int):
166
+ def init_embedding_cache(max_size: int = 0):
166
167
  global embedding_cache
167
168
  embedding_cache = MultiModalCache(max_size)
168
169
 
@@ -248,11 +249,14 @@ def _get_chunked_prefill_embedding(
248
249
  ) -> Optional[torch.Tensor]:
249
250
  # Calculate embedding for each request, try to get it from cache to avoid repeated calculation
250
251
  embedding_list = []
251
- for i in range(len(items_size) - 1):
252
+ # FIXME(Xinyuan): temporary workaround for eagle3, which may have len(items_size) > len(prefix_length)
253
+ max_iterations = min(len(items_size) - 1, len(prefix_length))
254
+ for i in range(max_iterations):
252
255
  if items_size[i] == items_size[i + 1]:
253
256
  continue
254
257
  embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
255
258
  items_offset = items_offset_list[i]
259
+ assert items_offset is not None, items_offset
256
260
  embedding_items_hash = get_embedding_hash(embedding_items_per_req)
257
261
  # if all items has been prefixed, we do not need to calculate embedding
258
262
  if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
@@ -269,7 +273,7 @@ def _get_chunked_prefill_embedding(
269
273
  embedding_per_req_chunk, _, end_index = get_embedding_chunk(
270
274
  embedding=embedding_per_req,
271
275
  extend_prefix_len=prefix_length[i],
272
- extend_seq_len=extend_length[i],
276
+ extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
273
277
  items_offset=items_offset,
274
278
  )
275
279
  # remove this item from cache if chunk reaches to the end
@@ -378,11 +382,9 @@ def embed_mm_inputs(
378
382
  extend_seq_lens: List[int],
379
383
  input_ids: torch.Tensor,
380
384
  input_embedding: nn.Embedding,
381
- image_data_embedding_func: Callable[
382
- [List[MultimodalDataItem]], torch.Tensor
383
- ] = None,
384
- audio_data_embedding_func: Callable[
385
- [List[MultimodalDataItem]], torch.Tensor
385
+ multimodal_model: nn.Module = None,
386
+ data_embedding_func_mapping: Dict[
387
+ Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
386
388
  ] = None,
387
389
  placeholder_tokens: dict[Modality, List[int]] = None,
388
390
  ) -> Optional[torch.Tensor]:
@@ -395,8 +397,6 @@ def embed_mm_inputs(
395
397
  extend_seq_lens: Sequence lengths for each request
396
398
  input_ids: Input token IDs tensor
397
399
  input_embedding: Embedding layer for text tokens
398
- image_data_embedding_func: Function to embed image data
399
- audio_data_embedding_func: Function to embed audio data
400
400
  placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
401
401
 
402
402
  Returns:
@@ -413,88 +413,53 @@ def embed_mm_inputs(
413
413
  item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
414
414
 
415
415
  embeddings, masks = [], []
416
-
417
416
  # 2. Get multimodal embedding separately
418
- # TODO: make this more generic
419
- # Try get image embedding if any
420
- if (
421
- any(True for item in item_flatten_list if item.is_image())
422
- and image_data_embedding_func
423
- ):
424
- items = [item for item in item_flatten_list if item.is_image()]
425
- placeholder_tensor = torch.tensor(
426
- [item.pad_value for item in items],
427
- device=input_ids.device,
417
+ # Try get mm embedding if any
418
+ for modality in Modality.all():
419
+ items = [
420
+ item for item in item_flatten_list if item.is_modality(modality=modality)
421
+ ]
422
+ embedder = (
423
+ None
424
+ if data_embedding_func_mapping is None
425
+ else data_embedding_func_mapping.get(modality, None)
428
426
  )
429
- # calculate per request items length offset
430
- items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
431
- items_offsets = []
432
- for i, mm_inputs in enumerate(mm_inputs_list):
433
- image_items = [item for item in mm_inputs.mm_items if item.is_image()]
434
- items_size[i + 1] = len(image_items)
435
- items_offsets.append(
436
- flatten_nested_list(
437
- [
438
- item.image_offsets
439
- for item in mm_inputs.mm_items
440
- if item.is_image()
441
- ]
442
- )
427
+ if embedder is None:
428
+ # "image", "video", etc
429
+ modality_id = modality.name.lower()
430
+ embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
431
+ if len(items) != 0 and embedder is not None:
432
+ placeholder_tensor = torch.tensor(
433
+ [item.pad_value for item in items],
434
+ device=input_ids.device,
443
435
  )
444
- items_size = torch.cumsum(items_size, dim=0).tolist()
445
-
446
- embedding, mask = get_embedding_and_mask(
447
- data_embedding_func=image_data_embedding_func,
448
- embedding_items=items,
449
- placeholder_tensor=placeholder_tensor,
450
- input_ids=input_ids,
451
- items_size=items_size,
452
- prefix_length=extend_prefix_lens,
453
- extend_length=extend_seq_lens,
454
- items_offset_list=items_offsets,
455
- )
456
- embeddings += [embedding]
457
- masks += [mask]
458
-
459
- # Try get audio embedding if any
460
- if (
461
- any(True for item in item_flatten_list if item.is_audio())
462
- and audio_data_embedding_func
463
- ):
464
- items = [item for item in item_flatten_list if item.is_audio()]
465
- placeholder_tensor = torch.tensor(
466
- [item.pad_value for item in items],
467
- device=input_ids.device,
468
- )
469
- items_offsets = []
470
- # calculate per request items length offset
471
- items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
472
- for i, mm_inputs in enumerate(mm_inputs_list):
473
- audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
474
- items_size[i + 1] = len(audio_items)
475
- items_offsets.append(
476
- flatten_nested_list(
477
- [
478
- item.audio_offsets
479
- for item in mm_inputs.mm_items
480
- if item.is_audio()
481
- ]
436
+ # calculate per request items length offset
437
+ items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
438
+ items_offsets = []
439
+ for i, mm_inputs in enumerate(mm_inputs_list):
440
+ mm_items = [
441
+ item
442
+ for item in mm_inputs.mm_items
443
+ if item.is_modality(modality=modality)
444
+ ]
445
+ items_size[i + 1] = len(mm_items)
446
+ items_offsets.append(
447
+ flatten_nested_list([item.offsets for item in mm_inputs.mm_items])
482
448
  )
449
+ items_size = torch.cumsum(items_size, dim=0).tolist()
450
+
451
+ embedding, mask = get_embedding_and_mask(
452
+ data_embedding_func=embedder,
453
+ embedding_items=items,
454
+ placeholder_tensor=placeholder_tensor,
455
+ input_ids=input_ids,
456
+ items_size=items_size,
457
+ prefix_length=extend_prefix_lens,
458
+ extend_length=extend_seq_lens,
459
+ items_offset_list=items_offsets,
483
460
  )
484
- items_size = torch.cumsum(items_size, dim=0)
485
-
486
- embedding, mask = get_embedding_and_mask(
487
- data_embedding_func=audio_data_embedding_func,
488
- embedding_items=items,
489
- placeholder_tensor=placeholder_tensor,
490
- input_ids=input_ids,
491
- items_size=items_size,
492
- prefix_length=extend_prefix_lens,
493
- extend_length=extend_seq_lens,
494
- items_offset_list=items_offsets,
495
- )
496
- embeddings += [embedding]
497
- masks += [mask]
461
+ embeddings += [embedding]
462
+ masks += [mask]
498
463
 
499
464
  # 3. Get input embeddings
500
465
  vocab_size = input_embedding.num_embeddings
@@ -521,11 +486,9 @@ def general_mm_embed_routine(
521
486
  input_ids: torch.Tensor,
522
487
  forward_batch: ForwardBatch,
523
488
  language_model: nn.Module,
524
- image_data_embedding_func: Optional[
525
- Callable[[List[MultimodalDataItem]], torch.Tensor]
526
- ] = None,
527
- audio_data_embedding_func: Optional[
528
- Callable[[List[MultimodalDataItem]], torch.Tensor]
489
+ multimodal_model: Optional[nn.Module] = None,
490
+ data_embedding_funcs: Dict[
491
+ Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
529
492
  ] = None,
530
493
  placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
531
494
  **kwargs,
@@ -570,8 +533,8 @@ def general_mm_embed_routine(
570
533
  extend_seq_lens=extend_seq_lens,
571
534
  input_ids=input_ids,
572
535
  input_embedding=embed_tokens,
573
- image_data_embedding_func=image_data_embedding_func,
574
- audio_data_embedding_func=audio_data_embedding_func,
536
+ multimodal_model=multimodal_model,
537
+ data_embedding_func_mapping=data_embedding_funcs,
575
538
  placeholder_tokens=placeholder_tokens,
576
539
  )
577
540
  # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models