sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/model_config.py +1 -0
  4. sglang/srt/constrained/base_grammar_backend.py +5 -1
  5. sglang/srt/custom_op.py +5 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  7. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  8. sglang/srt/entrypoints/engine.py +0 -5
  9. sglang/srt/layers/attention/flashattention_backend.py +394 -76
  10. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  11. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  12. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  13. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  14. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  15. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  20. sglang/srt/layers/moe/topk.py +49 -3
  21. sglang/srt/layers/quantization/__init__.py +4 -1
  22. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  23. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  24. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  25. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  26. sglang/srt/layers/quantization/utils.py +1 -1
  27. sglang/srt/layers/rotary_embedding.py +0 -12
  28. sglang/srt/managers/cache_controller.py +34 -11
  29. sglang/srt/managers/mm_utils.py +202 -156
  30. sglang/srt/managers/multimodal_processor.py +0 -2
  31. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  32. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  33. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  34. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  35. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  36. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  37. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  38. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  40. sglang/srt/managers/schedule_batch.py +185 -128
  41. sglang/srt/managers/scheduler.py +4 -4
  42. sglang/srt/managers/tokenizer_manager.py +1 -1
  43. sglang/srt/managers/utils.py +1 -6
  44. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  45. sglang/srt/mem_cache/memory_pool.py +72 -6
  46. sglang/srt/mem_cache/paged_allocator.py +39 -0
  47. sglang/srt/metrics/collector.py +23 -53
  48. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  49. sglang/srt/model_executor/forward_batch_info.py +10 -10
  50. sglang/srt/model_executor/model_runner.py +59 -57
  51. sglang/srt/model_loader/loader.py +8 -0
  52. sglang/srt/models/clip.py +12 -7
  53. sglang/srt/models/deepseek_janus_pro.py +10 -15
  54. sglang/srt/models/deepseek_v2.py +212 -121
  55. sglang/srt/models/deepseek_vl2.py +105 -104
  56. sglang/srt/models/gemma3_mm.py +14 -80
  57. sglang/srt/models/llama.py +4 -1
  58. sglang/srt/models/llava.py +31 -19
  59. sglang/srt/models/llavavid.py +16 -7
  60. sglang/srt/models/minicpmo.py +63 -147
  61. sglang/srt/models/minicpmv.py +17 -27
  62. sglang/srt/models/mllama.py +29 -14
  63. sglang/srt/models/qwen2.py +9 -6
  64. sglang/srt/models/qwen2_5_vl.py +21 -31
  65. sglang/srt/models/qwen2_vl.py +20 -21
  66. sglang/srt/openai_api/adapter.py +18 -6
  67. sglang/srt/platforms/interface.py +371 -0
  68. sglang/srt/server_args.py +99 -14
  69. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  70. sglang/srt/speculative/eagle_utils.py +140 -28
  71. sglang/srt/speculative/eagle_worker.py +93 -24
  72. sglang/srt/utils.py +104 -51
  73. sglang/test/test_custom_ops.py +55 -0
  74. sglang/test/test_utils.py +13 -26
  75. sglang/utils.py +2 -2
  76. sglang/version.py +1 -1
  77. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
  78. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
  79. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  80. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  81. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,11 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
11
11
  from sglang.srt.layers.dp_attention import disable_dp_size
12
12
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
13
13
  from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
14
- from sglang.srt.managers.schedule_batch import ScheduleBatch
14
+ from sglang.srt.managers.schedule_batch import (
15
+ ScheduleBatch,
16
+ get_last_loc,
17
+ global_server_args_dict,
18
+ )
15
19
  from sglang.srt.managers.tp_worker import TpModelWorker
