sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,273 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
from enum import Enum
|
6
|
+
from typing import Union
|
7
|
+
|
8
|
+
import torch
|
9
|
+
import torch.distributed as dist
|
10
|
+
from torch.distributed import ProcessGroup
|
11
|
+
|
12
|
+
from sglang.srt import _custom_ops as ops
|
13
|
+
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
|
14
|
+
is_full_nvlink,
|
15
|
+
is_weak_contiguous,
|
16
|
+
)
|
17
|
+
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
18
|
+
from sglang.srt.utils import is_cuda, is_hip
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
_is_cuda = is_cuda()
|
23
|
+
_is_hip = is_hip()
|
24
|
+
|
25
|
+
|
26
|
+
try:
|
27
|
+
ops.qr_max_size()
|
28
|
+
quick_ar = True
|
29
|
+
except Exception:
|
30
|
+
# For CPUs and CUDA
|
31
|
+
quick_ar = False
|
32
|
+
|
33
|
+
|
34
|
+
def qr_rocm_arch_available():
|
35
|
+
if not _is_hip:
|
36
|
+
return False
|
37
|
+
try:
|
38
|
+
props = torch.cuda.get_device_properties(0)
|
39
|
+
gcn_arch = getattr(props, "gcnArchName", "")
|
40
|
+
supported_archs = ["gfx94", "gfx95"]
|
41
|
+
return any(gfx in gcn_arch for gfx in supported_archs)
|
42
|
+
except Exception as e:
|
43
|
+
logger.warning("Failed to determine ROCm for quick allreduce: %s", e)
|
44
|
+
return False
|
45
|
+
|
46
|
+
|
47
|
+
class QuickReduceRegime(Enum):
|
48
|
+
FP = 0
|
49
|
+
INT8 = 1
|
50
|
+
INT6 = 2
|
51
|
+
INT4 = 3
|
52
|
+
NONE = 4
|
53
|
+
|
54
|
+
|
55
|
+
MB = 1024 * 1024
|
56
|
+
|
57
|
+
|
58
|
+
class QuickAllReduce:
|
59
|
+
|
60
|
+
_SUPPORTED_WORLD_SIZES = [2, 4, 8]
|
61
|
+
_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
|
62
|
+
# The following data is based on kernel tests.
|
63
|
+
# In this order [FP, INT8, INT6, INT4].
|
64
|
+
_QR_MIN_SIZE = {
|
65
|
+
(torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB],
|
66
|
+
(torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB],
|
67
|
+
(torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB],
|
68
|
+
(torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB],
|
69
|
+
(torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB],
|
70
|
+
(torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
|
71
|
+
}
|
72
|
+
|
73
|
+
def __init__(
|
74
|
+
self, group: ProcessGroup, device: Union[int, str, torch.device]
|
75
|
+
) -> None:
|
76
|
+
"""
|
77
|
+
Custom allreduce provides non-destructive acceleration and is
|
78
|
+
available for CUDA and ROCm MI300 series.
|
79
|
+
Custom quick allreduce leverages quantization for further
|
80
|
+
acceleration on ROCm. It currently supports Q8, Q6, and Q4
|
81
|
+
quantization formats and FP(float16, bfloat16).
|
82
|
+
Quick allreduce is designed as a complement to custom allreduce.
|
83
|
+
Its initialization requires even stricter conditions.
|
84
|
+
Only the ROCm MI300 series is supported for quick allreduce at
|
85
|
+
this time.
|
86
|
+
Args:
|
87
|
+
group: the process group to work on. If None, it will use the
|
88
|
+
default process group.
|
89
|
+
device: the device to bind the CustomAllreduce to. If None,
|
90
|
+
it will be bind to f"cuda:{local_rank}".
|
91
|
+
It is the caller's responsibility to make sure each communicator
|
92
|
+
is bind to a unique device, and all communicators in this group
|
93
|
+
are in the same node.
|
94
|
+
"""
|
95
|
+
self.disabled = True
|
96
|
+
if not qr_rocm_arch_available():
|
97
|
+
logger.debug(
|
98
|
+
"Custom quick allreduce is only supported on ROCm MI300 series."
|
99
|
+
)
|
100
|
+
return
|
101
|
+
|
102
|
+
if not quick_ar:
|
103
|
+
# disable because of missing quick reduce library
|
104
|
+
# e.g. in a cuda environment
|
105
|
+
logger.info(
|
106
|
+
"Custom quick allreduce is disabled because "
|
107
|
+
"of missing custom quick allreduce library"
|
108
|
+
)
|
109
|
+
return
|
110
|
+
|
111
|
+
self.group = group
|
112
|
+
assert (
|
113
|
+
dist.get_backend(group) != dist.Backend.NCCL
|
114
|
+
), "Custom quick allreduce should be attached to a non-NCCL group."
|
115
|
+
if not all(in_the_same_node_as(group, source_rank=0)):
|
116
|
+
# No need to initialize custom quick allreduce for
|
117
|
+
# multi-node case.
|
118
|
+
logger.warning(
|
119
|
+
"Custom quick allreduce is disabled because this "
|
120
|
+
"process group spans across nodes."
|
121
|
+
)
|
122
|
+
return
|
123
|
+
rank = dist.get_rank(group=self.group)
|
124
|
+
world_size = dist.get_world_size(group=self.group)
|
125
|
+
self.rank = rank
|
126
|
+
self.world_size = world_size
|
127
|
+
if world_size == 1:
|
128
|
+
# No need to initialize QuickReduce for single GPU case.
|
129
|
+
return
|
130
|
+
|
131
|
+
if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES:
|
132
|
+
logger.warning(
|
133
|
+
"Custom quick allreduce is disabled due to an "
|
134
|
+
"unsupported world size: %d. Supported world sizes: %s.",
|
135
|
+
world_size,
|
136
|
+
str(QuickAllReduce._SUPPORTED_WORLD_SIZES),
|
137
|
+
)
|
138
|
+
return
|
139
|
+
|
140
|
+
if isinstance(device, int):
|
141
|
+
device = torch.device(f"cuda:{device}")
|
142
|
+
elif isinstance(device, str):
|
143
|
+
device = torch.device(device)
|
144
|
+
assert isinstance(device, torch.device)
|
145
|
+
self.device = device
|
146
|
+
|
147
|
+
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
148
|
+
if cuda_visible_devices:
|
149
|
+
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
150
|
+
else:
|
151
|
+
device_ids = list(range(torch.cuda.device_count()))
|
152
|
+
physical_device_id = device_ids[device.index]
|
153
|
+
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
|
154
|
+
gather_list = [
|
155
|
+
torch.tensor([0], dtype=torch.int, device="cpu")
|
156
|
+
for _ in range(self.world_size)
|
157
|
+
]
|
158
|
+
dist.all_gather(gather_list, tensor, group=self.group)
|
159
|
+
physical_device_ids = [t.item() for t in gather_list]
|
160
|
+
|
161
|
+
# test nvlink first, this will filter out most of the cases
|
162
|
+
# where custom quick allreduce is not supported
|
163
|
+
# this checks hardware and driver support for NVLink
|
164
|
+
if _is_cuda or _is_hip:
|
165
|
+
self.fully_connected = is_full_nvlink(physical_device_ids, self.world_size)
|
166
|
+
if self.world_size > 2 and not self.fully_connected:
|
167
|
+
logger.debug(
|
168
|
+
"Custom quick allreduce is disabled because it's not supported "
|
169
|
+
"on more than two PCIe-only GPUs. "
|
170
|
+
)
|
171
|
+
return
|
172
|
+
|
173
|
+
self.init_quick_all_reduce()
|
174
|
+
|
175
|
+
def init_quick_all_reduce(self):
|
176
|
+
# On RocM, bfloat16 kernels are slower than fp16
|
177
|
+
# due to slower match operations
|
178
|
+
# If environment variable is set to 1, we convert input to fp16
|
179
|
+
self.use_fp16_kernels = int(
|
180
|
+
os.environ.get("ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", 1)
|
181
|
+
)
|
182
|
+
regime_str = os.environ.get("ROCM_QUICK_REDUCE_QUANTIZATION", "NONE")
|
183
|
+
if regime_str not in QuickReduceRegime.__members__:
|
184
|
+
logger.warning(
|
185
|
+
"Custom quick allreduce:",
|
186
|
+
f"Invalid quantization level: {regime_str}. "
|
187
|
+
"Supported levels: "
|
188
|
+
f"{list(QuickReduceRegime.__members__.keys())}",
|
189
|
+
)
|
190
|
+
return
|
191
|
+
|
192
|
+
if regime_str == "NONE":
|
193
|
+
logger.debug(
|
194
|
+
"Custom quick allreduce is disabled based "
|
195
|
+
"on env variable "
|
196
|
+
"ROCM_QUICK_REDUCE_QUANTIZATION='NONE'"
|
197
|
+
)
|
198
|
+
return
|
199
|
+
self.qr_quant_level = QuickReduceRegime[regime_str]
|
200
|
+
|
201
|
+
# TODO: If the dtype is not bfloat16 or then float16,
|
202
|
+
# quickallreduce should not be created.
|
203
|
+
|
204
|
+
# ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB
|
205
|
+
qr_max_size = int(os.environ.get("ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", 0))
|
206
|
+
if qr_max_size > 0:
|
207
|
+
if qr_max_size < 1:
|
208
|
+
logger.info(
|
209
|
+
"You should not set a max_size smaller than 1MB, which can "
|
210
|
+
"lead to error or degradation to custom allreduce or rccl."
|
211
|
+
)
|
212
|
+
qr_max_size = qr_max_size * MB
|
213
|
+
# If qr_max_size is None, then 2GB is used by default.
|
214
|
+
self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)
|
215
|
+
self.qr_max_size = qr_max_size if qr_max_size > 0 else ops.qr_max_size()
|
216
|
+
self.create_shared_buffer()
|
217
|
+
self.disabled = False
|
218
|
+
|
219
|
+
def create_shared_buffer(self):
|
220
|
+
"""
|
221
|
+
Creates a shared buffer for quickreduce.
|
222
|
+
Has to be called after init_custom_qr
|
223
|
+
"""
|
224
|
+
handle = ops.qr_get_handle(self._ptr)
|
225
|
+
world_size = dist.get_world_size(group=self.group)
|
226
|
+
handles = [None] * world_size
|
227
|
+
dist.all_gather_object(handles, handle, group=self.group)
|
228
|
+
ops.qr_open_handles(self._ptr, handles)
|
229
|
+
|
230
|
+
def should_quick_allreduce(self, inp: torch.Tensor):
|
231
|
+
"""
|
232
|
+
Check if quickreduce is available
|
233
|
+
"""
|
234
|
+
if self.disabled:
|
235
|
+
return False
|
236
|
+
if inp.dtype not in self._SUPPORTED_DTYPES:
|
237
|
+
return False
|
238
|
+
inp_size = inp.numel() * inp.element_size()
|
239
|
+
# custom quick allreduce requires input byte size to be
|
240
|
+
# multiples of 16
|
241
|
+
if inp_size % 16 != 0:
|
242
|
+
return False
|
243
|
+
if not is_weak_contiguous(inp):
|
244
|
+
return False
|
245
|
+
dtype = inp.dtype
|
246
|
+
if self.use_fp16_kernels:
|
247
|
+
dtype = torch.float16
|
248
|
+
return (
|
249
|
+
inp_size <= self.qr_max_size
|
250
|
+
and inp_size
|
251
|
+
>= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value]
|
252
|
+
)
|
253
|
+
|
254
|
+
def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):
|
255
|
+
"""Performs an out-of-place custom quick all reduce."""
|
256
|
+
# quick allreduce doesn't require a separate graph mode,
|
257
|
+
# as QR uses static IPC buffer.
|
258
|
+
if out is None:
|
259
|
+
out = torch.empty_like(inp)
|
260
|
+
ops.qr_all_reduce(
|
261
|
+
self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels
|
262
|
+
)
|
263
|
+
return out
|
264
|
+
|
265
|
+
def close(self):
|
266
|
+
if not self.disabled and getattr(self, "_ptr", None):
|
267
|
+
if ops is not None:
|
268
|
+
ops.qr_destroy(self._ptr)
|
269
|
+
self._ptr = 0
|
270
|
+
self.disabled = True
|
271
|
+
|
272
|
+
def __del__(self):
|
273
|
+
self.close()
|
@@ -16,7 +16,12 @@ from torch.distributed import ProcessGroup
|
|
16
16
|
from zmq import IPV6 # type: ignore
|
17
17
|
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
18
18
|
|
19
|
-
from sglang.srt.utils import
|
19
|
+
from sglang.srt.utils import (
|
20
|
+
format_tcp_address,
|
21
|
+
get_ip,
|
22
|
+
get_open_port,
|
23
|
+
is_valid_ipv6_address,
|
24
|
+
)
|
20
25
|
|
21
26
|
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
|
22
27
|
SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
|
@@ -225,9 +230,9 @@ class MessageQueue:
|
|
225
230
|
remote_subscribe_port = get_open_port()
|
226
231
|
if is_valid_ipv6_address(connect_ip):
|
227
232
|
self.remote_socket.setsockopt(IPV6, 1)
|
228
|
-
|
229
|
-
|
230
|
-
|
233
|
+
self.remote_socket.bind(
|
234
|
+
format_tcp_address(connect_ip, remote_subscribe_port)
|
235
|
+
)
|
231
236
|
|
232
237
|
else:
|
233
238
|
remote_subscribe_port = None
|
@@ -288,7 +293,9 @@ class MessageQueue:
|
|
288
293
|
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
289
294
|
if is_valid_ipv6_address(handle.connect_ip):
|
290
295
|
self.remote_socket.setsockopt(IPV6, 1)
|
291
|
-
socket_addr =
|
296
|
+
socket_addr = format_tcp_address(
|
297
|
+
handle.connect_ip, handle.remote_subscribe_port
|
298
|
+
)
|
292
299
|
logger.debug("Connecting to %s", socket_addr)
|
293
300
|
self.remote_socket.connect(socket_addr)
|
294
301
|
|
@@ -44,6 +44,7 @@ from sglang.srt.utils import (
|
|
44
44
|
get_bool_env_var,
|
45
45
|
get_int_env_var,
|
46
46
|
is_cuda_alike,
|
47
|
+
is_hip,
|
47
48
|
is_npu,
|
48
49
|
is_shm_available,
|
49
50
|
supports_custom_op,
|
@@ -126,14 +127,18 @@ if supports_custom_op():
|
|
126
127
|
fake_impl=inplace_all_reduce_fake,
|
127
128
|
)
|
128
129
|
|
129
|
-
def outplace_all_reduce(
|
130
|
+
def outplace_all_reduce(
|
131
|
+
tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
|
132
|
+
) -> torch.Tensor:
|
130
133
|
assert group_name in _groups, f"Group {group_name} is not found."
|
131
134
|
group = _groups[group_name]()
|
132
135
|
if group is None:
|
133
136
|
raise ValueError(f"Group {group_name} is destroyed.")
|
134
|
-
return group._all_reduce_out_place(tensor)
|
137
|
+
return group._all_reduce_out_place(tensor, outplace_all_reduce_method)
|
135
138
|
|
136
|
-
def outplace_all_reduce_fake(
|
139
|
+
def outplace_all_reduce_fake(
|
140
|
+
tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
|
141
|
+
) -> torch.Tensor:
|
137
142
|
return torch.empty_like(tensor)
|
138
143
|
|
139
144
|
direct_register_custom_op(
|
@@ -264,6 +269,12 @@ class GroupCoordinator:
|
|
264
269
|
PyNcclCommunicator,
|
265
270
|
)
|
266
271
|
|
272
|
+
if is_hip():
|
273
|
+
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
|
274
|
+
QuickAllReduce,
|
275
|
+
qr_rocm_arch_available,
|
276
|
+
)
|
277
|
+
|
267
278
|
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
268
279
|
if use_pynccl and self.world_size > 1:
|
269
280
|
self.pynccl_comm = PyNcclCommunicator(
|
@@ -283,6 +294,7 @@ class GroupCoordinator:
|
|
283
294
|
)
|
284
295
|
|
285
296
|
self.ca_comm: Optional[CustomAllreduce] = None
|
297
|
+
self.qr_comm: Optional[QuickAllReduce] = None
|
286
298
|
if use_custom_allreduce and self.world_size > 1:
|
287
299
|
# Initialize a custom fast all-reduce implementation.
|
288
300
|
try:
|
@@ -295,6 +307,18 @@ class GroupCoordinator:
|
|
295
307
|
f"Setup Custom allreduce failed with {e}. To silence this "
|
296
308
|
"warning, specify --disable-custom-all-reduce explicitly."
|
297
309
|
)
|
310
|
+
if is_hip():
|
311
|
+
try:
|
312
|
+
# Initialize a custom quick all-reduce implementation for AMD
|
313
|
+
# when rocm >= gfx942. Quick reduce is designed as a
|
314
|
+
# complement to custom allreduce.
|
315
|
+
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
|
316
|
+
if qr_rocm_arch_available():
|
317
|
+
self.qr_comm = QuickAllReduce(
|
318
|
+
group=self.cpu_group, device=self.device
|
319
|
+
)
|
320
|
+
except Exception as e:
|
321
|
+
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
298
322
|
|
299
323
|
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
300
324
|
HpuCommunicator,
|
@@ -373,7 +397,8 @@ class GroupCoordinator:
|
|
373
397
|
graph_capture_context = GraphCaptureContext(stream)
|
374
398
|
else:
|
375
399
|
stream = graph_capture_context.stream
|
376
|
-
|
400
|
+
# We don't need the context of custom quick allreduce because the ipc access
|
401
|
+
# is already collected in init() and we can capture the quick allreduce directly.
|
377
402
|
ca_comm = self.ca_comm
|
378
403
|
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
|
379
404
|
|
@@ -388,23 +413,24 @@ class GroupCoordinator:
|
|
388
413
|
# operations. The current status is:
|
389
414
|
# allreduce \ Mode | Eager | Graph |
|
390
415
|
# --------------------------------------------
|
416
|
+
# quick allreduce | enabled | enabled |
|
391
417
|
# custom allreduce | enabled | enabled |
|
392
418
|
# PyNccl | disabled| enabled |
|
393
419
|
# PyMscclpp | disabled| enabled |
|
394
420
|
# torch.distributed | enabled | disabled|
|
395
421
|
#
|
422
|
+
# Note: When custom quick allreduce is enabled, a runtime check
|
423
|
+
# will be performed. If the tensor size is too small, it will
|
424
|
+
# automatically fall back to the next available option.
|
396
425
|
# Note that custom allreduce will have a runtime check, if the
|
397
426
|
# tensor size is too large, it will fallback to the next
|
398
427
|
# available option.
|
399
428
|
# Note that the PyMsccl needs to register the tensor in ahead,
|
400
429
|
# which will introduce large overhead in the eager case,
|
401
430
|
# therefore it is only supported in the graph case.
|
402
|
-
# In summary:
|
403
|
-
#
|
404
|
-
#
|
405
|
-
# PyTorch NCCL. We always prioritize using custom all-reduce
|
406
|
-
# kernel but fall back to PyTorch or pynccl if it is
|
407
|
-
# disabled or not supported.
|
431
|
+
# In summary: We select the appropriate allreduce method for
|
432
|
+
# each mode based on the algorithm order in the table and
|
433
|
+
# their usage conditions.
|
408
434
|
pynccl_comm = self.pynccl_comm
|
409
435
|
maybe_pynccl_context: Any
|
410
436
|
if not pynccl_comm:
|
@@ -464,27 +490,47 @@ class GroupCoordinator:
|
|
464
490
|
if self.npu_communicator is not None and not self.npu_communicator.disabled:
|
465
491
|
return self.npu_communicator.all_reduce(input_)
|
466
492
|
|
493
|
+
outplace_all_reduce_method = None
|
467
494
|
if (
|
495
|
+
self.qr_comm is not None
|
496
|
+
and not self.qr_comm.disabled
|
497
|
+
and self.qr_comm.should_quick_allreduce(input_)
|
498
|
+
):
|
499
|
+
outplace_all_reduce_method = "qr"
|
500
|
+
elif (
|
468
501
|
self.ca_comm is not None
|
469
502
|
and not self.ca_comm.disabled
|
470
503
|
and self.ca_comm.should_custom_ar(input_)
|
471
|
-
)
|
504
|
+
):
|
505
|
+
outplace_all_reduce_method = "ca"
|
506
|
+
elif (
|
472
507
|
self.pymscclpp_comm is not None
|
473
508
|
and not self.pymscclpp_comm.disabled
|
474
509
|
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
475
510
|
):
|
511
|
+
outplace_all_reduce_method = "pymscclpp"
|
512
|
+
if outplace_all_reduce_method is not None:
|
476
513
|
return torch.ops.sglang.outplace_all_reduce(
|
477
|
-
input_,
|
514
|
+
input_,
|
515
|
+
group_name=self.unique_name,
|
516
|
+
outplace_all_reduce_method=outplace_all_reduce_method,
|
478
517
|
)
|
479
518
|
else:
|
480
519
|
torch.ops.sglang.inplace_all_reduce(input_, group_name=self.unique_name)
|
481
520
|
return input_
|
482
521
|
|
483
|
-
def _all_reduce_out_place(
|
522
|
+
def _all_reduce_out_place(
|
523
|
+
self, input_: torch.Tensor, outplace_all_reduce_method: str
|
524
|
+
) -> torch.Tensor:
|
525
|
+
qr_comm = self.qr_comm
|
484
526
|
ca_comm = self.ca_comm
|
485
527
|
pymscclpp_comm = self.pymscclpp_comm
|
486
|
-
assert ca_comm
|
487
|
-
if
|
528
|
+
assert any([qr_comm, ca_comm, pymscclpp_comm])
|
529
|
+
if outplace_all_reduce_method == "qr":
|
530
|
+
assert not qr_comm.disabled
|
531
|
+
out = qr_comm.quick_all_reduce(input_)
|
532
|
+
elif outplace_all_reduce_method == "ca":
|
533
|
+
assert not ca_comm.disabled
|
488
534
|
out = ca_comm.custom_all_reduce(input_)
|
489
535
|
else:
|
490
536
|
assert not pymscclpp_comm.disabled
|
@@ -499,6 +545,15 @@ class GroupCoordinator:
|
|
499
545
|
else:
|
500
546
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
501
547
|
|
548
|
+
def reduce_scatter_tensor(
|
549
|
+
self,
|
550
|
+
output: torch.Tensor,
|
551
|
+
input: torch.Tensor,
|
552
|
+
) -> None:
|
553
|
+
# TODO(ch-wan): support other backends
|
554
|
+
torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group)
|
555
|
+
return output
|
556
|
+
|
502
557
|
def reduce_scatter(
|
503
558
|
self,
|
504
559
|
output: torch.Tensor,
|
@@ -1065,8 +1120,23 @@ def init_model_parallel_group(
|
|
1065
1120
|
|
1066
1121
|
_TP: Optional[GroupCoordinator] = None
|
1067
1122
|
|
1123
|
+
# duplicate GroupCoordinator for prefill in PD-Multiplexing
|
1124
|
+
_PDMUX_PREFILL_TP_GROUP: Optional[GroupCoordinator] = None
|
1125
|
+
|
1126
|
+
_ENABLE_PDMUX_P_TP: bool = False
|
1127
|
+
|
1128
|
+
|
1129
|
+
def set_pdmux_status(enable_prefill_multiplexing: bool):
|
1130
|
+
global _ENABLE_PDMUX_P_TP
|
1131
|
+
_ENABLE_PDMUX_P_TP = enable_prefill_multiplexing
|
1132
|
+
|
1068
1133
|
|
1069
1134
|
def get_tp_group() -> GroupCoordinator:
|
1135
|
+
if _ENABLE_PDMUX_P_TP:
|
1136
|
+
assert (
|
1137
|
+
_PDMUX_PREFILL_TP_GROUP is not None
|
1138
|
+
), "tensor model parallel group for PD-Multiplexing Prefill is not initialized"
|
1139
|
+
return _PDMUX_PREFILL_TP_GROUP
|
1070
1140
|
assert _TP is not None, "tensor model parallel group is not initialized"
|
1071
1141
|
return _TP
|
1072
1142
|
|
@@ -1182,6 +1252,7 @@ def initialize_model_parallel(
|
|
1182
1252
|
tensor_model_parallel_size: int = 1,
|
1183
1253
|
pipeline_model_parallel_size: int = 1,
|
1184
1254
|
backend: Optional[str] = None,
|
1255
|
+
duplicate_tp_group: bool = False,
|
1185
1256
|
) -> None:
|
1186
1257
|
"""
|
1187
1258
|
Initialize model parallel groups.
|
@@ -1239,6 +1310,23 @@ def initialize_model_parallel(
|
|
1239
1310
|
group_name="tp",
|
1240
1311
|
)
|
1241
1312
|
|
1313
|
+
if duplicate_tp_group:
|
1314
|
+
global _PDMUX_PREFILL_TP_GROUP
|
1315
|
+
assert (
|
1316
|
+
_PDMUX_PREFILL_TP_GROUP is None
|
1317
|
+
), "tensor model parallel group for PD-Multiplexing Prefill is already initialized"
|
1318
|
+
_PDMUX_PREFILL_TP_GROUP = init_model_parallel_group(
|
1319
|
+
group_ranks,
|
1320
|
+
get_world_group().local_rank,
|
1321
|
+
backend,
|
1322
|
+
use_message_queue_broadcaster=get_bool_env_var(
|
1323
|
+
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
1324
|
+
),
|
1325
|
+
group_name="pdmux_prefill_tp",
|
1326
|
+
)
|
1327
|
+
_TP.pynccl_comm.disabled = False
|
1328
|
+
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
1329
|
+
|
1242
1330
|
# Build the pipeline model-parallel groups.
|
1243
1331
|
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
1244
1332
|
global _PP
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -46,9 +46,9 @@ from sglang.srt.managers.io_struct import (
|
|
46
46
|
EmbeddingReqInput,
|
47
47
|
GenerateReqInput,
|
48
48
|
GetWeightsByNameReqInput,
|
49
|
-
ImageDataItem,
|
50
49
|
InitWeightsUpdateGroupReqInput,
|
51
50
|
LoadLoRAAdapterReqInput,
|
51
|
+
MultimodalDataInputFormat,
|
52
52
|
ReleaseMemoryOccupationReqInput,
|
53
53
|
ResumeMemoryOccupationReqInput,
|
54
54
|
RpcReqInput,
|
@@ -71,7 +71,6 @@ from sglang.srt.utils import (
|
|
71
71
|
is_cuda,
|
72
72
|
kill_process_tree,
|
73
73
|
launch_dummy_health_check_server,
|
74
|
-
maybe_set_triton_cache_manager,
|
75
74
|
prepare_model_and_tokenizer,
|
76
75
|
set_prometheus_multiproc_dir,
|
77
76
|
set_ulimit,
|
@@ -148,13 +147,9 @@ class Engine(EngineBase):
|
|
148
147
|
# - List of images (one per request in a batch)
|
149
148
|
# - List of lists of images (multiple images per request)
|
150
149
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
151
|
-
image_data: Optional[
|
152
|
-
|
153
|
-
|
154
|
-
List[ImageDataItem],
|
155
|
-
ImageDataItem,
|
156
|
-
]
|
157
|
-
] = None,
|
150
|
+
image_data: Optional[MultimodalDataInputFormat] = None,
|
151
|
+
audio_data: Optional[MultimodalDataInputFormat] = None,
|
152
|
+
video_data: Optional[MultimodalDataInputFormat] = None,
|
158
153
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
159
154
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
160
155
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
@@ -187,6 +182,8 @@ class Engine(EngineBase):
|
|
187
182
|
input_ids=input_ids,
|
188
183
|
sampling_params=sampling_params,
|
189
184
|
image_data=image_data,
|
185
|
+
audio_data=audio_data,
|
186
|
+
video_data=video_data,
|
190
187
|
return_logprob=return_logprob,
|
191
188
|
logprob_start_len=logprob_start_len,
|
192
189
|
top_logprobs_num=top_logprobs_num,
|
@@ -231,13 +228,9 @@ class Engine(EngineBase):
|
|
231
228
|
# - List of images (one per request in a batch)
|
232
229
|
# - List of lists of images (multiple images per request)
|
233
230
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
234
|
-
image_data: Optional[
|
235
|
-
|
236
|
-
|
237
|
-
List[ImageDataItem],
|
238
|
-
ImageDataItem,
|
239
|
-
]
|
240
|
-
] = None,
|
231
|
+
image_data: Optional[MultimodalDataInputFormat] = None,
|
232
|
+
audio_data: Optional[MultimodalDataInputFormat] = None,
|
233
|
+
video_data: Optional[MultimodalDataInputFormat] = None,
|
241
234
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
242
235
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
243
236
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
@@ -272,6 +265,8 @@ class Engine(EngineBase):
|
|
272
265
|
input_ids=input_ids,
|
273
266
|
sampling_params=sampling_params,
|
274
267
|
image_data=image_data,
|
268
|
+
audio_data=audio_data,
|
269
|
+
video_data=video_data,
|
275
270
|
return_logprob=return_logprob,
|
276
271
|
logprob_start_len=logprob_start_len,
|
277
272
|
top_logprobs_num=top_logprobs_num,
|
@@ -295,19 +290,20 @@ class Engine(EngineBase):
|
|
295
290
|
def encode(
|
296
291
|
self,
|
297
292
|
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
298
|
-
image_data: Optional[
|
299
|
-
|
300
|
-
|
301
|
-
List[Union[Image, str]],
|
302
|
-
Union[Image, str],
|
303
|
-
]
|
304
|
-
] = None,
|
293
|
+
image_data: Optional[MultimodalDataInputFormat] = None,
|
294
|
+
audio_data: Optional[MultimodalDataInputFormat] = None,
|
295
|
+
video_data: Optional[MultimodalDataInputFormat] = None,
|
305
296
|
) -> Dict:
|
306
297
|
"""
|
307
298
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
308
299
|
Please refer to `EmbeddingReqInput` for the documentation.
|
309
300
|
"""
|
310
|
-
obj = EmbeddingReqInput(
|
301
|
+
obj = EmbeddingReqInput(
|
302
|
+
text=prompt,
|
303
|
+
image_data=image_data,
|
304
|
+
audio_data=audio_data,
|
305
|
+
video_data=video_data,
|
306
|
+
)
|
311
307
|
loop = asyncio.get_event_loop()
|
312
308
|
generator = self.tokenizer_manager.generate_request(obj, None)
|
313
309
|
ret = loop.run_until_complete(generator.__anext__())
|
@@ -316,7 +312,9 @@ class Engine(EngineBase):
|
|
316
312
|
async def async_encode(
|
317
313
|
self,
|
318
314
|
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
319
|
-
image_data: Optional[
|
315
|
+
image_data: Optional[MultimodalDataInputFormat] = None,
|
316
|
+
audio_data: Optional[MultimodalDataInputFormat] = None,
|
317
|
+
video_data: Optional[MultimodalDataInputFormat] = None,
|
320
318
|
) -> Dict:
|
321
319
|
"""
|
322
320
|
Asynchronous version of encode method.
|
@@ -324,7 +322,12 @@ class Engine(EngineBase):
|
|
324
322
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
325
323
|
Please refer to `EmbeddingReqInput` for the documentation.
|
326
324
|
"""
|
327
|
-
obj = EmbeddingReqInput(
|
325
|
+
obj = EmbeddingReqInput(
|
326
|
+
text=prompt,
|
327
|
+
image_data=image_data,
|
328
|
+
audio_data=audio_data,
|
329
|
+
video_data=video_data,
|
330
|
+
)
|
328
331
|
generator = self.tokenizer_manager.generate_request(obj, None)
|
329
332
|
return await generator.__anext__()
|
330
333
|
|
@@ -633,16 +636,11 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
633
636
|
# Set ulimit
|
634
637
|
set_ulimit()
|
635
638
|
|
636
|
-
# Fix triton bugs
|
637
|
-
if server_args.tp_size * server_args.dp_size > 1:
|
638
|
-
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
639
|
-
maybe_set_triton_cache_manager()
|
640
|
-
|
641
639
|
# Check flashinfer version
|
642
640
|
if server_args.attention_backend == "flashinfer":
|
643
641
|
assert_pkg_version(
|
644
642
|
"flashinfer_python",
|
645
|
-
"0.2.
|
643
|
+
"0.2.9rc1",
|
646
644
|
"Please uninstall the old version and "
|
647
645
|
"reinstall the latest version by following the instructions "
|
648
646
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -650,7 +648,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
650
648
|
if _is_cuda:
|
651
649
|
assert_pkg_version(
|
652
650
|
"sgl-kernel",
|
653
|
-
"0.2.
|
651
|
+
"0.2.7",
|
654
652
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
655
653
|
)
|
656
654
|
|