sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +36 -2
- sglang/srt/conversation.py +56 -3
- sglang/srt/disaggregation/ascend/__init__.py +6 -0
- sglang/srt/disaggregation/ascend/conn.py +44 -0
- sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
- sglang/srt/disaggregation/mooncake/conn.py +50 -18
- sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
- sglang/srt/disaggregation/utils.py +25 -3
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +1 -0
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +11 -0
- sglang/srt/entrypoints/openai/serving_chat.py +7 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/kimik2_detector.py +220 -0
- sglang/srt/hf_transformers_utils.py +18 -0
- sglang/srt/jinja_template_utils.py +8 -0
- sglang/srt/layers/communicator.py +20 -5
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/linear.py +12 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
- sglang/srt/layers/moe/ep_moe/layer.py +141 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/topk.py +8 -2
- sglang/srt/layers/parameter.py +19 -3
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -2
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +738 -14
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +35 -3
- sglang/srt/managers/mm_utils.py +59 -96
- sglang/srt/managers/schedule_batch.py +17 -6
- sglang/srt/managers/scheduler.py +38 -6
- sglang/srt/managers/tokenizer_manager.py +16 -0
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +176 -101
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -1
- sglang/srt/model_loader/loader.py +23 -12
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +78 -19
- sglang/srt/models/deepseek_vl2.py +1 -1
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +6 -3
- sglang/srt/models/internvl.py +8 -2
- sglang/srt/models/kimi_vl.py +8 -2
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llava.py +3 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpmo.py +1 -2
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral_quant.py +4 -0
- sglang/srt/models/mllama4.py +372 -82
- sglang/srt/models/phi4mm.py +8 -2
- sglang/srt/models/phimoe.py +553 -0
- sglang/srt/models/qwen2.py +2 -0
- sglang/srt/models/qwen2_5_vl.py +10 -7
- sglang/srt/models/qwen2_vl.py +12 -1
- sglang/srt/models/vila.py +8 -2
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/base_processor.py +197 -137
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
- sglang/srt/multimodal/processors/gemma3.py +4 -2
- sglang/srt/multimodal/processors/gemma3n.py +1 -1
- sglang/srt/multimodal/processors/internvl.py +1 -1
- sglang/srt/multimodal/processors/janus_pro.py +1 -1
- sglang/srt/multimodal/processors/kimi_vl.py +1 -1
- sglang/srt/multimodal/processors/minicpm.py +4 -3
- sglang/srt/multimodal/processors/mllama4.py +63 -61
- sglang/srt/multimodal/processors/phi4mm.py +1 -1
- sglang/srt/multimodal/processors/pixtral.py +1 -1
- sglang/srt/multimodal/processors/qwen_vl.py +203 -80
- sglang/srt/multimodal/processors/vila.py +1 -1
- sglang/srt/server_args.py +26 -4
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +191 -48
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py
CHANGED
@@ -15,7 +15,6 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
import base64
|
19
18
|
import builtins
|
20
19
|
import ctypes
|
21
20
|
import dataclasses
|
@@ -68,6 +67,7 @@ from typing import (
|
|
68
67
|
|
69
68
|
import numpy as np
|
70
69
|
import psutil
|
70
|
+
import pybase64
|
71
71
|
import requests
|
72
72
|
import torch
|
73
73
|
import torch.distributed
|
@@ -83,12 +83,7 @@ from torch.func import functional_call
|
|
83
83
|
from torch.library import Library
|
84
84
|
from torch.profiler import ProfilerActivity, profile, record_function
|
85
85
|
from torch.utils._contextlib import _DecoratorContextManager
|
86
|
-
from triton.runtime.cache import
|
87
|
-
FileCacheManager,
|
88
|
-
default_cache_dir,
|
89
|
-
default_dump_dir,
|
90
|
-
default_override_dir,
|
91
|
-
)
|
86
|
+
from triton.runtime.cache import FileCacheManager
|
92
87
|
|
93
88
|
logger = logging.getLogger(__name__)
|
94
89
|
|
@@ -202,7 +197,7 @@ def get_int_env_var(name: str, default: int = 0) -> int:
|
|
202
197
|
|
203
198
|
|
204
199
|
def support_triton(backend: str) -> bool:
|
205
|
-
return backend not in ["torch_native", "intel_amx"]
|
200
|
+
return backend not in ["torch_native", "intel_amx", "ascend"]
|
206
201
|
|
207
202
|
|
208
203
|
try:
|
@@ -621,7 +616,7 @@ def decode_video_base64(video_base64):
|
|
621
616
|
from PIL import Image
|
622
617
|
|
623
618
|
# Decode the base64 string
|
624
|
-
video_bytes =
|
619
|
+
video_bytes = pybase64.b64decode(video_base64, validate=True)
|
625
620
|
|
626
621
|
# Placeholder for the start indices of each PNG image
|
627
622
|
img_starts = []
|
@@ -707,7 +702,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
|
|
707
702
|
audio, original_sr = sf.read(BytesIO(audio_file))
|
708
703
|
elif audio_file.startswith("data:"):
|
709
704
|
audio_file = audio_file.split(",")[1]
|
710
|
-
audio, original_sr = sf.read(
|
705
|
+
audio, original_sr = sf.read(
|
706
|
+
BytesIO(pybase64.b64decode(audio_file, validate=True))
|
707
|
+
)
|
711
708
|
elif audio_file.startswith("http://") or audio_file.startswith("https://"):
|
712
709
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
|
713
710
|
response = requests.get(audio_file, stream=True, timeout=timeout)
|
@@ -731,33 +728,6 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
|
|
731
728
|
return audio
|
732
729
|
|
733
730
|
|
734
|
-
def encode_video(video_path, frame_count_limit=None):
|
735
|
-
# Lazy import because decord is not available on some arm platforms.
|
736
|
-
from decord import VideoReader, cpu
|
737
|
-
|
738
|
-
if not os.path.exists(video_path):
|
739
|
-
logger.error(f"Video {video_path} does not exist")
|
740
|
-
return []
|
741
|
-
|
742
|
-
if frame_count_limit == 0:
|
743
|
-
return []
|
744
|
-
|
745
|
-
def uniform_sample(l, n):
|
746
|
-
gap = len(l) / n
|
747
|
-
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
748
|
-
return [l[i] for i in idxs]
|
749
|
-
|
750
|
-
vr = VideoReader(video_path, ctx=cpu(0))
|
751
|
-
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
752
|
-
frame_indices = [i for i in range(0, len(vr), sample_fps)]
|
753
|
-
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
|
754
|
-
frame_indices = uniform_sample(frame_indices, frame_count_limit)
|
755
|
-
|
756
|
-
frames = vr.get_batch(frame_indices).asnumpy()
|
757
|
-
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
758
|
-
return frames
|
759
|
-
|
760
|
-
|
761
731
|
def load_image(
|
762
732
|
image_file: Union[Image.Image, str, bytes],
|
763
733
|
) -> tuple[Image.Image, tuple[int, int]]:
|
@@ -776,18 +746,70 @@ def load_image(
|
|
776
746
|
image = Image.open(image_file)
|
777
747
|
elif image_file.startswith("data:"):
|
778
748
|
image_file = image_file.split(",")[1]
|
779
|
-
image = Image.open(BytesIO(
|
780
|
-
elif image_file.startswith("video:"):
|
781
|
-
image_file = image_file.replace("video:", "")
|
782
|
-
image, image_size = decode_video_base64(image_file)
|
749
|
+
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
|
783
750
|
elif isinstance(image_file, str):
|
784
|
-
image = Image.open(BytesIO(
|
751
|
+
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
|
785
752
|
else:
|
786
753
|
raise ValueError(f"Invalid image: {image}")
|
787
754
|
|
788
755
|
return image, image_size
|
789
756
|
|
790
757
|
|
758
|
+
def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
759
|
+
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
|
760
|
+
from decord import VideoReader, cpu, gpu
|
761
|
+
|
762
|
+
try:
|
763
|
+
from decord.bridge import decord_bridge
|
764
|
+
|
765
|
+
ctx = gpu(0)
|
766
|
+
_ = decord_bridge.get_ctx_device(ctx)
|
767
|
+
except Exception:
|
768
|
+
ctx = cpu(0)
|
769
|
+
|
770
|
+
tmp_file = None
|
771
|
+
vr = None
|
772
|
+
try:
|
773
|
+
if isinstance(video_file, bytes):
|
774
|
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
775
|
+
tmp_file.write(video_file)
|
776
|
+
tmp_file.close()
|
777
|
+
vr = VideoReader(tmp_file.name, ctx=ctx)
|
778
|
+
elif isinstance(video_file, str):
|
779
|
+
if video_file.startswith(("http://", "https://")):
|
780
|
+
timeout = int(os.getenv("REQUEST_TIMEOUT", "10"))
|
781
|
+
response = requests.get(video_file, stream=True, timeout=timeout)
|
782
|
+
response.raise_for_status()
|
783
|
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
784
|
+
for chunk in response.iter_content(chunk_size=8192):
|
785
|
+
tmp_file.write(chunk)
|
786
|
+
tmp_file.close()
|
787
|
+
vr = VideoReader(tmp_file.name, ctx=ctx)
|
788
|
+
elif video_file.startswith("data:"):
|
789
|
+
_, encoded = video_file.split(",", 1)
|
790
|
+
video_bytes = base64.b64decode(encoded)
|
791
|
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
792
|
+
tmp_file.write(video_bytes)
|
793
|
+
tmp_file.close()
|
794
|
+
vr = VideoReader(tmp_file.name, ctx=ctx)
|
795
|
+
elif os.path.isfile(video_file):
|
796
|
+
vr = VideoReader(video_file, ctx=ctx)
|
797
|
+
else:
|
798
|
+
video_bytes = base64.b64decode(video_file)
|
799
|
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
800
|
+
tmp_file.write(video_bytes)
|
801
|
+
tmp_file.close()
|
802
|
+
vr = VideoReader(tmp_file.name, ctx=ctx)
|
803
|
+
else:
|
804
|
+
raise ValueError(f"Unsupported video input type: {type(video_file)}")
|
805
|
+
|
806
|
+
return vr
|
807
|
+
|
808
|
+
finally:
|
809
|
+
if tmp_file and os.path.exists(tmp_file.name):
|
810
|
+
os.unlink(tmp_file.name)
|
811
|
+
|
812
|
+
|
791
813
|
def suppress_other_loggers():
|
792
814
|
warnings.filterwarnings(
|
793
815
|
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
@@ -923,18 +945,41 @@ class CustomCacheManager(FileCacheManager):
|
|
923
945
|
|
924
946
|
self.key = key
|
925
947
|
self.lock_path = None
|
948
|
+
|
949
|
+
try:
|
950
|
+
module_path = "triton.runtime.cache"
|
951
|
+
cache_module = importlib.import_module(module_path)
|
952
|
+
|
953
|
+
default_cache_dir = getattr(cache_module, "default_cache_dir", None)
|
954
|
+
default_dump_dir = getattr(cache_module, "default_dump_dir", None)
|
955
|
+
default_override_dir = getattr(cache_module, "default_override_dir", None)
|
956
|
+
except (ModuleNotFoundError, AttributeError) as e:
|
957
|
+
default_cache_dir = None
|
958
|
+
default_dump_dir = None
|
959
|
+
default_override_dir = None
|
960
|
+
|
926
961
|
if dump:
|
927
|
-
self.cache_dir =
|
962
|
+
self.cache_dir = (
|
963
|
+
default_dump_dir()
|
964
|
+
if default_dump_dir is not None
|
965
|
+
else os.path.join(Path.home(), ".triton", "dump")
|
966
|
+
)
|
928
967
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
929
968
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
930
969
|
os.makedirs(self.cache_dir, exist_ok=True)
|
931
970
|
elif override:
|
932
|
-
self.cache_dir =
|
971
|
+
self.cache_dir = (
|
972
|
+
default_override_dir()
|
973
|
+
if default_override_dir is not None
|
974
|
+
else os.path.join(Path.home(), ".triton", "override")
|
975
|
+
)
|
933
976
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
934
977
|
else:
|
935
978
|
# create cache directory if it doesn't exist
|
936
|
-
self.cache_dir = (
|
937
|
-
|
979
|
+
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or (
|
980
|
+
default_cache_dir()
|
981
|
+
if default_cache_dir is not None
|
982
|
+
else os.path.join(Path.home(), ".triton", "cache")
|
938
983
|
)
|
939
984
|
if self.cache_dir:
|
940
985
|
try:
|
@@ -1848,7 +1893,7 @@ class MultiprocessingSerializer:
|
|
1848
1893
|
|
1849
1894
|
if output_str:
|
1850
1895
|
# Convert bytes to base64-encoded string
|
1851
|
-
output =
|
1896
|
+
output = pybase64.b64encode(output).decode("utf-8")
|
1852
1897
|
|
1853
1898
|
return output
|
1854
1899
|
|
@@ -1865,7 +1910,7 @@ class MultiprocessingSerializer:
|
|
1865
1910
|
"""
|
1866
1911
|
if isinstance(data, str):
|
1867
1912
|
# Decode base64 string to bytes
|
1868
|
-
data =
|
1913
|
+
data = pybase64.b64decode(data, validate=True)
|
1869
1914
|
|
1870
1915
|
return ForkingPickler.loads(data)
|
1871
1916
|
|
@@ -2737,3 +2782,101 @@ def lru_cache_frozenset(maxsize=128):
|
|
2737
2782
|
return wrapper
|
2738
2783
|
|
2739
2784
|
return decorator
|
2785
|
+
|
2786
|
+
|
2787
|
+
def apply_module_patch(target_module, target_function, wrappers):
|
2788
|
+
original_module, original_function = parse_module_path(
|
2789
|
+
target_module, target_function, False
|
2790
|
+
)
|
2791
|
+
|
2792
|
+
original_function_id = id(original_function)
|
2793
|
+
|
2794
|
+
candidate = original_function
|
2795
|
+
for wrapper in wrappers:
|
2796
|
+
candidate = wrapper(candidate)
|
2797
|
+
if target_function is not None:
|
2798
|
+
setattr(original_module, target_function, candidate)
|
2799
|
+
|
2800
|
+
for key, value in sys.modules.copy().items():
|
2801
|
+
if (
|
2802
|
+
target_function is not None
|
2803
|
+
and hasattr(value, target_function)
|
2804
|
+
and id(getattr(value, target_function)) == original_function_id
|
2805
|
+
):
|
2806
|
+
setattr(value, target_function, candidate)
|
2807
|
+
|
2808
|
+
|
2809
|
+
def parse_module_path(module_path, function_name, create_dummy):
|
2810
|
+
from importlib.machinery import ModuleSpec
|
2811
|
+
|
2812
|
+
def create_dummy_module(full_path, parent=None):
|
2813
|
+
"""Create and register a placeholder module"""
|
2814
|
+
dummy = types.ModuleType(full_path)
|
2815
|
+
dummy.__file__ = "vllm_ascend.dummy_module.py"
|
2816
|
+
dummy.__spec__ = ModuleSpec(full_path, None)
|
2817
|
+
sys.modules[full_path] = dummy
|
2818
|
+
if parent:
|
2819
|
+
setattr(parent, full_path.split(".")[-1], dummy)
|
2820
|
+
return dummy
|
2821
|
+
|
2822
|
+
def create_placeholder_function(func_name):
|
2823
|
+
"""Create dummy function that raises when called"""
|
2824
|
+
|
2825
|
+
def placeholder(*args, **kwargs):
|
2826
|
+
raise NotImplementedError(f"Function {func_name} is a placeholder")
|
2827
|
+
|
2828
|
+
placeholder.__name__ = func_name
|
2829
|
+
return placeholder
|
2830
|
+
|
2831
|
+
modules = module_path.split(".")
|
2832
|
+
current_module = None
|
2833
|
+
processed_path = []
|
2834
|
+
|
2835
|
+
for idx, part in enumerate(modules):
|
2836
|
+
current_path = ".".join(modules[: idx + 1])
|
2837
|
+
parent_path = ".".join(modules[:idx]) if idx > 0 else None
|
2838
|
+
|
2839
|
+
try:
|
2840
|
+
current_module = importlib.import_module(current_path)
|
2841
|
+
except ModuleNotFoundError:
|
2842
|
+
# Handle missing module
|
2843
|
+
parent = importlib.import_module(parent_path) if parent_path else None
|
2844
|
+
if parent and hasattr(parent, part):
|
2845
|
+
# Use existing attribute from parent
|
2846
|
+
current_module = getattr(parent, part)
|
2847
|
+
# Check for early function resolution
|
2848
|
+
if function_name and hasattr(current_module, function_name):
|
2849
|
+
return current_module, getattr(current_module, function_name)
|
2850
|
+
if function_name and create_dummy:
|
2851
|
+
ph_func = create_placeholder_function(function_name)
|
2852
|
+
setattr(current_module, function_name, ph_func)
|
2853
|
+
return current_module, ph_func
|
2854
|
+
if function_name:
|
2855
|
+
raise AttributeError(
|
2856
|
+
f"Function {function_name} missing in {current_path}"
|
2857
|
+
)
|
2858
|
+
else:
|
2859
|
+
if not create_dummy:
|
2860
|
+
raise
|
2861
|
+
# Create and register dummy module
|
2862
|
+
current_module = create_dummy_module(
|
2863
|
+
current_path,
|
2864
|
+
parent=(
|
2865
|
+
importlib.import_module(parent_path) if parent_path else None
|
2866
|
+
),
|
2867
|
+
)
|
2868
|
+
|
2869
|
+
processed_path.append(part)
|
2870
|
+
|
2871
|
+
# Final function handling
|
2872
|
+
final_module = sys.modules[module_path]
|
2873
|
+
if function_name is not None:
|
2874
|
+
if not hasattr(final_module, function_name):
|
2875
|
+
if create_dummy:
|
2876
|
+
ph_func = create_placeholder_function(function_name)
|
2877
|
+
setattr(final_module, function_name, ph_func)
|
2878
|
+
else:
|
2879
|
+
setattr(final_module, function_name, None)
|
2880
|
+
return final_module, getattr(final_module, function_name)
|
2881
|
+
|
2882
|
+
return final_module, None
|
@@ -0,0 +1,281 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
9
|
+
from sglang.srt.layers.moe.topk import select_experts
|
10
|
+
|
11
|
+
|
12
|
+
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
|
13
|
+
if int4_values_interleaved.shape[-1] % 2 != 0:
|
14
|
+
raise ValueError(
|
15
|
+
"the last dim size of int4_values_interleaved tensor must be even."
|
16
|
+
)
|
17
|
+
|
18
|
+
input_tensor_int8 = int4_values_interleaved.to(torch.int8)
|
19
|
+
|
20
|
+
low_nibbles = input_tensor_int8[..., 0::2]
|
21
|
+
high_nibbles = input_tensor_int8[..., 1::2]
|
22
|
+
|
23
|
+
packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F)
|
24
|
+
|
25
|
+
return packed_tensor.to(torch.int8)
|
26
|
+
|
27
|
+
|
28
|
+
def pack_interleave(num_experts, ref_weight, ref_scale):
|
29
|
+
n, k = ref_weight.shape[1], ref_weight.shape[2]
|
30
|
+
|
31
|
+
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
|
32
|
+
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
|
33
|
+
w_q = w_q.contiguous()
|
34
|
+
|
35
|
+
scale_interleaved = ref_scale.reshape(
|
36
|
+
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
|
37
|
+
) # [E, N, K/4, 4]
|
38
|
+
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
|
39
|
+
scale_interleaved = scale_interleaved.reshape(
|
40
|
+
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
|
41
|
+
) # [E, K/4, N*4]
|
42
|
+
w_scale = scale_interleaved.contiguous()
|
43
|
+
|
44
|
+
return w_q, w_scale
|
45
|
+
|
46
|
+
|
47
|
+
@pytest.mark.parametrize("M", [1, 2, 4, 8, 16])
|
48
|
+
@pytest.mark.parametrize("N", [2048])
|
49
|
+
@pytest.mark.parametrize("K", [7168])
|
50
|
+
@pytest.mark.parametrize("E", [256])
|
51
|
+
@pytest.mark.parametrize("ep_size", [8])
|
52
|
+
@pytest.mark.parametrize("topk", [8])
|
53
|
+
@pytest.mark.parametrize("group_size", [128])
|
54
|
+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
55
|
+
def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
|
56
|
+
local_e = E // ep_size
|
57
|
+
|
58
|
+
debug = False
|
59
|
+
if debug:
|
60
|
+
a = torch.ones((M, K), dtype=dtype, device="cuda") * 0.001
|
61
|
+
ref_weight_1 = torch.ones((local_e, N * 2, K), dtype=torch.int8, device="cuda")
|
62
|
+
ref_weight_2 = torch.ones((local_e, K, N), dtype=torch.int8, device="cuda")
|
63
|
+
a1_scale = torch.ones(1, dtype=torch.float32, device="cuda")
|
64
|
+
a2_scale = torch.ones(1, dtype=torch.float32, device="cuda")
|
65
|
+
scale_1 = torch.ones(
|
66
|
+
(local_e, N * 2, K // group_size), dtype=dtype, device="cuda"
|
67
|
+
)
|
68
|
+
scale_2 = torch.ones((local_e, K, N // group_size), dtype=dtype, device="cuda")
|
69
|
+
else:
|
70
|
+
a = torch.randn(M, K, dtype=dtype, device="cuda")
|
71
|
+
ref_weight_1 = torch.randint(
|
72
|
+
-8, 8, (local_e, N * 2, K), dtype=torch.int8, device="cuda"
|
73
|
+
)
|
74
|
+
ref_weight_2 = torch.randint(
|
75
|
+
-8, 8, (local_e, K, N), dtype=torch.int8, device="cuda"
|
76
|
+
)
|
77
|
+
affine_coeff = 0.005
|
78
|
+
a1_scale = torch.randn(1, dtype=torch.float32, device="cuda")
|
79
|
+
a2_scale = torch.randn(1, dtype=torch.float32, device="cuda")
|
80
|
+
scale_1 = (
|
81
|
+
torch.randn(local_e, N * 2, K // group_size, dtype=dtype, device="cuda")
|
82
|
+
* affine_coeff
|
83
|
+
)
|
84
|
+
scale_2 = (
|
85
|
+
torch.randn(local_e, K, N // group_size, dtype=dtype, device="cuda")
|
86
|
+
* affine_coeff
|
87
|
+
)
|
88
|
+
|
89
|
+
w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1)
|
90
|
+
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
|
91
|
+
|
92
|
+
device = "cuda"
|
93
|
+
a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
|
94
|
+
c_strides1 = torch.full((local_e, 3), 2 * N, device=device, dtype=torch.int64)
|
95
|
+
a_strides2 = torch.full((local_e, 3), N, device=device, dtype=torch.int64)
|
96
|
+
c_strides2 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
|
97
|
+
b_strides1 = a_strides1
|
98
|
+
s_strides13 = c_strides1
|
99
|
+
b_strides2 = a_strides2
|
100
|
+
s_strides2 = c_strides2
|
101
|
+
|
102
|
+
score = torch.randn((M, E), dtype=dtype, device=device)
|
103
|
+
topk_weights, topk_ids = select_experts(
|
104
|
+
hidden_states=a,
|
105
|
+
router_logits=score,
|
106
|
+
top_k=topk,
|
107
|
+
use_grouped_topk=False,
|
108
|
+
renormalize=False,
|
109
|
+
)
|
110
|
+
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
111
|
+
expert_map[local_e:] = E
|
112
|
+
|
113
|
+
output = cutlass_moe(
|
114
|
+
a,
|
115
|
+
w1_q,
|
116
|
+
w2_q,
|
117
|
+
w1_scale,
|
118
|
+
w2_scale,
|
119
|
+
topk_weights,
|
120
|
+
topk_ids,
|
121
|
+
a_strides1,
|
122
|
+
b_strides1,
|
123
|
+
c_strides1,
|
124
|
+
a_strides2,
|
125
|
+
b_strides2,
|
126
|
+
c_strides2,
|
127
|
+
s_strides13,
|
128
|
+
s_strides2,
|
129
|
+
0,
|
130
|
+
local_e - 1,
|
131
|
+
E,
|
132
|
+
a1_scale,
|
133
|
+
a2_scale,
|
134
|
+
expert_map,
|
135
|
+
)
|
136
|
+
|
137
|
+
ref_output = ref(
|
138
|
+
a,
|
139
|
+
local_e,
|
140
|
+
topk_weights,
|
141
|
+
topk_ids,
|
142
|
+
ref_weight_1,
|
143
|
+
ref_weight_2,
|
144
|
+
scale_1,
|
145
|
+
scale_2,
|
146
|
+
has_pre_quant=True,
|
147
|
+
has_alpha=True,
|
148
|
+
pre_quant_scale_1=a1_scale,
|
149
|
+
pre_quant_scale_2=a2_scale,
|
150
|
+
alpha_1=a1_scale,
|
151
|
+
alpha_2=a2_scale,
|
152
|
+
)
|
153
|
+
|
154
|
+
# compare
|
155
|
+
torch.cuda.synchronize()
|
156
|
+
|
157
|
+
# compare final output
|
158
|
+
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
|
159
|
+
print("SUCCESS: Final output tensors are close.")
|
160
|
+
|
161
|
+
|
162
|
+
def cutlass_moe(
|
163
|
+
a: torch.Tensor,
|
164
|
+
w1_q: torch.Tensor,
|
165
|
+
w2_q: torch.Tensor,
|
166
|
+
w1_scale: torch.Tensor,
|
167
|
+
w2_scale: torch.Tensor,
|
168
|
+
topk_weights: torch.Tensor,
|
169
|
+
topk_ids_: torch.Tensor,
|
170
|
+
a_strides1: torch.Tensor,
|
171
|
+
b_strides1: torch.Tensor,
|
172
|
+
c_strides1: torch.Tensor,
|
173
|
+
a_strides2: torch.Tensor,
|
174
|
+
b_strides2: torch.Tensor,
|
175
|
+
c_strides2: torch.Tensor,
|
176
|
+
s_strides13: torch.Tensor,
|
177
|
+
s_strides2: torch.Tensor,
|
178
|
+
start_expert_id: int,
|
179
|
+
end_expert_id: int,
|
180
|
+
E: int,
|
181
|
+
a1_scale: Optional[torch.Tensor] = None,
|
182
|
+
a2_scale: Optional[torch.Tensor] = None,
|
183
|
+
expert_map: Optional[torch.Tensor] = None,
|
184
|
+
apply_router_weight_on_input: bool = False,
|
185
|
+
):
|
186
|
+
local_topk_ids = topk_ids_
|
187
|
+
local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
|
188
|
+
device = a.device
|
189
|
+
|
190
|
+
local_num_experts = end_expert_id - start_expert_id + 1
|
191
|
+
expert_offsets = torch.empty(
|
192
|
+
(local_num_experts + 1), dtype=torch.int32, device=device
|
193
|
+
)
|
194
|
+
problem_sizes1 = torch.empty(
|
195
|
+
(local_num_experts, 3), dtype=torch.int32, device=device
|
196
|
+
)
|
197
|
+
problem_sizes2 = torch.empty(
|
198
|
+
(local_num_experts, 3), dtype=torch.int32, device=device
|
199
|
+
)
|
200
|
+
return cutlass_w4a8_moe(
|
201
|
+
start_expert_id,
|
202
|
+
end_expert_id,
|
203
|
+
E,
|
204
|
+
a,
|
205
|
+
w1_q,
|
206
|
+
w2_q,
|
207
|
+
w1_scale,
|
208
|
+
w2_scale,
|
209
|
+
topk_weights,
|
210
|
+
topk_ids_,
|
211
|
+
local_topk_ids,
|
212
|
+
a_strides1,
|
213
|
+
b_strides1,
|
214
|
+
c_strides1,
|
215
|
+
a_strides2,
|
216
|
+
b_strides2,
|
217
|
+
c_strides2,
|
218
|
+
s_strides13,
|
219
|
+
s_strides2,
|
220
|
+
expert_offsets,
|
221
|
+
problem_sizes1,
|
222
|
+
problem_sizes2,
|
223
|
+
a1_scale,
|
224
|
+
a2_scale,
|
225
|
+
apply_router_weight_on_input,
|
226
|
+
)
|
227
|
+
|
228
|
+
|
229
|
+
def ref(
|
230
|
+
x: torch.Tensor,
|
231
|
+
num_experts: int,
|
232
|
+
topk_weights: torch.Tensor,
|
233
|
+
topk_ids: torch.Tensor,
|
234
|
+
ref_weight_1: torch.Tensor,
|
235
|
+
ref_weight_2: torch.Tensor,
|
236
|
+
ref_weight_scale_1: torch.Tensor,
|
237
|
+
ref_weight_scale_2: torch.Tensor,
|
238
|
+
has_pre_quant: bool = False,
|
239
|
+
has_alpha: bool = False,
|
240
|
+
pre_quant_scale_1: Optional[torch.Tensor] = None,
|
241
|
+
pre_quant_scale_2: Optional[torch.Tensor] = None,
|
242
|
+
alpha_1: Optional[torch.Tensor] = None,
|
243
|
+
alpha_2: Optional[torch.Tensor] = None,
|
244
|
+
):
|
245
|
+
results = torch.zeros_like(x)
|
246
|
+
dtype = x.dtype
|
247
|
+
for e_idx in range(num_experts):
|
248
|
+
mask = topk_ids == e_idx
|
249
|
+
activated_tokens = mask.sum(1).bool()
|
250
|
+
act = x[activated_tokens, :]
|
251
|
+
if act.shape[0] == 0:
|
252
|
+
continue
|
253
|
+
final_scale = (topk_weights * mask).sum(1)[activated_tokens].unsqueeze(1)
|
254
|
+
|
255
|
+
act = (
|
256
|
+
torch.clamp((act / pre_quant_scale_1.float()), -448.0, 448.0)
|
257
|
+
.to(torch.float8_e4m3fn)
|
258
|
+
.to(dtype)
|
259
|
+
)
|
260
|
+
w3_w1 = ref_weight_1[e_idx]
|
261
|
+
ref_w_scale_repeat = (
|
262
|
+
ref_weight_scale_1[e_idx].repeat_interleave(128, dim=1).to(float)
|
263
|
+
)
|
264
|
+
w3_w1 = (w3_w1.to(float) * ref_w_scale_repeat).to(dtype)
|
265
|
+
fc1 = ((torch.matmul(act, w3_w1.T)) * alpha_1).to(torch.float16)
|
266
|
+
|
267
|
+
gate, fc1 = fc1.chunk(2, dim=-1)
|
268
|
+
fc1 = fc1 * torch.nn.functional.silu(gate)
|
269
|
+
act = (fc1 / pre_quant_scale_2.float()).to(torch.float8_e4m3fn)
|
270
|
+
act = act.to(dtype)
|
271
|
+
|
272
|
+
w2 = ref_weight_2[e_idx]
|
273
|
+
ref_w_scale_repeat = (
|
274
|
+
ref_weight_scale_2[e_idx].repeat_interleave(128, dim=1).to(float)
|
275
|
+
)
|
276
|
+
w2 = (w2.to(float) * ref_w_scale_repeat).to(dtype)
|
277
|
+
fc2 = (torch.matmul(act, w2.T) * alpha_2).to(torch.float16)
|
278
|
+
|
279
|
+
results[activated_tokens, :] += (fc2 * final_scale).to(results.dtype)
|
280
|
+
|
281
|
+
return results
|
sglang/utils.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
"""Common utilities"""
|
2
2
|
|
3
|
-
import base64
|
4
3
|
import importlib
|
5
4
|
import json
|
6
5
|
import logging
|
@@ -20,6 +19,7 @@ from json import dumps
|
|
20
19
|
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
21
20
|
|
22
21
|
import numpy as np
|
22
|
+
import pybase64
|
23
23
|
import requests
|
24
24
|
from IPython.display import HTML, display
|
25
25
|
from pydantic import BaseModel
|
@@ -148,15 +148,15 @@ def encode_image_base64(image_path: Union[str, bytes]):
|
|
148
148
|
if isinstance(image_path, str):
|
149
149
|
with open(image_path, "rb") as image_file:
|
150
150
|
data = image_file.read()
|
151
|
-
return
|
151
|
+
return pybase64.b64encode(data).decode("utf-8")
|
152
152
|
elif isinstance(image_path, bytes):
|
153
|
-
return
|
153
|
+
return pybase64.b64encode(image_path).decode("utf-8")
|
154
154
|
else:
|
155
155
|
# image_path is PIL.WebPImagePlugin.WebPImageFile
|
156
156
|
image = image_path
|
157
157
|
buffered = BytesIO()
|
158
158
|
image.save(buffered, format="PNG")
|
159
|
-
return
|
159
|
+
return pybase64.b64encode(buffered.getvalue()).decode("utf-8")
|
160
160
|
|
161
161
|
|
162
162
|
def encode_frame(frame):
|
@@ -223,7 +223,7 @@ def encode_video_base64(video_path: str, num_frames: int = 16):
|
|
223
223
|
video_bytes = b"".join(encoded_frames)
|
224
224
|
|
225
225
|
# Encode the concatenated bytes to base64
|
226
|
-
video_base64 = "video:" +
|
226
|
+
video_base64 = "video:" + pybase64.b64encode(video_bytes).decode("utf-8")
|
227
227
|
|
228
228
|
return video_base64
|
229
229
|
|
sglang/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.4.9"
|
1
|
+
__version__ = "0.4.9.post2"
|