sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -20,19 +20,18 @@ import copy
20
20
  import uuid
21
21
  from dataclasses import dataclass, field
22
22
  from enum import Enum
23
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
23
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
24
24
 
25
- from sglang.srt.mm_utils import has_valid_data
25
+ from sglang.srt.managers.schedule_batch import BaseFinishReason
26
+ from sglang.srt.multimodal.mm_utils import has_valid_data
27
+ from sglang.srt.sampling.sampling_params import SamplingParams
26
28
 
27
- # handle serialization of Image for pydantic
29
+ # Handle serialization of Image for pydantic
28
30
  if TYPE_CHECKING:
29
31
  from PIL.Image import Image
30
32
  else:
31
33
  Image = Any
32
34
 
33
- from sglang.srt.managers.schedule_batch import BaseFinishReason
34
- from sglang.srt.sampling.sampling_params import SamplingParams
35
-
36
35
 
37
36
  @dataclass
38
37
  class SessionParams:
@@ -40,6 +39,7 @@ class SessionParams:
40
39
  rid: Optional[str] = None
41
40
  offset: Optional[int] = None
42
41
  replace: Optional[bool] = None
42
+ drop_previous_output: Optional[bool] = None
43
43
 
44
44
 
45
45
  AudioDataItem = Union[str, Dict]
@@ -182,6 +182,7 @@ class GenerateReqInput:
182
182
  # Determine parallel sample count
183
183
  if self.sampling_params is None:
184
184
  self.parallel_sample_num = 1
185
+ return
185
186
  elif isinstance(self.sampling_params, dict):
186
187
  self.parallel_sample_num = self.sampling_params.get("n", 1)
187
188
  else: # isinstance(self.sampling_params, list):
@@ -199,6 +200,8 @@ class GenerateReqInput:
199
200
  self.text = [self.text]
200
201
  if self.input_ids is not None:
201
202
  self.input_ids = [self.input_ids]
203
+ if self.input_embeds is not None:
204
+ self.input_embeds = [self.input_embeds]
202
205
 
203
206
  def _normalize_single_inputs(self):
204
207
  """Normalize inputs for a single example."""
@@ -323,7 +326,9 @@ class GenerateReqInput:
323
326
  new_rids = [f"{self.rid}_{i}" for i in range(num)]
324
327
  self.rid = new_rids
325
328
  elif isinstance(self.rid, list):
326
- if len(self.rid) != num:
329
+ # Note: the length of rid shall be the same as the batch_size,
330
+ # as the rid would be expanded for parallel sampling in tokenizer_manager
331
+ if len(self.rid) != self.batch_size:
327
332
  raise ValueError(
328
333
  "The specified rids length mismatch with the batch_size for batch processing."
329
334
  )
