sglang 0.4.3.post3__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/hf_transformers_utils.py +16 -1
  14. sglang/srt/layers/attention/flashinfer_backend.py +95 -49
  15. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  16. sglang/srt/layers/attention/triton_backend.py +5 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  18. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  19. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  20. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  21. sglang/srt/layers/attention/vision.py +43 -62
  22. sglang/srt/layers/linear.py +1 -1
  23. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  24. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  33. sglang/srt/layers/parameter.py +10 -0
  34. sglang/srt/layers/quantization/__init__.py +90 -68
  35. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  36. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/fp8.py +174 -106
  63. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  64. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  65. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  66. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  67. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  68. sglang/srt/layers/rotary_embedding.py +5 -3
  69. sglang/srt/layers/sampler.py +29 -35
  70. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  71. sglang/srt/lora/backend/__init__.py +9 -12
  72. sglang/srt/managers/cache_controller.py +72 -8
  73. sglang/srt/managers/image_processor.py +37 -631
  74. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  75. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  76. sglang/srt/managers/image_processors/llava.py +152 -0
  77. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  78. sglang/srt/managers/image_processors/mlama.py +60 -0
  79. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  80. sglang/srt/managers/io_struct.py +33 -15
  81. sglang/srt/managers/multi_modality_padding.py +134 -0
  82. sglang/srt/managers/schedule_batch.py +212 -117
  83. sglang/srt/managers/schedule_policy.py +40 -8
  84. sglang/srt/managers/scheduler.py +258 -782
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
  86. sglang/srt/managers/tokenizer_manager.py +7 -6
  87. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  88. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  89. sglang/srt/mem_cache/chunk_cache.py +12 -44
  90. sglang/srt/mem_cache/hiradix_cache.py +63 -34
  91. sglang/srt/mem_cache/memory_pool.py +112 -46
  92. sglang/srt/mem_cache/paged_allocator.py +283 -0
  93. sglang/srt/mem_cache/radix_cache.py +117 -36
  94. sglang/srt/metrics/collector.py +8 -0
  95. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  96. sglang/srt/model_executor/forward_batch_info.py +12 -8
  97. sglang/srt/model_executor/model_runner.py +153 -134
  98. sglang/srt/model_loader/loader.py +2 -1
  99. sglang/srt/model_loader/weight_utils.py +1 -1
  100. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  101. sglang/srt/models/deepseek_nextn.py +23 -3
  102. sglang/srt/models/deepseek_v2.py +25 -19
  103. sglang/srt/models/minicpmv.py +28 -89
  104. sglang/srt/models/mllama.py +1 -1
  105. sglang/srt/models/qwen2.py +0 -1
  106. sglang/srt/models/qwen2_5_vl.py +25 -50
  107. sglang/srt/models/qwen2_vl.py +33 -49
  108. sglang/srt/openai_api/adapter.py +37 -15
  109. sglang/srt/openai_api/protocol.py +8 -1
  110. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  111. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  112. sglang/srt/server_args.py +19 -20
  113. sglang/srt/speculative/build_eagle_tree.py +6 -1
  114. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
  115. sglang/srt/speculative/eagle_utils.py +2 -1
  116. sglang/srt/speculative/eagle_worker.py +109 -38
  117. sglang/srt/utils.py +104 -9
  118. sglang/test/runners.py +104 -10
  119. sglang/test/test_block_fp8.py +106 -16
  120. sglang/test/test_custom_ops.py +88 -0
  121. sglang/test/test_utils.py +20 -4
  122. sglang/utils.py +0 -4
  123. sglang/version.py +1 -1
  124. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -9
  125. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/RECORD +128 -83
  126. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
  127. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py CHANGED
@@ -14,6 +14,7 @@
14
14
  """Common utilities."""
15
15
 
16
16
  import base64
17
+ import builtins
17
18
  import ctypes
18
19
  import dataclasses
19
20
  import io
@@ -37,6 +38,7 @@ import time
37
38
  import warnings
38
39
  from functools import lru_cache
39
40
  from importlib.metadata import PackageNotFoundError, version
41
+ from importlib.util import find_spec
40
42
  from io import BytesIO
41
43
  from multiprocessing import Pool
42
44
  from multiprocessing.reduction import ForkingPickler
