sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +36 -2
  3. sglang/srt/conversation.py +56 -3
  4. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  5. sglang/srt/disaggregation/ascend/conn.py +44 -0
  6. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +50 -18
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  9. sglang/srt/disaggregation/utils.py +25 -3
  10. sglang/srt/entrypoints/engine.py +1 -1
  11. sglang/srt/entrypoints/http_server.py +1 -0
  12. sglang/srt/entrypoints/http_server_engine.py +1 -1
  13. sglang/srt/entrypoints/openai/protocol.py +11 -0
  14. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  15. sglang/srt/function_call/function_call_parser.py +2 -0
  16. sglang/srt/function_call/kimik2_detector.py +220 -0
  17. sglang/srt/hf_transformers_utils.py +18 -0
  18. sglang/srt/jinja_template_utils.py +8 -0
  19. sglang/srt/layers/communicator.py +20 -5
  20. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  21. sglang/srt/layers/layernorm.py +2 -2
  22. sglang/srt/layers/linear.py +12 -2
  23. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  24. sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
  25. sglang/srt/layers/moe/ep_moe/layer.py +141 -2
  26. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  29. sglang/srt/layers/moe/topk.py +8 -2
  30. sglang/srt/layers/parameter.py +19 -3
  31. sglang/srt/layers/quantization/__init__.py +2 -0
  32. sglang/srt/layers/quantization/fp8.py +28 -7
  33. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  35. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  36. sglang/srt/layers/quantization/w4afp8.py +264 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  38. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  39. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  40. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  41. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  42. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  43. sglang/srt/managers/cache_controller.py +41 -195
  44. sglang/srt/managers/io_struct.py +35 -3
  45. sglang/srt/managers/mm_utils.py +59 -96
  46. sglang/srt/managers/schedule_batch.py +17 -6
  47. sglang/srt/managers/scheduler.py +38 -6
  48. sglang/srt/managers/tokenizer_manager.py +16 -0
  49. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  50. sglang/srt/mem_cache/memory_pool.py +176 -101
  51. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  52. sglang/srt/mem_cache/radix_cache.py +8 -4
  53. sglang/srt/model_executor/forward_batch_info.py +13 -1
  54. sglang/srt/model_loader/loader.py +23 -12
  55. sglang/srt/models/deepseek_janus_pro.py +1 -1
  56. sglang/srt/models/deepseek_v2.py +78 -19
  57. sglang/srt/models/deepseek_vl2.py +1 -1
  58. sglang/srt/models/gemma3_mm.py +1 -1
  59. sglang/srt/models/gemma3n_mm.py +6 -3
  60. sglang/srt/models/internvl.py +8 -2
  61. sglang/srt/models/kimi_vl.py +8 -2
  62. sglang/srt/models/llama.py +2 -0
  63. sglang/srt/models/llava.py +3 -1
  64. sglang/srt/models/llavavid.py +1 -1
  65. sglang/srt/models/minicpmo.py +1 -2
  66. sglang/srt/models/minicpmv.py +1 -1
  67. sglang/srt/models/mixtral_quant.py +4 -0
  68. sglang/srt/models/mllama4.py +372 -82
  69. sglang/srt/models/phi4mm.py +8 -2
  70. sglang/srt/models/phimoe.py +553 -0
  71. sglang/srt/models/qwen2.py +2 -0
  72. sglang/srt/models/qwen2_5_vl.py +10 -7
  73. sglang/srt/models/qwen2_vl.py +12 -1
  74. sglang/srt/models/vila.py +8 -2
  75. sglang/srt/multimodal/mm_utils.py +2 -2
  76. sglang/srt/multimodal/processors/base_processor.py +197 -137
  77. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  78. sglang/srt/multimodal/processors/gemma3.py +4 -2
  79. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  80. sglang/srt/multimodal/processors/internvl.py +1 -1
  81. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  82. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  83. sglang/srt/multimodal/processors/minicpm.py +4 -3
  84. sglang/srt/multimodal/processors/mllama4.py +63 -61
  85. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  86. sglang/srt/multimodal/processors/pixtral.py +1 -1
  87. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  88. sglang/srt/multimodal/processors/vila.py +1 -1
  89. sglang/srt/server_args.py +26 -4
  90. sglang/srt/two_batch_overlap.py +3 -0
  91. sglang/srt/utils.py +191 -48
  92. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  93. sglang/utils.py +5 -5
  94. sglang/version.py +1 -1
  95. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
  96. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
  97. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py CHANGED