@@ -399,6 +404,9 @@ class GenerateReqInput:
399
404
  return GenerateReqInput(
400
405
  text=self.text[i] if self.text is not None else None,
401
406
  input_ids=self.input_ids[i] if self.input_ids is not None else None,
407
+ input_embeds=(
408
+ self.input_embeds[i] if self.input_embeds is not None else None
409
+ ),
402
410
  image_data=self.image_data[i],
403
411
  audio_data=self.audio_data[i],
404
412
  sampling_params=self.sampling_params[i],
@@ -516,9 +524,6 @@ class EmbeddingReqInput:
516
524
  # For cross-encoder requests
517
525
  is_cross_encoder_request: bool = False
518
526
 
519
- def contains_mm_input(self) -> bool:
520
- return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
521
-
522
527
  def normalize_batch_and_arguments(self):
523
528
  # at least one of text, input_ids, or image should be provided
524
529
  if self.text is None and self.input_ids is None and self.image_data is None:
@@ -572,6 +577,9 @@ class EmbeddingReqInput:
572
577
  self.rid = uuid.uuid4().hex
573
578
  return self.rid
574
579
 
580
+ def contains_mm_input(self) -> bool:
581
+ return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
582
+
575
583
  def __getitem__(self, i):
576
584
  if self.is_cross_encoder_request:
577
585
  return EmbeddingReqInput(
@@ -740,6 +748,8 @@ class UpdateWeightFromDiskReqInput:
740
748
  model_path: str
741
749
  # The format to load the weights
742
750
  load_format: Optional[str] = None
751
+ # Whether to abort all requests before updating weights
752
+ abort_all_requests: bool = False
743
753
 
744
754
 
745
755
  @dataclass
@@ -752,9 +762,15 @@ class UpdateWeightFromDiskReqOutput:
752
762
 
753
763
  @dataclass
754
764
  class UpdateWeightsFromDistributedReqInput:
755
- name: str
756
- dtype: str
757
- shape: List[int]
765
+ names: List[str]
766
+ dtypes: List[str]
767
+ shapes: List[List[int]]
768
+ # The group name
769
+ group_name: str = "weight_update_group"
770
+ # Whether to flush the cache after updating weights
771
+ flush_cache: bool = True
772
+ # Whether to abort all requests before updating weights
773
+ abort_all_requests: bool = False
758
774
 
759
775
 
760
776
  @dataclass
@@ -776,6 +792,8 @@ class UpdateWeightsFromTensorReqInput:
776
792
  load_format: Optional[str] = None
777
793
  # Whether to flush the cache after updating weights
778
794
  flush_cache: bool = True
795
+ # Whether to abort all requests before updating weights
796
+ abort_all_requests: bool = False
779
797
 
780
798
 
781
799
  @dataclass
@@ -854,7 +872,9 @@ class SlowDownReqOutput:
854
872
  @dataclass
855
873
  class AbortReq:
856
874
  # The request id
857
- rid: str
875
+ rid: str = ""
876
+ # Whether to abort all requests
877
+ abort_all: bool = False
858
878
 
859
879
 
860
880
  @dataclass
@@ -1002,3 +1022,27 @@ class RpcReqInput:
1002
1022
  class RpcReqOutput:
1003
1023
  success: bool
1004
1024
  message: str
1025
+
1026
+
1027
+ @dataclass
1028
+ class LoadLoRAAdapterReqInput:
1029
+ # The name of the lora module to newly loaded.
1030
+ lora_name: str
1031
+ # The path of loading.
1032
+ lora_path: str
1033
+
1034
+
1035
+ @dataclass
1036
+ class UnloadLoRAAdapterReqInput:
1037
+ # The name of lora module to unload.
1038
+ lora_name: str
1039
+
1040
+
1041
+ @dataclass
1042
+ class LoRAUpdateResult:
1043
+ success: bool
1044
+ error_message: Optional[str] = None
1045
+ loaded_adapters: Dict[str, str] = field(default_factory=dict)
1046
+
1047
+
1048
+ LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@@ -2,14 +2,15 @@
2
2
  Multi-modality utils
3
3
  """
4
4
 
5
- import dataclasses
6
- import logging
5
+ import hashlib
7
6
  from abc import abstractmethod
8
7
  from typing import Callable, List, Optional, Tuple
9
8
 
9
+ import numpy as np
10
10
  import torch
11
11
  from torch import nn
12
12
 
13
+ from sglang.srt.layers.multimodal import gpu_tensor_hash
13
14
  from sglang.srt.managers.schedule_batch import (
14
15
  Modality,
15
16
  MultimodalDataItem,
@@ -124,74 +125,38 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
124
125
  e.g. <image><image>....<image>, or <audio><audio>...<audio>
125
126
  """
126
127
 
127
- def __init__(self, token_ids: List[int]) -> None:
128
- self.token_ids = token_ids
129
-
130
128
  def pad_input_tokens(
131
129
  self, input_ids: List[int], mm_inputs: MultimodalInputs
132
130
  ) -> List[int]:
133
131
  """
134
- Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
135
- and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
132
+ Replaces multimodal tokens in input_ids with corresponding pad_values from mm_items.
133
+ Each modality (image, audio, video) is handled separately based on its token_id.
136
134
  """
137
- pad_values = [item.pad_value for item in mm_inputs.mm_items]
138
- if not pad_values:
139
- # No multimodal items, return original input_ids
135
+ if not input_ids or not mm_inputs.mm_items:
140
136
  return input_ids
141
- if not input_ids:
142
- return []
143
137
 
144
138
  input_ids_tensor = torch.tensor(input_ids)
145
- device = input_ids_tensor.device
146
- token_ids_tensor = torch.tensor(self.token_ids, device=device)
147
- mask = torch.isin(input_ids_tensor, token_ids_tensor)
148
139
 
149
- if not mask.any():
150
- # No tokens match token_ids, return original input_ids
151
- return input_ids
140
+ # Create mapping of token_ids to pad_values for each modality
141
+ token_to_pad_mapping = {}
152
142
 
153
- # Find contiguous regions
154
- padded_mask = torch.cat(
155
- (
156
- torch.tensor([False], device=device),
157
- mask,
158
- torch.tensor([False], device=device),
159
- )
160
- )
161
- # Find indices where the mask value changes
162
- diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
163
-
164
- # Start indices are where False changes to True
165
- starts = diff_indices[::2]
166
- # End indices are where True changes to False (exclusive index)
167
- ends = diff_indices[1::2]
168
-
169
- # Check if the number of regions matches the number of pad values
170
- if len(starts) != len(pad_values):
171
- # Maybe log a warning here?
172
- num_regions = len(starts)
173
- num_pad_values = len(pad_values)
174
- if num_regions > 0 and num_pad_values > 0:
175
- pad_values = (pad_values * (num_regions // num_pad_values + 1))[
176
- :num_regions
177
- ]
178
- else: # If no regions or no pad_values, this loop won't run anyway.
179
- pad_values = [] # Ensure pad_values is empty if starts is empty
180
-
181
- # Create a copy to modify
182
- output_ids_tensor = input_ids_tensor.clone()
183
-
184
- # Replace tokens in each region with the corresponding pad value
185
- # Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
186
- for i in range(min(len(starts), len(pad_values))):
187
- start_idx = starts[i]
188
- end_idx = ends[i]
189
- pad_value = pad_values[i]
190
- if pad_value is not None: # Ensure pad_value is not None before assignment
191
- output_ids_tensor[start_idx:end_idx] = pad_value
143
+ for item in mm_inputs.mm_items:
144
+ if item.is_image() and mm_inputs.im_token_id is not None:
145
+ token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value
146
+ elif item.is_audio() and mm_inputs.audio_token_id is not None:
147
+ token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value
148
+ elif item.is_video() and mm_inputs.video_token_id is not None:
149
+ token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value
192
150
  else:
193
- logger.warning(f"Skipping region {i} due to None pad_value.")
194
- return output_ids_tensor.tolist()
151
+ raise ValueError(f"No multimodal token id provided for {item.modality}")
152
+
153
+ # Apply replacements for all tokens at once
154
+ for token_id, pad_value in token_to_pad_mapping.items():
155
+ input_ids_tensor[input_ids_tensor == token_id] = pad_value
156
+
157
+ ret_input_ids = input_ids_tensor.tolist()
158
+
159
+ return ret_input_ids
195
160
 
196
161
 
197
162
  embedding_cache = None
@@ -283,7 +248,9 @@ def _get_chunked_prefill_embedding(
283
248
  ) -> Optional[torch.Tensor]:
284
249
  # Calculate embedding for each request, try to get it from cache to avoid repeated calculation
285
250
  embedding_list = []
286
- for i in range(len(items_size) - 1):
251
+ # FIXME(Xinyuan): temporary workaround for eagle3, which may have len(items_size) > len(prefix_length)
252
+ max_iterations = min(len(items_size) - 1, len(prefix_length))
253
+ for i in range(max_iterations):
287
254
  if items_size[i] == items_size[i + 1]:
288
255
  continue
289
256
  embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
@@ -304,7 +271,7 @@ def _get_chunked_prefill_embedding(
304
271
  embedding_per_req_chunk, _, end_index = get_embedding_chunk(
305
272
  embedding=embedding_per_req,
306
273
  extend_prefix_len=prefix_length[i],
307
- extend_seq_len=extend_length[i],
274
+ extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
308
275
  items_offset=items_offset,
309
276
  )
310
277
  # remove this item from cache if chunk reaches to the end
@@ -680,3 +647,52 @@ def get_multimodal_data_bounds(
680
647
  # Convert valid pairs to tensor
681
648
  valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
682
649
  return valid_pairs_tensor
650
+
651
+
652
+ def data_hash(data) -> int:
653
+ hash_bytes = hashlib.sha256(data).digest()[:8]
654
+ return int.from_bytes(hash_bytes, byteorder="big", signed=False)
655
+
656
+
657
+ def tensor_hash(tensor_list) -> int:
658
+ """
659
+ hash a tensor or a tensor list
660
+ """
661
+ tensor = tensor_list
662
+ if isinstance(tensor_list, list):
663
+ tensor_list = flatten_nested_list(tensor_list)
664
+ tensor_list = [
665
+ x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list
666
+ ]
667
+ tensor = torch.concat(tensor_list)
668
+ if tensor.is_cuda:
669
+ return gpu_tensor_hash(tensor)
670
+ tensor = tensor.detach().contiguous()
671
+
672
+ if tensor.dtype == torch.bfloat16:
673
+ # memoryview() doesn't support PyTorch's BFloat16 dtype
674
+ tensor = tensor.float()
675
+
676
+ assert isinstance(tensor, torch.Tensor)
677
+ if tensor.is_cuda:
678
+ # TODO: improve this
679
+ tensor_cpu = tensor.cpu()
680
+ else:
681
+ tensor_cpu = tensor
682
+
683
+ mv = memoryview(tensor_cpu.numpy())
684
+ return data_hash(mv.tobytes())
685
+
686
+
687
+ def hash_feature(f):
688
+ if isinstance(f, list):
689
+ if isinstance(f[0], torch.Tensor):
690
+ return tensor_hash(f)
691
+ return data_hash(tuple(flatten_nested_list(f)))
692
+ elif isinstance(f, np.ndarray):
693
+ arr = np.ascontiguousarray(f)
694
+ arr_bytes = arr.tobytes()
695
+ return data_hash(arr_bytes)
696
+ elif isinstance(f, torch.Tensor):
697
+ return tensor_hash([f])
698
+ return data_hash(f)
@@ -3,11 +3,8 @@ import importlib
3
3
  import inspect
4
4
  import logging
5
5
  import pkgutil
6
- from functools import lru_cache
7
6
 
8
- from sglang.srt.managers.multimodal_processors.base_processor import (
9
- BaseMultimodalProcessor,
10
- )
7
+ from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
11
8
  from sglang.srt.server_args import ServerArgs
12
9
 
13
10
  logger = logging.getLogger(__name__)
@@ -27,9 +24,8 @@ def get_dummy_processor():
27
24
  return DummyMultimodalProcessor()
28
25
 
29
26
 
30
- @lru_cache()
31
27
  def import_processors():
32
- package_name = "sglang.srt.managers.multimodal_processors"
28
+ package_name = "sglang.srt.multimodal.processors"
33
29
  package = importlib.import_module(package_name)
34
30
  for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
35
31
  if not ispkg:
@@ -0,0 +1,94 @@
1
+ import re
2
+ from typing import List, Union
3
+
4
+ import torch
5
+
6
+ from sglang.srt.managers.multimodal_processors.base_processor import (
7
+ BaseMultimodalProcessor,
8
+ MultimodalSpecialTokens,
9
+ )
10
+ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
11
+ from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
12
+
13
+
14
+ class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
15
+ models = [Qwen2AudioForConditionalGeneration]
16
+
17
+ def __init__(self, hf_config, server_args, _processor):
18
+ super().__init__(hf_config, server_args, _processor)
19
+ self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
20
+ self.AUDIO_TOKEN_REGEX = re.compile(
21
+ r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
22
+ )
23
+
24
+ async def process_mm_data_async(
25
+ self,
26
+ image_data: List[Union[str, bytes]],
27
+ input_text,
28
+ request_obj,
29
+ max_req_input_len,
30
+ **kwargs,
31
+ ):
32
+ audio_data = request_obj.audio_data
33
+ if not isinstance(audio_data, list):
34
+ audio_data = [audio_data]
35
+
36
+ base_output = self.load_mm_data(
37
+ prompt=input_text,
38
+ max_req_input_len=max_req_input_len,
39
+ audio_data=audio_data,
40
+ multimodal_tokens=MultimodalSpecialTokens(
41
+ audio_token=self.AUDIO_TOKEN,
42
+ audio_token_regex=self.AUDIO_TOKEN_REGEX,
43
+ ),
44
+ )
45
+ if base_output is None:
46
+ return None
47
+
48
+ res = self.process_mm_data(
49
+ input_text=base_output.input_text,
50
+ audio=base_output.audios,
51
+ )
52
+
53
+ # Collect special token ids
54
+ tokenizer = self._processor.tokenizer
55
+ audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
56
+ audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
57
+ audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")
58
+
59
+ items = []
60
+ input_ids = res["input_ids"].flatten()
61
+
62
+ if (
63
+ "input_features" in res
64
+ and res["input_features"] is not None
65
+ and len(res["input_features"]) != 0
66
+ ):
67
+ if audio_start_id is not None and audio_end_id is not None:
68
+ audio_offsets = self.get_mm_items_offset_by_pair(
69
+ input_ids=input_ids,
70
+ mm_start_id=audio_start_id,
71
+ mm_end_id=audio_end_id,
72
+ )
73
+ else:
74
+ audio_offsets = None
75
+
76
+ input_lengths = res["feature_attention_mask"].sum(dim=-1)
77
+ input_lengths = (input_lengths - 1) // 2 + 1
78
+ output_lengths = (input_lengths - 2) // 2 + 1
79
+
80
+ item = MultimodalDataItem(
81
+ audio_features=res["input_features"],
82
+ audio_feature_lens=output_lengths,
83
+ audio_offsets=audio_offsets,
84
+ modality=Modality.AUDIO,
85
+ )
86
+ items += [item]
87
+
88
+ return {
89
+ "mm_items": items,
90
+ "input_ids": input_ids.tolist(),
91
+ "audio_start_id": audio_start_id,
92
+ "audio_token_id": audio_token_id,
93
+ "audio_end_id": audio_end_id,
94
+ }