sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/srt/configs/model_config.py +37 -5
  4. sglang/srt/constrained/base_grammar_backend.py +26 -5
  5. sglang/srt/constrained/llguidance_backend.py +1 -0
  6. sglang/srt/constrained/outlines_backend.py +1 -0
  7. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  8. sglang/srt/constrained/xgrammar_backend.py +1 -0
  9. sglang/srt/disaggregation/base/__init__.py +8 -0
  10. sglang/srt/disaggregation/base/conn.py +113 -0
  11. sglang/srt/disaggregation/decode.py +18 -5
  12. sglang/srt/disaggregation/mini_lb.py +53 -122
  13. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  14. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  16. sglang/srt/disaggregation/prefill.py +43 -19
  17. sglang/srt/disaggregation/utils.py +31 -0
  18. sglang/srt/entrypoints/EngineBase.py +53 -0
  19. sglang/srt/entrypoints/engine.py +36 -8
  20. sglang/srt/entrypoints/http_server.py +37 -8
  21. sglang/srt/entrypoints/http_server_engine.py +142 -0
  22. sglang/srt/entrypoints/verl_engine.py +37 -10
  23. sglang/srt/hf_transformers_utils.py +4 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +330 -200
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  26. sglang/srt/layers/attention/vision.py +1 -1
  27. sglang/srt/layers/dp_attention.py +2 -4
  28. sglang/srt/layers/elementwise.py +15 -2
  29. sglang/srt/layers/linear.py +1 -0
  30. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
  38. sglang/srt/layers/moe/router.py +7 -1
  39. sglang/srt/layers/moe/topk.py +37 -16
  40. sglang/srt/layers/quantization/__init__.py +12 -5
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  42. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  43. sglang/srt/layers/quantization/fp8.py +25 -13
  44. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  45. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  46. sglang/srt/layers/quantization/kv_cache.py +43 -52
  47. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  48. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  49. sglang/srt/layers/quantization/w8a8_int8.py +1 -0
  50. sglang/srt/layers/radix_attention.py +13 -1
  51. sglang/srt/layers/rotary_embedding.py +12 -1
  52. sglang/srt/managers/io_struct.py +254 -97
  53. sglang/srt/managers/mm_utils.py +3 -2
  54. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  55. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  56. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  57. sglang/srt/managers/schedule_batch.py +62 -21
  58. sglang/srt/managers/scheduler.py +71 -14
  59. sglang/srt/managers/tokenizer_manager.py +17 -3
  60. sglang/srt/managers/tp_worker.py +1 -0
  61. sglang/srt/mem_cache/memory_pool.py +14 -1
  62. sglang/srt/metrics/collector.py +9 -0
  63. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  64. sglang/srt/model_executor/forward_batch_info.py +234 -15
  65. sglang/srt/model_executor/model_runner.py +48 -9
  66. sglang/srt/model_loader/loader.py +31 -4
  67. sglang/srt/model_loader/weight_utils.py +4 -2
  68. sglang/srt/models/baichuan.py +2 -0
  69. sglang/srt/models/chatglm.py +1 -0
  70. sglang/srt/models/commandr.py +1 -0
  71. sglang/srt/models/dbrx.py +1 -0
  72. sglang/srt/models/deepseek.py +1 -0
  73. sglang/srt/models/deepseek_v2.py +248 -61
  74. sglang/srt/models/exaone.py +1 -0
  75. sglang/srt/models/gemma.py +1 -0
  76. sglang/srt/models/gemma2.py +1 -0
  77. sglang/srt/models/gemma3_causal.py +1 -0
  78. sglang/srt/models/gpt2.py +1 -0
  79. sglang/srt/models/gpt_bigcode.py +1 -0
  80. sglang/srt/models/granite.py +1 -0
  81. sglang/srt/models/grok.py +1 -0
  82. sglang/srt/models/internlm2.py +1 -0
  83. sglang/srt/models/llama.py +1 -0
  84. sglang/srt/models/llama4.py +101 -34
  85. sglang/srt/models/minicpm.py +1 -0
  86. sglang/srt/models/minicpm3.py +2 -0
  87. sglang/srt/models/mixtral.py +1 -0
  88. sglang/srt/models/mixtral_quant.py +1 -0
  89. sglang/srt/models/mllama.py +51 -8
  90. sglang/srt/models/mllama4.py +102 -29
  91. sglang/srt/models/olmo.py +1 -0
  92. sglang/srt/models/olmo2.py +1 -0
  93. sglang/srt/models/olmoe.py +1 -0
  94. sglang/srt/models/phi3_small.py +1 -0
  95. sglang/srt/models/qwen.py +1 -0
  96. sglang/srt/models/qwen2.py +1 -0
  97. sglang/srt/models/qwen2_5_vl.py +35 -70
  98. sglang/srt/models/qwen2_moe.py +1 -0
  99. sglang/srt/models/qwen2_vl.py +27 -25
  100. sglang/srt/models/stablelm.py +1 -0
  101. sglang/srt/models/xverse.py +1 -0
  102. sglang/srt/models/xverse_moe.py +1 -0
  103. sglang/srt/openai_api/adapter.py +4 -1
  104. sglang/srt/patch_torch.py +11 -0
  105. sglang/srt/server_args.py +34 -0
  106. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  107. sglang/srt/speculative/eagle_utils.py +1 -11
  108. sglang/srt/speculative/eagle_worker.py +6 -2
  109. sglang/srt/utils.py +120 -9
  110. sglang/test/attention/test_flashattn_backend.py +259 -221
  111. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  112. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  113. sglang/test/test_block_fp8.py +57 -0
  114. sglang/test/test_utils.py +19 -8
  115. sglang/version.py +1 -1
  116. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  117. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
  118. sglang/srt/disaggregation/conn.py +0 -81
  119. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  120. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  121. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,16 @@ import dataclasses
