sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +48 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +34 -0
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/nixl/conn.py +6 -6
  10. sglang/srt/disaggregation/prefill.py +2 -2
  11. sglang/srt/disaggregation/utils.py +1 -1
  12. sglang/srt/distributed/parallel_state.py +44 -17
  13. sglang/srt/entrypoints/EngineBase.py +8 -0
  14. sglang/srt/entrypoints/engine.py +40 -6
  15. sglang/srt/entrypoints/http_server.py +111 -24
  16. sglang/srt/entrypoints/openai/protocol.py +4 -2
  17. sglang/srt/eplb/__init__.py +0 -0
  18. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  19. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  20. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  21. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  22. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  24. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  25. sglang/srt/hf_transformers_utils.py +2 -1
  26. sglang/srt/layers/activation.py +2 -2
  27. sglang/srt/layers/amx_utils.py +86 -0
  28. sglang/srt/layers/attention/ascend_backend.py +219 -0
  29. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  30. sglang/srt/layers/attention/tbo_backend.py +37 -9
  31. sglang/srt/layers/communicator.py +18 -2
  32. sglang/srt/layers/dp_attention.py +9 -3
  33. sglang/srt/layers/elementwise.py +76 -12
  34. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  35. sglang/srt/layers/layernorm.py +26 -0
  36. sglang/srt/layers/linear.py +84 -14
  37. sglang/srt/layers/logits_processor.py +4 -4
  38. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +36 -13
  40. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  41. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
  42. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
  43. sglang/srt/layers/moe/router.py +60 -22
  44. sglang/srt/layers/moe/topk.py +10 -28
  45. sglang/srt/layers/parameter.py +67 -7
  46. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  47. sglang/srt/layers/quantization/fp8.py +44 -0
  48. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  49. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  50. sglang/srt/layers/quantization/gptq.py +5 -1
  51. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  52. sglang/srt/layers/quantization/quant_utils.py +166 -0
  53. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  54. sglang/srt/layers/rotary_embedding.py +2 -2
  55. sglang/srt/layers/vocab_parallel_embedding.py +11 -7
  56. sglang/srt/lora/lora.py +4 -5
  57. sglang/srt/lora/lora_manager.py +73 -20
  58. sglang/srt/managers/configure_logging.py +1 -1
  59. sglang/srt/managers/io_struct.py +50 -13
  60. sglang/srt/managers/mm_utils.py +73 -59
  61. sglang/srt/managers/multimodal_processor.py +2 -6
  62. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  63. sglang/srt/managers/schedule_batch.py +77 -84
  64. sglang/srt/managers/scheduler.py +113 -59
  65. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  66. sglang/srt/managers/session_controller.py +12 -3
  67. sglang/srt/managers/tokenizer_manager.py +314 -103
  68. sglang/srt/managers/tp_worker.py +13 -1
  69. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  70. sglang/srt/mem_cache/allocator.py +290 -0
  71. sglang/srt/mem_cache/chunk_cache.py +34 -2
  72. sglang/srt/mem_cache/memory_pool.py +289 -3
  73. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  74. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  75. sglang/srt/model_executor/forward_batch_info.py +17 -4
  76. sglang/srt/model_executor/model_runner.py +297 -56
  77. sglang/srt/model_loader/loader.py +41 -0
  78. sglang/srt/model_loader/weight_utils.py +72 -4
  79. sglang/srt/models/deepseek_nextn.py +1 -3
  80. sglang/srt/models/deepseek_v2.py +181 -45
  81. sglang/srt/models/deepseek_vl2.py +3 -5
  82. sglang/srt/models/gemma3_causal.py +1 -2
  83. sglang/srt/models/gemma3n_causal.py +4 -3
  84. sglang/srt/models/gemma3n_mm.py +4 -20
  85. sglang/srt/models/hunyuan.py +1 -1
  86. sglang/srt/models/kimi_vl.py +1 -2
  87. sglang/srt/models/llama.py +10 -4
  88. sglang/srt/models/llama4.py +32 -45
  89. sglang/srt/models/llama_eagle3.py +61 -11
  90. sglang/srt/models/llava.py +5 -5
  91. sglang/srt/models/minicpmo.py +2 -2
  92. sglang/srt/models/mistral.py +1 -1
  93. sglang/srt/models/mllama4.py +43 -11
  94. sglang/srt/models/phi4mm.py +1 -3
  95. sglang/srt/models/pixtral.py +3 -7
  96. sglang/srt/models/qwen2.py +31 -3
  97. sglang/srt/models/qwen2_5_vl.py +1 -3
  98. sglang/srt/models/qwen2_audio.py +200 -0
  99. sglang/srt/models/qwen2_moe.py +32 -6
  100. sglang/srt/models/qwen2_vl.py +1 -4
  101. sglang/srt/models/qwen3.py +94 -25
  102. sglang/srt/models/qwen3_moe.py +68 -21
  103. sglang/srt/models/vila.py +3 -8
  104. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  105. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  106. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  107. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  108. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  109. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  110. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  111. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  112. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  117. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  120. sglang/srt/operations_strategy.py +6 -2
  121. sglang/srt/reasoning_parser.py +26 -0
  122. sglang/srt/sampling/sampling_batch_info.py +39 -1
  123. sglang/srt/server_args.py +69 -22
  124. sglang/srt/speculative/build_eagle_tree.py +57 -18
  125. sglang/srt/speculative/eagle_worker.py +6 -4
  126. sglang/srt/two_batch_overlap.py +200 -27
  127. sglang/srt/utils.py +306 -146
  128. sglang/srt/warmup.py +12 -3
  129. sglang/test/runners.py +10 -1
  130. sglang/test/test_utils.py +15 -3
  131. sglang/version.py +1 -1
  132. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  133. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
  134. sglang/math_utils.py +0 -8
  135. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  136. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  137. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  138. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  139. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  140. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  141. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import (
35
35
  get_normalized_lora_weight_names,
36
36
  get_weight_name,
37
37
  )
