sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/_custom_ops.py +29 -1
  3. sglang/srt/configs/internvl.py +3 -0
  4. sglang/srt/configs/model_config.py +5 -1
  5. sglang/srt/constrained/base_grammar_backend.py +10 -2
  6. sglang/srt/constrained/xgrammar_backend.py +7 -5
  7. sglang/srt/conversation.py +17 -2
  8. sglang/srt/debug_utils/__init__.py +0 -0
  9. sglang/srt/debug_utils/dump_comparator.py +131 -0
  10. sglang/srt/debug_utils/dumper.py +108 -0
  11. sglang/srt/debug_utils/text_comparator.py +172 -0
  12. sglang/srt/disaggregation/common/conn.py +34 -6
  13. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  14. sglang/srt/disaggregation/mini_lb.py +3 -2
  15. sglang/srt/disaggregation/mooncake/conn.py +65 -20
  16. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  17. sglang/srt/disaggregation/nixl/conn.py +17 -13
  18. sglang/srt/disaggregation/prefill.py +13 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  21. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  23. sglang/srt/distributed/parallel_state.py +70 -15
  24. sglang/srt/entrypoints/engine.py +5 -9
  25. sglang/srt/entrypoints/http_server.py +20 -32
  26. sglang/srt/entrypoints/openai/protocol.py +3 -3
  27. sglang/srt/entrypoints/openai/serving_chat.py +148 -72
  28. sglang/srt/function_call/base_format_detector.py +74 -12
  29. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  30. sglang/srt/function_call/ebnf_composer.py +105 -66
  31. sglang/srt/function_call/function_call_parser.py +6 -4
  32. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  33. sglang/srt/function_call/kimik2_detector.py +41 -16
  34. sglang/srt/function_call/llama32_detector.py +6 -3
  35. sglang/srt/function_call/mistral_detector.py +11 -3
  36. sglang/srt/function_call/pythonic_detector.py +16 -14
  37. sglang/srt/function_call/qwen25_detector.py +12 -3
  38. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
  39. sglang/srt/layers/activation.py +11 -3
  40. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  41. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  42. sglang/srt/layers/attention/vision.py +56 -8
  43. sglang/srt/layers/communicator.py +12 -12
  44. sglang/srt/layers/dp_attention.py +72 -24
  45. sglang/srt/layers/layernorm.py +26 -1
  46. sglang/srt/layers/logits_processor.py +46 -25
  47. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  48. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  51. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  52. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  53. sglang/srt/layers/moe/topk.py +88 -34
  54. sglang/srt/layers/multimodal.py +11 -8
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  56. sglang/srt/layers/quantization/fp8.py +25 -247
  57. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  58. sglang/srt/layers/quantization/modelopt_quant.py +33 -14
  59. sglang/srt/layers/quantization/unquant.py +24 -76
  60. sglang/srt/layers/quantization/utils.py +0 -9
  61. sglang/srt/layers/quantization/w4afp8.py +68 -17
  62. sglang/srt/layers/radix_attention.py +5 -3
  63. sglang/srt/lora/lora_manager.py +133 -169
  64. sglang/srt/lora/lora_registry.py +188 -0
  65. sglang/srt/lora/mem_pool.py +2 -2
  66. sglang/srt/managers/cache_controller.py +62 -13
  67. sglang/srt/managers/io_struct.py +19 -1
  68. sglang/srt/managers/mm_utils.py +154 -35
  69. sglang/srt/managers/multimodal_processor.py +3 -14
  70. sglang/srt/managers/schedule_batch.py +27 -11
  71. sglang/srt/managers/scheduler.py +48 -26
  72. sglang/srt/managers/tokenizer_manager.py +62 -28
  73. sglang/srt/managers/tp_worker.py +5 -4
  74. sglang/srt/mem_cache/allocator.py +67 -7
  75. sglang/srt/mem_cache/hicache_storage.py +17 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +35 -18
  77. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  78. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  79. sglang/srt/model_executor/forward_batch_info.py +201 -29
  80. sglang/srt/model_executor/model_runner.py +109 -37
  81. sglang/srt/models/deepseek_v2.py +63 -30
  82. sglang/srt/models/glm4_moe.py +1035 -0
  83. sglang/srt/models/glm4_moe_nextn.py +167 -0
  84. sglang/srt/models/interns1.py +328 -0
  85. sglang/srt/models/internvl.py +143 -47
  86. sglang/srt/models/llava.py +9 -5
  87. sglang/srt/models/minicpmo.py +4 -1
  88. sglang/srt/models/mllama4.py +10 -3
  89. sglang/srt/models/qwen2_moe.py +2 -6
  90. sglang/srt/models/qwen3_moe.py +6 -8
  91. sglang/srt/multimodal/processors/base_processor.py +20 -6
  92. sglang/srt/multimodal/processors/clip.py +2 -2
  93. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  94. sglang/srt/multimodal/processors/gemma3.py +2 -2
  95. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  96. sglang/srt/multimodal/processors/internvl.py +21 -8
  97. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  98. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  99. sglang/srt/multimodal/processors/llava.py +4 -4
  100. sglang/srt/multimodal/processors/minicpm.py +2 -3
  101. sglang/srt/multimodal/processors/mlama.py +2 -2
  102. sglang/srt/multimodal/processors/mllama4.py +18 -111
  103. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  104. sglang/srt/multimodal/processors/pixtral.py +2 -2
  105. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  106. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  107. sglang/srt/multimodal/processors/vila.py +3 -1
  108. sglang/srt/reasoning_parser.py +48 -5
  109. sglang/srt/sampling/sampling_batch_info.py +6 -5
  110. sglang/srt/server_args.py +132 -60
  111. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  112. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  113. sglang/srt/speculative/eagle_utils.py +51 -23
  114. sglang/srt/speculative/eagle_worker.py +59 -44
  115. sglang/srt/two_batch_overlap.py +9 -5
  116. sglang/srt/utils.py +113 -69
  117. sglang/srt/weight_sync/utils.py +119 -0
  118. sglang/test/runners.py +4 -0
  119. sglang/test/test_activation.py +50 -1
  120. sglang/test/test_utils.py +65 -5
  121. sglang/utils.py +19 -0
  122. sglang/version.py +1 -1
  123. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
  124. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
  125. sglang/srt/debug_utils.py +0 -74
  126. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  127. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  128. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -201,8 +201,9 @@ class PrefetchOperation(StorageOperation):