16
20
  from sglang.srt.model_executor.forward_batch_info import (
17
21
  CaptureHiddenMode,
@@ -67,6 +71,7 @@ class EAGLEWorker(TpModelWorker):
67
71
  self.gpu_id = gpu_id
68
72
  self.device = server_args.device
69
73
  self.target_worker = target_worker
74
+ self.page_size = server_args.page_size
70
75
  self.speculative_algorithm = SpeculativeAlgorithm.from_string(
71
76
  server_args.speculative_algorithm
72
77
  )
@@ -145,15 +150,26 @@ class EAGLEWorker(TpModelWorker):
145
150
  def init_attention_backend(self):
146
151
  # Create multi-step attn backends and cuda graph runners
147
152
  if self.server_args.attention_backend == "flashinfer":
148
- from sglang.srt.layers.attention.flashinfer_backend import (
149
- FlashInferMultiStepDraftBackend,
150
- )
153
+ if not global_server_args_dict["use_mla_backend"]:
154
+ from sglang.srt.layers.attention.flashinfer_backend import (
155
+ FlashInferMultiStepDraftBackend,
156
+ )
151
157
 
152
- self.draft_attn_backend = FlashInferMultiStepDraftBackend(
153
- self.draft_model_runner,
154
- self.topk,
155
- self.speculative_num_steps,
156
- )
158
+ self.draft_attn_backend = FlashInferMultiStepDraftBackend(
159
+ self.draft_model_runner,
160
+ self.topk,
161
+ self.speculative_num_steps,
162
+ )
163
+ else:
164
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
165
+ FlashInferMLAMultiStepDraftBackend,
166
+ )
167
+
168
+ self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
169
+ self.draft_model_runner,
170
+ self.topk,
171
+ self.speculative_num_steps,
172
+ )
157
173
  self.draft_extend_attn_backend = None
158
174
  self.padded_static_len = self.speculative_num_steps + 1
159
175
  self.has_prefill_wrapper_verify = True
@@ -170,19 +186,19 @@ class EAGLEWorker(TpModelWorker):
170
186
  self.draft_extend_attn_backend = None
171
187
  self.padded_static_len = self.speculative_num_steps + 1
172
188
  self.has_prefill_wrapper_verify = False
173
- elif self.server_args.attention_backend == "flashinfer_mla":
174
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
175
- FlashInferMLAMultiStepDraftBackend,
189
+ elif self.server_args.attention_backend == "fa3":
190
+ from sglang.srt.layers.attention.flashattention_backend import (
191
+ FlashAttentionMultiStepBackend,
176
192
  )
177
193
 
178
- self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
194
+ self.draft_attn_backend = FlashAttentionMultiStepBackend(
179
195
  self.draft_model_runner,
180
196
  self.topk,
181
197
  self.speculative_num_steps,
182
198
  )
183
199
  self.draft_extend_attn_backend = None
184
200
  self.padded_static_len = self.speculative_num_steps + 1
185
- self.has_prefill_wrapper_verify = True
201
+ self.has_prefill_wrapper_verify = False
186
202
  else:
187
203
  raise ValueError(
188
204
  f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
@@ -234,14 +250,11 @@ class EAGLEWorker(TpModelWorker):
234
250
  """
235
251
  if batch.forward_mode.is_decode():
236
252
  with self.draft_tp_context(self.draft_model_runner.tp_group):
237
- spec_info, to_free_cache_loc = self.draft(batch)
253
+ spec_info = self.draft(batch)
238
254
  logits_output, verify_output, model_worker_batch = self.verify(
239
255
  batch, spec_info
240
256
  )
241
257
 
242
- # Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
243
- self.token_to_kv_pool_allocator.free(to_free_cache_loc)
244
-
245
258
  # If it is None, it means all requests are finished
246
259
  if batch.spec_info.verified_id is not None:
247
260
  with self.draft_tp_context(self.draft_model_runner.tp_group):
@@ -305,9 +318,59 @@ class EAGLEWorker(TpModelWorker):
305
318
  )
306
319
 
307
320
  # Allocate cache locations
308
- out_cache_loc = batch.alloc_token_slots(
309
- num_seqs * self.topk * self.speculative_num_steps
310
- )
321
+ if self.page_size == 1:
322
+ out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
323
+ num_seqs * self.topk * self.speculative_num_steps, backup_state=True
324
+ )
325
+ else:
326
+ if self.topk == 1:
327
+ prefix_lens = batch.seq_lens
328
+ seq_lens = prefix_lens + self.speculative_num_steps
329
+ extend_num_tokens = num_seqs * self.speculative_num_steps
330
+ else:
331
+ # In this case, the last partial page needs to be duplicated.
332
+ # KV cache layout in batch.req_to_token_pool.req_to_token:
333
+ #
334
+ # | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
335
+ # prefix top-k = 0 tok-k = 1 top-k = 2
336
+ #
337
+ # "-" means prefix tokens
338
+ # "x" means speculative draft tokens
339
+ # "." means padded tokens
340
+
341
+ # TODO: fuse these ops
342
+ prefix_lens = batch.seq_lens
343
+ last_page_lens = prefix_lens % self.page_size
344
+ num_new_pages = (
345
+ last_page_lens + self.speculative_num_steps + self.page_size - 1
346
+ ) // self.page_size
347
+ seq_lens = (
348
+ prefix_lens // self.page_size * self.page_size
349
+ + num_new_pages * (self.page_size * self.topk)
350
+ )
351
+ extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
352
+ raise NotImplementedError(
353
+ "page_size > 1 and top_k > 1 are not supported."
354
+ )
355
+ # TODO: Support page_size > 1 and top_k > 1
356
+ # 1. Duplicate the KV cache in the last partial page for all top-k segments
357
+ # 2. Modify generate_draft_decode_kv_indices accordingly
358
+
359
+ last_loc = get_last_loc(
360
+ batch.req_to_token_pool.req_to_token,
361
+ batch.req_pool_indices,
362
+ prefix_lens,
363
+ )
364
+ out_cache_loc, token_to_kv_pool_state_backup = (
365
+ batch.alloc_paged_token_slots_extend(
366
+ prefix_lens,
367
+ seq_lens,
368
+ last_loc,
369
+ extend_num_tokens,
370
+ backup_state=True,
371
+ )
372
+ )
373
+
311
374
  assign_draft_cache_locs[(num_seqs,)](
312
375
  batch.req_pool_indices,
313
376
  batch.req_to_token_pool.req_to_token,
@@ -316,6 +379,7 @@ class EAGLEWorker(TpModelWorker):
316
379
  batch.req_to_token_pool.req_to_token.shape[1],
317
380
  self.topk,
318
381
  self.speculative_num_steps,
382
+ self.page_size,
319
383
  )
320
384
  batch.out_cache_loc = out_cache_loc
321
385
  batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
@@ -343,6 +407,8 @@ class EAGLEWorker(TpModelWorker):
343
407
  # Run forward steps
344
408
  score_list, token_list, parents_list = self.draft_forward(forward_batch)
345
409
 
410
+ self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
411
+
346
412
  ret = EagleVerifyInput.create(
347
413
  spec_info.verified_id,
348
414
  score_list,
@@ -354,7 +420,7 @@ class EAGLEWorker(TpModelWorker):
354
420
  self.speculative_num_steps,
355
421
  self.server_args.speculative_num_draft_tokens,
356
422
  )
357
- return ret, out_cache_loc
423
+ return ret
358
424
 
359
425
  def draft_forward(self, forward_batch: ForwardBatch):
360
426
  # Parse args
@@ -411,7 +477,7 @@ class EAGLEWorker(TpModelWorker):
411
477
  return score_list, token_list, parents_list
412
478
 
413
479
  def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
414
- spec_info.prepare_for_verify(batch)
480
+ spec_info.prepare_for_verify(batch, self.page_size)
415
481
  batch.forward_mode = ForwardMode.TARGET_VERIFY
416
482
  batch.spec_info = spec_info
417
483
  model_worker_batch = batch.get_model_worker_batch()
@@ -421,7 +487,10 @@ class EAGLEWorker(TpModelWorker):
421
487
  self._detect_nan_if_needed(logits_output)
422
488
  spec_info.hidden_states = logits_output.hidden_states
423
489
  res: EagleVerifyOutput = spec_info.verify(
424
- batch, logits_output, self.token_to_kv_pool_allocator
490
+ batch,
491
+ logits_output,
492
+ self.token_to_kv_pool_allocator,
493
+ self.page_size,
425
494
  )
426
495
 
427
496
  # Post process based on verified outputs.
sglang/srt/utils.py CHANGED
@@ -12,7 +12,6 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """Common utilities."""
15
-
16
15
  import base64
17
16
  import builtins
18
17
  import ctypes
@@ -35,8 +34,10 @@ import sys
35
34
  import tempfile
36
35
  import threading
37
36
  import time
37
+ import traceback
38
38
  import warnings
39
39
  from contextlib import contextmanager
40
+ from enum import Enum
40
41
  from functools import lru_cache
41
42
  from importlib.metadata import PackageNotFoundError, version
42
43
  from importlib.util import find_spec
@@ -53,6 +54,7 @@ import torch.distributed
53
54
  import torch.distributed as dist
54
55
  import triton
55
56
  import zmq
57
+ from decord import VideoReader, cpu
56
58
  from fastapi.responses import ORJSONResponse
57
59
  from packaging import version as pkg_version
58
60
  from PIL import Image
@@ -261,7 +263,7 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
261
263
  When distributed is True, the available memory is the minimum available memory of all GPUs.
262
264
  """
263
265
  if device == "cuda":
264
- num_gpus = cuda_device_count_stateless()
266
+ num_gpus = torch.cuda.device_count()
265
267
  assert gpu_id < num_gpus
266
268
 
267
269
  if torch.cuda.current_device() != gpu_id:
@@ -512,13 +514,18 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
512
514
  import soundfile as sf
513
515
  from scipy.signal import resample
514
516
 
515
- # print(f"loading {audio_file}")
516
517
  # Load audio data
517
518
  if isinstance(audio_file, bytes):
518
519
  audio, original_sr = sf.read(BytesIO(audio_file))
519
520
  elif audio_file.startswith("data:"):
520
521
  audio_file = audio_file.split(",")[1]
521
522
  audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
523
+ elif audio_file.startswith("http://") or audio_file.startswith("https://"):
524
+ timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
525
+ response = requests.get(audio_file, stream=True, timeout=timeout)
526
+ audio_file = BytesIO(response.content)
527
+ response.close()
528
+ audio, original_sr = sf.read(audio_file)
522
529
  elif isinstance(audio_file, str):
523
530
  audio, original_sr = sf.read(audio_file)
524
531
  else:
@@ -536,10 +543,38 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
536
543
  return audio
537
544
 
538
545
 
539
- def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
540
- image = image_size = None
546
+ def encode_video(video_path, frame_count_limit=None):
547
+ if not os.path.exists(video_path):
548
+ logger.error(f"Video {video_path} does not exist")
549
+ return []
550
+
551
+ if frame_count_limit == 0:
552
+ return []
553
+
554
+ def uniform_sample(l, n):
555
+ gap = len(l) / n
556
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
557
+ return [l[i] for i in idxs]
558
+
559
+ vr = VideoReader(video_path, ctx=cpu(0))
560
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
561
+ frame_indices = [i for i in range(0, len(vr), sample_fps)]
562
+ if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
563
+ frame_indices = uniform_sample(frame_indices, frame_count_limit)
541
564
 
542
- if isinstance(image_file, bytes):
565
+ frames = vr.get_batch(frame_indices).asnumpy()
566
+ frames = [Image.fromarray(v.astype("uint8")) for v in frames]
567
+ return frames
568
+
569
+
570
+ def load_image(
571
+ image_file: Union[Image.Image, str, bytes]
572
+ ) -> tuple[Image.Image, tuple[int, int]]:
573
+ image = image_size = None
574
+ if isinstance(image_file, Image.Image):
575
+ image = image_file
576
+ image_size = (image.width, image.height)
577
+ elif isinstance(image_file, bytes):
543
578
  image = Image.open(BytesIO(image_file))
544
579
  elif image_file.startswith("http://") or image_file.startswith("https://"):
545
580
  timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
@@ -563,6 +598,10 @@ def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
563
598
 
564
599
 
565
600
  def suppress_other_loggers():
601
+ warnings.filterwarnings(
602
+ "ignore", category=UserWarning, message="The given NumPy array is not writable"
603
+ )
604
+
566
605
  try:
567
606
  from vllm.logger import logger as vllm_default_logger
568
607
  except ImportError:
@@ -577,10 +616,6 @@ def suppress_other_loggers():
577
616
  )