@@ -52,11 +54,13 @@ import triton
52
54
  import zmq
53
55
  from fastapi.responses import ORJSONResponse
54
56
  from packaging import version as pkg_version
57
+ from packaging.version import Version, parse
55
58
  from starlette.routing import Mount
56
59
  from torch import nn
57
60
  from torch.func import functional_call
58
61
  from torch.library import Library
59
62
  from torch.profiler import ProfilerActivity, profile, record_function
63
+ from torch.utils.cpp_extension import CUDA_HOME
60
64
  from triton.runtime.cache import (
61
65
  FileCacheManager,
62
66
  default_cache_dir,
@@ -69,14 +73,31 @@ logger = logging.getLogger(__name__)
69
73
  show_time_cost = False
70
74
  time_infos = {}
71
75
 
76
+ HIP_FP8_E4M3_FNUZ_MAX = 224.0
72
77
 
78
+
79
+ # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
73
80
  def is_hip() -> bool:
74
- """Return whether it is HIP on the AMD ROCm platform."""
75
81
  return torch.version.hip is not None
76
82
 
77
83
 
84
+ if is_hip():
85
+ FP8_E4M3_MAX = HIP_FP8_E4M3_FNUZ_MAX
86
+ else:
87
+ FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
88
+
89
+ FP8_E4M3_MIN = -FP8_E4M3_MAX
90
+
91
+ builtins.FP8_E4M3_MAX = FP8_E4M3_MAX
92
+ builtins.FP8_E4M3_MIN = FP8_E4M3_MIN
93
+
94
+
95
+ def is_rocm() -> bool:
96
+ return torch.cuda.is_available() and torch.version.hip
97
+
98
+
78
99
  def is_cuda():
79
- return hasattr(torch, "cuda") and torch.version.cuda is not None
100
+ return torch.cuda.is_available() and torch.version.cuda
80
101
 
81
102
 
82
103
  def is_cuda_alike():
@@ -98,11 +119,11 @@ def is_flashinfer_available():
98
119
  """
99
120
  if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
100
121
  return False
101
- return torch.cuda.is_available() and torch.version.cuda
122
+ return is_cuda()
102
123
 
103
124
 
104
125
  def is_cuda_available():
105
- return torch.cuda.is_available() and torch.version.cuda
126
+ return is_cuda()
106
127
 
107
128
 
108
129
  def enable_show_time_cost():
@@ -1045,6 +1066,65 @@ def get_device_name(device_id: int = 0) -> str:
1045
1066
  return torch.hpu.get_device_name(device_id)
1046
1067
 
1047
1068
 
1069
+ @lru_cache(maxsize=1)
1070
+ def is_habana_available() -> bool:
1071
+ return find_spec("habana_frameworks") is not None
1072
+
1073
+
1074
+ @lru_cache(maxsize=8)
1075
+ def get_device(device_id: Optional[int] = None) -> str:
1076
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
1077
+ if device_id is None:
1078
+ return "cuda"
1079
+ return "cuda:{}".format(device_id)
1080
+
1081
+ if hasattr(torch, "xpu") and torch.xpu.is_available():
1082
+ if device_id == None:
1083
+ return "xpu"
1084
+ return "xpu:{}".format(device_id)
1085
+
1086
+ if is_habana_available():
1087
+ try:
1088
+ import habana_frameworks.torch.hpu
1089
+
1090
+ if torch.hpu.is_available():
1091
+ if device_id == None:
1092
+ return "hpu"
1093
+ return "hpu:{}".format(device_id)
1094
+ except ImportError as e:
1095
+ raise ImportError(
1096
+ "Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
1097
+ )
1098
+
1099
+ raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
1100
+
1101
+
1102
+ @lru_cache(maxsize=1)
1103
+ def get_device_count() -> int:
1104
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
1105
+ try:
1106
+ return torch.cuda.device_count()
1107
+ except RuntimeError:
1108
+ return 0
1109
+
1110
+ if hasattr(torch, "xpu") and torch.xpu.is_available():
1111
+ try:
1112
+ return torch.xpu.device_count()
1113
+ except RuntimeError:
1114
+ return 0
1115
+
1116
+ if is_habana_available():
1117
+ try:
1118
+ import habana_frameworks.torch.hpu
1119
+
1120
+ if torch.hpu.is_available():
1121
+ return torch.hpu.device_count()
1122
+ except (ImportError, RuntimeError):
1123
+ return 0
1124
+
1125
+ return 0 # No accelerators available
1126
+
1127
+
1048
1128
  def get_device_core_count(device_id: int = 0) -> int:
1049
1129
  if hasattr(torch, "cuda") and torch.cuda.is_available():
1050
1130
  return torch.cuda.get_device_properties(device_id).multi_processor_count
@@ -1063,11 +1143,12 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
1063
1143
  )
1064
1144
  major, minor = int(major), int(minor)
1065
1145
 
1066
- # TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
1067
- # Update this once the support is available.
1068
1146
  if hasattr(torch, "hpu") and torch.hpu.is_available():
1069
1147
  try:
1070
- major, minor = torch.hpu.get_device_capability(device_id)
1148
+ # TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
1149
+ # Update this once the support is available.
1150
+ # major, minor = torch.hpu.get_device_capability(device_id)
1151
+ major, minor = None, None
1071
1152
  except Exception as e:
1072
1153
  raise RuntimeError(
1073
1154
  f"An error occurred while getting device capability of hpu: {e}."
@@ -1269,7 +1350,8 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor:
1269
1350
  elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
1270
1351
  x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
1271
1352
  else:
1272
- return x_
1353
+ # return x_
1354
+ x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 8), 2, 4)
1273
1355
 
1274
1356
  x_ = x_.permute(0, 1, 3, 4, 2, 5)
1275
1357
  x_ = x_.contiguous()
@@ -1341,7 +1423,7 @@ def kill_itself_when_parent_died():
1341
1423
  libc = ctypes.CDLL("libc.so.6")
1342
1424
  libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL)
1343
1425
  else:
1344
- logger.warninig("kill_itself_when_parent_died is only supported in linux.")
1426
+ logger.warning("kill_itself_when_parent_died is only supported in linux.")
1345
1427
 
1346
1428
 
1347
1429
  def set_uvicorn_logging_configs():
@@ -1430,6 +1512,12 @@ def rank0_print(msg: str):
1430
1512
  print(msg, flush=True)
1431
1513
 
1432
1514
 
1515
+ def get_cuda_version():
1516
+ if torch.version.cuda:
1517
+ return tuple(map(int, torch.version.cuda.split(".")))
1518
+ return (0, 0)
1519
+
1520
+
1433
1521
  def launch_dummy_health_check_server(host, port):
1434
1522
  import uvicorn
1435
1523
  from fastapi import FastAPI, Response
@@ -1466,6 +1554,13 @@ def set_cuda_arch():
1466
1554
  os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"
1467
1555
 
1468
1556
 
1557
+ def next_power_of_2(n: int):
1558
+ return 1 << (n - 1).bit_length() if n > 0 else 1
1559
+
1560
+
1561
+ setattr(triton, "next_power_of_2", next_power_of_2)
1562
+
1563
+
1469
1564
  def add_prefix(name: str, prefix: str) -> str:
1470
1565
  """Add a weight path prefix to a module name.
1471
1566
 
sglang/test/runners.py CHANGED
@@ -19,7 +19,7 @@ from typing import List, Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
  import torch.nn.functional as F
22
- from transformers import AutoModelForCausalLM
22
+ from transformers import AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor
23
23
 
24
24
  from sglang.srt.hf_transformers_utils import get_tokenizer
25
25
  from sglang.srt.server import Engine
@@ -135,6 +135,76 @@ class HFRunner:
135
135
  return True
136
136
  return False
137
137
 
138
+ # copy from https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py
139
+
140
+ def _get_gme_qwen2_vl_embeddings(
141
+ self, prompts, image_data: Optional[List[str]] = None
142
+ ):
143
+ from sglang.srt.utils import load_image
144
+
145
+ images = None
146
+ if image_data is not None:
147
+ images = [load_image(image)[0] for image in image_data]
148
+
149
+ inputs = self.processor(
150
+ text=prompts,
151
+ images=images,
152
+ padding=True,
153
+ truncation=True,
154
+ max_length=1800,
155
+ return_tensors="pt",
156
+ )
157
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
158
+ with torch.no_grad():
159
+ embeddings = self._forward_gme_qwen2_vl(**inputs)
160
+ return embeddings.tolist()
161
+
162
+ def _forward_gme_qwen2_vl(
163
+ self,
164
+ input_ids: Optional[torch.LongTensor] = None,
165
+ attention_mask: Optional[torch.Tensor] = None,
166
+ position_ids: Optional[torch.LongTensor] = None,
167
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
168
+ inputs_embeds: Optional[torch.FloatTensor] = None,
169
+ pixel_values: Optional[torch.Tensor] = None,
170
+ image_grid_thw: Optional[torch.LongTensor] = None,
171
+ pooling_mask: Optional[torch.LongTensor] = None,
172
+ **kwargs,
173
+ ) -> torch.Tensor:
174
+ if inputs_embeds is None:
175
+ inputs_embeds = self.model.model.embed_tokens(input_ids)
176
+ if pixel_values is not None:
177
+ pixel_values = pixel_values.type(self.model.visual.get_dtype())
178
+ image_embeds = self.model.visual(
179
+ pixel_values, grid_thw=image_grid_thw
180
+ ).to(inputs_embeds.device)
181
+ image_mask = input_ids == self.model.config.image_token_id
182
+ inputs_embeds[image_mask] = image_embeds
183
+ if attention_mask is not None:
184
+ attention_mask = attention_mask.to(inputs_embeds.device)
185
+
186
+ outputs = self.model.model(
187
+ input_ids=None,
188
+ position_ids=position_ids,
189
+ attention_mask=attention_mask,
190
+ past_key_values=past_key_values,
191
+ inputs_embeds=inputs_embeds,
192
+ )
193
+
194
+ pooling_mask = attention_mask if pooling_mask is None else pooling_mask
195
+ left_padding = pooling_mask[:, -1].sum() == pooling_mask.shape[0] # TODO
196
+ if left_padding:
197
+ embeddings = outputs.last_hidden_state[:, -1]
198
+ else:
199
+ sequence_lengths = pooling_mask.sum(dim=1) - 1
200
+ batch_size = outputs.last_hidden_state.shape[0]
201
+ embeddings = outputs.last_hidden_state[
202
+ torch.arange(batch_size, device=outputs.last_hidden_state.device),
203
+ sequence_lengths,
204
+ ]
205
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
206
+ return embeddings.contiguous()
207
+
138
208
  def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
139
209
  # Apply model-specific patches
140
210
  monkey_patch_gemma2_sdpa()
@@ -148,9 +218,18 @@ class HFRunner:
148
218
  low_cpu_mem_usage=True,
149
219
  ).cuda()