201
201
  def increment(self, num_tokens: int):
202
202
  with self._lock:
203
203
  if self._done_flag:
204
- return
204
+ return False
205
205
  self.completed_tokens += num_tokens
206
+ return True
206
207
 
207
208
  def mark_done(self):
208
209
  with self._lock:
@@ -219,6 +220,7 @@ class HiCacheController:
219
220
  token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
220
221
  mem_pool_host: HostKVCache,
221
222
  page_size: int,
223
+ tp_group: torch.distributed.ProcessGroup,
222
224
  load_cache_event: threading.Event = None,
223
225
  write_policy: str = "write_through_selective",
224
226
  io_backend: str = "",
@@ -244,11 +246,17 @@ class HiCacheController:
244
246
  self.enable_storage = False
245
247
  # todo: move backend initialization to storage backend module
246
248
  if storage_backend is not None:
249
+ # create a new communication group for synchronizing storage operations across TP workers
250
+ self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
251
+ if self.tp_world_size > 1:
252
+ group_ranks = torch.distributed.get_process_group_ranks(tp_group)
253
+ self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
254
+
247
255
  if storage_backend == "file":
248
256
  self.storage_backend = HiCacheFile()
249
257
  self.enable_storage = True
250
258
  # todo: threshold policy for prefetching
251
- self.prefetch_threshold = prefetch_threshold
259
+ self.prefetch_threshold = max(prefetch_threshold, self.page_size)
252
260
  else:
253
261
  raise NotImplementedError(
254
262
  f"Unsupported storage backend: {storage_backend}"
@@ -358,6 +366,7 @@ class HiCacheController:
358
366
  if host_indices is None:
359
367
  return None
360
368
  self.mem_pool_host.protect_write(host_indices)
369
+ torch.cuda.current_stream().synchronize()
361
370
  self.write_queue.put(
362
371
  CacheOperation(host_indices, device_indices, node_id, priority)
363
372
  )
@@ -520,12 +529,12 @@ class HiCacheController:
520
529
  f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
521
530
  )
