sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/bench_serving.py +72 -10
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/deepseekvl2.py +10 -1
  4. sglang/srt/configs/model_config.py +6 -16
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/custom_op.py +5 -0
  7. sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
  8. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  9. sglang/srt/distributed/parallel_state.py +32 -5
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/entrypoints/http_server.py +7 -1
  12. sglang/srt/entrypoints/verl_engine.py +2 -0
  13. sglang/srt/function_call_parser.py +0 -1
  14. sglang/srt/layers/attention/flashattention_backend.py +582 -125
  15. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  17. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  18. sglang/srt/layers/dp_attention.py +12 -1
  19. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  20. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  21. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  26. sglang/srt/layers/moe/topk.py +79 -6
  27. sglang/srt/layers/quantization/__init__.py +137 -165
  28. sglang/srt/layers/quantization/awq.py +200 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  30. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  31. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  32. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  33. sglang/srt/layers/quantization/gptq.py +30 -40
  34. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  35. sglang/srt/layers/quantization/utils.py +1 -1
  36. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  37. sglang/srt/lora/backend/base_backend.py +4 -4
  38. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  39. sglang/srt/lora/backend/triton_backend.py +5 -8
  40. sglang/srt/lora/layers.py +19 -33
  41. sglang/srt/lora/lora_manager.py +20 -7
  42. sglang/srt/lora/mem_pool.py +12 -6
  43. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  44. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  45. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  46. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  47. sglang/srt/lora/utils.py +6 -0
  48. sglang/srt/managers/cache_controller.py +34 -11
  49. sglang/srt/managers/io_struct.py +4 -2
  50. sglang/srt/managers/mm_utils.py +202 -156
  51. sglang/srt/managers/multimodal_processor.py +0 -2
  52. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  53. sglang/srt/managers/multimodal_processors/clip.py +44 -0
  54. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  55. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  56. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  57. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  58. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  59. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  61. sglang/srt/managers/schedule_batch.py +185 -127
  62. sglang/srt/managers/scheduler.py +29 -23
  63. sglang/srt/managers/tokenizer_manager.py +1 -2
  64. sglang/srt/managers/tp_worker.py +3 -0
  65. sglang/srt/managers/utils.py +1 -6
  66. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  67. sglang/srt/mem_cache/memory_pool.py +72 -6
  68. sglang/srt/mem_cache/paged_allocator.py +39 -0
  69. sglang/srt/metrics/collector.py +23 -53
  70. sglang/srt/model_executor/cuda_graph_runner.py +16 -13
  71. sglang/srt/model_executor/forward_batch_info.py +10 -10
  72. sglang/srt/model_executor/model_runner.py +64 -59
  73. sglang/srt/model_loader/loader.py +19 -1
  74. sglang/srt/model_loader/weight_utils.py +6 -3
  75. sglang/srt/models/clip.py +568 -0
  76. sglang/srt/models/deepseek_janus_pro.py +12 -17
  77. sglang/srt/models/deepseek_v2.py +339 -123
  78. sglang/srt/models/deepseek_vl2.py +105 -104
  79. sglang/srt/models/gemma3_causal.py +12 -2
  80. sglang/srt/models/gemma3_mm.py +20 -80
  81. sglang/srt/models/llama.py +4 -1
  82. sglang/srt/models/llava.py +31 -19
  83. sglang/srt/models/llavavid.py +16 -7
  84. sglang/srt/models/minicpmo.py +63 -147
  85. sglang/srt/models/minicpmv.py +17 -27
  86. sglang/srt/models/mllama.py +29 -14
  87. sglang/srt/models/qwen2.py +9 -6
  88. sglang/srt/models/qwen2_5_vl.py +21 -31
  89. sglang/srt/models/qwen2_vl.py +20 -21
  90. sglang/srt/openai_api/adapter.py +106 -93
  91. sglang/srt/openai_api/protocol.py +10 -5
  92. sglang/srt/patch_torch.py +71 -0
  93. sglang/srt/platforms/interface.py +371 -0
  94. sglang/srt/server_args.py +120 -25
  95. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  96. sglang/srt/speculative/eagle_utils.py +140 -28
  97. sglang/srt/speculative/eagle_worker.py +94 -25
  98. sglang/srt/utils.py +137 -51
  99. sglang/test/runners.py +27 -2
  100. sglang/test/test_custom_ops.py +55 -0
  101. sglang/test/test_utils.py +14 -27
  102. sglang/utils.py +2 -2
  103. sglang/version.py +1 -1
  104. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
  105. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
  106. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  107. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  108. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py CHANGED