150
220
  elif self.model_type == "embedding":
151
- self.model = _get_sentence_transformer_embedding_model(
152
- model_path, torch_dtype
153
- )
221
+ if "gme-qwen2-vl" in model_path.lower():
222
+ self.model = AutoModelForVision2Seq.from_pretrained(
223
+ model_path,
224
+ torch_dtype=torch_dtype,
225
+ trust_remote_code=False,
226
+ low_cpu_mem_usage=True,
227
+ ).cuda()
228
+ self.processor = AutoProcessor.from_pretrained(model_path)
229
+ else:
230
+ self.model = _get_sentence_transformer_embedding_model(
231
+ model_path, torch_dtype
232
+ )
154
233
  elif self.model_type == "reward":
155
234
  from transformers import AutoModelForSequenceClassification
156
235
 
@@ -169,7 +248,9 @@ class HFRunner:
169
248
 
170
249
  # Run forward
171
250
  while True:
172
- prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get()
251
+ prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob = (
252
+ in_queue.get()
253
+ )
173
254
  if lora_paths is not None:
174
255
  assert len(prompts) == len(lora_paths)
175
256
 
@@ -189,7 +270,10 @@ class HFRunner:
189
270
  )
190
271
  elif self.model_type == "embedding":
191
272
  assert not self.output_str_only