38
+ from sglang.srt.managers.io_struct import LoRAUpdateResult
38
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
40
  from sglang.srt.utils import replace_submodule
40
41
 
@@ -98,44 +99,96 @@ class LoRAManager:
98
99
  ],
99
100
  )
100
101
 
101
- def load_lora_adapters(self, lora_paths: Dict[str, str]):
102
+ def create_lora_update_result(
103
+ self, success: bool, error_message: str = ""
104
+ ) -> LoRAUpdateResult:
105
+ return LoRAUpdateResult(
106
+ success=success,
107
+ error_message=error_message,
108
+ loaded_adapters={
109
+ name: config.path for name, config in self.configs.items()
110
+ },
111
+ )
112
+
113
+ def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
102
114
  """
103
115
  Load LoRA adapters from the specified paths.
104
- TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
105
116
 
106
117
  Args:
107
118
  lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
108
119
  If a LoRA adapter is already loaded, it will be skipped with a warning.
109
120
  """
110
121
 
122
+ results = []
111
123
  for lora_name, lora_path in lora_paths.items():
112
- if lora_name in self.loras:
113
- logger.warning(
114
- f"LoRA adapter {lora_name} is already loaded."
115
- "If you want to reload it, please unload it first."
116
- )
117
- continue
124
+ result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
125
+ results.append(result)
126
+
127
+ self.update_state_from_configs()
128
+
129
+ return self.create_lora_update_result(
130
+ success=all(result.success for result in results),
131
+ error_message="\n".join(
132
+ result.error_message for result in results if not result.success
133
+ ),
134
+ )
135
+
136
+ def load_lora_adapter(
137
+ self, lora_name: str, lora_path: str, update_state: bool = True
138
+ ) -> LoRAUpdateResult:
139
+ """
140
+ Load a single LoRA adapter from the specified path.
141
+
142
+ Args:
143
+ lora_name (str): The name of the LoRA adapter.
144
+ lora_path (str): The file path to the LoRA adapter.
145
+ update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
146
+ """
118
147
 
148
+ success = True
149
+ error_message = ""
150
+
151
+ if lora_name in self.loras:
152
+ success = False
153
+ error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
154
+
155
+ try:
119
156
  self.configs[lora_name] = LoRAConfig(lora_path)
157
+ except Exception as e:
158
+ success = False
159
+ error_message = (
160
+ f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
161
+ )
120
162
 
121
- self.update_state_from_configs()
163
+ if update_state:
164
+ self.update_state_from_configs()
165
+
166
+ return self.create_lora_update_result(
167
+ success=success,
168
+ error_message=error_message,
169
+ )
122
170
 
123
- def unload_lora_adapters(self, lora_names: Set[str]):
171
+ def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
124
172
  """
125
173
  Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
126
174
  delete the corresponding LoRA modules.
127
-
128
- Args:
129
- lora_names (Set[str]): A set of LoRA adapter names to unload.
130
175
  """