@@ -12,7 +12,6 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """Common utilities."""
15
-
16
15
  import base64
17
16
  import builtins
18
17
  import ctypes
@@ -35,8 +34,10 @@ import sys
35
34
  import tempfile
36
35
  import threading
37
36
  import time
37
+ import traceback
38
38
  import warnings
39
39
  from contextlib import contextmanager
40
+ from enum import Enum
40
41
  from functools import lru_cache
41
42
  from importlib.metadata import PackageNotFoundError, version
42
43
  from importlib.util import find_spec
@@ -53,6 +54,7 @@ import torch.distributed
53
54
  import torch.distributed as dist
54
55
  import triton
55
56
  import zmq
57
+ from decord import VideoReader, cpu
56
58
  from fastapi.responses import ORJSONResponse
57
59
  from packaging import version as pkg_version
58
60
  from PIL import Image
@@ -261,7 +263,7 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
261
263
  When distributed is True, the available memory is the minimum available memory of all GPUs.
262
264
  """
263
265
  if device == "cuda":
264
- num_gpus = cuda_device_count_stateless()
266
+ num_gpus = torch.cuda.device_count()
265
267
  assert gpu_id < num_gpus
266
268
 
267
269
  if torch.cuda.current_device() != gpu_id:
@@ -512,13 +514,18 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
512
514
  import soundfile as sf
513
515
  from scipy.signal import resample
514
516
 
515
- # print(f"loading {audio_file}")
516
517
  # Load audio data
517
518
  if isinstance(audio_file, bytes):
518
519
  audio, original_sr = sf.read(BytesIO(audio_file))
519
520
  elif audio_file.startswith("data:"):
520
521
  audio_file = audio_file.split(",")[1]
521
522
  audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
523
+ elif audio_file.startswith("http://") or audio_file.startswith("https://"):
524
+ timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
525
+ response = requests.get(audio_file, stream=True, timeout=timeout)
526
+ audio_file = BytesIO(response.content)
527
+ response.close()
528
+ audio, original_sr = sf.read(audio_file)
522
529
  elif isinstance(audio_file, str):
523
530
  audio, original_sr = sf.read(audio_file)
524
531
  else:
@@ -536,10 +543,38 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
536
543
  return audio
537
544
 
538
545
 
539
- def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
540
- image = image_size = None
546
+ def encode_video(video_path, frame_count_limit=None):
547
+ if not os.path.exists(video_path):
548
+ logger.error(f"Video {video_path} does not exist")
549
+ return []
550
+
551
+ if frame_count_limit == 0:
552
+ return []
553
+
554
+ def uniform_sample(l, n):
555
+ gap = len(l) / n
556
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
557
+ return [l[i] for i in idxs]
558
+
559
+ vr = VideoReader(video_path, ctx=cpu(0))
560
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
561
+ frame_indices = [i for i in range(0, len(vr), sample_fps)]
562
+ if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
563
+ frame_indices = uniform_sample(frame_indices, frame_count_limit)
564
+
565
+ frames = vr.get_batch(frame_indices).asnumpy()
566
+ frames = [Image.fromarray(v.astype("uint8")) for v in frames]
567
+ return frames
541
568
 
542
- if isinstance(image_file, bytes):
569
+
570
+ def load_image(
571
+ image_file: Union[Image.Image, str, bytes]
572
+ ) -> tuple[Image.Image, tuple[int, int]]:
573
+ image = image_size = None
574
+ if isinstance(image_file, Image.Image):
575
+ image = image_file
576
+ image_size = (image.width, image.height)
577
+ elif isinstance(image_file, bytes):
543
578
  image = Image.open(BytesIO(image_file))
544
579
  elif image_file.startswith("http://") or image_file.startswith("https://"):
545
580
  timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
@@ -563,6 +598,10 @@ def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
563
598
 
564
599
 
565
600
  def suppress_other_loggers():
601
+ warnings.filterwarnings(
602
+ "ignore", category=UserWarning, message="The given NumPy array is not writable"
603
+ )
604
+
566
605
  try:
567
606
  from vllm.logger import logger as vllm_default_logger
568
607
  except ImportError:
@@ -577,10 +616,6 @@ def suppress_other_loggers():
577
616
  )
578
617
  logging.getLogger("vllm.config").setLevel(logging.ERROR)
579
618
 
580
- warnings.filterwarnings(
581
- "ignore", category=UserWarning, message="The given NumPy array is not writable"
582
- )
583
-
584
619
 
585
620
  def assert_pkg_version(pkg: str, min_version: str, message: str):
586
621
  try:
@@ -1381,47 +1416,6 @@ def disable_request_logging() -> bool:
1381
1416
  return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING")
1382
1417
 
1383
1418
 
1384
- @lru_cache(maxsize=8)
1385
- def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
1386
- # Note: cuda_visible_devices is not used, but we keep it as an argument for
1387
- # LRU Cache purposes.
1388
-
1389
- # Code below is based on
1390
- # https://github.com/pytorch/pytorch/blob/
1391
- # c1cd946818442aca8c7f812b16d187ce1586c3bc/
1392
- # torch/cuda/__init__.py#L831C1-L831C17
1393
- import torch.version
1394
-
1395
- if not torch.cuda._is_compiled():
1396
- return 0
1397
- if is_hip():
1398
- # ROCm uses amdsmi instead of nvml for stateless device count
1399
- # This requires a sufficiently modern version of Torch 2.4.0
1400
- raw_count = (
1401
- torch.cuda._device_count_amdsmi()
1402
- if (hasattr(torch.cuda, "_device_count_amdsmi"))
1403
- else -1
1404
- )
1405
- else:
1406
- raw_count = torch.cuda._device_count_nvml()
1407
- r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
1408
- return r
1409
-
1410
-
1411
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/utils.py
1412
- def cuda_device_count_stateless() -> int:
1413
- """Get number of CUDA devices, caching based on the value of
1414
- CUDA_VISIBLE_DEVICES at the time of call.
1415
-
1416
- This should be used instead of torch.cuda.device_count()
1417
- unless CUDA_VISIBLE_DEVICES has already been set to the desired
1418
- value."""
1419
-
1420
- # This can be removed and simply replaced with torch.cuda.get_device_count
1421
- # after https://github.com/pytorch/pytorch/pull/122815 is released.
1422
- return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
1423
-
1424
-
1425
1419
  def dataclass_to_string_truncated(
1426
1420
  data, max_length=2048, skip_names: Optional[Set[str]] = None
1427
1421
  ):
@@ -1602,6 +1596,7 @@ def get_ip() -> str:
1602
1596
  def get_open_port() -> int:
1603
1597
  port = os.getenv("SGLANG_PORT")
1604
1598
  if port is not None:
1599
+ port = int(port)
1605
1600
  while True:
1606
1601
  try:
1607
1602
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -1630,6 +1625,38 @@ def is_valid_ipv6_address(address: str) -> bool:
1630
1625
  return False
1631
1626
 
1632
1627
 
1628
+ def configure_ipv6(dist_init_addr):
1629
+ addr = dist_init_addr
1630
+ end = addr.find("]")
1631
+ if end == -1:
1632
+ raise ValueError("invalid IPv6 address format: missing ']'")
1633
+
1634
+ host = addr[: end + 1]
1635
+
1636
+ # this only validates the address without brackets: we still need the below checks.
1637
+ # if it's invalid, immediately raise an error so we know it's not formatting issues.
1638
+ if not is_valid_ipv6_address(host[1:end]):
1639
+ raise ValueError(f"invalid IPv6 address: {host}")
1640
+
1641
+ port_str = None
1642
+ if len(addr) > end + 1:
1643
+ if addr[end + 1] == ":":
1644
+ port_str = addr[end + 2 :]
1645
+ else:
1646
+ raise ValueError("received IPv6 address format: expected ':' after ']'")
1647
+
1648
+ if not port_str:
1649
+ raise ValueError(
1650
+ "a port must be specified in IPv6 address (format: [ipv6]:port)"
1651
+ )
1652
+
1653
+ try:
1654
+ port = int(port_str)
1655
+ except ValueError:
1656
+ raise ValueError(f"invalid port in IPv6 address: '{port_str}'")
1657
+ return port, host
1658
+
1659
+
1633
1660
  def rank0_print(msg: str):
1634
1661
  from sglang.srt.distributed import get_tensor_model_parallel_rank
1635
1662
 
@@ -1733,3 +1760,62 @@ def parse_connector_type(url: str) -> str:
1733
1760
  return ""
1734
1761
 
1735
1762
  return m.group(1)
1763
+
1764
+
1765
+ def retry(
1766
+ fn,
1767
+ max_retry: int,
1768
+ initial_delay: float = 2.0,
1769
+ max_delay: float = 60.0,
1770
+ should_retry: Callable[[Any], bool] = lambda e: True,
1771
+ ):
1772
+ for try_index in itertools.count():
1773
+ try:
1774
+ return fn()
1775
+ except Exception as e:
1776
+ if try_index >= max_retry:
1777
+ raise Exception(f"retry() exceed maximum number of retries.")
1778
+
1779
+ if not should_retry(e):
1780
+ raise Exception(f"retry() observe errors that should not be retried.")
1781
+
1782
+ delay = min(initial_delay * (2**try_index), max_delay) * (
1783
+ 0.75 + 0.25 * random.random()
1784
+ )
1785
+
1786
+ logger.warning(
1787
+ f"retry() failed once ({try_index}th try, maximum {max_retry} retries). Will delay {delay:.2f}s and retry. Error: {e}"
1788
+ )
1789
+ traceback.print_exc()
1790
+
1791
+ time.sleep(delay)
1792
+
1793
+
1794
+ def flatten_nested_list(nested_list):
1795
+ if isinstance(nested_list, list):
1796
+ return [
1797
+ item for sublist in nested_list for item in flatten_nested_list(sublist)
1798
+ ]
1799
+ else:
1800
+ return [nested_list]
1801
+
1802
+
1803
+ class DeepEPMode(Enum):
1804
+ normal = "normal"
1805
+ low_latency = "low_latency"
1806
+ auto = "auto"
1807
+
1808
+ def enable_normal(self):
1809
+ return self in [DeepEPMode.normal, DeepEPMode.auto]
1810
+
1811
+ def enable_low_latency(self):
1812
+ return self in [DeepEPMode.low_latency, DeepEPMode.auto]
1813
+
1814
+ def resolve(self, forward_mode):
1815
+ if self != DeepEPMode.auto:
1816
+ return self
1817
+
1818
+ if forward_mode.is_decode():
1819
+ return DeepEPMode.low_latency
1820
+ else:
1821
+ return DeepEPMode.normal
sglang/test/runners.py CHANGED
@@ -19,10 +19,16 @@ from typing import List, Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
  import torch.nn.functional as F
22
- from transformers import AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor
22
+ from transformers import (
23
+ AutoModel,
24
+ AutoModelForCausalLM,
25
+ AutoModelForVision2Seq,
26
+ AutoProcessor,
27
+ )
23
28
 
24
29
  from sglang.srt.hf_transformers_utils import get_tokenizer
25
30
  from sglang.srt.server import Engine
31
+ from sglang.srt.utils import load_image
26
32
  from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
27
33
 
28
34
  DEFAULT_PROMPTS = [
@@ -140,7 +146,6 @@ class HFRunner:
140
146
  def _get_gme_qwen2_vl_embeddings(
141
147
  self, prompts, image_data: Optional[List[str]] = None
142
148
  ):
143
- from sglang.srt.utils import load_image
144
149
 
145
150
  images = None
146
151
  if image_data is not None:
@@ -226,6 +231,9 @@ class HFRunner:
226
231
  low_cpu_mem_usage=True,
227
232
  ).cuda()
228
233
  self.processor = AutoProcessor.from_pretrained(model_path)
234
+ elif "clip" in model_path.lower():
235
+ self.model = AutoModel.from_pretrained(model_path).cuda()
236
+ self.processor = AutoProcessor.from_pretrained(model_path)
229
237
  else:
230
238
  self.model = _get_sentence_transformer_embedding_model(
231
239
  model_path, torch_dtype
@@ -272,6 +280,23 @@ class HFRunner:
272
280
  assert not self.output_str_only
273
281
  if "gme-qwen2-vl" in model_path.lower():
274
282
  logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data)
283
+ elif "clip" in model_path.lower():
284
+ if image_data is not None:
285
+ image = load_image(image_data)
286
+ inputs = self.processor(
287
+ images=image[0], return_tensors="pt"
288
+ )
289
+ logits = self.model.get_image_features(
290
+ pixel_values=inputs.data["pixel_values"].cuda(),
291
+ ).tolist()
292
+ else:
293
+ inputs = self.tokenizer(
294
+ prompts, padding=True, return_tensors="pt"
295
+ )
296
+ logits = self.model.get_text_features(
297
+ input_ids=inputs.data["input_ids"].cuda(),
298
+ attention_mask=inputs.data["attention_mask"].cuda(),
299
+ ).tolist()
275
300
  else:
276
301
  logits = self.model.encode(prompts).tolist()
277
302
  out_queue.put(ModelOutput(embed_logits=logits))
@@ -82,6 +82,61 @@ if is_cuda:
82
82
  dequantize_per_token(ref_y, scale, dtype),
83
83
  )
84
84
 
85
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
86
+ def test_scaled_fp8_quant_with_padding(dtype) -> None:
87
+ original_rows = 5
88
+ x = (torch.randn(size=(original_rows, 16), device="cuda") * 13).to(dtype)
89
+
90
+ padding_size = 10
91
+
92
+ # Test with dynamic quantization
93
+ y_dynamic, scale_dynamic = scaled_fp8_quant(
94
+ x, None, num_token_padding=padding_size
95
+ )
96
+
97
+ # Verify output shape has the padded size
98
+ assert y_dynamic.shape[0] == padding_size
99
+ assert y_dynamic.shape[1] == x.shape[1]
100
+
101
+ # Verify that the actual data in the non-padded region is correctly quantized
102
+ y_without_padding, scale_without_padding = scaled_fp8_quant(x, None)
103
+ torch.testing.assert_close(y_dynamic[:original_rows], y_without_padding)
104
+
105
+ # Test with static quantization
106
+ # First get a scale
107
+ _, scale = scaled_fp8_quant(x, None)
108
+
109
+ # Then use it for static quantization with padding
110
+ y_static, _ = scaled_fp8_quant(x, scale, num_token_padding=padding_size)
111
+
112
+ # Verify output shape has the padded size
113
+ assert y_static.shape[0] == padding_size
114
+ assert y_static.shape[1] == x.shape[1]
115
+
116
+ # Verify that the actual data in the non-padded region is correctly quantized
117
+ y_static_without_padding, _ = scaled_fp8_quant(x, scale)
118
+ torch.testing.assert_close(y_static[:original_rows], y_static_without_padding)
119
+
120
+ # Test with per-token dynamic quantization
121
+ y_per_token, scale_per_token = scaled_fp8_quant(
122
+ x, None, num_token_padding=padding_size, use_per_token_if_dynamic=True
123
+ )
124
+
125
+ # Verify output shape has the padded size
126
+ assert y_per_token.shape[0] == padding_size
127
+ assert y_per_token.shape[1] == x.shape[1]
128
+
129
+ # Verify that the actual data in the non-padded region is correctly quantized
130
+ y_per_token_without_padding, scale_per_token_without_padding = scaled_fp8_quant(
131
+ x, None, use_per_token_if_dynamic=True
132
+ )
133
+ torch.testing.assert_close(
134
+ y_per_token[:original_rows], y_per_token_without_padding
135
+ )
136
+ torch.testing.assert_close(
137
+ scale_per_token[:original_rows], scale_per_token_without_padding
138
+ )
139
+
85
140
 
86
141
  if __name__ == "__main__":
87
142
  # Run the specific test function directly
sglang/test/test_utils.py CHANGED
@@ -25,11 +25,11 @@ from sglang.bench_serving import run_benchmark
25
25
  from sglang.global_config import global_config
26
26
  from sglang.lang.backend.openai import OpenAI
27
27
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
28
- from sglang.srt.utils import get_bool_env_var, kill_process_tree
28
+ from sglang.srt.utils import get_bool_env_var, kill_process_tree, retry
29
29
  from sglang.test.run_eval import run_eval
30
30
  from sglang.utils import get_exception_traceback
31
31
 
32
- DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8"
32
+ DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
33
33
  DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST = "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
34
34
  DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST = (
35
35
  "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic"
@@ -76,11 +76,14 @@ def is_in_ci():
76
76
 
77
77
 
78
78
  if is_in_ci():
79
- DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157
80
- DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157"
79
+ DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
80
+ 5000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
81
+ )
81
82
  else:
82
- DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 1157
83
- DEFAULT_URL_FOR_TEST = "http://127.0.0.1:2157"
83
+ DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
84
+ 7000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
85
+ )
86
+ DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"
84
87
 
85
88
 
86
89
  def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
@@ -1010,26 +1013,10 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
1010
1013
 
1011
1014
  class CustomTestCase(unittest.TestCase):
1012
1015
  def _callTestMethod(self, method):
1013
- _retry_execution(
1014
- lambda: super(CustomTestCase, self)._callTestMethod(method),
1015
- max_retry=_get_max_retry(),
1016
+ max_retry = int(
1017
+ os.environ.get("SGLANG_TEST_MAX_RETRY", "1" if is_in_ci() else "0")
1016
1018
  )
1017
-
1018
-
1019
- def _get_max_retry():
1020
- return int(os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0"))
1021
-
1022
-
1023
- def _retry_execution(fn, max_retry: int):
1024
- if max_retry == 0:
1025
- fn()
1026
- return
1027
-
1028
- try:
1029
- fn()
1030
- except Exception as e:
1031
- print(
1032
- f"retry_execution failed once and will retry. This may be an error or a flaky test. Error: {e}"
1019
+ retry(
1020
+ lambda: super(CustomTestCase, self)._callTestMethod(method),
1021
+ max_retry=max_retry,
1033
1022
  )
1034
- traceback.print_exc()
1035
- _retry_execution(fn, max_retry=max_retry - 1)
sglang/utils.py CHANGED
@@ -25,8 +25,6 @@ from IPython.display import HTML, display
25
25
  from pydantic import BaseModel
26
26
  from tqdm import tqdm
27
27
 
28
- from sglang.srt.utils import kill_process_tree
29
-
30
28
  logger = logging.getLogger(__name__)
31
29
 
32
30
 
@@ -422,6 +420,8 @@ def terminate_process(process):
422
420
  """