578
617
  logging.getLogger("vllm.config").setLevel(logging.ERROR)
579
618
 
580
- warnings.filterwarnings(
581
- "ignore", category=UserWarning, message="The given NumPy array is not writable"
582
- )
583
-
584
619
 
585
620
  def assert_pkg_version(pkg: str, min_version: str, message: str):
586
621
  try:
@@ -1381,47 +1416,6 @@ def disable_request_logging() -> bool:
1381
1416
  return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING")
1382
1417
 
1383
1418
 
1384
- @lru_cache(maxsize=8)
1385
- def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
1386
- # Note: cuda_visible_devices is not used, but we keep it as an argument for
1387
- # LRU Cache purposes.
1388
-
1389
- # Code below is based on
1390
- # https://github.com/pytorch/pytorch/blob/
1391
- # c1cd946818442aca8c7f812b16d187ce1586c3bc/
1392
- # torch/cuda/__init__.py#L831C1-L831C17
1393
- import torch.version
1394
-
1395
- if not torch.cuda._is_compiled():
1396
- return 0
1397
- if is_hip():
1398
- # ROCm uses amdsmi instead of nvml for stateless device count
1399
- # This requires a sufficiently modern version of Torch 2.4.0
1400
- raw_count = (
1401
- torch.cuda._device_count_amdsmi()
1402
- if (hasattr(torch.cuda, "_device_count_amdsmi"))
1403
- else -1
1404
- )
1405
- else:
1406
- raw_count = torch.cuda._device_count_nvml()
1407
- r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
1408
- return r
1409
-
1410
-
1411
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/utils.py
1412
- def cuda_device_count_stateless() -> int:
1413
- """Get number of CUDA devices, caching based on the value of
1414
- CUDA_VISIBLE_DEVICES at the time of call.
1415
-
1416
- This should be used instead of torch.cuda.device_count()
1417
- unless CUDA_VISIBLE_DEVICES has already been set to the desired
1418
- value."""
1419
-
1420
- # This can be removed and simply replaced with torch.cuda.get_device_count
1421
- # after https://github.com/pytorch/pytorch/pull/122815 is released.
1422
- return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
1423
-
1424
-
1425
1419
  def dataclass_to_string_truncated(
1426
1420
  data, max_length=2048, skip_names: Optional[Set[str]] = None
1427
1421
  ):
