sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,161 @@
1
+ import asyncio
2
+ import math
3
+ from typing import List, Union
4
+
5
+ from PIL import Image
6
+
7
+ from sglang.srt.managers.image_processor import BaseImageProcessor
8
+ from sglang.srt.managers.image_processors.base_image_processor import (
9
+ get_global_processor,
10
+ )
11
+ from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
12
+ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
13
+
14
+
15
+ # Compatible with Qwen2VL and Qwen2_5VL
16
+ class Qwen2_5VLImageProcessor(BaseImageProcessor):
17
+ def __init__(self, hf_config, server_args, _processor):
18
+ super().__init__(hf_config, server_args, _processor)
19
+ self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
20
+ self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
21
+ self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
22
+ self.image_token_id = hf_config.image_token_id
23
+ self.video_token_id = hf_config.video_token_id
24
+ self.NUM_TOKEN_PER_FRAME = 770
25
+ self.IMAGE_FACTOR = 28
26
+ self.MIN_PIXELS = 4 * 28 * 28
27
+ self.MAX_PIXELS = 16384 * 28 * 28
28
+ self.MAX_PIXELS = 16384 * 28 * 28
29
+ self.MAX_RATIO = 200
30
+
31
+ @staticmethod
32
+ def _process_images_task(images, input_text, _hf_config):
33
+ if isinstance(images, list) and len(images) == 0:
34
+ images = None
35
+ result = get_global_processor().__call__(
36
+ text=[input_text], images=images, padding=True, return_tensors="pt"
37
+ )
38
+
39
+ return {
40
+ "input_ids": result.input_ids,
41
+ "pixel_values": getattr(result, "pixel_values", None),
42
+ "image_grid_thw": getattr(result, "image_grid_thw", None),
43
+ "second_per_grid_ts": getattr(result, "second_per_grid_ts", None),
44
+ "video_grid_thws": getattr(result, "video_grid_thws", None),
45
+ }
46
+
47
+ async def _process_images(self, images, input_text) -> dict:
48
+ if self.executor is not None:
49
+ loop = asyncio.get_event_loop()
50
+ return await loop.run_in_executor(
51
+ self.executor,
52
+ Qwen2_5VLImageProcessor._process_images_task,
53
+ images,
54
+ input_text,
55
+ self.hf_config,
56
+ )
57
+ else:
58
+ return self._process_images_task(images, input_text, self.hf_config)
59
+
60
+ async def process_images_async(
61
+ self,
62
+ image_data: List[Union[str, bytes]],
63
+ input_ids,
64
+ request_obj,
65
+ max_req_input_len,
66
+ *args,
67
+ **kwargs,
68
+ ):
69
+ if not image_data:
70
+ return None
71
+ if isinstance(image_data, str):
72
+ image_data = [image_data]
73
+
74
+ image_token = self.IMAGE_TOKEN
75
+ base_output = self.load_images(
76
+ input_ids,
77
+ image_data,
78
+ image_token,
79
+ max_req_input_len,
80
+ )
81
+
82
+ def smart_resize(
83
+ height: int,
84
+ width: int,
85
+ factor: int = self.IMAGE_FACTOR,
86
+ min_pixels: int = self.MIN_PIXELS,
87
+ max_pixels: int = self.MAX_PIXELS,
88
+ ) -> tuple[int, int]:
89
+ """
90
+ Rescales the image so that the following conditions are met:
91
+
92
+ 1. Both dimensions (height and width) are divisible by 'factor'.
93
+
94
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
95
+
96
+ 3. The aspect ratio of the image is maintained as closely as possible.
97
+ """
98
+ if max(height, width) / min(height, width) > self.MAX_RATIO:
99
+ raise ValueError(
100
+ f"absolute aspect ratio must be smaller than {self.MAX_RATIO}, got {max(height, width) / min(height, width)}"
101
+ )
102
+ h_bar = max(factor, round_by_factor(height, factor))
103
+ w_bar = max(factor, round_by_factor(width, factor))
104
+ if h_bar * w_bar > max_pixels:
105
+ beta = math.sqrt((height * width) / max_pixels)
106
+ h_bar = floor_by_factor(height / beta, factor)
107
+ w_bar = floor_by_factor(width / beta, factor)
108
+ elif h_bar * w_bar < min_pixels:
109
+ beta = math.sqrt(min_pixels / (height * width))
110
+ h_bar = ceil_by_factor(height * beta, factor)
111
+ w_bar = ceil_by_factor(width * beta, factor)
112
+ return h_bar, w_bar
113
+
114
+ def resize_image(image, size_factor: int = self.IMAGE_FACTOR) -> Image.Image:
115
+ width, height = image.size
116
+ min_pixels = self.MIN_PIXELS
117
+ max_pixels = self.MAX_PIXELS
118
+ resized_height, resized_width = smart_resize(
119
+ height,
120
+ width,
121
+ factor=size_factor,
122
+ min_pixels=min_pixels,
123
+ max_pixels=max_pixels,
124
+ )
125
+ image = image.resize((resized_width, resized_height))
126
+ return image
127
+
128
+ def round_by_factor(number: int, factor: int) -> int:
129
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
130
+ return round(number / factor) * factor
131
+
132
+ def ceil_by_factor(number: int, factor: int) -> int:
133
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
134
+ return math.ceil(number / factor) * factor
135
+
136
+ def floor_by_factor(number: int, factor: int) -> int:
137
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
138
+ return math.floor(number / factor) * factor
139
+
140
+ images = [resize_image(image) for image in base_output.all_frames]
141
+
142
+ ret = await self._process_images(images, base_output.input_text)
143
+ return {
144
+ "input_ids": ret["input_ids"].flatten().tolist(),
145
+ "pixel_values": ret["pixel_values"],
146
+ "image_hashes": base_output.image_hashes,
147
+ "modalities": request_obj.modalities or ["image"],
148
+ "image_grid_thws": ret["image_grid_thw"],
149
+ "video_grid_thws": ret["video_grid_thws"],
150
+ "im_start_id": self.IM_START_TOKEN_ID,
151
+ "im_end_id": self.IM_END_TOKEN_ID,
152
+ "im_token_id": self.image_token_id,
153
+ "video_token_id": self.video_token_id,
154
+ "second_per_grid_ts": ret["second_per_grid_ts"],
155
+ }
156
+
157
+
158
+ ImageProcessorMapping = {
159
+ Qwen2VLForConditionalGeneration: Qwen2_5VLImageProcessor,
160
+ Qwen2_5_VLForConditionalGeneration: Qwen2_5VLImageProcessor,
161
+ }
@@ -293,6 +293,8 @@ class TokenizedGenerateReqInput:
293
293
  class EmbeddingReqInput:
294
294
  # The input prompt. It can be a single prompt or a batch of prompts.
295
295
  text: Optional[Union[List[str], str]] = None
296
+ # The image input. It can be a file name, a url, or base64 encoded string.
297
+ image_data: Optional[Union[List[str], str]] = None
296
298
  # The token ids for text; one can either specify text or input_ids.
297
299
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
298
300
  # The request id.
@@ -303,28 +305,40 @@ class EmbeddingReqInput:
303
305
  input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
304
306
  # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
305
307
  log_metrics: bool = True
308
+ # The modalities of the image data [image, multi-images, video]
309
+ modalities: Optional[List[str]] = None
306
310
 
307
311
  def normalize_batch_and_arguments(self):
308
- if (self.text is None and self.input_ids is None) or (
309
- self.text is not None and self.input_ids is not None
310
- ):
311
- raise ValueError("Either text or input_ids should be provided.")
312
+ # at least one of text, input_ids, or image should be provided
313
+ if self.text is None and self.input_ids is None and self.image_data is None:
314
+ raise ValueError(
315
+ "At least one of text, input_ids, or image should be provided"
316
+ )
317
+
318
+ # text and input_ids cannot be provided at the same time
319
+ if self.text is not None and self.input_ids is not None:
320
+ raise ValueError("text and input_ids cannot be provided at the same time")
312
321
 
313
322
  # Derive the batch size
323
+ self.batch_size = 0
324
+ self.is_single = True
325
+
326
+ # check the batch size of text
314
327
  if self.text is not None:
315
- if isinstance(self.text, str):
316
- self.is_single = True
317
- self.batch_size = 1
328
+ if isinstance(self.text, list):
329
+ self.batch_size += len(self.text)
318
330
  else:
319
- self.is_single = False
320
- self.batch_size = len(self.text)
321
- else:
322
- if isinstance(self.input_ids[0], int):
323
- self.is_single = True
324
- self.batch_size = 1
331
+ self.batch_size += 1
332
+
333
+ # check the batch size of input_ids
334
+ if self.input_ids is not None:
335
+ if isinstance(self.input_ids[0], list):
336
+ self.batch_size += len(self.input_ids)
325
337
  else:
326
- self.is_single = False
327
- self.batch_size = len(self.input_ids)
338
+ self.batch_size += 1
339
+
340
+ if self.batch_size > 1:
341
+ self.is_single = False
328
342
 
329
343
  # Fill in default arguments
330
344
  if self.is_single:
@@ -352,6 +366,7 @@ class EmbeddingReqInput:
352
366
  return EmbeddingReqInput(
353
367
  text=self.text[i] if self.text is not None else None,
354
368
  input_ids=self.input_ids[i] if self.input_ids is not None else None,
369
+ image_data=self.image_data[i] if self.image_data is not None else None,
355
370
  sampling_params=self.sampling_params[i],
356
371
  rid=self.rid[i],
357
372
  )
@@ -365,6 +380,8 @@ class TokenizedEmbeddingReqInput:
365
380
  input_text: str
366
381
  # The input token ids
367
382
  input_ids: List[int]