4
4
  import multiprocessing as mp
5
5
  import os
6
6
  from abc import ABC, abstractmethod
7
- from typing import Optional
7
+ from typing import List, Optional
8
8
 
9
9
  import numpy as np
10
10
  import PIL
11
11
  from decord import VideoReader, cpu
12
12
  from PIL import Image
13
+ from transformers import BaseImageProcessorFast
13
14
 
14
- from sglang.srt.utils import encode_video, load_audio, load_image, logger
15
+ from sglang.srt.managers.schedule_batch import Modality
16
+ from sglang.srt.utils import encode_video, load_audio, load_image
15
17
 
16
18
 
17
19
  @dataclasses.dataclass
@@ -78,6 +80,10 @@ class BaseMultimodalProcessor(ABC):
78
80
  kwargs["audios"] = audios
79
81
 
80
82
  processor = self._processor
83
+ if hasattr(processor, "image_processor") and isinstance(
84
+ processor.image_processor, BaseImageProcessorFast
85
+ ):
86
+ kwargs["device"] = "cuda"
81
87
  result = processor.__call__(
82
88
  text=[input_text],
83
89
  padding=True,
@@ -111,6 +117,84 @@ class BaseMultimodalProcessor(ABC):
111
117
 
112
118
  return estimated_frames_list
113
119
 
120
+ @staticmethod
121
+ def _load_single_item(
122
+ data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
123
+ ):
124
+ """Static method that can be pickled for multiprocessing"""
125
+ try:
126
+ if is_audio:
127
+ return load_audio(data)
128
+ elif is_video:
129
+ path = data[len("video:") :]
130
+ return encode_video(path, frame_count_limit)
131
+ else:
132
+ img, _ = load_image(data)
133
+ return img.convert("RGB") if discard_alpha_channel else img
134
+ except Exception as e:
135
+ raise RuntimeError(f"Error while loading data {data}: {e}")
136
+
137
+ def submit_data_loading_tasks(
138
+ self,
139
+ text_parts: List[str],
140
+ multimodal_tokens: MultimodalSpecialTokens,
141
+ image_data: Optional[list] = None,
142
+ audio_data: Optional[list] = None,
143
+ discard_alpha_channel: bool = True,
144
+ ):
145
+ """
146
+ load multimodal data parallelly
147
+ """
148
+
149
+ # TODO(mick): load from server_args, env, or sampling_params
150
+ MAX_NUM_FRAMES = 30
151
+ estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
152
+ total_frame_count = sum(estimated_frames_list)
153
+ # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
154
+ # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
155
+ scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
156
+
157
+ assert len(image_data) == len(estimated_frames_list)
158
+ # Submit all tasks
159
+ futures = []
160
+ task_info = []
161
+ image_index, audio_index = 0, 0
162
+
163
+ for text_part in text_parts:
164
+ if text_part == multimodal_tokens.image_token:
165
+ data = image_data[image_index]
166
+ is_video = isinstance(data, str) and data.startswith("video:")
167
+ estimated_frames = estimated_frames_list[image_index]
168
+ frame_count_limit = max(1, int(estimated_frames * scaling_factor))
169
+ futures.append(
170
+ self.io_executor.submit(
171
+ BaseMultimodalProcessor._load_single_item,
172
+ data,
173
+ is_video,
174
+ False,
175
+ frame_count_limit,
176
+ discard_alpha_channel,
177
+ )
178
+ )
179
+ task_info.append((Modality.IMAGE, data, frame_count_limit))
180
+ image_index += 1
181
+ elif text_part == multimodal_tokens.audio_token:
182
+ data = audio_data[audio_index]
183
+ futures.append(
184
+ self.io_executor.submit(
185
+ BaseMultimodalProcessor._load_single_item,
186
+ data,
187
+ False,
188
+ True,
189
+ None,
190
+ discard_alpha_channel,
191
+ )
192
+ )
193
+ task_info.append((Modality.AUDIO, data, None))
194
+ audio_index += 1
195
+
196
+ return futures, task_info
197
+
114
198
  def load_mm_data(
115
199
  self,
116
200
  prompt: str,
@@ -155,84 +239,37 @@ class BaseMultimodalProcessor(ABC):
155
239
  # split text into list of normal text and special tokens
156
240
  text_parts = re.split(pattern, prompt)
157
241
 
158
- # TODO(mick): load from server_args, env, or sampling_params
159
- MAX_NUM_FRAMES = 30
160
- estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
161
- total_frame_count = sum(estimated_frames_list)
162
- # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
163
- # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
164
- scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
165
-
166
- assert len(image_data) == len(estimated_frames_list)
167
-
168
- image_index, audio_index = 0, 0
169
- hashes, image_sizes, images, audios = [], [], [], []
242
+ futures, task_info = self.submit_data_loading_tasks(
243
+ text_parts=text_parts,
244
+ multimodal_tokens=multimodal_tokens,
245
+ image_data=image_data,
246
+ audio_data=audio_data,
247
+ discard_alpha_channel=discard_alpha_channel,
248
+ )
249
+ # Process results
250
+ image_sizes, images, audios = [], [], []
170
251
  new_text = ""
171
- for index, text_part in enumerate(text_parts):
172
- try:
173
- if text_part == multimodal_tokens.image_token:
174
- # load as image
175
- if len(images) >= MAX_NUM_FRAMES:
176
- frames_to_process = 0
177
- else:
178
- estimated_frames = estimated_frames_list[image_index]
179
- frames_to_process = max(
180
- 1, int(estimated_frames * scaling_factor)
181
- )
182
-
183
- if frames_to_process == 0:
184
- frames = []
185
- else:
186
- image_file = image_data[image_index]
187
- if isinstance(image_file, str) and image_file.startswith(
188
- "video:"
189
- ):
190
- # video
191
- path = image_file[len("video:") :]
192
- frames = encode_video(
193
- path, frame_count_limit=frames_to_process
194
- )
195
- else:
196
- # image
197
- raw_image, _size = load_image(image_file)
198
- if discard_alpha_channel:
199
- raw_image = raw_image.convert("RGB")
200
- frames = [raw_image]
201
- if len(frames) == 0:
202
- continue
203
-
204
- image_sizes += frames[0].size * len(frames)
205
-
206
- # Generate a hashable value for the image file
207
- if isinstance(image_file, Image.Image):
208
- # For PIL.Image objects, use the ID as a hashable value
209
- hash_value = hash(id(image_file))
210
- else:
211
- # For other types (strings, etc.), use the regular hash
212
- hash_value = hash(image_file)
213
-
214
- hashes += [hash_value] * len(frames)
215
- images += frames
216
- image_index += 1
217
- if frames_to_process != 0:
252
+ task_ptr = 0
253
+
254
+ for text_part in text_parts:
255
+ if text_part in multimodal_tokens.collect():
256
+ task_type, data, frame_limit = task_info[task_ptr]
257
+ result = futures[task_ptr].result()
258
+ task_ptr += 1
259
+
260
+ if task_type == Modality.IMAGE:
261
+ frames = [result] if not isinstance(result, list) else result
262
+ if frames:
263
+ image_sizes += frames[0].size * len(frames)
264
+ images += frames
218
265
  new_text += multimodal_tokens.image_token * len(frames)
219
- assert frames_to_process == len(frames)
220
- elif text_part == multimodal_tokens.audio_token:
221
- # load as audio
222
- audio_file = audio_data[audio_index]
223
- audio = load_audio(audio_file)
224
- hashes += [hash(audio_file)]
225
- audios += [audio]
226
- audio_index += 1
266
+ elif task_type == Modality.AUDIO:
267
+ # audio
268
+ audios.append(result)
227
269
  new_text += multimodal_tokens.audio_token
228
- else:
229
- # TODO(mick): handle video
230
- # normal text
231
- new_text += text_part
232
-
233
- except Exception as e:
234
- logger.error(f"An exception occurred while loading images: {e}")
235
- raise RuntimeError(f"An exception occurred while loading images: {e}")
270
+ # TODO: handle video
271
+ else:
272
+ new_text += text_part
236
273
 
237
274
  out = BaseMultiModalProcessorOutput(
238
275
  images=images,
@@ -33,7 +33,9 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
33
33
  base_out = self.load_mm_data(
34
34
  prompt=input_ids,
35
35
  image_data=image_data,
36
- multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag),
36
+ multimodal_tokens=MultimodalSpecialTokens(
37
+ image_token=processor.image_token
38
+ ),
37
39
  max_req_input_len=max_req_input_len,
38
40
  )
39
41
 
@@ -1,10 +1,8 @@
1
- from typing import List, Mapping, Optional, Tuple, Union
1
+ from typing import List, Union
2
2
 
3
3
  import torch
4
- from PIL import Image
5
- from transformers import Llama4Processor
6
4
  from transformers.image_utils import SizeDict
7
- from transformers.models.llama4.image_processing_llama4 import (
5
+ from transformers.models.llama4.image_processing_llama4_fast import (
8
6
  find_supported_resolutions,
9
7
  get_best_fit,
10
8
  )
@@ -15,7 +13,6 @@ from sglang.srt.managers.multimodal_processors.base_processor import (
15
13
  )
16
14
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
17
15
  from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
18
- from sglang.srt.utils import load_image
19
16
 
20
17
 
21
18
  class Mllama4ImageProcessor(BaseMultimodalProcessor):
@@ -25,6 +22,9 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
25
22
  super().__init__(hf_config, server_args, _processor)
26
23
  self.vision_config = hf_config.vision_config
27
24
  self.text_config = hf_config.text_config
25
+ self.boi_token_index = hf_config.boi_token_index
26
+ self.eoi_token_index = hf_config.eoi_token_index
27
+ self.image_token_index = hf_config.image_token_index
28
28
  self.multimodal_tokens = MultimodalSpecialTokens(
29
29
  image_token=_processor.image_token
30
30
  )
@@ -54,19 +54,16 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
54
54
  )
55
55
 
56
56
  # Process the images using the processor
57
- processor = Llama4Processor.from_pretrained(
58
- self.server_args.model_path, **kwargs
59
- )
57
+ processor = self._processor
60
58
 
61
59
  # Process the prompt and images
62
- image_inputs = processor(
63
- text=processed_data.input_text,
60
+ processor_output = self.process_mm_data(
61
+ input_text=processed_data.input_text,
64
62
  images=processed_data.images,
65
- return_tensors="pt",
66
63
  )
67
64
 
68
65
  # Handle image resolutions and aspect ratios
69
- if "pixel_values" in image_inputs:
66
+ if "pixel_values" in processor_output:
70
67
  image_processor = processor.image_processor
71
68
  tokenizer = self._processor.tokenizer
72
69
 
@@ -100,8 +97,8 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
100
97
  ]
101
98
 
102
99
  # Add to image_inputs
103
- image_inputs["aspect_ratios"] = aspect_ratios
104
- image_inputs["patches_per_image"] = torch.tensor(patches_per_image)
100
+ processor_output["aspect_ratios"] = aspect_ratios
101
+ processor_output["patches_per_image"] = torch.tensor(patches_per_image)
105
102
 
106
103
  # Process embed_is_patch
107
104
  vocab = tokenizer.get_vocab()
@@ -109,7 +106,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
109
106
  image_end_id = vocab.get(processor.end_of_img_token, -1)
110
107
 
111
108
  if patch_id != -1 and image_end_id != -1:
112
- input_ids = image_inputs["input_ids"].view(-1)
109
+ input_ids = processor_output["input_ids"].view(-1)
113
110
 
114
111
  # Remove BOS token if present
115
112
  if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
@@ -129,33 +126,21 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
129
126
  for per_image_input_ids in split_input_ids:
130
127
  embed_is_patch.append(per_image_input_ids == patch_id)
131
128
 
132
- image_inputs["embed_is_patch"] = embed_is_patch
129
+ processor_output["embed_is_patch"] = embed_is_patch
133
130
 
134
131
  # Convert to the format expected by SGLang
135
- image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
132
+ processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
133
+
134
+ processor_output["im_start_id"] = self.boi_token_index
135
+ processor_output["im_end_id"] = self.eoi_token_index
136
+ processor_output["im_token_id"] = self.image_token_index
136
137
 
137
138
  # Add metadata for image processing
138
- image_inputs["mm_items"] = [
139
+ processor_output["mm_items"] = [
139
140
  MultimodalDataItem(
140
- pixel_values=image_inputs["pixel_values"],
141
+ pixel_values=processor_output["pixel_values"],
141
142
  modality=Modality.IMAGE,
142
- # Add additional metadata needed for Llama4 vision processing
143
- embed_is_patch=image_inputs.get("embed_is_patch", None),
144
- aspect_ratios=image_inputs.get("aspect_ratios", None),
145
- patches_per_image=image_inputs.get("patches_per_image", None),
146
143
  )
147
144
  ]
148
145
 
149
- return image_inputs
150
-
151
- def get_patch_per_chunk(self):
152
- """Calculate patches per chunk based on vision config"""
153
- image_size = self.vision_config.image_size
154
- patch_size = self.vision_config.patch_size
155
-
156
- assert (
157
- image_size % patch_size == 0
158
- ), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
159
-
160
- ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
161
- return (image_size // patch_size) ** 2 // ds_ratio
146
+ return processor_output
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import hashlib
3
4
  from enum import Enum, auto
4
5
 
5
6
  # Copyright 2023-2024 SGLang Team
@@ -44,7 +45,7 @@ import triton.language as tl
44
45
  from sglang.global_config import global_config
45
46
  from sglang.srt.configs.model_config import ModelConfig
46
47
  from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
47
- from sglang.srt.disaggregation.conn import KVSender
48
+ from sglang.srt.disaggregation.base import BaseKVSender
48
49
  from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
49
50
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
50
51
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
@@ -82,6 +83,7 @@ global_server_args_dict = {
82
83
  "chunked_prefill_size": ServerArgs.chunked_prefill_size,
83
84
  "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
84
85
  "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
86
+ "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
85
87
  }
86
88
 
87
89
  logger = logging.getLogger(__name__)
@@ -157,7 +159,7 @@ class Modality(Enum):
157
159
  @dataclasses.dataclass
158
160
  class MultimodalDataItem:
159
161
  """
160
- A single multimodal data, from a single image/video/audio or other
162
+ A single multimodal data, from a single image/video/audio or others
161
163
  """
162
164
 
163
165
  modality: Modality
@@ -195,17 +197,54 @@ class MultimodalDataItem:
195
197
 
196
198
  def set_pad_value(self):
197
199
  """
198
- Set the pad value after first hashign the data
200
+ Set the pad value after first hashing the data
199
201
  """
200
202
 
203
+ def data_hash(data) -> int:
204
+ hash_bytes = hashlib.sha256(data).digest()[:8]
205
+ return int.from_bytes(hash_bytes, byteorder="big", signed=False)
206
+
207
+ def tensor_hash(tensor_list) -> int:
208
+ """
209
+ hash a tensor or a tensor list
210
+ """
211
+ tensor = tensor_list
212
+ if isinstance(tensor_list, list):
213
+ tensor_list = flatten_nested_list(tensor_list)
214
+ tensor_list = [
215
+ x.flatten() if isinstance(x, torch.Tensor) else x
216
+ for x in tensor_list
217
+ ]
218
+ tensor = torch.concat(tensor_list)
219
+
220
+ tensor = tensor.detach().contiguous()
221
+
222
+ if tensor.dtype == torch.bfloat16:
223
+ # memoryview() doesn't support PyTorch's BFloat16 dtype
224
+ tensor = tensor.float()
225
+
226
+ assert isinstance(tensor, torch.Tensor)
227
+ if tensor.is_cuda:
228
+ # TODO: improve this
229
+ tensor_cpu = tensor.cpu()
230
+ else:
231
+ tensor_cpu = tensor
232
+
233
+ mv = memoryview(tensor_cpu.numpy())
234
+ return data_hash(mv.tobytes())
235
+
201
236
  def hash_feature(f):
202
237
  if isinstance(f, list):
203
- return hash(tuple(flatten_nested_list(f)))
238
+ if isinstance(f[0], torch.Tensor):
239
+ return tensor_hash(f)
240
+ return data_hash(tuple(flatten_nested_list(f)))
204
241
  elif isinstance(f, np.ndarray):
205
242
  arr = np.ascontiguousarray(f)
206
243
  arr_bytes = arr.tobytes()
207
- return hash(arr_bytes)
208
- return hash(f)
244
+ return data_hash(arr_bytes)
245
+ elif isinstance(f, torch.Tensor):
246
+ return tensor_hash([f])
247
+ return data_hash(f)
209
248
 
210
249
  if self.is_audio():
211
250
  self.hash = hash_feature(self.audio_features)
@@ -230,6 +269,9 @@ class MultimodalDataItem:
230
269
  self.modality == Modality.VIDEO
231
270
  ) and not MultimodalDataItem.is_empty_list(self.pixel_values)
232
271
 
272
+ def is_valid(self) -> bool:
273
+ return self.is_image() or self.is_video() or self.is_audio()
274
+
233
275
  def validate(self):
234
276
  ...
235
277
  # TODO
@@ -248,7 +290,7 @@ class MultimodalInputs:
248
290
  mrope_position_delta: Optional[torch.Tensor] = None
249
291
 
250
292
  # image
251
- im_token_id: Optional[torch.Tensor] = None
293
+ im_token_id: Optional[int] = None
252
294
  im_start_id: Optional[int] = None
253
295
  im_end_id: Optional[int] = None
254
296
  slice_start_id: Optional[int] = None
@@ -268,11 +310,7 @@ class MultimodalInputs:
268
310
  )
269
311
 
270
312
  assert isinstance(ret.mm_items, list)
271
- ret.mm_items = [
272
- item
273
- for item in ret.mm_items
274
- if item.is_audio() or item.is_image() or item.is_video()
275
- ]
313
+ ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
276
314
 
277
315
  assert len(ret.mm_items) != 0
278
316
 
@@ -284,7 +322,6 @@ class MultimodalInputs:
284
322
  item.set_pad_value()
285
323
 
286
324
  optional_args = [
287
- "modalities",
288
325
  "im_token_id",
289
326
  "im_start_id",
290
327
  "im_end_id",
@@ -307,8 +344,8 @@ class MultimodalInputs:
307
344
  """ """
308
345
  return any(item.is_audio() for item in self.mm_items)
309
346
 
310
- def collect_image_inputs(self) -> List[torch.Tensor]:
311
- return [item.pixel_values for item in self.mm_items if item.is_image()]
347
+ def contains_mm_input(self) -> bool:
348
+ return any(True for item in self.mm_items if item.is_valid())
312
349
 
313
350
  def merge(self, other: MultimodalInputs):
314
351
  """
@@ -322,10 +359,8 @@ class MultimodalInputs:
322
359
 
323
360
  # args needed to be merged
324
361
  optional_args = [
325
- "items",
326
- "image_offsets",
362
+ "mm_items",
327
363
  "image_pad_len",
328
- # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
329
364
  ]
330
365
  for arg in optional_args:
331
366
  self_arg = getattr(self, arg, None)
@@ -354,6 +389,8 @@ class Req:
354
389
  custom_logit_processor: Optional[str] = None,
355
390
  return_hidden_states: bool = False,
356
391
  eos_token_ids: Optional[Set[int]] = None,
392
+ bootstrap_host: Optional[str] = None,
393
+ bootstrap_room: Optional[int] = None,
357
394
  ):
358
395
  # Input and output info
359
396
  self.rid = rid
@@ -438,6 +475,10 @@ class Req:
438
475
  self.temp_scaled_logprobs = False
439
476
  self.top_p_normalized_logprobs = False
440
477
 
478
+ # Latency Breakdown
479
+ self.queue_time_start = None
480
+ self.queue_time_end = None
481
+
441
482
  # Logprobs (return values)
442
483
  self.input_token_logprobs_val: Optional[List[float]] = None
443
484
  self.input_token_logprobs_idx: Optional[List[int]] = None
@@ -483,9 +524,9 @@ class Req:
483
524
  self.lora_path = lora_path
484
525
 
485
526
  # For disaggregation
486
- self.bootstrap_host: str = "0.0.0.0"
487
- self.bootstrap_room: Optional[int] = None
488
- self.disagg_kv_sender: Optional[KVSender] = None
527
+ self.bootstrap_host: str = bootstrap_host
528
+ self.bootstrap_room: Optional[int] = bootstrap_room
529
+ self.disagg_kv_sender: Optional[BaseKVSender] = None
489
530
 
490
531
  # used for warmup because we don't have a pair yet when init
491
532
  self.skip_kv_transfer: bool = False