sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -20,19 +20,18 @@ import copy
20
20
  import uuid
21
21
  from dataclasses import dataclass, field
22
22
  from enum import Enum
23
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
23
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
24
24
 
25
- from sglang.srt.mm_utils import has_valid_data
25
+ from sglang.srt.managers.schedule_batch import BaseFinishReason
26
+ from sglang.srt.multimodal.mm_utils import has_valid_data
27
+ from sglang.srt.sampling.sampling_params import SamplingParams
26
28
 
27
- # handle serialization of Image for pydantic
29
+ # Handle serialization of Image for pydantic
28
30
  if TYPE_CHECKING:
29
31
  from PIL.Image import Image
30
32
  else:
31
33
  Image = Any
32
34
 
33
- from sglang.srt.managers.schedule_batch import BaseFinishReason
34
- from sglang.srt.sampling.sampling_params import SamplingParams
35
-
36
35
 
37
36
  @dataclass
38
37
  class SessionParams:
@@ -40,6 +39,7 @@ class SessionParams:
40
39
  rid: Optional[str] = None
41
40
  offset: Optional[int] = None
42
41
  replace: Optional[bool] = None
42
+ drop_previous_output: Optional[bool] = None
43
43
 
44
44
 
45
45
  AudioDataItem = Union[str, Dict]
@@ -182,6 +182,7 @@ class GenerateReqInput:
182
182
  # Determine parallel sample count
183
183
  if self.sampling_params is None:
184
184
  self.parallel_sample_num = 1
185
+ return
185
186
  elif isinstance(self.sampling_params, dict):
186
187
  self.parallel_sample_num = self.sampling_params.get("n", 1)
187
188
  else: # isinstance(self.sampling_params, list):
@@ -319,8 +320,16 @@ class GenerateReqInput:
319
320
  """Normalize request IDs for batch processing."""
320
321
  if self.rid is None:
321
322
  self.rid = [uuid.uuid4().hex for _ in range(num)]
322
- elif not isinstance(self.rid, list):
323
- raise ValueError("The rid should be a list for batch processing.")
323
+ elif isinstance(self.rid, str):
324
+ new_rids = [f"{self.rid}_{i}" for i in range(num)]
325
+ self.rid = new_rids
326
+ elif isinstance(self.rid, list):
327
+ if len(self.rid) != num:
328
+ raise ValueError(
329
+ "The specified rids length mismatch with the batch_size for batch processing."
330
+ )
331
+ else:
332
+ raise ValueError("The rid should be a string or a list of strings.")
324
333
 
325
334
  def _normalize_logprob_params(self, num):
326
335
  """Normalize logprob-related parameters for batch processing."""
@@ -508,9 +517,6 @@ class EmbeddingReqInput:
508
517
  # For cross-encoder requests
509
518
  is_cross_encoder_request: bool = False
510
519
 
511
- def contains_mm_input(self) -> bool:
512
- return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
513
-
514
520
  def normalize_batch_and_arguments(self):
515
521
  # at least one of text, input_ids, or image should be provided
516
522
  if self.text is None and self.input_ids is None and self.image_data is None:
@@ -564,6 +570,9 @@ class EmbeddingReqInput:
564
570
  self.rid = uuid.uuid4().hex
565
571
  return self.rid
566
572
 
573
+ def contains_mm_input(self) -> bool:
574
+ return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
575
+
567
576
  def __getitem__(self, i):
568
577
  if self.is_cross_encoder_request:
569
578
  return EmbeddingReqInput(
@@ -732,6 +741,8 @@ class UpdateWeightFromDiskReqInput:
732
741
  model_path: str
733
742
  # The format to load the weights
734
743
  load_format: Optional[str] = None
744
+ # Whether to abort all requests before updating weights
745
+ abort_all_requests: bool = False
735
746
 
736
747
 
737
748
  @dataclass
@@ -744,9 +755,15 @@ class UpdateWeightFromDiskReqOutput:
744
755
 
745
756
  @dataclass
746
757
  class UpdateWeightsFromDistributedReqInput:
747
- name: str
748
- dtype: str
749
- shape: List[int]
758
+ names: List[str]
759
+ dtypes: List[str]
760
+ shapes: List[List[int]]
761
+ # The group name
762
+ group_name: str = "weight_update_group"
763
+ # Whether to flush the cache after updating weights
764
+ flush_cache: bool = True
765
+ # Whether to abort all requests before updating weights
766
+ abort_all_requests: bool = False
750
767
 
751
768
 
752
769
  @dataclass
@@ -768,6 +785,8 @@ class UpdateWeightsFromTensorReqInput:
768
785
  load_format: Optional[str] = None
769
786
  # Whether to flush the cache after updating weights
770
787
  flush_cache: bool = True
788
+ # Whether to abort all requests before updating weights
789
+ abort_all_requests: bool = False
771
790
 
772
791
 
773
792
  @dataclass
@@ -846,7 +865,9 @@ class SlowDownReqOutput:
846
865
  @dataclass
847
866
  class AbortReq:
848
867
  # The request id
849
- rid: str
868
+ rid: str = ""
869
+ # Whether to abort all requests
870
+ abort_all: bool = False
850
871
 
851
872
 
852
873
  @dataclass
@@ -994,3 +1015,27 @@ class RpcReqInput:
994
1015
  class RpcReqOutput:
995
1016
  success: bool
996
1017
  message: str
1018
+
1019
+
1020
+ @dataclass
1021
+ class LoadLoRAAdapterReqInput:
1022
+ # The name of the lora module to newly loaded.
1023
+ lora_name: str
1024
+ # The path of loading.
1025
+ lora_path: str
1026
+
1027
+
1028
+ @dataclass
1029
+ class UnloadLoRAAdapterReqInput:
1030
+ # The name of lora module to unload.
1031
+ lora_name: str
1032
+
1033
+
1034
+ @dataclass
1035
+ class LoRAUpdateResult:
1036
+ success: bool
1037
+ error_message: Optional[str] = None
1038
+ loaded_adapters: Dict[str, str] = field(default_factory=dict)
1039
+
1040
+
1041
+ LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@@ -2,14 +2,15 @@
2
2
  Multi-modality utils
3
3
  """
4
4
 
5
- import dataclasses
6
- import logging
5
+ import hashlib
7
6
  from abc import abstractmethod
8
7
  from typing import Callable, List, Optional, Tuple
9
8
 
9
+ import numpy as np
10
10
  import torch
11
11
  from torch import nn
12
12
 
13
+ from sglang.srt.layers.multimodal import gpu_tensor_hash
13
14
  from sglang.srt.managers.schedule_batch import (
14
15
  Modality,
15
16
  MultimodalDataItem,
@@ -124,74 +125,38 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
124
125
  e.g. <image><image>....<image>, or <audio><audio>...<audio>
125
126
  """
126
127
 
127
- def __init__(self, token_ids: List[int]) -> None:
128
- self.token_ids = token_ids
129
-
130
128
  def pad_input_tokens(
131
129
  self, input_ids: List[int], mm_inputs: MultimodalInputs
132
130
  ) -> List[int]:
133
131
  """
134
- Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
135
- and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
132
+ Replaces multimodal tokens in input_ids with corresponding pad_values from mm_items.
133
+ Each modality (image, audio, video) is handled separately based on its token_id.
136
134
  """
137
- pad_values = [item.pad_value for item in mm_inputs.mm_items]
138
- if not pad_values:
139
- # No multimodal items, return original input_ids
135
+ if not input_ids or not mm_inputs.mm_items:
140
136
  return input_ids
141
- if not input_ids:
142
- return []
143
137
 
144
138
  input_ids_tensor = torch.tensor(input_ids)
145
- device = input_ids_tensor.device
146
- token_ids_tensor = torch.tensor(self.token_ids, device=device)
147
- mask = torch.isin(input_ids_tensor, token_ids_tensor)
148
139
 
149
- if not mask.any():
150
- # No tokens match token_ids, return original input_ids
151
- return input_ids
140
+ # Create mapping of token_ids to pad_values for each modality
141
+ token_to_pad_mapping = {}
152
142
 
153
- # Find contiguous regions
154
- padded_mask = torch.cat(
155
- (
156
- torch.tensor([False], device=device),
157
- mask,
158
- torch.tensor([False], device=device),
159
- )
160
- )
161
- # Find indices where the mask value changes
162
- diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
163
-
164
- # Start indices are where False changes to True
165
- starts = diff_indices[::2]
166
- # End indices are where True changes to False (exclusive index)
167
- ends = diff_indices[1::2]
168
-
169
- # Check if the number of regions matches the number of pad values
170
- if len(starts) != len(pad_values):
171
- # Maybe log a warning here?
172
- num_regions = len(starts)
173
- num_pad_values = len(pad_values)
174
- if num_regions > 0 and num_pad_values > 0:
175
- pad_values = (pad_values * (num_regions // num_pad_values + 1))[
176
- :num_regions
177
- ]
178
- else: # If no regions or no pad_values, this loop won't run anyway.
179
- pad_values = [] # Ensure pad_values is empty if starts is empty
180
-
181
- # Create a copy to modify
182
- output_ids_tensor = input_ids_tensor.clone()
183
-
184
- # Replace tokens in each region with the corresponding pad value
185
- # Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
186
- for i in range(min(len(starts), len(pad_values))):
187
- start_idx = starts[i]
188
- end_idx = ends[i]
189
- pad_value = pad_values[i]
190
- if pad_value is not None: # Ensure pad_value is not None before assignment
191
- output_ids_tensor[start_idx:end_idx] = pad_value
143
+ for item in mm_inputs.mm_items:
144
+ if item.is_image() and mm_inputs.im_token_id is not None:
145
+ token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value
146
+ elif item.is_audio() and mm_inputs.audio_token_id is not None:
147
+ token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value
148
+ elif item.is_video() and mm_inputs.video_token_id is not None:
149
+ token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value
192
150
  else:
193
- logger.warning(f"Skipping region {i} due to None pad_value.")
194
- return output_ids_tensor.tolist()
151
+ raise ValueError(f"No multimodal token id provided for {item.modality}")
152
+
153
+ # Apply replacements for all tokens at once
154
+ for token_id, pad_value in token_to_pad_mapping.items():
155
+ input_ids_tensor[input_ids_tensor == token_id] = pad_value
156
+
157
+ ret_input_ids = input_ids_tensor.tolist()
158
+
159
+ return ret_input_ids
195
160
 
196
161
 
197
162
  embedding_cache = None
@@ -680,3 +645,52 @@ def get_multimodal_data_bounds(
680
645
  # Convert valid pairs to tensor
681
646
  valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
682
647
  return valid_pairs_tensor
648
+
649
+
650
+ def data_hash(data) -> int:
651
+ hash_bytes = hashlib.sha256(data).digest()[:8]
652
+ return int.from_bytes(hash_bytes, byteorder="big", signed=False)
653
+
654
+
655
+ def tensor_hash(tensor_list) -> int:
656
+ """
657
+ hash a tensor or a tensor list
658
+ """
659
+ tensor = tensor_list
660
+ if isinstance(tensor_list, list):
661
+ tensor_list = flatten_nested_list(tensor_list)
662
+ tensor_list = [
663
+ x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list
664
+ ]
665
+ tensor = torch.concat(tensor_list)
666
+ if tensor.is_cuda:
667
+ return gpu_tensor_hash(tensor)
668
+ tensor = tensor.detach().contiguous()
669
+
670
+ if tensor.dtype == torch.bfloat16:
671
+ # memoryview() doesn't support PyTorch's BFloat16 dtype
672
+ tensor = tensor.float()
673
+
674
+ assert isinstance(tensor, torch.Tensor)
675
+ if tensor.is_cuda:
676
+ # TODO: improve this
677
+ tensor_cpu = tensor.cpu()
678
+ else:
679
+ tensor_cpu = tensor
680
+
681
+ mv = memoryview(tensor_cpu.numpy())
682
+ return data_hash(mv.tobytes())
683
+
684
+
685
+ def hash_feature(f):
686
+ if isinstance(f, list):
687
+ if isinstance(f[0], torch.Tensor):
688
+ return tensor_hash(f)
689
+ return data_hash(tuple(flatten_nested_list(f)))
690
+ elif isinstance(f, np.ndarray):
691
+ arr = np.ascontiguousarray(f)
692
+ arr_bytes = arr.tobytes()
693
+ return data_hash(arr_bytes)
694
+ elif isinstance(f, torch.Tensor):
695
+ return tensor_hash([f])
696
+ return data_hash(f)
@@ -3,11 +3,8 @@ import importlib
3
3
  import inspect
4
4
  import logging
5
5
  import pkgutil
6
- from functools import lru_cache
7
6
 
8
- from sglang.srt.managers.multimodal_processors.base_processor import (
9
- BaseMultimodalProcessor,
10
- )
7
+ from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
11
8
  from sglang.srt.server_args import ServerArgs
12
9
 
13
10
  logger = logging.getLogger(__name__)
@@ -27,9 +24,8 @@ def get_dummy_processor():
27
24
  return DummyMultimodalProcessor()
28
25
 
29
26
 
30
- @lru_cache()
31
27
  def import_processors():
32
- package_name = "sglang.srt.managers.multimodal_processors"
28
+ package_name = "sglang.srt.multimodal.processors"
33
29
  package = importlib.import_module(package_name)
34
30
  for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
35
31
  if not ispkg:
@@ -0,0 +1,94 @@
1
+ import re
2
+ from typing import List, Union
3
+
4
+ import torch
5
+
6
+ from sglang.srt.managers.multimodal_processors.base_processor import (
7
+ BaseMultimodalProcessor,
8
+ MultimodalSpecialTokens,
9
+ )
10
+ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
11
+ from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
12
+
13
+
14
+ class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
15
+ models = [Qwen2AudioForConditionalGeneration]
16
+
17
+ def __init__(self, hf_config, server_args, _processor):
18
+ super().__init__(hf_config, server_args, _processor)
19
+ self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
20
+ self.AUDIO_TOKEN_REGEX = re.compile(
21
+ r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
22
+ )
23
+
24
+ async def process_mm_data_async(
25
+ self,
26
+ image_data: List[Union[str, bytes]],
27
+ input_text,
28
+ request_obj,
29
+ max_req_input_len,
30
+ **kwargs,
31
+ ):
32
+ audio_data = request_obj.audio_data
33
+ if not isinstance(audio_data, list):
34
+ audio_data = [audio_data]
35
+
36
+ base_output = self.load_mm_data(
37
+ prompt=input_text,
38
+ max_req_input_len=max_req_input_len,
39
+ audio_data=audio_data,
40
+ multimodal_tokens=MultimodalSpecialTokens(
41
+ audio_token=self.AUDIO_TOKEN,
42
+ audio_token_regex=self.AUDIO_TOKEN_REGEX,
43
+ ),
44
+ )
45
+ if base_output is None:
46
+ return None
47
+
48
+ res = self.process_mm_data(
49
+ input_text=base_output.input_text,
50
+ audio=base_output.audios,
51
+ )
52
+
53
+ # Collect special token ids
54
+ tokenizer = self._processor.tokenizer
55
+ audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
56
+ audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
57
+ audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")
58
+
59
+ items = []
60
+ input_ids = res["input_ids"].flatten()
61
+
62
+ if (
63
+ "input_features" in res
64
+ and res["input_features"] is not None
65
+ and len(res["input_features"]) != 0
66
+ ):
67
+ if audio_start_id is not None and audio_end_id is not None:
68
+ audio_offsets = self.get_mm_items_offset_by_pair(
69
+ input_ids=input_ids,
70
+ mm_start_id=audio_start_id,
71
+ mm_end_id=audio_end_id,
72
+ )
73
+ else:
74
+ audio_offsets = None
75
+
76
+ input_lengths = res["feature_attention_mask"].sum(dim=-1)
77
+ input_lengths = (input_lengths - 1) // 2 + 1
78
+ output_lengths = (input_lengths - 2) // 2 + 1
79
+
80
+ item = MultimodalDataItem(
81
+ audio_features=res["input_features"],
82
+ audio_feature_lens=output_lengths,
83
+ audio_offsets=audio_offsets,
84
+ modality=Modality.AUDIO,
85
+ )
86
+ items += [item]
87
+
88
+ return {
89
+ "mm_items": items,
90
+ "input_ids": input_ids.tolist(),
91
+ "audio_start_id": audio_start_id,
92
+ "audio_token_id": audio_token_id,
93
+ "audio_end_id": audio_end_id,
94
+ }