192
- logits = self.model.encode(prompts).tolist()
273
+ if "gme-qwen2-vl" in model_path.lower():
274
+ logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data)
275
+ else:
276
+ logits = self.model.encode(prompts).tolist()
193
277
  out_queue.put(ModelOutput(embed_logits=logits))
194
278
 
195
279
  elif self.model_type == "reward":
@@ -211,11 +295,14 @@ class HFRunner:
211
295
  def forward(
212
296
  self,
213
297
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
298
+ image_data: Optional[List[str]] = None,
214
299
  max_new_tokens: int = 8,
215
300
  lora_paths: Optional[List[str]] = None,
216
301
  token_ids_logprob: Optional[int] = None,
217
302
  ):
218
- self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob))
303
+ self.in_queue.put(
304
+ (prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob)
305
+ )
219
306
  return self.out_queue.get()
220
307
 
221
308
  def terminate(self):
@@ -396,6 +483,7 @@ class SRTRunner:
396
483
  def forward(
397
484
  self,
398
485
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
486
+ image_data: Optional[List[str]] = None,
399
487
  max_new_tokens: int = 8,
400
488
  lora_paths: Optional[List[str]] = None,
401
489
  logprob_start_len: int = 0,
@@ -413,17 +501,23 @@ class SRTRunner:
413
501
  token_ids_logprob=token_ids_logprob,
414
502
  )