131
- for lora_name in lora_names:
132
- if lora_name in self.loras:
133
- del self.configs[lora_name]
134
- else:
135
- logger.warning(f"LoRA adapter {lora_name} is not loaded.")
176
+
177
+ success = True
178
+ error_message = ""
179
+ if lora_name in self.loras:
180
+ del self.configs[lora_name]
181
+ else:
182
+ error_message = f"LoRA adapter {lora_name} is not loaded."
183
+ success = False
136
184
 
137
185
  self.update_state_from_configs()
138
186
 
187
+ return self.create_lora_update_result(
188
+ success=success,
189
+ error_message=error_message,
190
+ )
191
+
139
192
  def prepare_lora_batch(self, forward_batch: ForwardBatch):
140
193
  # load active loras into lora memory pool
141
194
  cur_uids = set(forward_batch.lora_paths)
@@ -372,8 +425,8 @@ class LoRAManager:
372
425
  lora_adapter.initialize_weights()
373
426
  self.loras[name] = lora_adapter
374
427
 
375
- # Clean up unused LoRA adapters
376
- for name in self.loras:
428
+ # Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
429
+ for name in list(self.loras):
377
430
  if name not in self.configs:
378
431
  logger.info(f"Unloading LoRA adapter {name}")
379
432
  del self.loras[name]
@@ -28,7 +28,7 @@ if __name__ == "__main__":
28
28
  parser = argparse.ArgumentParser()
29
29
  parser.add_argument("--url", type=str, default="http://localhost:30000")
30
30
  parser.add_argument("--log-requests", action="store_true")
31
- parser.add_argument("--log-requests-level", type=int, default=2)
31
+ parser.add_argument("--log-requests-level", type=int, default=3)
32
32
  parser.add_argument(
33
33
  "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
34
34
  )
@@ -20,19 +20,18 @@ import copy
20
20
  import uuid
21
21
  from dataclasses import dataclass, field
22
22
  from enum import Enum
23
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
23
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
24
24
 
25
- from sglang.srt.mm_utils import has_valid_data
25
+ from sglang.srt.managers.schedule_batch import BaseFinishReason
26
+ from sglang.srt.multimodal.mm_utils import has_valid_data
27
+ from sglang.srt.sampling.sampling_params import SamplingParams
26
28
 
27
- # handle serialization of Image for pydantic
29
+ # Handle serialization of Image for pydantic
28
30
  if TYPE_CHECKING:
29
31
  from PIL.Image import Image
30
32
  else:
31
33
  Image = Any
32
34
 
33
- from sglang.srt.managers.schedule_batch import BaseFinishReason
34
- from sglang.srt.sampling.sampling_params import SamplingParams
35
-
36
35
 
37
36
  @dataclass
38
37
  class SessionParams:
@@ -40,6 +39,7 @@ class SessionParams:
40
39
  rid: Optional[str] = None
41
40
  offset: Optional[int] = None
42
41
  replace: Optional[bool] = None
42
+ drop_previous_output: Optional[bool] = None
43
43
 
44
44
 
45
45
  AudioDataItem = Union[str, Dict]
@@ -182,6 +182,7 @@ class GenerateReqInput:
182
182
  # Determine parallel sample count
183
183
  if self.sampling_params is None:
184
184
  self.parallel_sample_num = 1
185
+ return
185
186
  elif isinstance(self.sampling_params, dict):
186
187
  self.parallel_sample_num = self.sampling_params.get("n", 1)
187
188
  else: # isinstance(self.sampling_params, list):
@@ -516,9 +517,6 @@ class EmbeddingReqInput:
516
517
  # For cross-encoder requests
517
518
  is_cross_encoder_request: bool = False
518
519
 
519
- def contains_mm_input(self) -> bool:
520
- return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
521
-
522
520
  def normalize_batch_and_arguments(self):
523
521
  # at least one of text, input_ids, or image should be provided
524
522
  if self.text is None and self.input_ids is None and self.image_data is None:
@@ -572,6 +570,9 @@ class EmbeddingReqInput:
572
570
  self.rid = uuid.uuid4().hex
573
571
  return self.rid
574
572
 
573
+ def contains_mm_input(self) -> bool:
574
+ return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
575
+
575
576
  def __getitem__(self, i):
576
577
  if self.is_cross_encoder_request:
577
578
  return EmbeddingReqInput(
@@ -740,6 +741,8 @@ class UpdateWeightFromDiskReqInput:
740
741
  model_path: str
741
742
  # The format to load the weights
742
743
  load_format: Optional[str] = None
744
+ # Whether to abort all requests before updating weights
745
+ abort_all_requests: bool = False
743
746
 
744
747
 
745
748
  @dataclass
@@ -752,9 +755,15 @@ class UpdateWeightFromDiskReqOutput:
752
755
 
753
756
  @dataclass
754
757
  class UpdateWeightsFromDistributedReqInput:
755
- name: str
756
- dtype: str
757
- shape: List[int]
758
+ names: List[str]
759
+ dtypes: List[str]
760
+ shapes: List[List[int]]
761
+ # The group name
762
+ group_name: str = "weight_update_group"
763
+ # Whether to flush the cache after updating weights
764
+ flush_cache: bool = True
765
+ # Whether to abort all requests before updating weights
766
+ abort_all_requests: bool = False
758
767
 
759
768
 
760
769
  @dataclass
@@ -776,6 +785,8 @@ class UpdateWeightsFromTensorReqInput:
776
785
  load_format: Optional[str] = None
777
786
  # Whether to flush the cache after updating weights
778
787
  flush_cache: bool = True
788
+ # Whether to abort all requests before updating weights
789
+ abort_all_requests: bool = False
779
790
 
780
791
 
781
792
  @dataclass
@@ -854,7 +865,9 @@ class SlowDownReqOutput:
854
865
  @dataclass
855
866
  class AbortReq:
856
867
  # The request id
857
- rid: str
868
+ rid: str = ""
869
+ # Whether to abort all requests
870
+ abort_all: bool = False
858
871
 
859
872
 
860
873
  @dataclass
@@ -1002,3 +1015,27 @@ class RpcReqInput:
1002
1015
  class RpcReqOutput:
1003
1016
  success: bool
1004
1017
  message: str
1018
+
1019
+
1020
+ @dataclass
1021
+ class LoadLoRAAdapterReqInput:
1022
+ # The name of the lora module to newly loaded.
1023
+ lora_name: str
1024
+ # The path of loading.
1025
+ lora_path: str
1026
+
1027
+
1028
+ @dataclass
1029
+ class UnloadLoRAAdapterReqInput:
1030
+ # The name of lora module to unload.
1031
+ lora_name: str
1032
+
1033
+
1034
+ @dataclass
1035
+ class LoRAUpdateResult:
1036
+ success: bool
1037
+ error_message: Optional[str] = None
1038
+ loaded_adapters: Dict[str, str] = field(default_factory=dict)
1039
+
1040
+
1041
+ LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@@ -2,14 +2,15 @@
2
2
  Multi-modality utils
3
3
  """
4
4
 
5
- import dataclasses
6
- import logging
5
+ import hashlib
7
6
  from abc import abstractmethod
8
7
  from typing import Callable, List, Optional, Tuple
9
8
 
9
+ import numpy as np
10
10
  import torch
11
11
  from torch import nn
12
12
 
13
+ from sglang.srt.layers.multimodal import gpu_tensor_hash
13
14
  from sglang.srt.managers.schedule_batch import (
14
15
  Modality,
15
16
  MultimodalDataItem,
@@ -124,74 +125,38 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
124
125
  e.g. <image><image>....<image>, or <audio><audio>...<audio>
125
126
  """
126
127
 
127
- def __init__(self, token_ids: List[int]) -> None:
128
- self.token_ids = token_ids
129
-
130
128
  def pad_input_tokens(
131
129
  self, input_ids: List[int], mm_inputs: MultimodalInputs
132
130
  ) -> List[int]:
133
131
  """
134
- Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
135
- and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
132
+ Replaces multimodal tokens in input_ids with corresponding pad_values from mm_items.
133
+ Each modality (image, audio, video) is handled separately based on its token_id.
136
134
  """
137
- pad_values = [item.pad_value for item in mm_inputs.mm_items]
138
- if not pad_values:
139
- # No multimodal items, return original input_ids
135
+ if not input_ids or not mm_inputs.mm_items:
140
136
  return input_ids
141
- if not input_ids:
142
- return []
143
137
 
144
138
  input_ids_tensor = torch.tensor(input_ids)
145
- device = input_ids_tensor.device
146
- token_ids_tensor = torch.tensor(self.token_ids, device=device)
147
- mask = torch.isin(input_ids_tensor, token_ids_tensor)
148
139
 
149
- if not mask.any():
150
- # No tokens match token_ids, return original input_ids
151
- return input_ids
140
+ # Create mapping of token_ids to pad_values for each modality
141
+ token_to_pad_mapping = {}
152
142
 
153
- # Find contiguous regions
154
- padded_mask = torch.cat(
155
- (
156
- torch.tensor([False], device=device),
157
- mask,
158
- torch.tensor([False], device=device),
159
- )
160
- )
161
- # Find indices where the mask value changes
162
- diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
163
-
164
- # Start indices are where False changes to True
165
- starts = diff_indices[::2]
166
- # End indices are where True changes to False (exclusive index)
167
- ends = diff_indices[1::2]
168
-
169
- # Check if the number of regions matches the number of pad values
170
- if len(starts) != len(pad_values):
171
- # Maybe log a warning here?
172
- num_regions = len(starts)
173
- num_pad_values = len(pad_values)
174
- if num_regions > 0 and num_pad_values > 0:
175
- pad_values = (pad_values * (num_regions // num_pad_values + 1))[
176
- :num_regions
177
- ]
178
- else: # If no regions or no pad_values, this loop won't run anyway.
179
- pad_values = [] # Ensure pad_values is empty if starts is empty
180
-
181
- # Create a copy to modify
182
- output_ids_tensor = input_ids_tensor.clone()
183
-
184
- # Replace tokens in each region with the corresponding pad value
185
- # Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
186
- for i in range(min(len(starts), len(pad_values))):
187
- start_idx = starts[i]
188
- end_idx = ends[i]
189
- pad_value = pad_values[i]
190
- if pad_value is not None: # Ensure pad_value is not None before assignment
191
- output_ids_tensor[start_idx:end_idx] = pad_value
143
+ for item in mm_inputs.mm_items:
144
+ if item.is_image() and mm_inputs.im_token_id is not None:
145
+ token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value
146
+ elif item.is_audio() and mm_inputs.audio_token_id is not None:
147
+ token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value
148
+ elif item.is_video() and mm_inputs.video_token_id is not None:
149
+ token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value
192
150
  else:
193
- logger.warning(f"Skipping region {i} due to None pad_value.")
194
- return output_ids_tensor.tolist()
151
+ raise ValueError(f"No multimodal token id provided for {item.modality}")
152
+
153
+ # Apply replacements for all tokens at once
154
+ for token_id, pad_value in token_to_pad_mapping.items():
155
+ input_ids_tensor[input_ids_tensor == token_id] = pad_value
156
+
157
+ ret_input_ids = input_ids_tensor.tolist()
158
+
159
+ return ret_input_ids
195
160
 
196
161
 
197
162
  embedding_cache = None
@@ -680,3 +645,52 @@ def get_multimodal_data_bounds(
680
645
  # Convert valid pairs to tensor
681
646
  valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
682
647
  return valid_pairs_tensor
648
+
649
+
650
+ def data_hash(data) -> int:
651
+ hash_bytes = hashlib.sha256(data).digest()[:8]
652
+ return int.from_bytes(hash_bytes, byteorder="big", signed=False)
653
+
654
+
655
+ def tensor_hash(tensor_list) -> int:
656
+ """
657
+ hash a tensor or a tensor list
658
+ """
659
+ tensor = tensor_list
660
+ if isinstance(tensor_list, list):
661
+ tensor_list = flatten_nested_list(tensor_list)
662
+ tensor_list = [
663
+ x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list
664
+ ]
665
+ tensor = torch.concat(tensor_list)
666
+ if tensor.is_cuda:
667
+ return gpu_tensor_hash(tensor)
668
+ tensor = tensor.detach().contiguous()
669
+
670
+ if tensor.dtype == torch.bfloat16:
671
+ # memoryview() doesn't support PyTorch's BFloat16 dtype
672
+ tensor = tensor.float()
673
+
674
+ assert isinstance(tensor, torch.Tensor)
675
+ if tensor.is_cuda:
676
+ # TODO: improve this
677
+ tensor_cpu = tensor.cpu()
678
+ else:
679
+ tensor_cpu = tensor
680
+
681
+ mv = memoryview(tensor_cpu.numpy())
682
+ return data_hash(mv.tobytes())
683
+
684
+
685
+ def hash_feature(f):
686
+ if isinstance(f, list):
687
+ if isinstance(f[0], torch.Tensor):
688
+ return tensor_hash(f)
689
+ return data_hash(tuple(flatten_nested_list(f)))
690
+ elif isinstance(f, np.ndarray):
691
+ arr = np.ascontiguousarray(f)
692
+ arr_bytes = arr.tobytes()
693
+ return data_hash(arr_bytes)
694
+ elif isinstance(f, torch.Tensor):
695
+ return tensor_hash([f])
696
+ return data_hash(f)
@@ -3,11 +3,8 @@ import importlib
3
3
  import inspect
4
4
  import logging
5
5
  import pkgutil
6
- from functools import lru_cache
7
6
 
8
- from sglang.srt.managers.multimodal_processors.base_processor import (
9
- BaseMultimodalProcessor,
10
- )
7
+ from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
11
8
  from sglang.srt.server_args import ServerArgs
12
9
 
13
10
  logger = logging.getLogger(__name__)
@@ -27,9 +24,8 @@ def get_dummy_processor():
27
24
  return DummyMultimodalProcessor()
28
25
 
29
26
 
30
- @lru_cache()
31
27
  def import_processors():
32
- package_name = "sglang.srt.managers.multimodal_processors"
28
+ package_name = "sglang.srt.multimodal.processors"
33
29
  package = importlib.import_module(package_name)
34
30
  for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
35
31
  if not ispkg:
@@ -0,0 +1,94 @@
1
+ import re
2
+ from typing import List, Union
3
+
4
+ import torch
5
+
6
+ from sglang.srt.managers.multimodal_processors.base_processor import (
7
+ BaseMultimodalProcessor,
8
+ MultimodalSpecialTokens,
9
+ )
10
+ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
11
+ from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
12
+
13
+
14
+ class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
15
+ models = [Qwen2AudioForConditionalGeneration]
16
+
17
+ def __init__(self, hf_config, server_args, _processor):
18
+ super().__init__(hf_config, server_args, _processor)
19
+ self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
20
+ self.AUDIO_TOKEN_REGEX = re.compile(
21
+ r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
22
+ )
23
+
24
+ async def process_mm_data_async(
25
+ self,
26
+ image_data: List[Union[str, bytes]],
27
+ input_text,
28
+ request_obj,
29
+ max_req_input_len,
30
+ **kwargs,
31
+ ):
32
+ audio_data = request_obj.audio_data
33
+ if not isinstance(audio_data, list):
34
+ audio_data = [audio_data]
35
+
36
+ base_output = self.load_mm_data(
37
+ prompt=input_text,
38
+ max_req_input_len=max_req_input_len,
39
+ audio_data=audio_data,
40
+ multimodal_tokens=MultimodalSpecialTokens(
41
+ audio_token=self.AUDIO_TOKEN,
42
+ audio_token_regex=self.AUDIO_TOKEN_REGEX,
43
+ ),
44
+ )
45
+ if base_output is None:
46
+ return None
47
+
48
+ res = self.process_mm_data(
49
+ input_text=base_output.input_text,
50
+ audio=base_output.audios,
51
+ )
52
+
53
+ # Collect special token ids
54
+ tokenizer = self._processor.tokenizer
55
+ audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
56
+ audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
57
+ audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")
58
+
59
+ items = []
60
+ input_ids = res["input_ids"].flatten()
61
+
62
+ if (
63
+ "input_features" in res
64
+ and res["input_features"] is not None
65
+ and len(res["input_features"]) != 0
66
+ ):
67
+ if audio_start_id is not None and audio_end_id is not None:
68
+ audio_offsets = self.get_mm_items_offset_by_pair(
69
+ input_ids=input_ids,
70
+ mm_start_id=audio_start_id,
71
+ mm_end_id=audio_end_id,
72
+ )
73
+ else:
74
+ audio_offsets = None
75
+
76
+ input_lengths = res["feature_attention_mask"].sum(dim=-1)
77
+ input_lengths = (input_lengths - 1) // 2 + 1
78
+ output_lengths = (input_lengths - 2) // 2 + 1
79
+
80
+ item = MultimodalDataItem(
81
+ audio_features=res["input_features"],
82
+ audio_feature_lens=output_lengths,
83
+ audio_offsets=audio_offsets,
84
+ modality=Modality.AUDIO,
85
+ )
86
+ items += [item]
87
+
88
+ return {
89
+ "mm_items": items,
90
+ "input_ids": input_ids.tolist(),
91
+ "audio_start_id": audio_start_id,
92
+ "audio_token_id": audio_token_id,
93
+ "audio_end_id": audio_end_id,
94
+ }