@@ -1766,3 +1760,62 @@ def parse_connector_type(url: str) -> str:
1766
1760
  return ""
1767
1761
 
1768
1762
  return m.group(1)
1763
+
1764
+
1765
+ def retry(
1766
+ fn,
1767
+ max_retry: int,
1768
+ initial_delay: float = 2.0,
1769
+ max_delay: float = 60.0,
1770
+ should_retry: Callable[[Any], bool] = lambda e: True,
1771
+ ):
1772
+ for try_index in itertools.count():
1773
+ try:
1774
+ return fn()
1775
+ except Exception as e:
1776
+ if try_index >= max_retry:
1777
+ raise Exception(f"retry() exceed maximum number of retries.")
1778
+
1779
+ if not should_retry(e):
1780
+ raise Exception(f"retry() observe errors that should not be retried.")
1781
+
1782
+ delay = min(initial_delay * (2**try_index), max_delay) * (
1783
+ 0.75 + 0.25 * random.random()
1784
+ )
1785
+
1786
+ logger.warning(
1787
+ f"retry() failed once ({try_index}th try, maximum {max_retry} retries). Will delay {delay:.2f}s and retry. Error: {e}"
1788
+ )
1789
+ traceback.print_exc()
1790
+
1791
+ time.sleep(delay)
1792
+
1793
+
1794
+ def flatten_nested_list(nested_list):
1795
+ if isinstance(nested_list, list):
1796
+ return [
1797
+ item for sublist in nested_list for item in flatten_nested_list(sublist)
1798
+ ]
1799
+ else:
1800
+ return [nested_list]
1801
+
1802
+
1803
+ class DeepEPMode(Enum):
1804
+ normal = "normal"
1805
+ low_latency = "low_latency"
1806
+ auto = "auto"
1807
+
1808
+ def enable_normal(self):
1809
+ return self in [DeepEPMode.normal, DeepEPMode.auto]
1810
+
1811
+ def enable_low_latency(self):
1812
+ return self in [DeepEPMode.low_latency, DeepEPMode.auto]
1813
+
1814
+ def resolve(self, forward_mode):
1815
+ if self != DeepEPMode.auto:
1816
+ return self
1817
+
1818
+ if forward_mode.is_decode():
1819
+ return DeepEPMode.low_latency
1820
+ else:
1821
+ return DeepEPMode.normal
@@ -82,6 +82,61 @@ if is_cuda:
82
82
  dequantize_per_token(ref_y, scale, dtype),