415
503
  else:
416
- response = self.engine.encode(prompts)
417
504
  if self.model_type == "embedding":
418
- logits = [x["embedding"] for x in response]
505
+ response = self.engine.encode(prompt=prompts, image_data=image_data)
506
+ if isinstance(response, list):
507
+ logits = [x["embedding"] for x in response]
508
+ else:
509
+ logits = [response["embedding"]]
419
510
  return ModelOutput(embed_logits=logits)
511
+ # reward model
420
512
  else:
513
+ response = self.engine.encode(prompts)
421
514
  scores = [x["embedding"][0] for x in response]
422
515
  return ModelOutput(scores=scores)
423
516
 
424
517
  def batch_forward(
425
518
  self,
426
519
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
520
+ image_data: Optional[List[str]] = None,
427
521
  max_new_tokens=8,
428
522
  lora_paths=None,
429
523
  ):
@@ -439,7 +533,7 @@ class SRTRunner:
439
533
  lora_paths=lora_paths,
440
534
  )
441
535
  else:
442
- response = self.engine.encode(prompts)
536
+ response = self.engine.encode(prompts, image_data)
443
537
  if self.model_type == "embedding":
444
538
  logits = [x["embedding"] for x in response]
445
539
  return ModelOutput(embed_logits=logits)
@@ -1,4 +1,5 @@
1
1
  import itertools
2
+ import os
2
3
  import unittest
3
4
 
4
5
  import torch
@@ -7,9 +8,12 @@ from sglang.srt.layers.activation import SiluAndMul
7
8
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
8
9
  from sglang.srt.layers.quantization.fp8_kernel import (
9
10
  per_token_group_quant_fp8,
11
+ static_quant_fp8,
10
12
  w8a8_block_fp8_matmul,
11
13
  )
12
14
 
15
+ _is_cuda = torch.cuda.is_available() and torch.version.cuda
16
+
13
17
 
14
18
  # For test
15
19
  def native_per_token_group_quant_fp8(
@@ -63,7 +67,7 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
63
67
  out, scale = per_token_group_quant_fp8(x, group_size)
64
68
 
65
69
  self.assertTrue(
66
- torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
70
+ torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20)
67
71
  )
68
72
  self.assertTrue(torch.allclose(scale, ref_scale))
69
73
 
@@ -85,6 +89,71 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
85
89
  self._per_token_group_quant_fp8(*params)
86
90
 
87
91
 