423
421
  Terminate the process and automatically release the reserved port.
424
422
  """
423
+ from sglang.srt.utils import kill_process_tree
424
+
425
425
  kill_process_tree(process.pid)
426
426
 
427
427
  lock_socket = process_socket_map.pop(process, None)
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.4.post2"
1
+ __version__ = "0.4.4.post4"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sglang
3
- Version: 0.4.4.post2
3
+ Version: 0.4.4.post4
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -218,6 +218,7 @@ Requires-Dist: numpy
218
218
  Requires-Dist: IPython
219
219
  Requires-Dist: setproctitle
220
220
  Provides-Extra: runtime-common
221
+ Requires-Dist: compressed-tensors; extra == "runtime-common"
221
222
  Requires-Dist: datasets; extra == "runtime-common"
222
223
  Requires-Dist: decord; extra == "runtime-common"
223
224
  Requires-Dist: fastapi; extra == "runtime-common"
@@ -233,21 +234,25 @@ Requires-Dist: pillow; extra == "runtime-common"
233
234
  Requires-Dist: prometheus-client>=0.20.0; extra == "runtime-common"
234
235
  Requires-Dist: psutil; extra == "runtime-common"
235
236
  Requires-Dist: pydantic; extra == "runtime-common"
237
+ Requires-Dist: pynvml; extra == "runtime-common"
236
238
  Requires-Dist: python-multipart; extra == "runtime-common"
237
239
  Requires-Dist: pyzmq>=25.1.2; extra == "runtime-common"
238
240
  Requires-Dist: soundfile==0.13.1; extra == "runtime-common"
239
241
  Requires-Dist: torchao>=0.7.0; extra == "runtime-common"
240
- Requires-Dist: transformers==4.50.0; extra == "runtime-common"
242
+ Requires-Dist: transformers==4.51.0; extra == "runtime-common"
241
243
  Requires-Dist: uvicorn; extra == "runtime-common"
242
244
  Requires-Dist: uvloop; extra == "runtime-common"
243
- Requires-Dist: xgrammar==0.1.16; extra == "runtime-common"
245
+ Requires-Dist: compressed-tensors; extra == "runtime-common"
246
+ Requires-Dist: xgrammar==0.1.17; extra == "runtime-common"
244
247
  Provides-Extra: srt
245
248
  Requires-Dist: sglang[runtime_common]; extra == "srt"
246
- Requires-Dist: sgl-kernel==0.0.5.post3; extra == "srt"
249
+ Requires-Dist: sgl-kernel==0.0.8; extra == "srt"
247
250
  Requires-Dist: flashinfer_python==0.2.3; extra == "srt"
248
251
  Requires-Dist: torch==2.5.1; extra == "srt"
249
252
  Requires-Dist: cuda-python; extra == "srt"
250
253
  Requires-Dist: outlines<=0.1.11,>=0.0.44; extra == "srt"
254
+ Requires-Dist: partial_json_parser; extra == "srt"
255
+ Requires-Dist: einops; extra == "srt"
251
256
  Provides-Extra: srt-hip
252
257
  Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
253
258
  Requires-Dist: torch; extra == "srt-hip"
@@ -271,7 +276,7 @@ Requires-Dist: anthropic>=0.20.0; extra == "anthropic"
271
276
  Provides-Extra: litellm
272
277
  Requires-Dist: litellm>=1.0.0; extra == "litellm"
273
278
  Provides-Extra: torch-memory-saver
274
- Requires-Dist: torch_memory_saver>=0.0.3; extra == "torch-memory-saver"
279
+ Requires-Dist: torch_memory_saver>=0.0.4; extra == "torch-memory-saver"
275
280
  Provides-Extra: test
276
281
  Requires-Dist: jsonlines; extra == "test"
277
282
  Requires-Dist: matplotlib; extra == "test"