sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -201,8 +201,9 @@ class PrefetchOperation(StorageOperation):
|
|
201
201
|
def increment(self, num_tokens: int):
|
202
202
|
with self._lock:
|
203
203
|
if self._done_flag:
|
204
|
-
return
|
204
|
+
return False
|
205
205
|
self.completed_tokens += num_tokens
|
206
|
+
return True
|
206
207
|
|
207
208
|
def mark_done(self):
|
208
209
|
with self._lock:
|
@@ -219,6 +220,7 @@ class HiCacheController:
|
|
219
220
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
220
221
|
mem_pool_host: HostKVCache,
|
221
222
|
page_size: int,
|
223
|
+
tp_group: torch.distributed.ProcessGroup,
|
222
224
|
load_cache_event: threading.Event = None,
|
223
225
|
write_policy: str = "write_through_selective",
|
224
226
|
io_backend: str = "",
|
@@ -244,11 +246,17 @@ class HiCacheController:
|
|
244
246
|
self.enable_storage = False
|
245
247
|
# todo: move backend initialization to storage backend module
|
246
248
|
if storage_backend is not None:
|
249
|
+
# create a new communication group for synchronizing storage operations across TP workers
|
250
|
+
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
251
|
+
if self.tp_world_size > 1:
|
252
|
+
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
253
|
+
self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
254
|
+
|
247
255
|
if storage_backend == "file":
|
248
256
|
self.storage_backend = HiCacheFile()
|
249
257
|
self.enable_storage = True
|
250
258
|
# todo: threshold policy for prefetching
|
251
|
-
self.prefetch_threshold = prefetch_threshold
|
259
|
+
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
252
260
|
else:
|
253
261
|
raise NotImplementedError(
|
254
262
|
f"Unsupported storage backend: {storage_backend}"
|
@@ -358,6 +366,7 @@ class HiCacheController:
|
|
358
366
|
if host_indices is None:
|
359
367
|
return None
|
360
368
|
self.mem_pool_host.protect_write(host_indices)
|
369
|
+
torch.cuda.current_stream().synchronize()
|
361
370
|
self.write_queue.put(
|
362
371
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
363
372
|
)
|
@@ -520,12 +529,12 @@ class HiCacheController:
|
|
520
529
|
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
521
530
|
)
|
522
531
|
break
|
523
|
-
self.
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
532
|
+
if operation.increment(self.page_size):
|
533
|
+
self.mem_pool_host.set_from_flat_data_page(
|
534
|
+
operation.host_indices[operation.completed_tokens],
|
535
|
+
page_data,
|
536
|
+
)
|
537
|
+
else:
|
529
538
|
# operation terminated by controller, release pre-allocated memory
|
530
539
|
self.mem_pool_host.free(
|
531
540
|
operation.host_indices[operation.completed_tokens :]
|
@@ -567,13 +576,33 @@ class HiCacheController:
|
|
567
576
|
else:
|
568
577
|
break
|
569
578
|
|
579
|
+
if self.tp_world_size > 1:
|
580
|
+
storage_hit_count_tensor = torch.tensor(
|
581
|
+
storage_hit_count, dtype=torch.int
|
582
|
+
)
|
583
|
+
torch.distributed.all_reduce(
|
584
|
+
storage_hit_count_tensor,
|
585
|
+
op=torch.distributed.ReduceOp.MIN,
|
586
|
+
group=self.tp_group,
|
587
|
+
)
|
588
|
+
storage_hit_count = storage_hit_count_tensor.item()
|
589
|
+
|
570
590
|
if storage_hit_count < self.prefetch_threshold:
|
571
591
|
# not to prefetch if not enough benefits
|
572
592
|
self.prefetch_revoke_queue.put(operation.request_id)
|
593
|
+
self.mem_pool_host.free(operation.host_indices)
|
594
|
+
logger.debug(
|
595
|
+
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
596
|
+
)
|
573
597
|
else:
|
574
|
-
operation.hash_value = hash_value
|
598
|
+
operation.hash_value = hash_value[
|
599
|
+
: (storage_hit_count // self.page_size)
|
600
|
+
]
|
601
|
+
# free the pre-allocated memory for pages that are not hit
|
602
|
+
self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
|
603
|
+
operation.host_indices = operation.host_indices[:storage_hit_count]
|
575
604
|
logger.debug(
|
576
|
-
f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
|
605
|
+
f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
|
577
606
|
)
|
578
607
|
self.prefetch_buffer.put(operation)
|
579
608
|
|
@@ -610,17 +639,37 @@ class HiCacheController:
|
|
610
639
|
last_hash = get_hash_str(
|
611
640
|
tokens_to_backup[i : i + self.page_size], last_hash
|
612
641
|
)
|
613
|
-
|
614
|
-
self.storage_backend.set(
|
642
|
+
success = self.storage_backend.set(
|
615
643
|
last_hash,
|
616
644
|
self.mem_pool_host.get_flat_data_page(
|
617
645
|
operation.host_indices[i]
|
618
646
|
),
|
619
647
|
)
|
648
|
+
if not success:
|
649
|
+
logger.warning(f"Failed to write page {last_hash} to storage.")
|
650
|
+
break
|
620
651
|
operation.completed_tokens += self.page_size
|
621
652
|
operation.hash_value.append(last_hash)
|
622
653
|
|
623
|
-
|
654
|
+
min_completed_tokens = operation.completed_tokens
|
655
|
+
if self.tp_world_size > 1:
|
656
|
+
completed_tokens_tensor = torch.tensor(
|
657
|
+
min_completed_tokens, dtype=torch.int
|
658
|
+
)
|
659
|
+
torch.distributed.all_reduce(
|
660
|
+
completed_tokens_tensor,
|
661
|
+
op=torch.distributed.ReduceOp.MIN,
|
662
|
+
group=self.tp_group,
|
663
|
+
)
|
664
|
+
min_completed_tokens = completed_tokens_tensor.item()
|
665
|
+
|
666
|
+
self.ack_backup_queue.put(
|
667
|
+
(
|
668
|
+
operation.id,
|
669
|
+
operation.hash_value[: min_completed_tokens // self.page_size],
|
670
|
+
min_completed_tokens,
|
671
|
+
)
|
672
|
+
)
|
624
673
|
|
625
674
|
except Empty:
|
626
675
|
continue
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -22,6 +22,7 @@ from dataclasses import dataclass, field
|
|
22
22
|
from enum import Enum
|
23
23
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
24
24
|
|
25
|
+
from sglang.srt.lora.lora_registry import LoRARef
|
25
26
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
26
27
|
from sglang.srt.multimodal.mm_utils import has_valid_data
|
27
28
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput:
|
|
1067
1068
|
lora_name: str
|
1068
1069
|
# The path of loading.
|
1069
1070
|
lora_path: str
|
1071
|
+
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
1072
|
+
lora_id: Optional[str] = None
|
1073
|
+
|
1074
|
+
def to_ref(self) -> LoRARef:
|
1075
|
+
return LoRARef(
|
1076
|
+
lora_id=self.lora_id,
|
1077
|
+
lora_name=self.lora_name,
|
1078
|
+
lora_path=self.lora_path,
|
1079
|
+
)
|
1070
1080
|
|
1071
1081
|
|
1072
1082
|
@dataclass
|
1073
1083
|
class UnloadLoRAAdapterReqInput:
|
1074
1084
|
# The name of lora module to unload.
|
1075
1085
|
lora_name: str
|
1086
|
+
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
1087
|
+
lora_id: Optional[str] = None
|
1088
|
+
|
1089
|
+
def to_ref(self) -> LoRARef:
|
1090
|
+
return LoRARef(
|
1091
|
+
lora_id=self.lora_id,
|
1092
|
+
lora_name=self.lora_name,
|
1093
|
+
)
|
1076
1094
|
|
1077
1095
|
|
1078
1096
|
@dataclass
|
1079
1097
|
class LoRAUpdateResult:
|
1080
1098
|
success: bool
|
1081
1099
|
error_message: Optional[str] = None
|
1082
|
-
loaded_adapters: Dict[str,
|
1100
|
+
loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
|
1083
1101
|
|
1084
1102
|
|
1085
1103
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -3,8 +3,9 @@ Multi-modality utils
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
import hashlib
|
6
|
+
import pickle
|
6
7
|
from abc import abstractmethod
|
7
|
-
from typing import Callable, Dict, List, Optional, Tuple
|
8
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
|
8
9
|
|
9
10
|
import numpy as np
|
10
11
|
import torch
|
@@ -27,6 +28,128 @@ from sglang.utils import logger
|
|
27
28
|
# propagation that can cause some log messages (like 'server is fired up') to not appear
|
28
29
|
# in the console when multimodal support is enabled.
|
29
30
|
|
31
|
+
# TODO(mick): nccl
|
32
|
+
# cuda_ipc: for intranode tensor sharing
|
33
|
+
TensorTransportMode = Literal["cuda_ipc", "auto", "default"]
|
34
|
+
|
35
|
+
|
36
|
+
class TransportProxyTensor(torch.Tensor):
|
37
|
+
"""
|
38
|
+
A convenient torch.Tensor subclass that carries extra metadata and supports
|
39
|
+
efficient inter-process communications
|
40
|
+
"""
|
41
|
+
|
42
|
+
@staticmethod
|
43
|
+
def __new__(
|
44
|
+
cls,
|
45
|
+
data: torch.Tensor,
|
46
|
+
name: Optional[str] = None,
|
47
|
+
fields: Optional[Dict[str, Any]] = None,
|
48
|
+
transport_mode: TensorTransportMode = "default",
|
49
|
+
*args,
|
50
|
+
**kwargs,
|
51
|
+
):
|
52
|
+
|
53
|
+
if not isinstance(data, torch.Tensor):
|
54
|
+
raise TypeError(
|
55
|
+
f"Input 'data' must be a torch.Tensor, but got {type(data)}"
|
56
|
+
)
|
57
|
+
|
58
|
+
instance = data.as_subclass(cls)
|
59
|
+
|
60
|
+
instance._metadata = {
|
61
|
+
"name": name,
|
62
|
+
"fields": fields if fields is not None else {},
|
63
|
+
"transport_mode": transport_mode,
|
64
|
+
}
|
65
|
+
|
66
|
+
return instance
|
67
|
+
|
68
|
+
def __getstate__(self):
|
69
|
+
"""
|
70
|
+
Called during pickling. Implements the serialization logic.
|
71
|
+
"""
|
72
|
+
# acquire all serialize metadata from _metadata
|
73
|
+
state = {
|
74
|
+
"metadata": self._metadata,
|
75
|
+
"tensor_data": None,
|
76
|
+
"ipc_extra": None,
|
77
|
+
}
|
78
|
+
|
79
|
+
transport_mode = self._metadata.get("transport_mode", "default")
|
80
|
+
|
81
|
+
if transport_mode == "cuda_ipc" and self.is_cuda:
|
82
|
+
try:
|
83
|
+
storage = self.untyped_storage()
|
84
|
+
handle = storage._share_cuda_()
|
85
|
+
|
86
|
+
state["ipc_extra"] = {
|
87
|
+
"handle": handle,
|
88
|
+
"shape": self.shape,
|
89
|
+
"dtype": self.dtype,
|
90
|
+
"stride": self.stride(),
|
91
|
+
"device_index": self.device.index,
|
92
|
+
}
|
93
|
+
state["tensor_data"] = None
|
94
|
+
except Exception as e:
|
95
|
+
# Failed to get CUDA IPC handle (possibly tp). Falling back to default transport.
|
96
|
+
state["metadata"]["transport_mode"] = "default"
|
97
|
+
state["tensor_data"] = self.as_subclass(torch.Tensor)
|
98
|
+
else:
|
99
|
+
state["metadata"]["transport_mode"] = "default"
|
100
|
+
state["tensor_data"] = self.as_subclass(torch.Tensor)
|
101
|
+
|
102
|
+
return state
|
103
|
+
|
104
|
+
def __setstate__(self, state: Dict[str, Any]):
|
105
|
+
"""
|
106
|
+
Called during unpickling. Implements the deserialization logic.
|
107
|
+
"""
|
108
|
+
self._metadata = state["metadata"]
|
109
|
+
|
110
|
+
transport_mode = self._metadata.get("transport_mode", "default")
|
111
|
+
|
112
|
+
if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None:
|
113
|
+
ipc_extra = state["ipc_extra"]
|
114
|
+
handle, shape, dtype, stride, source_device_index = (
|
115
|
+
ipc_extra["handle"],
|
116
|
+
ipc_extra["shape"],
|
117
|
+
ipc_extra["dtype"],
|
118
|
+
ipc_extra["stride"],
|
119
|
+
ipc_extra["device_index"],
|
120
|
+
)
|
121
|
+
|
122
|
+
try:
|
123
|
+
target_device = torch.device(f"cuda:{source_device_index}")
|
124
|
+
with torch.cuda.device(target_device):
|
125
|
+
storage = torch.UntypedStorage._new_shared_cuda(*handle)
|
126
|
+
reconstructed_tensor = torch.empty(
|
127
|
+
0, dtype=dtype, device=target_device
|
128
|
+
).set_(storage, storage_offset=0, size=shape, stride=stride)
|
129
|
+
self.set_(reconstructed_tensor)
|
130
|
+
except Exception as e:
|
131
|
+
print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).")
|
132
|
+
raise e
|
133
|
+
|
134
|
+
elif state["tensor_data"] is not None:
|
135
|
+
self.set_(state["tensor_data"])
|
136
|
+
else:
|
137
|
+
raise pickle.UnpicklingError(
|
138
|
+
"Invalid state for TransportProxyTensor: no tensor data found."
|
139
|
+
)
|
140
|
+
|
141
|
+
@property
|
142
|
+
def name(self) -> Optional[str]:
|
143
|
+
return self._metadata.get("name")
|
144
|
+
|
145
|
+
@property
|
146
|
+
def fields(self) -> Dict[str, Any]:
|
147
|
+
return self._metadata.get("fields", {})
|
148
|
+
|
149
|
+
@property
|
150
|
+
def transport_mode(self) -> TensorTransportMode:
|
151
|
+
return self._metadata.get("transport_mode", "default")
|
152
|
+
|
30
153
|
|
31
154
|
class MultiModalityDataPaddingPattern:
|
32
155
|
"""
|
@@ -85,8 +208,8 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
85
208
|
"No data_token_pairs provided, RadixAttention might be influenced."
|
86
209
|
)
|
87
210
|
return input_ids
|
88
|
-
start_token_ids =
|
89
|
-
end_tokens_ids =
|
211
|
+
start_token_ids = {s for s, _e in data_token_pairs}
|
212
|
+
end_tokens_ids = {e for _s, e in data_token_pairs}
|
90
213
|
|
91
214
|
padded_ids = []
|
92
215
|
last_idx = 0
|
@@ -135,7 +258,7 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
|
135
258
|
if not input_ids or not mm_inputs.mm_items:
|
136
259
|
return input_ids
|
137
260
|
|
138
|
-
input_ids_tensor = torch.
|
261
|
+
input_ids_tensor = torch.as_tensor(input_ids)
|
139
262
|
|
140
263
|
# Create mapping of token_ids to pad_values for each modality
|
141
264
|
token_to_pad_mapping = {}
|
@@ -211,7 +334,7 @@ def get_embedding_chunk(
|
|
211
334
|
end_index += extend_end_index - start + 1
|
212
335
|
elif extend_end_index > end:
|
213
336
|
end_index += end - start + 1
|
214
|
-
# some models embedding is 3-dim, reshape it to 2-dim
|
337
|
+
# some models' embedding is 3-dim, reshape it to 2-dim
|
215
338
|
embedding = embedding.reshape(-1, embedding.shape[-1])
|
216
339
|
embedding_chunk = embedding[start_index:end_index]
|
217
340
|
return embedding_chunk, start_index, end_index
|
@@ -428,7 +551,7 @@ def embed_mm_inputs(
|
|
428
551
|
modality_id = modality.name.lower()
|
429
552
|
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
|
430
553
|
if len(items) != 0 and embedder is not None:
|
431
|
-
placeholder_tensor = torch.
|
554
|
+
placeholder_tensor = torch.as_tensor(
|
432
555
|
[item.pad_value for item in items],
|
433
556
|
device=input_ids.device,
|
434
557
|
)
|
@@ -473,11 +596,9 @@ def embed_mm_inputs(
|
|
473
596
|
for embedding, mask in zip(embeddings, masks):
|
474
597
|
if embedding is None or mask is None:
|
475
598
|
continue
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
480
|
-
)
|
599
|
+
# in-place update
|
600
|
+
indices = torch.where(mask.squeeze(dim=-1))[0]
|
601
|
+
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
|
481
602
|
return inputs_embeds
|
482
603
|
|
483
604
|
|
@@ -561,34 +682,36 @@ def get_multimodal_data_bounds(
|
|
561
682
|
[bounds_count, 2]
|
562
683
|
"""
|
563
684
|
# All the multimodal data in the batch should share the same special bound token ids.
|
564
|
-
start_tokens =
|
565
|
-
end_tokens =
|
685
|
+
start_tokens = {s for s, _e in token_pairs}
|
686
|
+
end_tokens = {e for _s, e in token_pairs}
|
566
687
|
|
567
688
|
assert all(isinstance(t, int) for t in start_tokens)
|
568
689
|
assert all(isinstance(t, int) for t in end_tokens)
|
569
690
|
|
570
691
|
start_cond = torch.isin(
|
571
|
-
input_ids, torch.
|
692
|
+
input_ids, torch.as_tensor(start_tokens, device=input_ids.device)
|
693
|
+
)
|
694
|
+
end_cond = torch.isin(
|
695
|
+
input_ids, torch.as_tensor(end_tokens, device=input_ids.device)
|
572
696
|
)
|
573
|
-
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
|
574
697
|
|
575
698
|
(data_start_tokens,) = torch.where(start_cond)
|
576
699
|
(data_end_tokens,) = torch.where(end_cond)
|
577
700
|
|
701
|
+
data_start_tokens_cpu = data_start_tokens.cpu().tolist()
|
702
|
+
data_end_tokens_cpu = data_end_tokens.cpu().tolist()
|
703
|
+
|
578
704
|
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
|
579
|
-
if len(
|
705
|
+
if len(data_start_tokens_cpu) != len(data_end_tokens_cpu):
|
580
706
|
if (
|
581
|
-
len(
|
582
|
-
and input_ids[0] in pad_values
|
583
|
-
and
|
707
|
+
len(data_start_tokens_cpu) + 1 == len(data_end_tokens_cpu)
|
708
|
+
and input_ids[0].item() in pad_values
|
709
|
+
and data_end_tokens_cpu
|
710
|
+
and data_start_tokens_cpu
|
711
|
+
and data_end_tokens_cpu[0] < data_start_tokens_cpu[0]
|
584
712
|
):
|
585
|
-
|
586
|
-
|
587
|
-
torch.tensor([0], device=data_start_tokens.device),
|
588
|
-
data_start_tokens,
|
589
|
-
]
|
590
|
-
)
|
591
|
-
valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
|
713
|
+
data_start_tokens_cpu.insert(0, 0)
|
714
|
+
valid_mm_data_nums = min(len(data_start_tokens_cpu), len(data_end_tokens_cpu))
|
592
715
|
|
593
716
|
if valid_mm_data_nums == 0:
|
594
717
|
return torch.zeros((0, 2), device=input_ids.device)
|
@@ -596,8 +719,8 @@ def get_multimodal_data_bounds(
|
|
596
719
|
# Filter out pairs where start_token >= end_token
|
597
720
|
valid_pairs = []
|
598
721
|
for i in range(valid_mm_data_nums):
|
599
|
-
start_token =
|
600
|
-
end_token =
|
722
|
+
start_token = data_start_tokens_cpu[i]
|
723
|
+
end_token = data_end_tokens_cpu[i]
|
601
724
|
if start_token < end_token:
|
602
725
|
valid_pairs.append((start_token + 1, end_token - 1))
|
603
726
|
|
@@ -605,7 +728,7 @@ def get_multimodal_data_bounds(
|
|
605
728
|
return torch.zeros((0, 2), device=input_ids.device)
|
606
729
|
|
607
730
|
# Convert valid pairs to tensor
|
608
|
-
valid_pairs_tensor = torch.
|
731
|
+
valid_pairs_tensor = torch.as_tensor(valid_pairs, device=input_ids.device)
|
609
732
|
return valid_pairs_tensor
|
610
733
|
|
611
734
|
|
@@ -626,7 +749,7 @@ def tensor_hash(tensor_list) -> int:
|
|
626
749
|
]
|
627
750
|
tensor = torch.concat(tensor_list)
|
628
751
|
if tensor.is_cuda:
|
629
|
-
return gpu_tensor_hash(tensor)
|
752
|
+
return gpu_tensor_hash(tensor.cuda())
|
630
753
|
tensor = tensor.detach().contiguous()
|
631
754
|
|
632
755
|
if tensor.dtype == torch.bfloat16:
|
@@ -634,11 +757,7 @@ def tensor_hash(tensor_list) -> int:
|
|
634
757
|
tensor = tensor.float()
|
635
758
|
|
636
759
|
assert isinstance(tensor, torch.Tensor)
|
637
|
-
|
638
|
-
# TODO: improve this
|
639
|
-
tensor_cpu = tensor.cpu()
|
640
|
-
else:
|
641
|
-
tensor_cpu = tensor
|
760
|
+
tensor_cpu = tensor.cpu()
|
642
761
|
|
643
762
|
mv = memoryview(tensor_cpu.numpy())
|
644
763
|
return data_hash(mv.tobytes())
|
@@ -12,18 +12,6 @@ logger = logging.getLogger(__name__)
|
|
12
12
|
PROCESSOR_MAPPING = {}
|
13
13
|
|
14
14
|
|
15
|
-
class DummyMultimodalProcessor(BaseMultimodalProcessor):
|
16
|
-
def __init__(self):
|
17
|
-
pass
|
18
|
-
|
19
|
-
async def process_mm_data_async(self, *args, **kwargs):
|
20
|
-
return None
|
21
|
-
|
22
|
-
|
23
|
-
def get_dummy_processor():
|
24
|
-
return DummyMultimodalProcessor()
|
25
|
-
|
26
|
-
|
27
15
|
def import_processors():
|
28
16
|
package_name = "sglang.srt.multimodal.processors"
|
29
17
|
package = importlib.import_module(package_name)
|
@@ -49,11 +37,12 @@ def import_processors():
|
|
49
37
|
|
50
38
|
|
51
39
|
def get_mm_processor(
|
52
|
-
hf_config, server_args: ServerArgs, processor
|
40
|
+
hf_config, server_args: ServerArgs, processor, transport_mode
|
53
41
|
) -> BaseMultimodalProcessor:
|
54
42
|
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
|
55
43
|
if model_cls.__name__ in hf_config.architectures:
|
56
|
-
return processor_cls(hf_config, server_args, processor)
|
44
|
+
return processor_cls(hf_config, server_args, processor, transport_mode)
|
45
|
+
|
57
46
|
raise ValueError(
|
58
47
|
f"No processor registered for architecture: {hf_config.architectures}.\n"
|
59
48
|
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
|
@@ -45,7 +45,6 @@ import triton
|
|
45
45
|
import triton.language as tl
|
46
46
|
|
47
47
|
from sglang.global_config import global_config
|
48
|
-
from sglang.srt.configs.model_config import ModelConfig
|
49
48
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
50
49
|
from sglang.srt.disaggregation.base import BaseKVSender
|
51
50
|
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
@@ -68,6 +67,7 @@ from sglang.srt.server_args import ServerArgs
|
|
68
67
|
from sglang.srt.utils import flatten_nested_list, support_triton
|
69
68
|
|
70
69
|
if TYPE_CHECKING:
|
70
|
+
from sglang.srt.configs.model_config import ModelConfig
|
71
71
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
72
72
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
73
73
|
|
@@ -88,7 +88,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
88
88
|
"enable_deepep_moe",
|
89
89
|
"deepep_mode",
|
90
90
|
"enable_ep_moe",
|
91
|
-
"
|
91
|
+
"enable_flashinfer_cutlass_moe",
|
92
|
+
"enable_flashinfer_trtllm_moe",
|
92
93
|
"enable_flashinfer_allreduce_fusion",
|
93
94
|
"moe_dense_tp_size",
|
94
95
|
"ep_dispatch_algorithm",
|
@@ -106,6 +107,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
106
107
|
"num_reserved_decode_tokens",
|
107
108
|
"weight_loader_disable_mmap",
|
108
109
|
"enable_triton_kernel_moe",
|
110
|
+
"enable_multimodal",
|
109
111
|
]
|
110
112
|
|
111
113
|
# Put some global args for easy access
|
@@ -208,10 +210,11 @@ class MultimodalDataItem:
|
|
208
210
|
hash: int = None
|
209
211
|
pad_value: int = None
|
210
212
|
offsets: Optional[list] = None
|
213
|
+
|
211
214
|
# the raw features returned by processor, e.g. pixel_values or audio_features
|
212
215
|
feature: Union[torch.Tensor, np.ndarray] = None
|
213
|
-
|
214
|
-
#
|
216
|
+
# the precomputed embeddings, passed as final encoder embeddings
|
217
|
+
# One and only one of the feature and precomputed_embeddings will be empty
|
215
218
|
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
|
216
219
|
|
217
220
|
# Model-specific data stored in a dictionary
|
@@ -430,6 +433,7 @@ class Req:
|
|
430
433
|
bootstrap_port: Optional[int] = None,
|
431
434
|
bootstrap_room: Optional[int] = None,
|
432
435
|
data_parallel_rank: Optional[int] = None,
|
436
|
+
vocab_size: Optional[int] = None,
|
433
437
|
):
|
434
438
|
# Input and output info
|
435
439
|
self.rid = rid
|
@@ -479,6 +483,7 @@ class Req:
|
|
479
483
|
self.to_abort_message: str = None
|
480
484
|
self.stream = stream
|
481
485
|
self.eos_token_ids = eos_token_ids
|
486
|
+
self.vocab_size = vocab_size
|
482
487
|
|
483
488
|
# For incremental decoding
|
484
489
|
# ----- | --------- read_ids -------|
|
@@ -712,6 +717,14 @@ class Req:
|
|
712
717
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
713
718
|
return
|
714
719
|
|
720
|
+
if last_token_id > self.vocab_size or last_token_id < 0:
|
721
|
+
if self.sampling_params.stop_token_ids:
|
722
|
+
self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
|
723
|
+
if self.eos_token_ids:
|
724
|
+
self.output_ids[-1] = next(iter(self.eos_token_ids))
|
725
|
+
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
|
726
|
+
return
|
727
|
+
|
715
728
|
# Check stop strings
|
716
729
|
if len(self.sampling_params.stop_strs) > 0:
|
717
730
|
tail_str = self.tokenizer.decode(
|
@@ -1677,16 +1690,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1677
1690
|
extend_prefix_lens = self.prefix_lens
|
1678
1691
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
1679
1692
|
|
1693
|
+
if self.forward_mode.is_decode_or_idle():
|
1694
|
+
attention_backend_str = global_server_args_dict["decode_attention_backend"]
|
1695
|
+
else:
|
1696
|
+
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
1680
1697
|
# Create seq_lens_cpu when needed
|
1681
1698
|
if (
|
1682
|
-
|
1699
|
+
attention_backend_str == "fa3"
|
1683
1700
|
or (
|
1684
1701
|
global_server_args_dict["use_mla_backend"]
|
1685
|
-
and
|
1702
|
+
and attention_backend_str == "flashinfer"
|
1686
1703
|
)
|
1687
|
-
or
|
1688
|
-
or
|
1689
|
-
or
|
1704
|
+
or attention_backend_str == "flashmla"
|
1705
|
+
or attention_backend_str == "cutlass_mla"
|
1706
|
+
or attention_backend_str == "ascend"
|
1690
1707
|
or global_server_args_dict["enable_two_batch_overlap"]
|
1691
1708
|
):
|
1692
1709
|
seq_lens_cpu = (
|
@@ -1879,7 +1896,7 @@ class ModelWorkerBatch:
|
|
1879
1896
|
sampling_info: SamplingBatchInfo
|
1880
1897
|
|
1881
1898
|
# The input Embeds
|
1882
|
-
input_embeds: Optional[torch.
|
1899
|
+
input_embeds: Optional[torch.Tensor] = None
|
1883
1900
|
|
1884
1901
|
# For corss-encoder model
|
1885
1902
|
token_type_ids: Optional[torch.Tensor] = None
|
@@ -1889,7 +1906,6 @@ class ModelWorkerBatch:
|
|
1889
1906
|
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
1890
1907
|
# If set, the output of the batch contains the hidden states of the run.
|
1891
1908
|
capture_hidden_mode: CaptureHiddenMode = None
|
1892
|
-
spec_num_draft_tokens: Optional[int] = None
|
1893
1909
|
hicache_consumer_index: int = 0
|
1894
1910
|
|
1895
1911
|
# Overlap event
|