92
+ # For test
93
+ def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn):
94
+ """Function to perform static quantization on an input tensor `x` using native torch.
95
+
96
+ It converts the tensor values into float8 values and returns the
97
+ quantized tensor along with the scaling factor used for quantization.
98
+ """
99
+ assert x.is_contiguous(), "`x` is not contiguous"
100
+ assert x_s.numel() == 1, "only supports per-tensor scale"
101
+
102
+ finfo = torch.finfo(dtype)
103
+ fp8_min = finfo.min
104
+ fp8_max = finfo.max
105
+
106
+ x_ = x.reshape(x.numel() // x.shape[-1], x.shape[-1])
107
+ x_s_inv = 1.0 / x_s
108
+ x_q = (x_ * x_s_inv).clamp(min=fp8_min, max=fp8_max).to(dtype)
109
+ x_q = x_q.reshape(x.shape)
110
+
111
+ return x_q, x_s
112
+
113
+
114
+ class TestStaticQuantFP8(unittest.TestCase):
115
+ DTYPES = [torch.half, torch.bfloat16, torch.float32]
116
+ NUM_TOKENS = [7, 83, 2048]
117
+ D = [512, 4096, 5120, 13824]
118
+ SEEDS = [0]
119
+
120
+ @classmethod
121
+ def setUpClass(cls):
122
+ if not torch.cuda.is_available():
123
+ raise unittest.SkipTest("CUDA is not available")
124
+ torch.set_default_device("cuda")
125
+
126
+ def _static_quant_fp8(self, num_tokens, d, dtype, seed):
127
+ torch.manual_seed(seed)
128
+
129
+ x = torch.rand(num_tokens, d, dtype=dtype)
130
+ fp8_max = torch.finfo(torch.float8_e4m3fn).max
131
+ x_s = x.max() / fp8_max
132
+
133
+ with torch.inference_mode():
134
+ ref_out, _ = native_static_quant_fp8(x, x_s)
135
+ out, _ = static_quant_fp8(x, x_s, repeat_scale=True)
136
+
137
+ self.assertTrue(
138
+ torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50)
139
+ )
140
+
141
+ def test_static_quant_fp8(self):
142
+ for params in itertools.product(
143
+ self.NUM_TOKENS,
144
+ self.D,
145
+ self.DTYPES,
146
+ self.SEEDS,
147
+ ):
148
+ with self.subTest(
149
+ num_tokens=params[0],
150
+ d=params[1],
151
+ dtype=params[2],
152
+ seed=params[3],
153
+ ):
154
+ self._static_quant_fp8(*params)
155
+
156
+
88
157
  # For test
89
158
  def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