522
531
  break
523
- self.mem_pool_host.set_from_flat_data_page(
524
- operation.host_indices[operation.completed_tokens],
525
- page_data,
526
- )
527
- operation.increment(self.page_size)
528
- if operation.is_done():
532
+ if operation.increment(self.page_size):
533
+ self.mem_pool_host.set_from_flat_data_page(
534
+ operation.host_indices[operation.completed_tokens],
535
+ page_data,
536
+ )
537
+ else:
529
538
  # operation terminated by controller, release pre-allocated memory
530
539
  self.mem_pool_host.free(
531
540
  operation.host_indices[operation.completed_tokens :]
@@ -567,13 +576,33 @@ class HiCacheController:
567
576
  else:
568
577
  break
569
578
 
579
+ if self.tp_world_size > 1:
580
+ storage_hit_count_tensor = torch.tensor(
581
+ storage_hit_count, dtype=torch.int
582
+ )
583
+ torch.distributed.all_reduce(
584
+ storage_hit_count_tensor,
585
+ op=torch.distributed.ReduceOp.MIN,
586
+ group=self.tp_group,
587
+ )
588
+ storage_hit_count = storage_hit_count_tensor.item()
589
+
570
590
  if storage_hit_count < self.prefetch_threshold:
571
591
  # not to prefetch if not enough benefits
572
592
  self.prefetch_revoke_queue.put(operation.request_id)
593
+ self.mem_pool_host.free(operation.host_indices)
594
+ logger.debug(
595
+ f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
596
+ )
573
597
  else:
574
- operation.hash_value = hash_value
598
+ operation.hash_value = hash_value[
599
+ : (storage_hit_count // self.page_size)
600
+ ]
601
+ # free the pre-allocated memory for pages that are not hit
602
+ self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
603
+ operation.host_indices = operation.host_indices[:storage_hit_count]
575
604
  logger.debug(
576
- f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
605
+ f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
577
606
  )
578
607
  self.prefetch_buffer.put(operation)
579
608
 
@@ -610,17 +639,37 @@ class HiCacheController:
610
639
  last_hash = get_hash_str(
611
640
  tokens_to_backup[i : i + self.page_size], last_hash
612
641
  )
613
- # todo, handle failures in storage backend
614
- self.storage_backend.set(
642
+ success = self.storage_backend.set(
615
643
  last_hash,
616
644
  self.mem_pool_host.get_flat_data_page(
617
645
  operation.host_indices[i]
618
646
  ),
619
647
  )
648
+ if not success:
649
+ logger.warning(f"Failed to write page {last_hash} to storage.")
650
+ break
620
651
  operation.completed_tokens += self.page_size
621
652
  operation.hash_value.append(last_hash)
622
653
 
623
- self.ack_backup_queue.put((operation.id, operation.hash_value))
654
+ min_completed_tokens = operation.completed_tokens
655
+ if self.tp_world_size > 1:
656
+ completed_tokens_tensor = torch.tensor(
657
+ min_completed_tokens, dtype=torch.int
658
+ )
659
+ torch.distributed.all_reduce(
660
+ completed_tokens_tensor,
661
+ op=torch.distributed.ReduceOp.MIN,
662
+ group=self.tp_group,
663
+ )
664
+ min_completed_tokens = completed_tokens_tensor.item()
665
+
666
+ self.ack_backup_queue.put(
667
+ (
668
+ operation.id,
669
+ operation.hash_value[: min_completed_tokens // self.page_size],
670
+ min_completed_tokens,
671
+ )
672
+ )
624
673
 
625
674
  except Empty:
626
675
  continue
@@ -22,6 +22,7 @@ from dataclasses import dataclass, field
22
22
  from enum import Enum
23
23
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
24
24
 
25
+ from sglang.srt.lora.lora_registry import LoRARef
25
26
  from sglang.srt.managers.schedule_batch import BaseFinishReason
26
27
  from sglang.srt.multimodal.mm_utils import has_valid_data
27
28
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput:
1067
1068
  lora_name: str
1068
1069
  # The path of loading.
1069
1070
  lora_path: str
1071
+ # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
1072
+ lora_id: Optional[str] = None
1073
+
1074
+ def to_ref(self) -> LoRARef:
1075
+ return LoRARef(
1076
+ lora_id=self.lora_id,
1077
+ lora_name=self.lora_name,
1078
+ lora_path=self.lora_path,
1079
+ )
1070
1080
 
1071
1081
 
1072
1082
  @dataclass
1073
1083
  class UnloadLoRAAdapterReqInput:
1074
1084
  # The name of lora module to unload.
1075
1085
  lora_name: str
1086
+ # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
1087
+ lora_id: Optional[str] = None
1088
+
1089
+ def to_ref(self) -> LoRARef:
1090
+ return LoRARef(
1091
+ lora_id=self.lora_id,
1092
+ lora_name=self.lora_name,
1093
+ )
1076
1094
 
1077
1095
 
1078
1096
  @dataclass
1079
1097
  class LoRAUpdateResult:
1080
1098
  success: bool
1081
1099
  error_message: Optional[str] = None
1082
- loaded_adapters: Dict[str, str] = field(default_factory=dict)
1100
+ loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
1083
1101
 
1084
1102
 
1085
1103
  LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@@ -3,8 +3,9 @@ Multi-modality utils
3
3
  """
4
4
 
5
5
  import hashlib
6
+ import pickle
6
7
  from abc import abstractmethod
7
- from typing import Callable, Dict, List, Optional, Tuple
8
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
8
9
 
9
10
  import numpy as np
10
11
  import torch
@@ -27,6 +28,128 @@ from sglang.utils import logger
27
28
  # propagation that can cause some log messages (like 'server is fired up') to not appear
28
29
  # in the console when multimodal support is enabled.
29
30
 
31
+ # TODO(mick): nccl
32
+ # cuda_ipc: for intranode tensor sharing
33
+ TensorTransportMode = Literal["cuda_ipc", "auto", "default"]
34
+
35
+
36
+ class TransportProxyTensor(torch.Tensor):
37
+ """
38
+ A convenient torch.Tensor subclass that carries extra metadata and supports
39
+ efficient inter-process communications
40
+ """
41
+
42
+ @staticmethod
43
+ def __new__(
44
+ cls,
45
+ data: torch.Tensor,
46
+ name: Optional[str] = None,
47
+ fields: Optional[Dict[str, Any]] = None,
48
+ transport_mode: TensorTransportMode = "default",
49
+ *args,
50
+ **kwargs,
51
+ ):
52
+
53
+ if not isinstance(data, torch.Tensor):
54
+ raise TypeError(
55
+ f"Input 'data' must be a torch.Tensor, but got {type(data)}"
56
+ )
57
+
58
+ instance = data.as_subclass(cls)
59
+
60
+ instance._metadata = {
61
+ "name": name,
62
+ "fields": fields if fields is not None else {},
63
+ "transport_mode": transport_mode,
64
+ }
65
+
66
+ return instance
67
+
68
+ def __getstate__(self):
69
+ """
70
+ Called during pickling. Implements the serialization logic.
71
+ """
72
+ # acquire all serialize metadata from _metadata
73
+ state = {
74
+ "metadata": self._metadata,
75
+ "tensor_data": None,
76
+ "ipc_extra": None,
77
+ }
78
+
79
+ transport_mode = self._metadata.get("transport_mode", "default")
80
+
81
+ if transport_mode == "cuda_ipc" and self.is_cuda:
82
+ try:
83
+ storage = self.untyped_storage()
84
+ handle = storage._share_cuda_()
85
+
86
+ state["ipc_extra"] = {
87
+ "handle": handle,
88
+ "shape": self.shape,
89
+ "dtype": self.dtype,
90
+ "stride": self.stride(),
91
+ "device_index": self.device.index,
92
+ }
93
+ state["tensor_data"] = None
94
+ except Exception as e:
95
+ # Failed to get CUDA IPC handle (possibly tp). Falling back to default transport.
96
+ state["metadata"]["transport_mode"] = "default"
97
+ state["tensor_data"] = self.as_subclass(torch.Tensor)
98
+ else:
99
+ state["metadata"]["transport_mode"] = "default"
100
+ state["tensor_data"] = self.as_subclass(torch.Tensor)
101
+
102
+ return state
103
+
104
+ def __setstate__(self, state: Dict[str, Any]):
105
+ """
106
+ Called during unpickling. Implements the deserialization logic.
107
+ """
108
+ self._metadata = state["metadata"]
109
+
110
+ transport_mode = self._metadata.get("transport_mode", "default")
111
+
112
+ if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None:
113
+ ipc_extra = state["ipc_extra"]
114
+ handle, shape, dtype, stride, source_device_index = (
115
+ ipc_extra["handle"],
116
+ ipc_extra["shape"],
117
+ ipc_extra["dtype"],
118
+ ipc_extra["stride"],
119
+ ipc_extra["device_index"],
120
+ )
121
+
122
+ try:
123
+ target_device = torch.device(f"cuda:{source_device_index}")
124
+ with torch.cuda.device(target_device):
125
+ storage = torch.UntypedStorage._new_shared_cuda(*handle)
126
+ reconstructed_tensor = torch.empty(
127
+ 0, dtype=dtype, device=target_device
128
+ ).set_(storage, storage_offset=0, size=shape, stride=stride)
129
+ self.set_(reconstructed_tensor)
130
+ except Exception as e:
131
+ print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).")
132
+ raise e
133
+
134
+ elif state["tensor_data"] is not None:
135
+ self.set_(state["tensor_data"])
136
+ else:
137
+ raise pickle.UnpicklingError(
138
+ "Invalid state for TransportProxyTensor: no tensor data found."
139
+ )
140
+
141
+ @property
142
+ def name(self) -> Optional[str]:
143
+ return self._metadata.get("name")
144
+
145
+ @property
146
+ def fields(self) -> Dict[str, Any]:
147
+ return self._metadata.get("fields", {})
148
+
149
+ @property
150
+ def transport_mode(self) -> TensorTransportMode:
151
+ return self._metadata.get("transport_mode", "default")
152
+
30
153
 
31
154
  class MultiModalityDataPaddingPattern:
32
155
  """
@@ -85,8 +208,8 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
85
208
  "No data_token_pairs provided, RadixAttention might be influenced."
86
209
  )
87
210
  return input_ids
88
- start_token_ids = [s for s, _e in data_token_pairs]
89
- end_tokens_ids = [e for _s, e in data_token_pairs]
211
+ start_token_ids = {s for s, _e in data_token_pairs}
212
+ end_tokens_ids = {e for _s, e in data_token_pairs}
90
213
 
91
214
  padded_ids = []
92
215
  last_idx = 0
@@ -135,7 +258,7 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
135
258
  if not input_ids or not mm_inputs.mm_items:
136
259
  return input_ids
137
260
 
138
- input_ids_tensor = torch.tensor(input_ids)
261
+ input_ids_tensor = torch.as_tensor(input_ids)
139
262
 
140
263
  # Create mapping of token_ids to pad_values for each modality
141
264
  token_to_pad_mapping = {}
@@ -211,7 +334,7 @@ def get_embedding_chunk(
211
334
  end_index += extend_end_index - start + 1
212
335
  elif extend_end_index > end:
213
336
  end_index += end - start + 1
214
- # some models embedding is 3-dim, reshape it to 2-dim
337
+ # some models' embedding is 3-dim, reshape it to 2-dim
215
338
  embedding = embedding.reshape(-1, embedding.shape[-1])
216
339
  embedding_chunk = embedding[start_index:end_index]
217
340
  return embedding_chunk, start_index, end_index
@@ -428,7 +551,7 @@ def embed_mm_inputs(
428
551
  modality_id = modality.name.lower()
429
552
  embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
430
553
  if len(items) != 0 and embedder is not None:
431
- placeholder_tensor = torch.tensor(
554
+ placeholder_tensor = torch.as_tensor(
432
555
  [item.pad_value for item in items],
433
556
  device=input_ids.device,
434
557
  )
@@ -473,11 +596,9 @@ def embed_mm_inputs(
473
596
  for embedding, mask in zip(embeddings, masks):
474
597
  if embedding is None or mask is None:
475
598
  continue
476
- mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
477
- inputs_embeds = inputs_embeds.masked_scatter(
478
- mask,
479
- embedding.to(inputs_embeds.device, inputs_embeds.dtype),
480
- )
599
+ # in-place update
600
+ indices = torch.where(mask.squeeze(dim=-1))[0]
601
+ inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
481
602
  return inputs_embeds
482
603
 
483
604
 
@@ -561,34 +682,36 @@ def get_multimodal_data_bounds(
561
682
  [bounds_count, 2]
562
683
  """
563
684
  # All the multimodal data in the batch should share the same special bound token ids.
564
- start_tokens = [s for s, _e in token_pairs]
565
- end_tokens = [e for _s, e in token_pairs]
685
+ start_tokens = {s for s, _e in token_pairs}
686
+ end_tokens = {e for _s, e in token_pairs}
566
687
 
567
688
  assert all(isinstance(t, int) for t in start_tokens)
568
689
  assert all(isinstance(t, int) for t in end_tokens)
569
690
 
570
691
  start_cond = torch.isin(
571
- input_ids, torch.tensor(start_tokens, device=input_ids.device)
692
+ input_ids, torch.as_tensor(start_tokens, device=input_ids.device)
693
+ )
694
+ end_cond = torch.isin(
695
+ input_ids, torch.as_tensor(end_tokens, device=input_ids.device)
572
696
  )
573
- end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
574
697
 
575
698
  (data_start_tokens,) = torch.where(start_cond)
576
699
  (data_end_tokens,) = torch.where(end_cond)
577
700
 
701
+ data_start_tokens_cpu = data_start_tokens.cpu().tolist()
702
+ data_end_tokens_cpu = data_end_tokens.cpu().tolist()
703
+
578
704
  # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
579
- if len(data_start_tokens) != len(data_end_tokens):
705
+ if len(data_start_tokens_cpu) != len(data_end_tokens_cpu):
580
706
  if (
581
- len(data_start_tokens) + 1 == len(data_end_tokens)
582
- and input_ids[0] in pad_values
583
- and data_end_tokens[0] < data_start_tokens[0]
707
+ len(data_start_tokens_cpu) + 1 == len(data_end_tokens_cpu)
708
+ and input_ids[0].item() in pad_values
709
+ and data_end_tokens_cpu
710
+ and data_start_tokens_cpu
711
+ and data_end_tokens_cpu[0] < data_start_tokens_cpu[0]
584
712
  ):
585
- data_start_tokens = torch.cat(
586
- [
587
- torch.tensor([0], device=data_start_tokens.device),
588
- data_start_tokens,
589
- ]
590
- )
591
- valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
713
+ data_start_tokens_cpu.insert(0, 0)
714
+ valid_mm_data_nums = min(len(data_start_tokens_cpu), len(data_end_tokens_cpu))
592
715
 
593
716
  if valid_mm_data_nums == 0:
594
717
  return torch.zeros((0, 2), device=input_ids.device)
@@ -596,8 +719,8 @@ def get_multimodal_data_bounds(
596
719
  # Filter out pairs where start_token >= end_token
597
720
  valid_pairs = []
598
721
  for i in range(valid_mm_data_nums):
599
- start_token = data_start_tokens[i]
600
- end_token = data_end_tokens[i]
722
+ start_token = data_start_tokens_cpu[i]
723
+ end_token = data_end_tokens_cpu[i]
601
724
  if start_token < end_token:
602
725
  valid_pairs.append((start_token + 1, end_token - 1))
603
726
 
@@ -605,7 +728,7 @@ def get_multimodal_data_bounds(
605
728
  return torch.zeros((0, 2), device=input_ids.device)
606
729
 
607
730
  # Convert valid pairs to tensor
608
- valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
731
+ valid_pairs_tensor = torch.as_tensor(valid_pairs, device=input_ids.device)
609
732
  return valid_pairs_tensor
610
733
 
611
734
 
@@ -626,7 +749,7 @@ def tensor_hash(tensor_list) -> int:
626
749
  ]
627
750
  tensor = torch.concat(tensor_list)
628
751
  if tensor.is_cuda:
629
- return gpu_tensor_hash(tensor)
752
+ return gpu_tensor_hash(tensor.cuda())
630
753
  tensor = tensor.detach().contiguous()
631
754
 
632
755
  if tensor.dtype == torch.bfloat16:
@@ -634,11 +757,7 @@ def tensor_hash(tensor_list) -> int:
634
757
  tensor = tensor.float()
635
758
 
636
759
  assert isinstance(tensor, torch.Tensor)
637
- if tensor.is_cuda:
638
- # TODO: improve this
639
- tensor_cpu = tensor.cpu()
640
- else:
641
- tensor_cpu = tensor
760
+ tensor_cpu = tensor.cpu()
642
761
 
643
762
  mv = memoryview(tensor_cpu.numpy())
644
763
  return data_hash(mv.tobytes())
@@ -12,18 +12,6 @@ logger = logging.getLogger(__name__)
12
12
  PROCESSOR_MAPPING = {}
13
13
 
14
14
 
15
- class DummyMultimodalProcessor(BaseMultimodalProcessor):
16
- def __init__(self):
17
- pass
18
-
19
- async def process_mm_data_async(self, *args, **kwargs):
20
- return None
21
-
22
-
23
- def get_dummy_processor():
24
- return DummyMultimodalProcessor()
25
-
26
-
27
15
  def import_processors():
28
16
  package_name = "sglang.srt.multimodal.processors"
29
17
  package = importlib.import_module(package_name)
@@ -49,11 +37,12 @@ def import_processors():
49
37
 
50
38
 
51
39
  def get_mm_processor(
52
- hf_config, server_args: ServerArgs, processor
40
+ hf_config, server_args: ServerArgs, processor, transport_mode
53
41
  ) -> BaseMultimodalProcessor:
54
42
  for model_cls, processor_cls in PROCESSOR_MAPPING.items():
55
43
  if model_cls.__name__ in hf_config.architectures:
56
- return processor_cls(hf_config, server_args, processor)
44
+ return processor_cls(hf_config, server_args, processor, transport_mode)
45
+
57
46
  raise ValueError(
58
47
  f"No processor registered for architecture: {hf_config.architectures}.\n"
59
48
  f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
@@ -45,7 +45,6 @@ import triton
45
45
  import triton.language as tl
46
46
 
47
47
  from sglang.global_config import global_config
48
- from sglang.srt.configs.model_config import ModelConfig
49
48
  from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
50
49
  from sglang.srt.disaggregation.base import BaseKVSender
51
50
  from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
@@ -68,6 +67,7 @@ from sglang.srt.server_args import ServerArgs
68
67
  from sglang.srt.utils import flatten_nested_list, support_triton
69
68
 
70
69
  if TYPE_CHECKING:
70
+ from sglang.srt.configs.model_config import ModelConfig
71
71
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
72
72
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
73
73
 
@@ -88,7 +88,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
88
88
  "enable_deepep_moe",
89
89
  "deepep_mode",
90
90
  "enable_ep_moe",
91
- "enable_flashinfer_moe",
91
+ "enable_flashinfer_cutlass_moe",
92
+ "enable_flashinfer_trtllm_moe",
92
93
  "enable_flashinfer_allreduce_fusion",
93
94
  "moe_dense_tp_size",
94
95
  "ep_dispatch_algorithm",
@@ -106,6 +107,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
106
107
  "num_reserved_decode_tokens",
107
108
  "weight_loader_disable_mmap",
108
109
  "enable_triton_kernel_moe",
110
+ "enable_multimodal",
109
111
  ]
110
112
 
111
113
  # Put some global args for easy access
@@ -208,10 +210,11 @@ class MultimodalDataItem:
208
210
  hash: int = None
209
211
  pad_value: int = None
210
212
  offsets: Optional[list] = None
213
+
211
214
  # the raw features returned by processor, e.g. pixel_values or audio_features
212
215
  feature: Union[torch.Tensor, np.ndarray] = None
213
-
214
- # the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio
216
+ # the precomputed embeddings, passed as final encoder embeddings
217
+ # One and only one of the feature and precomputed_embeddings will be empty
215
218
  precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
216
219
 
217
220
  # Model-specific data stored in a dictionary
@@ -430,6 +433,7 @@ class Req:
430
433
  bootstrap_port: Optional[int] = None,
431
434
  bootstrap_room: Optional[int] = None,
432
435
  data_parallel_rank: Optional[int] = None,
436
+ vocab_size: Optional[int] = None,
433
437
  ):
434
438
  # Input and output info
435
439
  self.rid = rid
@@ -479,6 +483,7 @@ class Req:
479
483
  self.to_abort_message: str = None
480
484
  self.stream = stream
481
485
  self.eos_token_ids = eos_token_ids
486
+ self.vocab_size = vocab_size
482
487
 
483
488
  # For incremental decoding
484
489
  # ----- | --------- read_ids -------|
@@ -712,6 +717,14 @@ class Req:
712
717
  self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
713
718
  return
714
719
 
720
+ if last_token_id > self.vocab_size or last_token_id < 0:
721
+ if self.sampling_params.stop_token_ids:
722
+ self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
723
+ if self.eos_token_ids:
724
+ self.output_ids[-1] = next(iter(self.eos_token_ids))
725
+ self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
726
+ return
727
+
715
728
  # Check stop strings
716
729
  if len(self.sampling_params.stop_strs) > 0:
717
730
  tail_str = self.tokenizer.decode(
@@ -1677,16 +1690,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1677
1690
  extend_prefix_lens = self.prefix_lens
1678
1691
  extend_logprob_start_lens = self.extend_logprob_start_lens
1679
1692
 
1693
+ if self.forward_mode.is_decode_or_idle():
1694
+ attention_backend_str = global_server_args_dict["decode_attention_backend"]
1695
+ else:
1696
+ attention_backend_str = global_server_args_dict["prefill_attention_backend"]
1680
1697
  # Create seq_lens_cpu when needed
1681
1698
  if (
1682
- global_server_args_dict["attention_backend"] == "fa3"
1699
+ attention_backend_str == "fa3"
1683
1700
  or (
1684
1701
  global_server_args_dict["use_mla_backend"]
1685
- and global_server_args_dict["attention_backend"] == "flashinfer"
1702
+ and attention_backend_str == "flashinfer"
1686
1703
  )
1687
- or global_server_args_dict["attention_backend"] == "flashmla"
1688
- or global_server_args_dict["attention_backend"] == "cutlass_mla"
1689
- or global_server_args_dict["attention_backend"] == "ascend"
1704
+ or attention_backend_str == "flashmla"
1705
+ or attention_backend_str == "cutlass_mla"
1706
+ or attention_backend_str == "ascend"
1690
1707
  or global_server_args_dict["enable_two_batch_overlap"]
1691
1708
  ):
1692
1709
  seq_lens_cpu = (
@@ -1879,7 +1896,7 @@ class ModelWorkerBatch:
1879
1896
  sampling_info: SamplingBatchInfo
1880
1897
 
1881
1898
  # The input Embeds
1882
- input_embeds: Optional[torch.tensor] = None
1899
+ input_embeds: Optional[torch.Tensor] = None
1883
1900
 
1884
1901
  # For corss-encoder model
1885
1902
  token_type_ids: Optional[torch.Tensor] = None
@@ -1889,7 +1906,6 @@ class ModelWorkerBatch:
1889
1906
  spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
1890
1907
  # If set, the output of the batch contains the hidden states of the run.
1891
1908
  capture_hidden_mode: CaptureHiddenMode = None
1892
- spec_num_draft_tokens: Optional[int] = None
1893
1909
  hicache_consumer_index: int = 0
1894
1910
 
1895
1911
  # Overlap event