@@ -15,7 +15,6 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import base64
19
18
  import builtins
20
19
  import ctypes
21
20
  import dataclasses
@@ -68,6 +67,7 @@ from typing import (
68
67
 
69
68
  import numpy as np
70
69
  import psutil
70
+ import pybase64
71
71
  import requests
72
72
  import torch
73
73
  import torch.distributed
@@ -83,12 +83,7 @@ from torch.func import functional_call
83
83
  from torch.library import Library
84
84
  from torch.profiler import ProfilerActivity, profile, record_function
85
85
  from torch.utils._contextlib import _DecoratorContextManager
86
- from triton.runtime.cache import (
87
- FileCacheManager,
88
- default_cache_dir,
89
- default_dump_dir,
90
- default_override_dir,
91
- )
86
+ from triton.runtime.cache import FileCacheManager
92
87
 
93
88
  logger = logging.getLogger(__name__)
94
89
 
@@ -202,7 +197,7 @@ def get_int_env_var(name: str, default: int = 0) -> int:
202
197
 
203
198
 
204
199
  def support_triton(backend: str) -> bool:
205
- return backend not in ["torch_native", "intel_amx"]
200
+ return backend not in ["torch_native", "intel_amx", "ascend"]
206
201
 
207
202
 
208
203
  try:
@@ -621,7 +616,7 @@ def decode_video_base64(video_base64):
621
616
  from PIL import Image
622
617
 
623
618
  # Decode the base64 string
624
- video_bytes = base64.b64decode(video_base64)
619
+ video_bytes = pybase64.b64decode(video_base64, validate=True)
625
620
 
626
621
  # Placeholder for the start indices of each PNG image
627
622
  img_starts = []
@@ -707,7 +702,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
707
702
  audio, original_sr = sf.read(BytesIO(audio_file))
708
703
  elif audio_file.startswith("data:"):
709
704
  audio_file = audio_file.split(",")[1]
710
- audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
705
+ audio, original_sr = sf.read(
706
+ BytesIO(pybase64.b64decode(audio_file, validate=True))
707
+ )
711
708
  elif audio_file.startswith("http://") or audio_file.startswith("https://"):
712
709
  timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
713
710
  response = requests.get(audio_file, stream=True, timeout=timeout)
@@ -731,33 +728,6 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
731
728
  return audio
732
729
 
733
730
 
734
- def encode_video(video_path, frame_count_limit=None):
735
- # Lazy import because decord is not available on some arm platforms.
736
- from decord import VideoReader, cpu
737
-
738
- if not os.path.exists(video_path):
739
- logger.error(f"Video {video_path} does not exist")
740
- return []
741
-
742
- if frame_count_limit == 0:
743
- return []
744
-
745
- def uniform_sample(l, n):
746
- gap = len(l) / n
747
- idxs = [int(i * gap + gap / 2) for i in range(n)]
748
- return [l[i] for i in idxs]
749
-
750
- vr = VideoReader(video_path, ctx=cpu(0))
751
- sample_fps = round(vr.get_avg_fps() / 1) # FPS
752
- frame_indices = [i for i in range(0, len(vr), sample_fps)]
753
- if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
754
- frame_indices = uniform_sample(frame_indices, frame_count_limit)
755
-
756
- frames = vr.get_batch(frame_indices).asnumpy()
757
- frames = [Image.fromarray(v.astype("uint8")) for v in frames]
758
- return frames
759
-
760
-
761
731
  def load_image(
762
732
  image_file: Union[Image.Image, str, bytes],
763
733
  ) -> tuple[Image.Image, tuple[int, int]]:
@@ -776,18 +746,70 @@ def load_image(
776
746
  image = Image.open(image_file)
777
747
  elif image_file.startswith("data:"):
778
748
  image_file = image_file.split(",")[1]
779
- image = Image.open(BytesIO(base64.b64decode(image_file)))
780
- elif image_file.startswith("video:"):
781
- image_file = image_file.replace("video:", "")
782
- image, image_size = decode_video_base64(image_file)
749
+ image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
783
750
  elif isinstance(image_file, str):
784
- image = Image.open(BytesIO(base64.b64decode(image_file)))
751
+ image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
785
752
  else:
786
753
  raise ValueError(f"Invalid image: {image}")
787
754
 
788
755
  return image, image_size
789
756
 
790
757
 
758
+ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
759
+ # We import decord here to avoid a strange Segmentation fault (core dumped) issue.
760
+ from decord import VideoReader, cpu, gpu
761
+
762
+ try:
763
+ from decord.bridge import decord_bridge
764
+
765
+ ctx = gpu(0)
766
+ _ = decord_bridge.get_ctx_device(ctx)
767
+ except Exception:
768
+ ctx = cpu(0)
769
+
770
+ tmp_file = None
771
+ vr = None
772
+ try:
773
+ if isinstance(video_file, bytes):
774
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
775
+ tmp_file.write(video_file)
776
+ tmp_file.close()
777
+ vr = VideoReader(tmp_file.name, ctx=ctx)
778
+ elif isinstance(video_file, str):
779
+ if video_file.startswith(("http://", "https://")):
780
+ timeout = int(os.getenv("REQUEST_TIMEOUT", "10"))
781
+ response = requests.get(video_file, stream=True, timeout=timeout)
782
+ response.raise_for_status()
783
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
784
+ for chunk in response.iter_content(chunk_size=8192):
785
+ tmp_file.write(chunk)
786
+ tmp_file.close()
787
+ vr = VideoReader(tmp_file.name, ctx=ctx)
788
+ elif video_file.startswith("data:"):
789
+ _, encoded = video_file.split(",", 1)
790
+ video_bytes = base64.b64decode(encoded)
791
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
792
+ tmp_file.write(video_bytes)
793
+ tmp_file.close()
794
+ vr = VideoReader(tmp_file.name, ctx=ctx)
795
+ elif os.path.isfile(video_file):
796
+ vr = VideoReader(video_file, ctx=ctx)
797
+ else:
798
+ video_bytes = base64.b64decode(video_file)
799
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
800
+ tmp_file.write(video_bytes)
801
+ tmp_file.close()
802
+ vr = VideoReader(tmp_file.name, ctx=ctx)
803
+ else:
804
+ raise ValueError(f"Unsupported video input type: {type(video_file)}")
805
+
806
+ return vr
807
+
808
+ finally:
809
+ if tmp_file and os.path.exists(tmp_file.name):
810
+ os.unlink(tmp_file.name)
811
+
812
+
791
813
  def suppress_other_loggers():
792
814
  warnings.filterwarnings(
793
815
  "ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -923,18 +945,41 @@ class CustomCacheManager(FileCacheManager):
923
945
 
924
946
  self.key = key
925
947
  self.lock_path = None
948
+
949
+ try:
950
+ module_path = "triton.runtime.cache"
951
+ cache_module = importlib.import_module(module_path)
952
+
953
+ default_cache_dir = getattr(cache_module, "default_cache_dir", None)
954
+ default_dump_dir = getattr(cache_module, "default_dump_dir", None)
955
+ default_override_dir = getattr(cache_module, "default_override_dir", None)
956
+ except (ModuleNotFoundError, AttributeError) as e:
957
+ default_cache_dir = None
958
+ default_dump_dir = None
959
+ default_override_dir = None
960
+
926
961
  if dump:
927
- self.cache_dir = default_dump_dir()
962
+ self.cache_dir = (
963
+ default_dump_dir()
964
+ if default_dump_dir is not None
965
+ else os.path.join(Path.home(), ".triton", "dump")
966
+ )
928
967
  self.cache_dir = os.path.join(self.cache_dir, self.key)
929
968
  self.lock_path = os.path.join(self.cache_dir, "lock")
930
969
  os.makedirs(self.cache_dir, exist_ok=True)
931
970
  elif override:
932
- self.cache_dir = default_override_dir()
971
+ self.cache_dir = (
972
+ default_override_dir()
973
+ if default_override_dir is not None
974
+ else os.path.join(Path.home(), ".triton", "override")
975
+ )
933
976
  self.cache_dir = os.path.join(self.cache_dir, self.key)
934
977
  else:
935
978
  # create cache directory if it doesn't exist
936
- self.cache_dir = (
937
- os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
979
+ self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or (
980
+ default_cache_dir()
981
+ if default_cache_dir is not None
982
+ else os.path.join(Path.home(), ".triton", "cache")
938
983
  )
939
984
  if self.cache_dir:
940
985
  try:
@@ -1848,7 +1893,7 @@ class MultiprocessingSerializer:
1848
1893
 
1849
1894
  if output_str:
1850
1895
  # Convert bytes to base64-encoded string
1851
- output = base64.b64encode(output).decode("utf-8")
1896
+ output = pybase64.b64encode(output).decode("utf-8")
1852
1897
 
1853
1898
  return output
1854
1899
 
@@ -1865,7 +1910,7 @@ class MultiprocessingSerializer:
1865
1910
  """
1866
1911
  if isinstance(data, str):
1867
1912
  # Decode base64 string to bytes
1868
- data = base64.b64decode(data)
1913
+ data = pybase64.b64decode(data, validate=True)
1869
1914
 
1870
1915
  return ForkingPickler.loads(data)
1871
1916
 
@@ -2737,3 +2782,101 @@ def lru_cache_frozenset(maxsize=128):
2737
2782
  return wrapper
2738
2783
 
2739
2784
  return decorator
2785
+
2786
+
2787
+ def apply_module_patch(target_module, target_function, wrappers):
2788
+ original_module, original_function = parse_module_path(
2789
+ target_module, target_function, False
2790
+ )
2791
+
2792
+ original_function_id = id(original_function)
2793
+
2794
+ candidate = original_function
2795
+ for wrapper in wrappers:
2796
+ candidate = wrapper(candidate)
2797
+ if target_function is not None:
2798
+ setattr(original_module, target_function, candidate)
2799
+
2800
+ for key, value in sys.modules.copy().items():
2801
+ if (
2802
+ target_function is not None
2803
+ and hasattr(value, target_function)
2804
+ and id(getattr(value, target_function)) == original_function_id
2805
+ ):
2806
+ setattr(value, target_function, candidate)
2807
+
2808
+
2809
+ def parse_module_path(module_path, function_name, create_dummy):
2810
+ from importlib.machinery import ModuleSpec
2811
+
2812
+ def create_dummy_module(full_path, parent=None):
2813
+ """Create and register a placeholder module"""
2814
+ dummy = types.ModuleType(full_path)
2815
+ dummy.__file__ = "vllm_ascend.dummy_module.py"
2816
+ dummy.__spec__ = ModuleSpec(full_path, None)
2817
+ sys.modules[full_path] = dummy
2818
+ if parent:
2819
+ setattr(parent, full_path.split(".")[-1], dummy)
2820
+ return dummy
2821
+
2822
+ def create_placeholder_function(func_name):
2823
+ """Create dummy function that raises when called"""
2824
+
2825
+ def placeholder(*args, **kwargs):
2826
+ raise NotImplementedError(f"Function {func_name} is a placeholder")
2827
+
2828
+ placeholder.__name__ = func_name
2829
+ return placeholder
2830
+
2831
+ modules = module_path.split(".")
2832
+ current_module = None
2833
+ processed_path = []
2834
+
2835
+ for idx, part in enumerate(modules):
2836
+ current_path = ".".join(modules[: idx + 1])
2837
+ parent_path = ".".join(modules[:idx]) if idx > 0 else None
2838
+
2839
+ try:
2840
+ current_module = importlib.import_module(current_path)
2841
+ except ModuleNotFoundError:
2842
+ # Handle missing module
2843
+ parent = importlib.import_module(parent_path) if parent_path else None
2844
+ if parent and hasattr(parent, part):
2845
+ # Use existing attribute from parent
2846
+ current_module = getattr(parent, part)
2847
+ # Check for early function resolution
2848
+ if function_name and hasattr(current_module, function_name):
2849
+ return current_module, getattr(current_module, function_name)
2850
+ if function_name and create_dummy:
2851
+ ph_func = create_placeholder_function(function_name)
2852
+ setattr(current_module, function_name, ph_func)
2853
+ return current_module, ph_func
2854
+ if function_name:
2855
+ raise AttributeError(
2856
+ f"Function {function_name} missing in {current_path}"
2857
+ )
2858
+ else:
2859
+ if not create_dummy:
2860
+ raise
2861
+ # Create and register dummy module
2862
+ current_module = create_dummy_module(
2863
+ current_path,
2864
+ parent=(
2865
+ importlib.import_module(parent_path) if parent_path else None
2866
+ ),
2867
+ )
2868
+
2869
+ processed_path.append(part)
2870
+
2871
+ # Final function handling
2872
+ final_module = sys.modules[module_path]
2873
+ if function_name is not None:
2874
+ if not hasattr(final_module, function_name):
2875
+ if create_dummy:
2876
+ ph_func = create_placeholder_function(function_name)
2877
+ setattr(final_module, function_name, ph_func)
2878
+ else:
2879
+ setattr(final_module, function_name, None)
2880
+ return final_module, getattr(final_module, function_name)
2881
+
2882
+ return final_module, None
@@ -0,0 +1,281 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Optional
4
+
5
+ import pytest
6
+ import torch
7
+
8
+ from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
9
+ from sglang.srt.layers.moe.topk import select_experts
10
+
11
+
12
+ def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
13
+ if int4_values_interleaved.shape[-1] % 2 != 0:
14
+ raise ValueError(
15
+ "the last dim size of int4_values_interleaved tensor must be even."
16
+ )
17
+
18
+ input_tensor_int8 = int4_values_interleaved.to(torch.int8)
19
+
20
+ low_nibbles = input_tensor_int8[..., 0::2]
21
+ high_nibbles = input_tensor_int8[..., 1::2]
22
+
23
+ packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F)
24
+
25
+ return packed_tensor.to(torch.int8)
26
+
27
+
28
+ def pack_interleave(num_experts, ref_weight, ref_scale):
29
+ n, k = ref_weight.shape[1], ref_weight.shape[2]
30
+
31
+ weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
32
+ w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
33
+ w_q = w_q.contiguous()
34
+
35
+ scale_interleaved = ref_scale.reshape(
36
+ ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
37
+ ) # [E, N, K/4, 4]
38
+ scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
39
+ scale_interleaved = scale_interleaved.reshape(
40
+ ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
41
+ ) # [E, K/4, N*4]
42
+ w_scale = scale_interleaved.contiguous()
43
+
44
+ return w_q, w_scale
45
+
46
+
47
+ @pytest.mark.parametrize("M", [1, 2, 4, 8, 16])
48
+ @pytest.mark.parametrize("N", [2048])
49
+ @pytest.mark.parametrize("K", [7168])
50
+ @pytest.mark.parametrize("E", [256])
51
+ @pytest.mark.parametrize("ep_size", [8])
52
+ @pytest.mark.parametrize("topk", [8])
53
+ @pytest.mark.parametrize("group_size", [128])
54
+ @pytest.mark.parametrize("dtype", [torch.bfloat16])
55
+ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
56
+ local_e = E // ep_size
57
+
58
+ debug = False
59
+ if debug:
60
+ a = torch.ones((M, K), dtype=dtype, device="cuda") * 0.001
61
+ ref_weight_1 = torch.ones((local_e, N * 2, K), dtype=torch.int8, device="cuda")
62
+ ref_weight_2 = torch.ones((local_e, K, N), dtype=torch.int8, device="cuda")
63
+ a1_scale = torch.ones(1, dtype=torch.float32, device="cuda")
64
+ a2_scale = torch.ones(1, dtype=torch.float32, device="cuda")
65
+ scale_1 = torch.ones(
66
+ (local_e, N * 2, K // group_size), dtype=dtype, device="cuda"
67
+ )
68
+ scale_2 = torch.ones((local_e, K, N // group_size), dtype=dtype, device="cuda")
69
+ else:
70
+ a = torch.randn(M, K, dtype=dtype, device="cuda")
71
+ ref_weight_1 = torch.randint(
72
+ -8, 8, (local_e, N * 2, K), dtype=torch.int8, device="cuda"
73
+ )
74
+ ref_weight_2 = torch.randint(
75
+ -8, 8, (local_e, K, N), dtype=torch.int8, device="cuda"
76
+ )
77
+ affine_coeff = 0.005
78
+ a1_scale = torch.randn(1, dtype=torch.float32, device="cuda")
79
+ a2_scale = torch.randn(1, dtype=torch.float32, device="cuda")
80
+ scale_1 = (
81
+ torch.randn(local_e, N * 2, K // group_size, dtype=dtype, device="cuda")
82
+ * affine_coeff
83
+ )
84
+ scale_2 = (
85
+ torch.randn(local_e, K, N // group_size, dtype=dtype, device="cuda")
86
+ * affine_coeff
87
+ )
88
+
89
+ w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1)
90
+ w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
91
+
92
+ device = "cuda"
93
+ a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
94
+ c_strides1 = torch.full((local_e, 3), 2 * N, device=device, dtype=torch.int64)
95
+ a_strides2 = torch.full((local_e, 3), N, device=device, dtype=torch.int64)
96
+ c_strides2 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
97
+ b_strides1 = a_strides1
98
+ s_strides13 = c_strides1
99
+ b_strides2 = a_strides2
100
+ s_strides2 = c_strides2
101
+
102
+ score = torch.randn((M, E), dtype=dtype, device=device)
103
+ topk_weights, topk_ids = select_experts(
104
+ hidden_states=a,
105
+ router_logits=score,
106
+ top_k=topk,
107
+ use_grouped_topk=False,
108
+ renormalize=False,
109
+ )
110
+ expert_map = torch.arange(E, dtype=torch.int32, device=device)
111
+ expert_map[local_e:] = E
112
+
113
+ output = cutlass_moe(
114
+ a,
115
+ w1_q,
116
+ w2_q,
117
+ w1_scale,
118
+ w2_scale,
119
+ topk_weights,
120
+ topk_ids,
121
+ a_strides1,
122
+ b_strides1,
123
+ c_strides1,
124
+ a_strides2,
125
+ b_strides2,
126
+ c_strides2,
127
+ s_strides13,
128
+ s_strides2,
129
+ 0,
130
+ local_e - 1,
131
+ E,
132
+ a1_scale,
133
+ a2_scale,
134
+ expert_map,
135
+ )
136
+
137
+ ref_output = ref(
138
+ a,
139
+ local_e,
140
+ topk_weights,
141
+ topk_ids,
142
+ ref_weight_1,
143
+ ref_weight_2,
144
+ scale_1,
145
+ scale_2,
146
+ has_pre_quant=True,
147
+ has_alpha=True,
148
+ pre_quant_scale_1=a1_scale,
149
+ pre_quant_scale_2=a2_scale,
150
+ alpha_1=a1_scale,
151
+ alpha_2=a2_scale,
152
+ )
153
+
154
+ # compare
155
+ torch.cuda.synchronize()
156
+
157
+ # compare final output
158
+ torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
159
+ print("SUCCESS: Final output tensors are close.")
160
+
161
+
162
+ def cutlass_moe(
163
+ a: torch.Tensor,
164
+ w1_q: torch.Tensor,
165
+ w2_q: torch.Tensor,
166
+ w1_scale: torch.Tensor,
167
+ w2_scale: torch.Tensor,
168
+ topk_weights: torch.Tensor,
169
+ topk_ids_: torch.Tensor,
170
+ a_strides1: torch.Tensor,
171
+ b_strides1: torch.Tensor,
172
+ c_strides1: torch.Tensor,
173
+ a_strides2: torch.Tensor,
174
+ b_strides2: torch.Tensor,
175
+ c_strides2: torch.Tensor,
176
+ s_strides13: torch.Tensor,
177
+ s_strides2: torch.Tensor,
178
+ start_expert_id: int,
179
+ end_expert_id: int,
180
+ E: int,
181
+ a1_scale: Optional[torch.Tensor] = None,
182
+ a2_scale: Optional[torch.Tensor] = None,
183
+ expert_map: Optional[torch.Tensor] = None,
184
+ apply_router_weight_on_input: bool = False,
185
+ ):
186
+ local_topk_ids = topk_ids_
187
+ local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
188
+ device = a.device
189
+
190
+ local_num_experts = end_expert_id - start_expert_id + 1
191
+ expert_offsets = torch.empty(
192
+ (local_num_experts + 1), dtype=torch.int32, device=device
193
+ )
194
+ problem_sizes1 = torch.empty(
195
+ (local_num_experts, 3), dtype=torch.int32, device=device
196
+ )
197
+ problem_sizes2 = torch.empty(
198
+ (local_num_experts, 3), dtype=torch.int32, device=device
199
+ )
200
+ return cutlass_w4a8_moe(
201
+ start_expert_id,
202
+ end_expert_id,
203
+ E,
204
+ a,
205
+ w1_q,
206
+ w2_q,
207
+ w1_scale,
208
+ w2_scale,
209
+ topk_weights,
210
+ topk_ids_,
211
+ local_topk_ids,
212
+ a_strides1,
213
+ b_strides1,
214
+ c_strides1,
215
+ a_strides2,
216
+ b_strides2,
217
+ c_strides2,
218
+ s_strides13,
219
+ s_strides2,
220
+ expert_offsets,
221
+ problem_sizes1,
222
+ problem_sizes2,
223
+ a1_scale,
224
+ a2_scale,
225
+ apply_router_weight_on_input,
226
+ )
227
+
228
+
229
+ def ref(
230
+ x: torch.Tensor,
231
+ num_experts: int,
232
+ topk_weights: torch.Tensor,
233
+ topk_ids: torch.Tensor,
234
+ ref_weight_1: torch.Tensor,
235
+ ref_weight_2: torch.Tensor,
236
+ ref_weight_scale_1: torch.Tensor,
237
+ ref_weight_scale_2: torch.Tensor,
238
+ has_pre_quant: bool = False,
239
+ has_alpha: bool = False,
240
+ pre_quant_scale_1: Optional[torch.Tensor] = None,
241
+ pre_quant_scale_2: Optional[torch.Tensor] = None,
242
+ alpha_1: Optional[torch.Tensor] = None,
243
+ alpha_2: Optional[torch.Tensor] = None,
244
+ ):
245
+ results = torch.zeros_like(x)
246
+ dtype = x.dtype
247
+ for e_idx in range(num_experts):
248
+ mask = topk_ids == e_idx
249
+ activated_tokens = mask.sum(1).bool()
250
+ act = x[activated_tokens, :]
251
+ if act.shape[0] == 0:
252
+ continue
253
+ final_scale = (topk_weights * mask).sum(1)[activated_tokens].unsqueeze(1)
254
+
255
+ act = (
256
+ torch.clamp((act / pre_quant_scale_1.float()), -448.0, 448.0)
257
+ .to(torch.float8_e4m3fn)
258
+ .to(dtype)
259
+ )
260
+ w3_w1 = ref_weight_1[e_idx]
261
+ ref_w_scale_repeat = (
262
+ ref_weight_scale_1[e_idx].repeat_interleave(128, dim=1).to(float)
263
+ )
264
+ w3_w1 = (w3_w1.to(float) * ref_w_scale_repeat).to(dtype)
265
+ fc1 = ((torch.matmul(act, w3_w1.T)) * alpha_1).to(torch.float16)
266
+
267
+ gate, fc1 = fc1.chunk(2, dim=-1)
268
+ fc1 = fc1 * torch.nn.functional.silu(gate)
269
+ act = (fc1 / pre_quant_scale_2.float()).to(torch.float8_e4m3fn)
270
+ act = act.to(dtype)
271
+
272
+ w2 = ref_weight_2[e_idx]
273
+ ref_w_scale_repeat = (
274
+ ref_weight_scale_2[e_idx].repeat_interleave(128, dim=1).to(float)
275
+ )
276
+ w2 = (w2.to(float) * ref_w_scale_repeat).to(dtype)
277
+ fc2 = (torch.matmul(act, w2.T) * alpha_2).to(torch.float16)
278
+
279
+ results[activated_tokens, :] += (fc2 * final_scale).to(results.dtype)
280
+
281
+ return results
sglang/utils.py CHANGED
@@ -1,6 +1,5 @@
1
1
  """Common utilities"""
2
2
 
3
- import base64
4
3
  import importlib
5
4
  import json
6
5
  import logging
@@ -20,6 +19,7 @@ from json import dumps
20
19
  from typing import Any, Callable, List, Optional, Tuple, Type, Union
21
20
 
22
21
  import numpy as np
22
+ import pybase64
23
23
  import requests
24
24
  from IPython.display import HTML, display
25
25
  from pydantic import BaseModel
@@ -148,15 +148,15 @@ def encode_image_base64(image_path: Union[str, bytes]):
148
148
  if isinstance(image_path, str):
149
149
  with open(image_path, "rb") as image_file:
150
150
  data = image_file.read()
151
- return base64.b64encode(data).decode("utf-8")
151
+ return pybase64.b64encode(data).decode("utf-8")
152
152
  elif isinstance(image_path, bytes):
153
- return base64.b64encode(image_path).decode("utf-8")
153
+ return pybase64.b64encode(image_path).decode("utf-8")
154
154
  else:
155
155
  # image_path is PIL.WebPImagePlugin.WebPImageFile
156
156
  image = image_path
157
157
  buffered = BytesIO()
158
158
  image.save(buffered, format="PNG")
159
- return base64.b64encode(buffered.getvalue()).decode("utf-8")
159
+ return pybase64.b64encode(buffered.getvalue()).decode("utf-8")
160
160
 
161
161
 
162
162
  def encode_frame(frame):
@@ -223,7 +223,7 @@ def encode_video_base64(video_path: str, num_frames: int = 16):
223
223
  video_bytes = b"".join(encoded_frames)
224
224
 
225
225
  # Encode the concatenated bytes to base64
226
- video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8")
226
+ video_base64 = "video:" + pybase64.b64encode(video_bytes).decode("utf-8")
227
227
 
228
228
  return video_base64
229
229
 
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.9"
1
+ __version__ = "0.4.9.post2"