90
159
  """This function performs matrix multiplication with block-wise quantization using native torch.
@@ -142,13 +211,35 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
142
211
 
143
212
 
144
213
  class TestW8A8BlockFP8Matmul(unittest.TestCase):
145
- OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
146
- M = [1, 7, 83, 512, 2048]
147
- N = [128, 512, 1024, 4096, 7748, 13824]
148
- K = [256, 4096, 5120, 3884, 13824]
149
- # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
150
- BLOCK_SIZE = [[128, 128]]
151
- SEEDS = [0]
214
+
215
+ if not _is_cuda:
216
+ OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
217
+ M = [1, 7, 83, 512, 2048]
218
+ NKs = [
219
+ (N, K)
220
+ for N in [128, 512, 1024, 4096, 7748, 13824]
221
+ for K in [256, 4096, 5120, 3884, 13824]
222
+ ]
223
+ # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
224
+ BLOCK_SIZE = [[128, 128]]
225
+ SEEDS = [0]
226
+ else:
227
+ # use practical shape in DeepSeek V3 for test
228
+ OUT_DTYPES = [torch.bfloat16]
229
+ M = [64, 128, 512, 1024, 4096]
230
+ NKs = [
231
+ (1536, 7168),
232
+ (3072, 1536),
233
+ (24576, 7168),
234
+ (4096, 512),
235
+ (7168, 2048),
236
+ (4608, 7168),
237
+ (512, 7168),
238
+ (7168, 2304),
239
+ (7168, 512),
240
+ ]
241
+ BLOCK_SIZE = [[128, 128]]
242
+ SEEDS = [0]
152
243
 
153
244
  @classmethod
154
245
  def setUpClass(cls):
@@ -156,7 +247,8 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
156
247
  raise unittest.SkipTest("CUDA is not available")
157
248
  torch.set_default_device("cuda")
158
249
 
159
- def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed):
250
+ def _w8a8_block_fp8_matmul(self, M, NK, block_size, out_dtype, seed):
251
+ N, K = NK
160
252
  torch.manual_seed(seed)
161
253
  # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
162
254
  factor_for_scale = 1e-2
@@ -191,19 +283,17 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
191
283
  def test_w8a8_block_fp8_matmul(self):
192
284
  for params in itertools.product(
193
285
  self.M,
194
- self.N,
195
- self.K,
286
+ self.NKs,
196
287
  self.BLOCK_SIZE,
197
288
  self.OUT_DTYPES,
198
289
  self.SEEDS,
199
290
  ):
200
291
  with self.subTest(
201
292
  M=params[0],
202
- N=params[1],
203
- K=params[2],
204
- block_size=params[3],
205
- out_dtype=params[4],
206
- seed=params[5],
293
+ NKs=params[1],
294
+ block_size=params[2],
295
+ out_dtype=params[3],
296
+ seed=params[4],
207
297
  ):
208
298
  self._w8a8_block_fp8_matmul(*params)
209
299
 
@@ -0,0 +1,88 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/8ca7a71df787ad711ad3ac70a5bd2eb2bb398938/tests/quantization/test_fp8.py
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ from sglang.srt.custom_op import scaled_fp8_quant
7
+ from sglang.srt.utils import is_cuda
8
+
9
+
10
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
11
+ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
12
+
13
+ def quantize_ref_per_tensor(tensor, inv_scale):
14
+ # The reference implementation that fully aligns to
15
+ # the kernel being tested.
16
+ finfo = torch.finfo(torch.float8_e4m3fn)
17
+ scale = inv_scale.reciprocal()
18
+ qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
19
+ qweight = qweight.to(torch.float8_e4m3fn)
20
+ return qweight
21
+
22
+ def dequantize_per_tensor(tensor, inv_scale, dtype):
23
+ fake_qweight = tensor.to(dtype)
24
+ dq_weight = fake_qweight * inv_scale
25
+ return dq_weight
26
+
27
+ # Note that we use a shape % 8 != 0 to cover edge cases,
28
+ # because scaled_fp8_quant is vectorized by 8.
29
+ x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
30
+
31
+ # Test Per Tensor Dynamic quantization
32
+ # scale = max(abs(x)) / FP8_E4M3_MAX
33
+ y, scale = scaled_fp8_quant(x, None)
34
+ ref_y = quantize_ref_per_tensor(x, scale)
35
+ torch.testing.assert_close(y, ref_y)
36
+ torch.testing.assert_close(
37
+ dequantize_per_tensor(y, scale, dtype),
38
+ dequantize_per_tensor(ref_y, scale, dtype),
39
+ )
40
+
41
+ # Test Per Tensor Static quantization
42
+ y, _ = scaled_fp8_quant(x, scale)
43
+ ref_y = quantize_ref_per_tensor(x, scale)
44
+ torch.testing.assert_close(y, ref_y)
45
+ torch.testing.assert_close(
46
+ dequantize_per_tensor(y, scale, dtype),
47
+ dequantize_per_tensor(ref_y, scale, dtype),
48
+ )
49
+
50
+
51
+ if is_cuda:
52
+
53
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
54
+ def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
55
+ def quantize_ref_per_token(tensor, inv_scale):
56
+ # The reference implementation that fully aligns to
57
+ # the kernel being tested.
58
+ finfo = torch.finfo(torch.float8_e4m3fn)
59
+ scale = inv_scale.reciprocal()
60
+ qweight = (tensor.to(torch.float32) * scale).clamp(
61
+ min=finfo.min, max=finfo.max
62
+ )
63
+ qweight = qweight.to(torch.float8_e4m3fn)
64
+ return qweight
65
+
66
+ def dequantize_per_token(tensor, inv_scale, dtype):
67
+ fake_qweight = tensor.to(dtype)
68
+ dq_weight = fake_qweight * inv_scale
69
+ return dq_weight
70
+
71
+ # Note that we use a shape % 8 = 0,
72
+ # because per_token_quant_fp8 is vectorized by 8 elements.
73
+ x = (torch.randn(size=(11, 16), device="cuda") * 13).to(dtype)
74
+
75
+ # Test Per Tensor Dynamic quantization
76
+ # scale = max(abs(x)) / FP8_E4M3_MAX
77
+ y, scale = scaled_fp8_quant(x, None, use_per_token_if_dynamic=True)
78
+ ref_y = quantize_ref_per_token(x, scale)
79
+ torch.testing.assert_close(y, ref_y)
80
+ torch.testing.assert_close(
81
+ dequantize_per_token(y, scale, dtype),
82
+ dequantize_per_token(ref_y, scale, dtype),
83
+ )
84
+
85
+
86
+ if __name__ == "__main__":
87
+ # Run the specific test function directly
88
+ pytest.main([__file__])