sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +40 -4
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +609 -202
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +28 -14
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +14 -0
- sglang/srt/layers/rotary_embedding.py +75 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +49 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +13 -4
- sglang/srt/models/llama4.py +487 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +227 -0
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -156,6 +156,7 @@ class ServerArgs:
|
|
156
156
|
disable_outlines_disk_cache: bool = False
|
157
157
|
disable_custom_all_reduce: bool = False
|
158
158
|
disable_mla: bool = False
|
159
|
+
enable_llama4_multimodal: Optional[bool] = None
|
159
160
|
disable_overlap_schedule: bool = False
|
160
161
|
enable_mixed_chunk: bool = False
|
161
162
|
enable_dp_attention: bool = False
|
@@ -185,6 +186,7 @@ class ServerArgs:
|
|
185
186
|
warmups: Optional[str] = None
|
186
187
|
n_share_experts_fusion: int = 0
|
187
188
|
disable_shared_experts_fusion: bool = False
|
189
|
+
disable_chunked_prefix_cache: bool = False
|
188
190
|
|
189
191
|
# Debug tensor dumps
|
190
192
|
debug_tensor_dump_output_folder: Optional[str] = None
|
@@ -194,6 +196,10 @@ class ServerArgs:
|
|
194
196
|
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
|
195
197
|
disaggregation_mode: str = "null"
|
196
198
|
disaggregation_bootstrap_port: int = 8998
|
199
|
+
disaggregation_transfer_backend: str = "mooncake"
|
200
|
+
|
201
|
+
# multimodal
|
202
|
+
disable_fast_image_processor: bool = False
|
197
203
|
|
198
204
|
def __post_init__(self):
|
199
205
|
# Expert parallelism
|
@@ -294,6 +300,8 @@ class ServerArgs:
|
|
294
300
|
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
295
301
|
)
|
296
302
|
|
303
|
+
self.enable_multimodal: Optional[bool] = self.enable_llama4_multimodal
|
304
|
+
|
297
305
|
# Data parallelism attention
|
298
306
|
if self.enable_dp_attention:
|
299
307
|
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
@@ -495,6 +503,7 @@ class ServerArgs:
|
|
495
503
|
"bitsandbytes",
|
496
504
|
"gguf",
|
497
505
|
"modelopt",
|
506
|
+
"modelopt_fp4",
|
498
507
|
"w8a8_int8",
|
499
508
|
"w8a8_fp8",
|
500
509
|
"moe_wna16",
|
@@ -973,6 +982,12 @@ class ServerArgs:
|
|
973
982
|
action="store_true",
|
974
983
|
help="Disable Multi-head Latent Attention (MLA) for DeepSeek V2/V3/R1 series models.",
|
975
984
|
)
|
985
|
+
parser.add_argument(
|
986
|
+
"--enable-llama4-multimodal",
|
987
|
+
default=ServerArgs.enable_llama4_multimodal,
|
988
|
+
action="store_true",
|
989
|
+
help="Enable the multimodal functionality for Llama-4.",
|
990
|
+
)
|
976
991
|
parser.add_argument(
|
977
992
|
"--disable-overlap-schedule",
|
978
993
|
action="store_true",
|
@@ -1100,6 +1115,7 @@ class ServerArgs:
|
|
1100
1115
|
"--deepep-mode",
|
1101
1116
|
type=str,
|
1102
1117
|
choices=["normal", "low_latency", "auto"],
|
1118
|
+
default="auto",
|
1103
1119
|
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
1104
1120
|
)
|
1105
1121
|
|
@@ -1115,6 +1131,11 @@ class ServerArgs:
|
|
1115
1131
|
action="store_true",
|
1116
1132
|
help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
|
1117
1133
|
)
|
1134
|
+
parser.add_argument(
|
1135
|
+
"--disable-chunked-prefix-cache",
|
1136
|
+
action="store_true",
|
1137
|
+
help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
|
1138
|
+
)
|
1118
1139
|
|
1119
1140
|
# Server warmups
|
1120
1141
|
parser.add_argument(
|
@@ -1159,6 +1180,19 @@ class ServerArgs:
|
|
1159
1180
|
default=ServerArgs.disaggregation_bootstrap_port,
|
1160
1181
|
help="Bootstrap server port on the prefill server. Default is 8998.",
|
1161
1182
|
)
|
1183
|
+
parser.add_argument(
|
1184
|
+
"--disaggregation-transfer-backend",
|
1185
|
+
type=str,
|
1186
|
+
default=ServerArgs.disaggregation_transfer_backend,
|
1187
|
+
help="The backend for disaggregation transfer. Default is mooncake.",
|
1188
|
+
)
|
1189
|
+
|
1190
|
+
# Multimodal
|
1191
|
+
parser.add_argument(
|
1192
|
+
"--disable-fast-image-processor",
|
1193
|
+
action="store_true",
|
1194
|
+
help="Adopt base image processor instead of fast image processor.",
|
1195
|
+
)
|
1162
1196
|
|
1163
1197
|
@classmethod
|
1164
1198
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -84,10 +84,10 @@ class EAGLEDraftCudaGraphRunner:
|
|
84
84
|
raise Exception(
|
85
85
|
f"Capture cuda graph failed: {e}\n"
|
86
86
|
"Possible solutions:\n"
|
87
|
-
"1.
|
88
|
-
"2.
|
89
|
-
"3.
|
90
|
-
"4.
|
87
|
+
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
88
|
+
"2. disable torch compile by not using --enable-torch-compile\n"
|
89
|
+
"3. specify --dtype to the same dtype (e.g. bfloat16)\n"
|
90
|
+
"4. disable cuda graph by --disable-cuda-graph\n"
|
91
91
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
92
92
|
)
|
93
93
|
|
@@ -19,7 +19,7 @@ from sglang.srt.managers.schedule_batch import (
|
|
19
19
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
20
20
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
21
21
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
22
|
-
from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2
|
22
|
+
from sglang.srt.utils import fast_topk, is_cuda_available, is_hip, next_power_of_2
|
23
23
|
|
24
24
|
if is_cuda_available():
|
25
25
|
from sgl_kernel import (
|
@@ -772,16 +772,6 @@ def select_top_k_tokens(
|
|
772
772
|
return input_ids, hidden_states, scores, tree_info
|
773
773
|
|
774
774
|
|
775
|
-
def fast_topk(values, topk, dim):
|
776
|
-
if topk == 1:
|
777
|
-
# Use max along the specified dimension to get both value and index
|
778
|
-
max_value, max_index = torch.max(values, dim=dim)
|
779
|
-
return max_value.unsqueeze(1), max_index.unsqueeze(1)
|
780
|
-
else:
|
781
|
-
# Use topk for efficiency with larger k values
|
782
|
-
return torch.topk(values, topk, dim=dim)
|
783
|
-
|
784
|
-
|
785
775
|
def _generate_simulated_accept_index(
|
786
776
|
accept_index,
|
787
777
|
predict,
|
@@ -31,11 +31,15 @@ from sglang.srt.speculative.eagle_utils import (
|
|
31
31
|
EagleVerifyInput,
|
32
32
|
EagleVerifyOutput,
|
33
33
|
assign_draft_cache_locs,
|
34
|
-
fast_topk,
|
35
34
|
select_top_k_tokens,
|
36
35
|
)
|
37
36
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
38
|
-
from sglang.srt.utils import
|
37
|
+
from sglang.srt.utils import (
|
38
|
+
empty_context,
|
39
|
+
fast_topk,
|
40
|
+
get_available_gpu_memory,
|
41
|
+
is_cuda_available,
|
42
|
+
)
|
39
43
|
|
40
44
|
if is_cuda_available():
|
41
45
|
from sgl_kernel import segment_packbits
|
sglang/srt/utils.py
CHANGED
@@ -16,6 +16,7 @@ import base64
|
|
16
16
|
import builtins
|
17
17
|
import ctypes
|
18
18
|
import dataclasses
|
19
|
+
import importlib
|
19
20
|
import io
|
20
21
|
import ipaddress
|
21
22
|
import itertools
|
@@ -127,7 +128,7 @@ def is_flashinfer_available():
|
|
127
128
|
"""
|
128
129
|
if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
|
129
130
|
return False
|
130
|
-
return is_cuda()
|
131
|
+
return importlib.util.find_spec("flashinfer") is not None and is_cuda()
|
131
132
|
|
132
133
|
|
133
134
|
def is_cuda_available():
|
@@ -568,7 +569,7 @@ def encode_video(video_path, frame_count_limit=None):
|
|
568
569
|
|
569
570
|
|
570
571
|
def load_image(
|
571
|
-
image_file: Union[Image.Image, str, bytes]
|
572
|
+
image_file: Union[Image.Image, str, bytes],
|
572
573
|
) -> tuple[Image.Image, tuple[int, int]]:
|
573
574
|
image = image_size = None
|
574
575
|
if isinstance(image_file, Image.Image):
|
@@ -845,33 +846,38 @@ def broadcast_pyobj(
|
|
845
846
|
rank: int,
|
846
847
|
dist_group: Optional[torch.distributed.ProcessGroup] = None,
|
847
848
|
src: int = 0,
|
849
|
+
force_cpu_device: bool = True,
|
848
850
|
):
|
849
851
|
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
852
|
+
device = torch.device(
|
853
|
+
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
|
854
|
+
)
|
850
855
|
|
851
856
|
if rank == 0:
|
852
857
|
if len(data) == 0:
|
853
|
-
tensor_size = torch.tensor([0], dtype=torch.long)
|
858
|
+
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
|
854
859
|
dist.broadcast(tensor_size, src=src, group=dist_group)
|
855
860
|
else:
|
856
861
|
serialized_data = pickle.dumps(data)
|
857
862
|
size = len(serialized_data)
|
863
|
+
|
858
864
|
tensor_data = torch.ByteTensor(
|
859
865
|
np.frombuffer(serialized_data, dtype=np.uint8)
|
860
|
-
)
|
861
|
-
tensor_size = torch.tensor([size], dtype=torch.long)
|
866
|
+
).to(device)
|
867
|
+
tensor_size = torch.tensor([size], dtype=torch.long, device=device)
|
862
868
|
|
863
869
|
dist.broadcast(tensor_size, src=src, group=dist_group)
|
864
870
|
dist.broadcast(tensor_data, src=src, group=dist_group)
|
865
871
|
return data
|
866
872
|
else:
|
867
|
-
tensor_size = torch.tensor([0], dtype=torch.long)
|
873
|
+
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
|
868
874
|
dist.broadcast(tensor_size, src=src, group=dist_group)
|
869
875
|
size = tensor_size.item()
|
870
876
|
|
871
877
|
if size == 0:
|
872
878
|
return []
|
873
879
|
|
874
|
-
tensor_data = torch.empty(size, dtype=torch.uint8)
|
880
|
+
tensor_data = torch.empty(size, dtype=torch.uint8, device=device)
|
875
881
|
dist.broadcast(tensor_data, src=src, group=dist_group)
|
876
882
|
|
877
883
|
serialized_data = bytes(tensor_data.cpu().numpy())
|
@@ -1480,14 +1486,43 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor:
|
|
1480
1486
|
|
1481
1487
|
class MultiprocessingSerializer:
|
1482
1488
|
@staticmethod
|
1483
|
-
def serialize(obj):
|
1489
|
+
def serialize(obj, output_str: bool = False):
|
1490
|
+
"""
|
1491
|
+
Serialize a Python object using ForkingPickler.
|
1492
|
+
|
1493
|
+
Args:
|
1494
|
+
obj: The object to serialize.
|
1495
|
+
output_str (bool): If True, return a base64-encoded string instead of raw bytes.
|
1496
|
+
|
1497
|
+
Returns:
|
1498
|
+
bytes or str: The serialized object.
|
1499
|
+
"""
|
1484
1500
|
buf = io.BytesIO()
|
1485
1501
|
ForkingPickler(buf).dump(obj)
|
1486
1502
|
buf.seek(0)
|
1487
|
-
|
1503
|
+
output = buf.read()
|
1504
|
+
|
1505
|
+
if output_str:
|
1506
|
+
# Convert bytes to base64-encoded string
|
1507
|
+
output = base64.b64encode(output).decode("utf-8")
|
1508
|
+
|
1509
|
+
return output
|
1488
1510
|
|
1489
1511
|
@staticmethod
|
1490
1512
|
def deserialize(data):
|
1513
|
+
"""
|
1514
|
+
Deserialize a previously serialized object.
|
1515
|
+
|
1516
|
+
Args:
|
1517
|
+
data (bytes or str): The serialized data, optionally base64-encoded.
|
1518
|
+
|
1519
|
+
Returns:
|
1520
|
+
The deserialized Python object.
|
1521
|
+
"""
|
1522
|
+
if isinstance(data, str):
|
1523
|
+
# Decode base64 string to bytes
|
1524
|
+
data = base64.b64decode(data)
|
1525
|
+
|
1491
1526
|
return ForkingPickler.loads(data)
|
1492
1527
|
|
1493
1528
|
|
@@ -1819,3 +1854,79 @@ class DeepEPMode(Enum):
|
|
1819
1854
|
return DeepEPMode.low_latency
|
1820
1855
|
else:
|
1821
1856
|
return DeepEPMode.normal
|
1857
|
+
|
1858
|
+
|
1859
|
+
def fast_topk(values, topk, dim):
|
1860
|
+
if topk == 1:
|
1861
|
+
# Use max along the specified dimension to get both value and index
|
1862
|
+
return torch.max(values, dim=dim, keepdim=True)
|
1863
|
+
else:
|
1864
|
+
# Use topk for efficiency with larger k values
|
1865
|
+
return torch.topk(values, topk, dim=dim)
|
1866
|
+
|
1867
|
+
|
1868
|
+
def is_hopper_with_cuda_12_3():
|
1869
|
+
if not is_cuda():
|
1870
|
+
return False
|
1871
|
+
is_hopper = torch.cuda.get_device_capability()[0] == 9
|
1872
|
+
cuda_version = torch.version.cuda.split(".")
|
1873
|
+
is_cuda_compatible = int(cuda_version[0]) == 12 and int(cuda_version[1]) >= 3
|
1874
|
+
return is_hopper and is_cuda_compatible
|
1875
|
+
|
1876
|
+
|
1877
|
+
def get_free_port():
|
1878
|
+
# try ipv4
|
1879
|
+
try:
|
1880
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
1881
|
+
s.bind(("", 0))
|
1882
|
+
return s.getsockname()[1]
|
1883
|
+
except OSError:
|
1884
|
+
# try ipv6
|
1885
|
+
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
1886
|
+
s.bind(("", 0))
|
1887
|
+
return s.getsockname()[1]
|
1888
|
+
|
1889
|
+
|
1890
|
+
def get_local_ip_by_remote() -> str:
|
1891
|
+
# try ipv4
|
1892
|
+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
1893
|
+
try:
|
1894
|
+
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
1895
|
+
return s.getsockname()[0]
|
1896
|
+
except Exception:
|
1897
|
+
pass
|
1898
|
+
|
1899
|
+
# try ipv6
|
1900
|
+
try:
|
1901
|
+
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
1902
|
+
# Google's public DNS server, see
|
1903
|
+
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
1904
|
+
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
1905
|
+
return s.getsockname()[0]
|
1906
|
+
except Exception:
|
1907
|
+
raise ValueError(f"Can not get local ip")
|
1908
|
+
|
1909
|
+
|
1910
|
+
def is_page_size_one(server_args):
|
1911
|
+
return server_args.page_size == 1
|
1912
|
+
|
1913
|
+
|
1914
|
+
def is_no_spec_infer_or_topk_one(server_args):
|
1915
|
+
return server_args.speculative_eagle_topk is None or (
|
1916
|
+
server_args.speculative_eagle_topk is not None
|
1917
|
+
and server_args.speculative_eagle_topk == 1
|
1918
|
+
and is_page_size_one(server_args)
|
1919
|
+
)
|
1920
|
+
|
1921
|
+
|
1922
|
+
def is_fa3_default_architecture(hf_config):
|
1923
|
+
architectures = getattr(hf_config, "architectures", None)
|
1924
|
+
if not isinstance(architectures, list) or not architectures:
|
1925
|
+
return False
|
1926
|
+
default_archs = {
|
1927
|
+
"Qwen2ForCausalLM",
|
1928
|
+
"Llama4ForConditionalGeneration",
|
1929
|
+
"LlamaForCausalLM",
|
1930
|
+
"MistralForCausalLM",
|
1931
|
+
}
|
1932
|
+
return architectures[0] in default_archs
|