83
83
  )
84
84
 
85
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
86
+ def test_scaled_fp8_quant_with_padding(dtype) -> None:
87
+ original_rows = 5
88
+ x = (torch.randn(size=(original_rows, 16), device="cuda") * 13).to(dtype)
89
+
90
+ padding_size = 10
91
+
92
+ # Test with dynamic quantization
93
+ y_dynamic, scale_dynamic = scaled_fp8_quant(
94
+ x, None, num_token_padding=padding_size
95
+ )
96
+
97
+ # Verify output shape has the padded size
98
+ assert y_dynamic.shape[0] == padding_size
99
+ assert y_dynamic.shape[1] == x.shape[1]
100
+
101
+ # Verify that the actual data in the non-padded region is correctly quantized
102
+ y_without_padding, scale_without_padding = scaled_fp8_quant(x, None)
103
+ torch.testing.assert_close(y_dynamic[:original_rows], y_without_padding)
104
+
105
+ # Test with static quantization
106
+ # First get a scale
107
+ _, scale = scaled_fp8_quant(x, None)
108
+
109
+ # Then use it for static quantization with padding
110
+ y_static, _ = scaled_fp8_quant(x, scale, num_token_padding=padding_size)
111
+
112
+ # Verify output shape has the padded size
113
+ assert y_static.shape[0] == padding_size
114
+ assert y_static.shape[1] == x.shape[1]
115
+
116
+ # Verify that the actual data in the non-padded region is correctly quantized
117
+ y_static_without_padding, _ = scaled_fp8_quant(x, scale)
118
+ torch.testing.assert_close(y_static[:original_rows], y_static_without_padding)
119
+
120
+ # Test with per-token dynamic quantization
121
+ y_per_token, scale_per_token = scaled_fp8_quant(
122
+ x, None, num_token_padding=padding_size, use_per_token_if_dynamic=True
123
+ )
124
+
125
+ # Verify output shape has the padded size
126
+ assert y_per_token.shape[0] == padding_size
127
+ assert y_per_token.shape[1] == x.shape[1]
128
+
129
+ # Verify that the actual data in the non-padded region is correctly quantized
130
+ y_per_token_without_padding, scale_per_token_without_padding = scaled_fp8_quant(
131
+ x, None, use_per_token_if_dynamic=True
132
+ )
133
+ torch.testing.assert_close(
134
+ y_per_token[:original_rows], y_per_token_without_padding
135
+ )
136
+ torch.testing.assert_close(
137
+ scale_per_token[:original_rows], scale_per_token_without_padding
138
+ )
139
+
85
140
 
86
141
  if __name__ == "__main__":
87
142
  # Run the specific test function directly
sglang/test/test_utils.py CHANGED
@@ -25,7 +25,7 @@ from sglang.bench_serving import run_benchmark
25
25
  from sglang.global_config import global_config
26
26
  from sglang.lang.backend.openai import OpenAI
27
27
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
28
- from sglang.srt.utils import get_bool_env_var, kill_process_tree
28
+ from sglang.srt.utils import get_bool_env_var, kill_process_tree, retry
29
29
  from sglang.test.run_eval import run_eval
30
30
  from sglang.utils import get_exception_traceback
31
31
 
@@ -76,11 +76,14 @@ def is_in_ci():
76
76
 
77
77
 
78
78
  if is_in_ci():
79
- DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157
80
- DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157"
79
+ DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
80
+ 5000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
81
+ )
81
82
  else:
82
- DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 1157
83
- DEFAULT_URL_FOR_TEST = "http://127.0.0.1:2157"
83
+ DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
84
+ 7000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
85
+ )
86
+ DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"
84
87
 
85
88
 
86
89
  def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
@@ -1010,26 +1013,10 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
1010
1013
 
1011
1014
  class CustomTestCase(unittest.TestCase):
1012
1015
  def _callTestMethod(self, method):
1013
- _retry_execution(
1014
- lambda: super(CustomTestCase, self)._callTestMethod(method),
1015
- max_retry=_get_max_retry(),
1016
+ max_retry = int(
1017
+ os.environ.get("SGLANG_TEST_MAX_RETRY", "1" if is_in_ci() else "0")
1016
1018
  )
1017
-
1018
-
1019
- def _get_max_retry():
1020
- return int(os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0"))
1021
-
1022
-
1023
- def _retry_execution(fn, max_retry: int):
1024
- if max_retry == 0:
1025
- fn()
1026
- return
1027
-
1028
- try:
1029
- fn()
1030
- except Exception as e:
1031
- print(
1032
- f"retry_execution failed once and will retry. This may be an error or a flaky test. Error: {e}"
1019
+ retry(
1020
+ lambda: super(CustomTestCase, self)._callTestMethod(method),
1021
+ max_retry=max_retry,
1033
1022
  )
1034
- traceback.print_exc()
1035
- _retry_execution(fn, max_retry=max_retry - 1)
sglang/utils.py CHANGED
@@ -25,8 +25,6 @@ from IPython.display import HTML, display
25
25
  from pydantic import BaseModel
26
26
  from tqdm import tqdm
27
27
 
28
- from sglang.srt.utils import kill_process_tree
29
-
30
28
  logger = logging.getLogger(__name__)
31
29
 
32
30
 
@@ -422,6 +420,8 @@ def terminate_process(process):
422
420
  """
423
421
  Terminate the process and automatically release the reserved port.
424
422
  """
423
+ from sglang.srt.utils import kill_process_tree
424
+
425
425
  kill_process_tree(process.pid)
426
426
 
427
427
  lock_socket = process_socket_map.pop(process, None)
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.4.post3"
1
+ __version__ = "0.4.4.post4"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sglang
3
- Version: 0.4.4.post3
3
+ Version: 0.4.4.post4
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -234,18 +234,19 @@ Requires-Dist: pillow; extra == "runtime-common"
234
234
  Requires-Dist: prometheus-client>=0.20.0; extra == "runtime-common"
235
235
  Requires-Dist: psutil; extra == "runtime-common"
236
236
  Requires-Dist: pydantic; extra == "runtime-common"
237
+ Requires-Dist: pynvml; extra == "runtime-common"
237
238
  Requires-Dist: python-multipart; extra == "runtime-common"
238
239
  Requires-Dist: pyzmq>=25.1.2; extra == "runtime-common"
239
240
  Requires-Dist: soundfile==0.13.1; extra == "runtime-common"
240
241
  Requires-Dist: torchao>=0.7.0; extra == "runtime-common"
241
- Requires-Dist: transformers==4.50.0; extra == "runtime-common"
242
+ Requires-Dist: transformers==4.51.0; extra == "runtime-common"
242
243
  Requires-Dist: uvicorn; extra == "runtime-common"
243
244
  Requires-Dist: uvloop; extra == "runtime-common"
244
245
  Requires-Dist: compressed-tensors; extra == "runtime-common"
245
246
  Requires-Dist: xgrammar==0.1.17; extra == "runtime-common"
246
247
  Provides-Extra: srt
247
248
  Requires-Dist: sglang[runtime_common]; extra == "srt"
248
- Requires-Dist: sgl-kernel==0.0.5.post4; extra == "srt"
249
+ Requires-Dist: sgl-kernel==0.0.8; extra == "srt"
249
250
  Requires-Dist: flashinfer_python==0.2.3; extra == "srt"
250
251
  Requires-Dist: torch==2.5.1; extra == "srt"
251
252
  Requires-Dist: cuda-python; extra == "srt"