383
+ # The image inputs
384
+ image_inputs: dict
368
385
  # Dummy sampling params for compatibility
369
386
  sampling_params: SamplingParams
370
387
 
@@ -0,0 +1,134 @@
1
+ from abc import abstractmethod
2
+ from typing import Callable, List, Optional, Tuple
3
+
4
+ from sglang.srt.managers.schedule_batch import ImageInputs
5
+ from sglang.utils import logger
6
+
7
+
8
+ class MultiModalityDataPaddingPattern:
9
+ """
10
+ Data tokens (like image tokens) often need special handling during padding
11
+ to maintain model compatibility. This class provides the interface for
12
+ implementing different padding strategies for data tokens
13
+ """
14
+
15
+ @abstractmethod
16
+ def pad_input_tokens(
17
+ self, input_ids: List[int], image_inputs: ImageInputs
18
+ ) -> List[int]:
19
+ """
20
+ Pad the input ids sequence containing data tokens, and replace them with pad_values
21
+ """
22
+ pass
23
+
24
+
25
+ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
26
+ """In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
27
+
28
+ This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
29
+ """
30
+
31
+ def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
32
+ self.data_token_id_pairs = data_token_pairs
33
+
34
+ def pad_input_tokens(
35
+ self, input_ids: List[int], image_inputs: ImageInputs
36
+ ) -> List[int]:
37
+ """
38
+ This function will replace the data-tokens inbetween with pad_values accordingly
39
+ """
40
+ pad_values = image_inputs.pad_values
41
+ data_token_pairs = self.data_token_id_pairs
42
+ image_inputs.image_offsets = []
43
+ if data_token_pairs is None:
44
+ data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
45
+ if data_token_pairs is None:
46
+ logger.warning(
47
+ "No data_token_pairs provided, RadixAttention might be influenced."
48
+ )
49
+ return input_ids
50
+ start_token_ids = [s for s, _e in data_token_pairs]
51
+ end_tokens_ids = [e for _s, e in data_token_pairs]
52
+ # First start token marks new data
53
+ data_start_token = start_token_ids[0]
54
+
55
+ padded_ids = []
56
+ last_idx = 0
57
+ data_idx = -1
58
+
59
+ start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
60
+ end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
61
+
62
+ if len(start_indices) != len(end_indices):
63
+ return input_ids
64
+
65
+ for start_idx, end_idx in zip(start_indices, end_indices):
66
+ padded_ids.extend(input_ids[last_idx : start_idx + 1])
67
+
68
+ if input_ids[start_idx] == data_start_token:
69
+ data_idx += 1
70
+ image_inputs.image_offsets += [start_idx]
71
+
72
+ num_tokens = end_idx - start_idx - 1
73
+ pad_value = pad_values[data_idx]
74
+ padded_ids.extend([pad_value] * num_tokens)
75
+
76
+ last_idx = end_idx
77
+
78
+ padded_ids.extend(input_ids[last_idx:])
79
+
80
+ assert len(input_ids) == len(padded_ids)
81
+ return padded_ids
82
+
83
+
84
+ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
85
+ """In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
86
+ which needs first to be expanded to multiple tokens, then replaced with their padding values
87
+
88
+ This strategy should be used when a single data token represents content that should
89
+ be expanded to multiple tokens during processing.
90
+ """
91
+
92
+ def __init__(
93
+ self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
94
+ ) -> None:
95
+ self.num_data_token_calc_func = num_data_token_calc_func
96
+
97
+ def pad_input_tokens(
98
+ self, input_ids: List[int], image_inputs: ImageInputs
99
+ ) -> List[int]:
100
+ """
101
+ This function will follow the procedure of:
102
+ 1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
103
+ 2. the padded data tokens will be replaced with their pad_values
104
+ """
105
+ image_grid_thws = image_inputs.image_grid_thws
106
+ pad_values = image_inputs.pad_values
107
+
108
+ image_indices = [
109
+ idx
110
+ for idx, token in enumerate(input_ids)
111
+ if token == image_inputs.im_token_id
112
+ ]
113
+
114
+ image_inputs.image_offsets = []
115
+
116
+ input_ids_with_image = []
117
+ for image_cnt, _ in enumerate(image_grid_thws):
118
+ print(f"image_cnt {image_cnt}")
119
+ num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
120
+ if image_cnt == 0:
121
+ non_image_tokens = input_ids[: image_indices[image_cnt]]
122
+ else:
123
+ non_image_tokens = input_ids[
124
+ image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
125
+ ]
126
+ input_ids_with_image.extend(non_image_tokens)
127
+ image_inputs.image_offsets.append(len(input_ids_with_image))
128
+ pad_ids = pad_values * (
129
+ (num_image_tokens + len(pad_values)) // len(pad_values)
130
+ )
131
+ input_ids_with_image.extend(pad_ids[:num_image_tokens])
132
+ input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
133
+
134
+ return input_ids_with_image